diff --git a/EdgeFLite/.DS_Store b/EdgeFLite/.DS_Store new file mode 100644 index 0000000..fd8c702 Binary files /dev/null and b/EdgeFLite/.DS_Store differ diff --git a/EdgeFLite/README.md b/EdgeFLite/README.md new file mode 100644 index 0000000..7e9e246 --- /dev/null +++ b/EdgeFLite/README.md @@ -0,0 +1,41 @@ +# EdgeFLite:Edge Federated Learning for Improved Training Efficiency + + +- EdgeFLite is a cutting-edge framework developed to tackle the memory limitations of federated learning (FL) on edge devices with restricted resources. By partitioning large convolutional neural networks (CNNs) into smaller sub-models and distributing the training across local clients, EdgeFLite ensures efficient learning while maintaining data privacy. Clients in clusters collaborate by sharing learned representations, which are then aggregated by a central server to refine the global model. Experimental results on medical imaging and natural datasets demonstrate that EdgeFLite consistently outperforms other FL frameworks, setting new benchmarks for performance. + +- Within 6G-enabled mobile edge computing (MEC) networks, EdgeFLite addresses the challenges posed by client diversity and resource constraints. It optimizes local models and resource allocation to improve overall efficiency. Through a detailed convergence analysis, this research establishes a clear relationship between training loss and resource usage. The innovative Intelligent Frequency Band Allocation (IFBA) algorithm minimizes latency and enhances training efficiency by 5-10%, making EdgeFLite a robust solution for improving federated learning across a wide range of edge environments. + +## Preparation +### Dataset Setup +- The CIFAR-10 and CIFAR-100 datasets, both derived from the Tiny Images dataset, will be automatically downloaded. CIFAR-10 includes 60,000 32x32 color images across 10 categories: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 6,000 images per category, split into 5,000 for training and 1,000 for testing. + +- CIFAR-100 is a more complex dataset, featuring 100 categories with fewer images per class compared to CIFAR-10. These datasets serve as standard benchmarks for image classification tasks and provide a robust evaluation environment for machine learning models. + +### Dependency Installation + +```bash +Pytorch 1.10.2 +OpenCV 4.5.5 +``` + +## Running Experiments +*Top-1 accuracy (%) of FedDCT compared to state-of-the-art FL methods on the CIFAR-10 and CIFAR-100 test datasets.* + +1. **Specify Experiment Name:** + Add `--spid` to specify the experiment name in each training script, like this: + ```bash + python run_gkt.py --is_fed=1 --fixed_cluster=0 --split_factor=1 --num_clusters=20 --num_selected=20 --dataset=cifar10 --num_classes=10 --is_single_branch=0 --is_amp=0 --num_rounds=300 --fed_epochs=1 + ``` + +2. **Training Scripts for CIFAR-10:** + + - **Centralized Training:** + ```bash + python run_local.py --is_fed=0 --split_factor=1 --dataset=cifar10 --num_classes=10 --is_single_branch=0 --is_amp=0 --epochs=300 + ``` + + - **FedDCT:** + ```bash + python train_EdgeFLite.py --is_fed=1 --fixed_cluster=0 --split_factor=4 --num_clusters=5 --num_selected=5 --dataset=cifar10 --num_classes=10 --is_single_branch=0 --is_amp=0 --num_rounds=300 --fed_epochs=1 + ``` +--- \ No newline at end of file diff --git a/EdgeFLite/architecture/.DS_Store b/EdgeFLite/architecture/.DS_Store new file mode 100644 index 0000000..3dfe470 Binary files /dev/null and b/EdgeFLite/architecture/.DS_Store differ diff --git a/EdgeFLite/architecture/coremodel.py b/EdgeFLite/architecture/coremodel.py new file mode 100644 index 0000000..f3f5bce --- /dev/null +++ b/EdgeFLite/architecture/coremodel.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +from sklearn import ensemble +import torch +import torch.nn as nn +import torch.nn.functional as F +from .mixup import mixup_loss_criterion, combine_mixup_data +from . import resnet, resnet_sl + +# Exported members of the module +__all__ = ['coremodelSL'] + +def _retrieve_networkwork(arch='resnet_model_110sl'): + """Retrieve the specific network architecture based on the provided name.""" + available_networks = { + 'resnet_model_110sl': resnet_sl.resnet_model_110sl, + 'wide_resnetsl50_2': resnet_sl.wide_resnetsl50_2, + 'wide_resnetsl16_8': resnet_sl.wide_resnetsl16_8, + } + # Ensure the architecture requested exists in the available networks + assert arch in available_networks, f"Architecture '{arch}' is not supported." + return available_networks[arch] + +class CoreModelClient(nn.Module): + """Main client model for training and inference, managing multiple sub-networks.""" + + def __init__(self, args, norm_layer=None, criterion=None, progress=True): + super(CoreModelClient, self).__init__() + + # Parameters and configurations for the client model + self.split_factor = args.split_factor + self.arch = args.arch + self.loop_factor = args.loop_factor + self.is_train_sep = args.is_train_sep + self.epochs = args.epochs + self.num_classes = args.num_classes + self.is_diff_data_train = args.is_diff_data_train + self.is_mixup = args.is_mixup + self.mix_alpha = args.mix_alpha + + # Model arguments + model_kwargs = { + 'num_classes': self.num_classes, + 'norm_layer': norm_layer, + 'dataset': args.dataset, + 'split_factor': self.split_factor, + 'output_stride': args.output_stride + } + + # Initialize multiple instances of the network architecture for the main client + if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']: + self.main_client_models = nn.ModuleList( + [_retrieve_networkwork(self.arch)(models_pretrained=args.models_pretrained, **model_kwargs)[0] + for _ in range(self.loop_factor)] + ) + else: + raise NotImplementedError(f"Architecture '{self.arch}' not implemented.") + + # Identical initialization of the model if specified + if args.is_identical_init: + print("INFO:PyTorch: Using identical initialization.") + self._identical_init() + + def forward(self, x, target=None, mode='train', epoch=0, streams=None): + """Forward pass for the main client. Handles both training and evaluation modes.""" + main_client_outputs = [] + + if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']: + if mode == 'train': + # Apply mixup augmentation if enabled + if self.is_mixup: + x, y_a, y_b, lam = combine_mixup_data(x, target, alpha=self.mix_alpha) + + # Split input data across multiple sub-networks during training + all_x = torch.chunk(x, chunks=self.loop_factor, dim=1) if self.is_diff_data_train else [x] * self.loop_factor + + for i in range(self.loop_factor): + fx = self.main_client_models[i](all_x[i]) + main_client_outputs.append(fx.clone().detach().requires_grad_(True)) + + return main_client_outputs, y_a, y_b, lam + elif mode in ['val', 'test']: + # Forward pass during evaluation or testing + for i in range(self.loop_factor): + fx = self.main_client_models[i](x) + main_client_outputs.append(fx.clone().detach().requires_grad_(True)) + + return main_client_outputs + else: + # Return a dummy tensor if the mode is unsupported + return torch.ones(1) + else: + raise NotImplementedError(f"Mode '{mode}' not supported for architecture '{self.arch}'.") + + def _identical_init(self): + """Ensure identical initialization of weights for sub-networks.""" + with torch.no_grad(): + # Copy weights from the first model to all subsequent models + for i in range(1, self.split_factor): + for (name1, param1), (name2, param2) in zip(self.main_client_models[i].named_parameters(), + self.main_client_models[0].named_parameters()): + if 'weight' in name1: + param1.data.copy_(param2.data) + +class coremodelProxyClient(nn.Module): + """Proxy client model to handle downstream processing and training logic.""" + + def __init__(self, args, norm_layer=None, criterion=None, progress=True): + super(coremodelProxyClient, self).__init__() + + # Parameters and configurations for the proxy client model + self.split_factor = args.split_factor + self.arch = args.arch + self.loop_factor = args.loop_factor + self.epochs = args.epochs + self.num_classes = args.num_classes + self.criterion = criterion + self.is_mixup = args.is_mixup + self.is_ensembled_loss = args.is_ensembled_loss if self.split_factor > 1 else False + self.ensembled_loss_weight = args.ensembled_loss_weight + self.is_ensembled_after_softmax = args.is_ensembled_after_softmax if self.split_factor > 1 else False + self.is_max_ensemble = args.is_max_ensemble if self.split_factor > 1 else False + self.is_cot_loss = args.is_cot_loss if self.split_factor > 1 else False + self.cot_weight = args.cot_weight + self.is_cot_weight_warm_up = args.is_cot_weight_warm_up + self.cot_weight_warm_up_epochs = args.cot_weight_warm_up_epochs + self.cot_loss_choose = args.cot_loss_choose + + # Model arguments for the proxy client + model_kwargs = { + 'num_classes': self.num_classes, + 'norm_layer': norm_layer, + 'dataset': args.dataset, + 'split_factor': self.split_factor, + 'output_stride': args.output_stride + } + + # Initialize multiple instances of the network architecture for the proxy client + if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']: + self.proxy_clients_models = nn.ModuleList( + [_retrieve_networkwork(self.arch)(models_pretrained=args.models_pretrained, **model_kwargs)[1] + for _ in range(self.loop_factor)] + ) + else: + raise NotImplementedError(f"Architecture '{self.arch}' not implemented.") + + # Identical initialization of the model if specified + if args.is_identical_init: + print("INFO:PyTorch: Using identical initialization.") + self._identical_init() + + def forward(self, main_client_outputs, y_a=None, y_b=None, lam=None, target=None, mode='train', epoch=0, streams=None): + """Forward pass for the proxy client. Manages multiple sub-networks and ensemble outputs.""" + outputs = [] + ce_losses = [] + + if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']: + if mode == 'train': + # Calculate loss and forward pass during training + for i in range(self.loop_factor): + output = self.proxy_clients_models[i](main_client_outputs[i]) + loss = mixup_loss_criterion(self.criterion, output, y_a, y_b, lam) if self.is_mixup else self.criterion(output, target) + outputs.append(output) + ce_losses.append(loss) + + ensemble_output = self._collect_ensemble_output(outputs) + ce_loss = torch.sum(torch.stack(ce_losses, dim=0)) + + # Calculate co-training loss if enabled + if self.is_cot_loss: + cot_loss = self._calculate_co_training_loss(outputs, epoch) + else: + cot_loss = torch.zeros_like(ce_loss) + + return ensemble_output, torch.stack(outputs, dim=0), ce_loss, cot_loss + + elif mode in ['val', 'test']: + # Forward pass during evaluation or testing + for i in range(self.loop_factor): + output = self.proxy_clients_models[i](main_client_outputs[i]) + loss = self.criterion(output, target) if self.criterion else torch.zeros(1) + outputs.append(output) + ce_losses.append(loss) + + ensemble_output = self._collect_ensemble_output(outputs) + ce_loss = torch.sum(torch.stack(ce_losses, dim=0)) + return ensemble_output, torch.stack(outputs, dim=0), ce_loss + else: + # Return a dummy tensor if the mode is unsupported + return torch.ones(1) + else: + raise NotImplementedError(f"Mode '{mode}' not supported for architecture '{self.arch}'.") + + def _collect_ensemble_output(self, outputs): + """Calculate the ensemble output from multiple sub-networks.""" + stacked_outputs = torch.stack(outputs, dim=0) + + # Apply softmax to the outputs before ensembling if specified + if self.is_ensembled_after_softmax: + if self.is_max_ensemble: + ensemble_output, _ = torch.max(F.softmax(stacked_outputs, dim=-1), dim=0) + else: + ensemble_output = torch.mean(F.softmax(stacked_outputs, dim=-1), dim=0) + else: + if self.is_max_ensemble: + ensemble_output, _ = torch.max(stacked_outputs, dim=0) + else: + ensemble_output = torch.mean(stacked_outputs, dim=0) + + return ensemble_output + + def _calculate_co_training_loss(self, outputs, epoch): + """Calculate the co-training loss between outputs of different sub-networks.""" + # Adjust the weight of the co-training loss during warm-up epochs + weight_now = self.cot_weight if not self.is_cot_weight_warm_up or epoch >= self.cot_weight_warm_up_epochs else max(self.cot_weight * epoch / self.cot_weight_warm_up_epochs, 0.005) + + # Different methods of calculating co-training loss + if self.cot_loss_choose == 'js_divergence': + outputs_all = torch.stack(outputs, dim=0) + p_all = F.softmax(outputs_all, dim=-1) + p_mean = torch.mean(p_all, dim=0) + H_mean = (-p_mean * torch.log(p_mean)).sum(-1).mean() + H_sep = (-p_all * F.log_softmax(outputs_all, dim=-1)).sum(-1).mean() + return weight_now * (H_mean - H_sep) + elif self.cot_loss_choose == 'kl_separate': + outputs_all = torch.stack(outputs, dim=0) + outputs_r1 = torch.repeat_interleave(outputs_all, self.split_factor - 1, dim=0) + index_list = [j for i in range(self.split_factor) for j in range(self.split_factor) if j != i] + outputs_r2 = torch.index_select(outputs_all, dim=0, index=torch.tensor(index_list, dtype=torch.long).cuda()) + kl_loss = F.kl_div(F.log_softmax(outputs_r1, dim=-1), F.softmax(outputs_r2, dim=-1).detach(), reduction='none') + return weight_now * kl_loss.sum(-1).mean(-1).sum() / (self.split_factor - 1) + else: + raise NotImplementedError(f"Co-training loss '{self.cot_loss_choose}' not implemented.") + + def _identical_init(self): + """Ensure identical initialization of weights for sub-networks.""" + with torch.no_grad(): + # Copy weights from the first model to all subsequent models + for i in range(1, self.split_factor): + for (name1, param1), (name2, param2) in zip(self.proxy_clients_models[i].named_parameters(), + self.proxy_clients_models[0].named_parameters()): + if 'weight' in name1: + param1.data.copy_(param2.data) diff --git a/EdgeFLite/architecture/mixup.py b/EdgeFLite/architecture/mixup.py new file mode 100644 index 0000000..6ce96f3 --- /dev/null +++ b/EdgeFLite/architecture/mixup.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import numpy as np + +@torch.no_grad() +def combine_mixup_data(x, y, alpha=1.0, use_cuda=True): + """ + Perform the mixup operation on input data. + + Args: + x (Tensor): Input features, typically from the dataset. + y (Tensor): Input labels corresponding to the features. + alpha (float): Mixup interpolation coefficient. The default value is 1.0. + A higher value results in more mixing between samples. + use_cuda (bool): Boolean flag to indicate whether CUDA should be used if available. + + Returns: + mixed_x (Tensor): Mixed input features, a linear combination of x and a permuted version of x. + y_a (Tensor): Original input labels corresponding to x. + y_b (Tensor): Permuted input labels corresponding to the mixed samples. + lam (float): The lambda value used for interpolation between samples. + """ + # Draw lambda value from the Beta distribution if alpha > 0, otherwise set lam to 1 (no mixup) + lam = np.random.beta(alpha, alpha) if alpha > 0 else 1 + + # Get the batch size from the input tensor + batch_size = x.size(0) + + # Generate a random permutation of indices for mixing + # Use CUDA if available, otherwise stick with CPU + index = torch.randperm(batch_size).cuda() if use_cuda else torch.randperm(batch_size) + + # Mix the features of the original and permuted samples using the lambda value + mixed_x = lam * x + (1 - lam) * x[index, :] + + # Assign original and permuted labels to y_a and y_b, respectively + y_a, y_b = y, y[index] + + # Return mixed features, original and permuted labels, and the lambda value + return mixed_x, y_a, y_b, lam + + +def mixup_loss_criterion(criterion, pred, y_a, y_b, lam): + """ + Compute the mixup loss using the provided criterion. + + Args: + criterion (function): The loss function used to compute the error (e.g., CrossEntropyLoss). + pred (Tensor): The model predictions, typically the output of a neural network. + y_a (Tensor): The original labels corresponding to the original input features. + y_b (Tensor): The permuted labels corresponding to the mixed input features. + lam (float): The lambda value for mixup, used to interpolate between the two losses. + + Returns: + loss (Tensor): The final mixup loss, computed as a weighted sum of the two losses. + """ + # Compute the mixup loss by combining the loss from the original and permuted labels + return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) diff --git a/EdgeFLite/architecture/resnet.py b/EdgeFLite/architecture/resnet.py new file mode 100644 index 0000000..99ab705 --- /dev/null +++ b/EdgeFLite/architecture/resnet.py @@ -0,0 +1,237 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn + +# Try to import the method to load model weights from a URL, with a fallback in case of ImportError +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# List of available ResNet architectures +__all__ = ['resnet_model_18', 'resnet_model_34', 'resnet_model_50', + 'resnet_model_101', 'resnet_model_152', 'resnet_model_200', + 'resnet110', 'resnet164', + 'resnext29_8x64d', 'resnext29_16x64d', + 'resnext50_32x4d', 'resnext101_32x4d', + 'resnext101_32x8d', 'resnext101_64x4d', + 'wide_resnet_model_50_2', 'wide_resnet_model_50_3', 'wide_resnet_model_101_2', + 'wide_resnet16_8', 'wide_resnet52_8', 'wide_resnet16_12', + 'wide_resnet28_10', 'wide_resnet40_10'] + +# Pre-trained model URLs for various ResNet variants +model_urls = { + 'resnet_model_18': 'https://download.pytorch.org/models/resnet_model_18-5c106cde.pth', + 'resnet_model_34': 'https://download.pytorch.org/models/resnet_model_34-333f7ec4.pth', + 'resnet_model_50': 'https://download.pytorch.org/models/resnet_model_50-19c8e357.pth', + 'resnet_model_101': 'https://download.pytorch.org/models/resnet_model_101-5d3b4d8f.pth', + 'resnet_model_152': 'https://download.pytorch.org/models/resnet_model_152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet_model_50_2': 'https://download.pytorch.org/models/wide_resnet_model_50_2-95faca4d.pth', + 'wide_resnet_model_101_2': 'https://download.pytorch.org/models/wide_resnet_model_101_2-32ee1156.pth', +} + +# Function for a 3x3 convolution with padding +def apply_3x3_convolution(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +# Function for a 1x1 convolution +def apply_1x1_convolution(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +# BasicBlock class for the ResNet architecture +class BasicBlock(nn.Module): + expansion = 1 # Expansion factor for the output channels + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + # If norm_layer is not provided, use BatchNorm2d as the default + if norm_layer is None: + norm_layer = nn.BatchNorm2d + # Ensure BasicBlock is restricted to specific parameters + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock is restricted to groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("BasicBlock does not support dilation greater than 1") + + # Define the layers for the BasicBlock + self.conv1 = apply_3x3_convolution(inplanes, planes, stride) # First 3x3 convolution + self.bn1 = norm_layer(planes) # First BatchNorm layer + self.relu = nn.ReLU(inplace=True) # ReLU activation + self.conv2 = apply_3x3_convolution(planes, planes) # Second 3x3 convolution + self.bn2 = norm_layer(planes) # Second BatchNorm layer + self.downsample = downsample # Optional downsample layer + self.stride = stride + + # Define the forward pass for BasicBlock + def forward(self, x): + identity = x # Save the input for the skip connection + + out = self.conv1(x) # First convolution + out = self.bn1(out) # BatchNorm after first convolution + out = self.relu(out) # ReLU activation + + out = self.conv2(out) # Second convolution + out = self.bn2(out) # BatchNorm after second convolution + + # Apply downsample if defined + if self.downsample is not None: + identity = self.downsample(x) + + out += identity # Add the skip connection + out = self.relu(out) # Apply ReLU activation again + + return out + +# Bottleneck class for the ResNet architecture, a more complex block used in deeper ResNet models +class Bottleneck(nn.Module): + expansion = 4 # Expansion factor for the output channels + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups # Calculate width based on base width and groups + + # Define the layers for the Bottleneck block + self.conv1 = apply_1x1_convolution(inplanes, width) # 1x1 convolution to reduce the dimensions + self.bn1 = norm_layer(width) # BatchNorm after 1x1 convolution + self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) # 3x3 convolution + self.bn2 = norm_layer(width) # BatchNorm after 3x3 convolution + self.conv3 = apply_1x1_convolution(width, planes * self.expansion) # 1x1 convolution to expand the dimensions + self.bn3 = norm_layer(planes * self.expansion) # BatchNorm after final 1x1 convolution + self.relu = nn.ReLU(inplace=True) # ReLU activation + self.downsample = downsample # Optional downsample layer + self.stride = stride + + # Define the forward pass for Bottleneck + def forward(self, x): + identity = x # Save the input for the skip connection + + out = self.conv1(x) # First convolution + out = self.bn1(out) # BatchNorm after first convolution + out = self.relu(out) # ReLU activation + + out = self.conv2(out) # Second convolution + out = self.bn2(out) # BatchNorm after second convolution + out = self.relu(out) # ReLU activation + + out = self.conv3(out) # Third convolution + out = self.bn3(out) # BatchNorm after third convolution + + # Apply downsample if defined + if self.downsample is not None: + identity = self.downsample(x) + + out += identity # Add the skip connection + out = self.relu(out) # Apply ReLU activation again + + return out + +# Main ResNet class, a customizable deep learning model architecture +class ResNet(nn.Module): + + def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d # Default normalization layer + self._norm_layer = norm_layer + + self.groups = groups # Number of groups in convolutions + self.inplanes = 16 if dataset in ['cifar10', 'cifar100'] else 64 # Adjust initial planes for CIFAR + + # First layer: a combination of convolution, normalization, and ReLU + self.layer0 = nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + norm_layer(self.inplanes), + nn.ReLU(inplace=True), + ) + + # Subsequent ResNet layers using the _create_model_layer method + self.layer1 = self._create_model_layer(block, 16, layers[0]) + self.layer2 = self._create_model_layer(block, 32, layers[1], stride=2) + self.layer3 = self._create_model_layer(block, 64, layers[2], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling + self.fc = nn.Linear(64 * block.expansion, num_classes) # Fully connected layer for classification + + # Initialization for model weights + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 1e-3) + + # Zero-initialize the last BatchNorm in residual connections if required + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + # Helper function to create layers in ResNet + def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer # Set normalization layer + downsample = None + # If the stride is not 1 or input/output planes do not match, create a downsample layer + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + apply_1x1_convolution(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + layers = [block(self.inplanes, planes, stride, downsample)] # Create the first block with downsampling + self.inplanes = planes * block.expansion # Update inplanes for next blocks + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) # Add subsequent blocks without downsampling + return nn.Sequential(*layers) + + # Forward pass through the ResNet architecture + def forward(self, x): + x = self.layer0(x) # Pass input through the first layer + x = self.layer1(x) # First ResNet layer + x = self.layer2(x) # Second ResNet layer + x = self.layer3(x) # Third ResNet layer + x = self.avgpool(x) # Global average pooling + x = torch.flatten(x, 1) # Flatten the output for the fully connected layer + x = self.fc(x) # Pass through the fully connected layer + return x + +# Helper function to instantiate ResNet with pretrained weights if available +def _resnet(arch, block, layers, models_pretrained, progress, **kwargs): + model = ResNet(arch, block, layers, **kwargs) # Create a ResNet model + if models_pretrained: # Load pretrained weights if requested + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + +# Functions to create specific ResNet variants +def resnet_model_18(models_pretrained=False, progress=True, **kwargs): + return _resnet('resnet_model_18', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs) + +def resnet_model_34(models_pretrained=False, progress=True, **kwargs): + return _resnet('resnet_model_34', BasicBlock, [3, 4, 6, 3], models_pretrained, progress, **kwargs) + +def resnet_model_50(models_pretrained=False, progress=True, **kwargs): + return _resnet('resnet_model_50', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs) + +def resnet_model_101(models_pretrained=False, progress=True, **kwargs): + return _resnet('resnet_model_101', Bottleneck, [3, 4, 23, 3], models_pretrained, progress, **kwargs) + +def resnet_model_152(models_pretrained=False, progress=True, **kwargs): + return _resnet('resnet_model_152', Bottleneck, [3, 8, 36, 3], models_pretrained, progress, **kwargs) + +def resnet_model_200(models_pretrained=False, progress=True, **kwargs): + return _resnet('resnet_model_200', Bottleneck, [3, 24, 36, 3], models_pretrained, progress, **kwargs) diff --git a/EdgeFLite/architecture/resnet_sl.py b/EdgeFLite/architecture/resnet_sl.py new file mode 100644 index 0000000..2f9bdd9 --- /dev/null +++ b/EdgeFLite/architecture/resnet_sl.py @@ -0,0 +1,312 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Importing necessary PyTorch libraries +import torch +import torch.nn as nn + +# Attempt to import model loading utilities from torch.hub; fall back to torch.utils.model_zoo if unavailable +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Specify all the modules and functions to export +__all__ = ['resnet110_sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8'] + +# Function for 3x3 convolution with padding +def apply_3x3_convolution(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +# Function for 1x1 convolution, typically used to change the number of channels +def apply_1x1_convolution(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +# Basic Block class for ResNet (used in smaller networks like resnet_model_18/resnet_model_34) +class BasicBlock(nn.Module): + expansion = 1 # Expansion factor for output channels + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + # BasicBlock only supports groups=1 and base_width=64 + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("BasicBlock does not support dilation greater than 1") + + # Define two 3x3 convolution layers with batch normalization and ReLU activation + self.conv1 = apply_3x3_convolution(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = apply_3x3_convolution(planes, planes) + self.bn2 = norm_layer(planes) + # Optional downsample layer for changing the dimensions + self.downsample = downsample + self.stride = stride + + # Forward function defining the data flow through the block + def forward(self, x): + identity = x # Save the input for residual connection + + # First convolution, batch norm, and ReLU + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # Second convolution, batch norm + out = self.conv2(out) + out = self.bn2(out) + + # Apply downsample if needed to match dimensions for residual addition + if self.downsample is not None: + identity = self.downsample(x) + + # Residual connection (add identity to output) + out += identity + out = self.relu(out) + + return out + +# Bottleneck block class for deeper ResNet architectures (e.g., resnet_model_50/resnet_model_101) +class Bottleneck(nn.Module): + expansion = 4 # Expansion factor for output channels (output = input * 4) + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + # Width of the block based on base_width and groups + width = int(planes * (base_width / 64.)) * groups + + # Define 1x1, 3x3, and 1x1 convolutions with batch norm and ReLU activation + self.conv1 = apply_1x1_convolution(inplanes, width) # First 1x1 convolution + self.bn1 = norm_layer(width) + self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) # Main 3x3 convolution + self.bn2 = norm_layer(width) + self.conv3 = apply_1x1_convolution(width, planes * self.expansion) # Final 1x1 convolution + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample # Downsample layer for dimension adjustment + self.stride = stride + + # Forward function defining the data flow through the bottleneck block + def forward(self, x): + identity = x # Save the input for residual connection + + # First 1x1 convolution, batch norm, and ReLU + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # Second 3x3 convolution, batch norm, and ReLU + out = self.conv2(x) + out = self.bn2(out) + out = self.relu(out) + + # Third 1x1 convolution, batch norm + out = self.conv3(x) + out = self.bn3(out) + + # Apply downsample if needed for residual connection + if self.downsample is not None: + identity = self.downsample(x) + + # Residual connection (add identity to output) + out += identity + out = self.relu(out) + + return out + +# ResNet model for the main client (usually the primary model) +class PrimaryResNetClient(nn.Module): + + def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None): + super(PrimaryResNetClient, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + self._norm_layer = norm_layer + + # Initialize the number of input channels based on the dataset and split factor + inplanes_dict = { + 'cifar10': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3}, + 'cifar100': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3}, + 'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24}, + 'pill_base': {1: 64, 2: 44, 4: 32, 8: 24}, + 'medical_images': {1: 64, 2: 44, 4: 32, 8: 24}, + } + self.inplanes = inplanes_dict[dataset][split_factor] + + # Adjust input planes if using a wide ResNet + if 'wide_resnet' in arch: + widen_factor = int(arch.split('_')[-1]) + self.inplanes *= int(max(widen_factor / (split_factor ** 0.5) + 0.4, 1.0)) + + self.base_width = width_per_group + self.dilation = 1 + replace_stride_with_dilation = replace_stride_with_dilation or [False, False, False] + + # Check if replace_stride_with_dilation is properly defined + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation must either be None or a tuple with three elements") + + # Initialize input layer depending on the dataset (small or large) + if dataset in ['skin_dataset', 'pill_base', 'medical_images']: + self.layer0 = self._initialize_primary_layer_large() + else: + self.layer0 = self._init_layer0_small() + + # Initialize model weights + self._init_model_weights(zero_init_residual) + + # Define the large initial convolution layer for large datasets + def _initialize_primary_layer_large(self): + return nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1, bias=False), + self._norm_layer(self.inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + ) + + # Define the small initial convolution layer for smaller datasets like CIFAR + def _init_layer0_small(self): + return nn.Sequential( + nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), + self._norm_layer(self.inplanes), + nn.ReLU(inplace=True), + ) + + # Function to initialize weights in the network + def _init_model_weights(self, zero_init_residual): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=1e-3) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + # Initialize residual weights for Bottleneck and BasicBlock if specified + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + # Define forward pass for the model + def forward(self, x): + x = self.layer0(x) + return x + +# ResNet model for proxy clients (usually assisting the main model) +class ResNetProxies(nn.Module): + + def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None): + super(ResNetProxies, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + self._norm_layer = norm_layer + + # Set input channels based on architecture, dataset, and split factor + self.inplanes = self._set_input_planes(arch, dataset, split_factor, width_per_group) + self.base_width = width_per_group + + # Define layers of the network (layer1, layer2, layer3) + self.layer1 = self._create_model_layer(block, self.inplanes, layers[0], stride=1) + self.layer2 = self._create_model_layer(block, self.inplanes * 2, layers[1], stride=2) + self.layer3 = self._create_model_layer(block, self.inplanes * 4, layers[2], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive average pooling layer + self.fc = nn.Linear(self.inplanes * 4 * block.expansion, num_classes) + + # Initialize model weights + self._init_model_weights(zero_init_residual) + + # Set input channels based on dataset and split factor + def _set_input_planes(self, arch, dataset, split_factor, width_per_group): + inplanes_dict = { + 'cifar10': {1: 16, 2: 12, 4: 8, 8: 6}, + 'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24}, + } + inplanes = inplanes_dict[dataset][split_factor] + + # Adjust input planes for wide ResNet + if 'wide_resnet' in arch: + widen_factor = float(arch.split('_')[-1]) + inplanes *= int(max(widen_factor / (split_factor ** 0.5) + 0.4, 1.0)) + + return inplanes + + # Function to create layers of the network (consisting of blocks) + def _create_model_layer(self, block, planes, blocks, stride=1): + layers = [block(self.inplanes, planes, stride)] # First block + self.inplanes = planes * block.expansion # Update input planes + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) # Additional blocks + return nn.Sequential(*layers) + + # Initialize weights in the network + def _init_model_weights(self, zero_init_residual): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=1e-3) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + # Initialize residual weights for Bottleneck and BasicBlock if specified + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + # Define forward pass for the model + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + return x + +# Helper function to create the main ResNet client +def _resnetsl_primary_client_(arch, block, layers, models_pretrained, progress, **kwargs): + return PrimaryResNetClient(arch, block, layers, **kwargs) + +# Helper function to create the proxy ResNet client +def _resnetsl_secondary_client_(arch, block, layers, models_pretrained, progress, **kwargs): + return ResNetProxies(arch, block, layers, **kwargs) + +# Function to define a ResNet-110 model for main and proxy clients +def resnet_model_110sl(models_pretrained=False, progress=True, **kwargs): + assert 'cifar' in kwargs['dataset'] # Ensure that CIFAR dataset is used + return _resnetsl_primary_client_('resnet110_sl', Bottleneck, [12, 12, 12, 12], models_pretrained, progress, **kwargs), \ + _resnetsl_secondary_client_('resnet110_sl', Bottleneck, [12, 12, 12, 12], models_pretrained, progress, **kwargs) + +# Function to define a Wide ResNet-50-2 model for main and proxy clients +def wide_resnetsl50_2(models_pretrained=False, progress=True, **kwargs): + kwargs['width_per_group'] = 64 * 2 # Adjust width for Wide ResNet + return _resnetsl_primary_client_('wide_resnetsl50_2', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs), \ + _resnetsl_secondary_client_('wide_resnetsl50_2', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs) + +# Function to define a Wide ResNet-16-8 model for main and proxy clients +def wide_resnetsl16_8(models_pretrained=False, progress=True, **kwargs): + kwargs['width_per_group'] = 64 # Adjust width for Wide ResNet + return _resnetsl_primary_client_('wide_resnetsl16_8', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs), \ + _resnetsl_secondary_client_('wide_resnetsl16_8', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs) diff --git a/EdgeFLite/architecture/splitnet.py b/EdgeFLite/architecture/splitnet.py new file mode 100644 index 0000000..d1f8821 --- /dev/null +++ b/EdgeFLite/architecture/splitnet.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from sklearn import ensemble +from .mixup import mixup_loss_criterion, combine_mixup_data +from . import resnet, resnet_sl + +__all__ = ['coremodel'] + +def _retrieve_network(arch='wide_resnet28_10'): + """ + Get the network architecture based on the provided name. + + Args: + arch (str): Name of the architecture. + + Returns: + Callable: The network class or function corresponding to the given architecture. + """ + networks = { + 'wide_resnet28_10': resnet.wide_resnet28_10, + 'wide_resnet16_8': resnet.wide_resnet16_8, + 'resnet110': resnet.resnet110, + 'wide_resnet_model_50_2': resnet.wide_resnet_model_50_2 + } + if arch not in networks: + raise ValueError(f"Architecture {arch} is not supported.") + return networks[arch] + +class coremodel(nn.Module): + def __init__(self, args, norm_layer=None, criterion=None, progress=True): + """ + Initialize the coremodel model with multiple sub-networks. + + Args: + args (argparse.Namespace): Configuration arguments. + norm_layer (callable, optional): Normalization layer. + criterion (callable, optional): Loss function. + progress (bool): Whether to show progress. + """ + super(coremodel, self).__init__() + + # Configuration parameters + self.split_factor = args.split_factor + self.arch = args.arch + self.loop_factor = args.loop_factor + self.is_train_sep = args.is_train_sep + self.epochs = args.epochs + self.criterion = criterion + self.is_diff_data_train = args.is_diff_data_train + self.is_mixup = args.is_mixup + self.mix_alpha = args.mix_alpha + + # Define model architectures + valid_archs = [ + 'resnet_model_50', 'resnet_model_101', 'resnet_model_152', 'resnet_model_200', + 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', + 'resnext29_8x64d', 'resnext29_16x64d', 'resnet110', 'resnet164', + 'wide_resnet16_8', 'wide_resnet16_12', 'wide_resnet28_10', 'wide_resnet40_10', + 'wide_resnet52_8', 'wide_resnet_model_50_2', 'wide_resnet_model_50_3', 'wide_resnet_model_101_2' + ] + + if self.arch not in valid_archs: + raise NotImplementedError(f"Architecture {self.arch} is not implemented.") + + model_args = { + 'num_classes': args.num_classes, + 'norm_layer': norm_layer, + 'dataset': args.dataset, + 'split_factor': self.split_factor, + 'output_stride': args.output_stride + } + # Initialize multiple sub-models based on the loop factor + self.models = nn.ModuleList([_retrieve_network(self.arch)(models_models_pretrained=args.models_models_pretrained, **model_args) for _ in range(self.loop_factor)]) + + if args.is_identical_init: + print("INFO: Using identical initialization.") + self._identical_init() + + # Ensemble settings + self.is_ensembled_loss = args.is_ensembled_loss if self.split_factor > 1 else False + self.ensembled_loss_weight = args.ensembled_loss_weight + self.is_ensembled_after_softmax = args.is_ensembled_after_softmax if self.split_factor > 1 else False + self.is_max_ensemble = args.is_max_ensemble if self.split_factor > 1 else False + + # Co-training settings + self.is_cot_loss = args.is_cot_loss if self.split_factor > 1 else False + self.cot_weight = args.cot_weight + self.is_cot_weight_warm_up = args.is_cot_weight_warm_up + self.cot_weight_warm_up_epochs = args.cot_weight_warm_up_epochs + self.cot_loss_choose = args.cot_loss_choose + print(f"INFO: The co-training loss is {self.cot_loss_choose}.") + self.num_classes = args.num_classes + + def forward(self, x, target=None, mode='train', epoch=0, streams=None): + """ + Forward pass through the model with optional mixup and co-training loss. + + Args: + x (Tensor): Input tensor. + target (Tensor, optional): Target tensor for loss computation. + mode (str): Mode of operation ('train', 'val', or 'test'). + epoch (int): Current epoch. + streams (optional): Additional data streams. + + Returns: + Tuple: + - ensemble_output (Tensor): The ensemble output of shape [batch_size, num_classes]. + - outputs (Tensor): Stack of individual outputs of shape [split_factor, batch_size, num_classes]. + - ce_loss (Tensor): Sum of cross-entropy losses for each model. + - cot_loss (Tensor): Co-training loss if applicable. + """ + outputs, ce_losses = [], [] + + if 'train' in mode: + if self.is_mixup: + x, y_a, y_b, lam = combine_mixup_data(x, target, alpha=self.mix_alpha) + + # Split input data based on the loop factor + all_x = torch.chunk(x, chunks=self.loop_factor, dim=1) if self.is_diff_data_train else [x] + + for i in range(self.loop_factor: + x_input = all_x[i] + output = self.models[i](x_input) + loss = mixup_loss_criterion(self.criterion, output, y_a, y_b, lam) if self.is_mixup else self.criterion(output, target) + outputs.append(output) + ce_losses.append(loss) + + elif mode in ['val', 'test']: + for i in range(self.loop_factor: + output = self.models[i](x) + loss = self.criterion(output, target) if self.criterion else torch.zeros(1) + outputs.append(output) + ce_losses.append(loss) + + else: + return torch.ones(1), None, None, None + + # Calculate ensemble output and losses + ensemble_output = self._collect_ensemble_output(outputs) + ce_loss = torch.sum(torch.stack(ce_losses)) + + if mode in ['val', 'test']: + return ensemble_output, torch.stack(outputs, dim=0), ce_loss + + if self.is_cot_loss: + cot_loss = self._calculate_co_training_loss(outputs, self.cot_loss_choose, epoch) + else: + cot_loss = torch.zeros_like(ce_loss) + + return ensemble_output, torch.stack(outputs, dim=0), ce_loss, cot_loss + + def _collect_ensemble_output(self, outputs): + """ + Calculate the ensemble output from a list of tensors. + + Args: + outputs (list of tensors): A list where each tensor has shape [batch_size, num_classes]. + + Returns: + Tensor: The ensemble output with shape [batch_size, num_classes]. + """ + stacked_outputs = torch.stack(outputs, dim=0) + + if self.is_ensembled_after_softmax: + softmax_outputs = F.softmax(stacked_outputs, dim=-1) + if self.is_max_ensemble: + ensemble_output, _ = torch.max(softmax_outputs, dim=0) + else: + ensemble_output = torch.mean(softmax_outputs, dim=0) + else: + if self.is_max_ensemble: + ensemble_output, _ = torch.max(stacked_outputs, dim=0) + else: + ensemble_output = torch.mean(stacked_outputs, dim=0) + + return ensemble_output + + def _calculate_co_training_loss(self, outputs, loss_choose, epoch=0): + """ + Calculate the co-training loss between outputs of different networks. + + Args: + outputs (list of tensors): A list where each tensor has shape [batch_size, num_classes]. + loss_choose (str): Type of co-training loss to compute ('js_divergence' or 'kl_seperate'). + epoch (int): Current epoch. + + Returns: + Tensor: The computed co-training loss. + """ + weight_now = self.cot_weight + if self.is_cot_weight_warm_up and epoch < self.cot_weight_warm_up_epochs: + weight_now = max(self.cot_weight * epoch / self.cot_weight_warm_up_epochs, 0.005) + + stacked_outputs = torch.stack(outputs, dim=0) + + if loss_choose == 'js_divergence': + p_all = F.softmax(stacked_outputs, dim=-1) + p_mean = torch.mean(p_all, dim=0) + H_mean = (-p_mean * torch.log(p_mean + 1e-8)).sum(-1).mean() + H_sep = (-p_all * F.log_softmax(stacked_outputs, dim=-1)).sum(-1).mean() + cot_loss = weight_now * (H_mean - H_sep) + + elif loss_choose == 'kl_seperate': + outputs_r1 = torch.repeat_interleave(stacked_outputs, self.split_factor - 1, dim=0) + index_list = [j for i in range(self.split_factor) for j in range(self.split_factor) if j != i] + outputs_r2 = torch.index_select(stacked_outputs, dim=0, index=torch.tensor(index_list, dtype=torch.long, device=stacked_outputs.device)) + kl_loss = F.kl_div(F.log_softmax(outputs_r1, dim=-1), F.softmax(outputs_r2,” diff --git a/EdgeFLite/configurations/.DS_Store b/EdgeFLite/configurations/.DS_Store new file mode 100644 index 0000000..32fb741 Binary files /dev/null and b/EdgeFLite/configurations/.DS_Store differ diff --git a/EdgeFLite/configurations/training_config.py b/EdgeFLite/configurations/training_config.py new file mode 100644 index 0000000..28ff6c0 --- /dev/null +++ b/EdgeFLite/configurations/training_config.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +from __future__ import absolute_import, division, print_function +import json +import torch +from config import * + +# Function to save hyperparameters into a JSON file +def store_hyperparameters_json(args): + """Save hyperparameters to a JSON file.""" + # Create the model directory if it does not exist + os.makedirs(args.model_dir, exist_ok=True) + # Determine the filename based on whether it's evaluation or training mode + filename = os.path.join(args.model_dir, 'hparams_eval.json' if args.evaluate else 'hparams_train.json') + # Convert the arguments to a dictionary + hparams = vars(args) + # Write the hyperparameters to a JSON file with indentation and sorted keys + with open(filename, 'w') as f: + json.dump(hparams, f, indent=4, sort_keys=True) + +# Function to add parser arguments for command-line interface +def add_parser_arguments(parser): + # Dataset and model settings + parser.add_argument('--data', type=str, default=f"{data_dir}/dataset_hub/", help='Path to dataset') # Path to the dataset + parser.add_argument('--model_dir', type=str, default="EdgeFLite", help='Directory to save the model') # Directory where the model is saved + parser.add_argument('--arch', type=str, default='wide_resnet16_8', choices=[ + 'resnet110', 'resnet_model_110sl', 'wide_resnet16_8', 'wide_resnetsl16_8', + 'wide_resnet_model_50_2', 'wide_resnetsl50_2'], help='Neural architecture name') # Neural architecture options + + # Normalization and training settings + parser.add_argument('--norm_mode', type=str, default='batch', choices=['batch', 'group', 'layer', 'instance', 'none'], help='Batch normalization style') # Type of normalization used + parser.add_argument('--is_syncbn', default=0, type=int, help='Use nn.SyncBatchNorm or not') # Whether to use synchronized batch normalization + parser.add_argument('--workers', default=16, type=int, help='Number of data loading workers') # Number of workers for data loading + parser.add_argument('--epochs', default=650, type=int, help='Total epochs to run') # Total number of training epochs + parser.add_argument('--start_epoch', default=0, type=int, help='Manual epoch number for restarts') # Starting epoch number for restarting training + parser.add_argument('--eval_per_epoch', default=1, type=int, help='Evaluation frequency per epoch') # Frequency of evaluation during training + parser.add_argument('--spid', default="EdgeFLite", type=str, help='Experiment name') # Name of the experiment + parser.add_argument('--save_weight', default=False, type=bool, help='Save model weights') # Whether to save model weights + + # Data augmentation settings + parser.add_argument('--batch_size', default=128, type=int, help='Mini-batch size for training') # Batch size for training + parser.add_argument('--eval_batch_size', default=100, type=int, help='Mini-batch size for evaluation') # Batch size for evaluation + parser.add_argument('--crop_size', default=32, type=int, help='Crop size for images') # Size of the image crops + parser.add_argument('--output_stride', default=8, type=int, help='Output stride for model') # Output stride for the model + parser.add_argument('--padding', default=4, type=int, help='Padding size for images') # Padding size for image processing + + # Learning rate settings + parser.add_argument('--lr_mode', type=str, default='cos', choices=['cos', 'step', 'poly', 'HTD', 'exponential'], help='Learning rate strategy') # Strategy for adjusting learning rate + parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, help='Initial learning rate') # Initial learning rate value + parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'AdamW', 'RMSprop', 'RMSpropTF'], help='Optimizer choice') # Choice of optimizer + parser.add_argument('--lr_milestones', nargs='+', type=int, default=[100, 200], help='Epochs for learning rate steps') # Epochs where learning rate adjustments occur + parser.add_argument('--lr_step_multiplier', default=0.1, type=float, help='Multiplier at learning rate milestones') # Multiplier applied at learning rate steps + parser.add_argument('--end_lr', type=float, default=1e-4, help='Ending learning rate') # Final learning rate value + + # Additional hyperparameters + parser.add_argument('--weight_decay', default=1e-4, type=float, help='Weight decay for regularization') # Weight decay for L2 regularization + parser.add_argument('--momentum', default=0.9, type=float, help='Optimizer momentum') # Momentum for optimizers like SGD + parser.add_argument('--print_freq', default=20, type=int, help='Print frequency for logging') # Frequency for printing logs during training + + # Federated learning settings + parser.add_argument('--is_fed', default=1, type=int, help='Enable federated learning') # Enable or disable federated learning + parser.add_argument('--num_clusters', default=20, type=int, help='Number of clusters for federated learning') # Number of clusters in federated learning + parser.add_argument('--num_selected', default=20, type=int, help='Number of clients selected for training per round') # Number of clients selected each round + parser.add_argument('--num_rounds', default=300, type=int, help='Total number of training rounds') # Total number of federated learning rounds + + # Processing and decentralized training settings + parser.add_argument('--gpu', default=None, type=int, help='GPU ID to use') # GPU ID to be used for training + parser.add_argument('--no_cuda', action='store_true', default=False, help='Disable CUDA training') # Whether to disable CUDA + parser.add_argument('--gpu_ids', type=str, default='0', help='Comma-separated list of GPU IDs for training') # Comma-separated GPU IDs for multi-GPU training + + # Parse command-line arguments + args = parser.parse_args() + + # Additional configurations + args.cuda = not args.no_cuda and torch.cuda.is_available() # Enable CUDA if not disabled and available + if args.cuda: + args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] # Parse GPU IDs from comma-separated string + args.num_gpus = len(args.gpu_ids) # Count number of GPUs being used + + return args diff --git a/EdgeFLite/data_collection/.DS_Store b/EdgeFLite/data_collection/.DS_Store new file mode 100644 index 0000000..43eeefe Binary files /dev/null and b/EdgeFLite/data_collection/.DS_Store differ diff --git a/EdgeFLite/data_collection/augment_auto.py b/EdgeFLite/data_collection/augment_auto.py new file mode 100644 index 0000000..cefa20d --- /dev/null +++ b/EdgeFLite/data_collection/augment_auto.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + + +from PIL import Image, ImageEnhance, ImageOps + +import numpy as np +import random + + +class CIFAR10Policy(object): + def __init__(self, fillcolor=(128, 128, 128)): + + self.policies = [ + SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), + SubPolicy(0.7, "rotate_image", 2, 0.3, "translateX", 9, fillcolor), + SubPolicy(0.8, "adjust_image_sharpness", 1, 0.9, "adjust_image_sharpness", 3, fillcolor), + SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), + SubPolicy(0.2, "shearY", 7, 0.3, "apply_posterization", 7, fillcolor), + SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), + SubPolicy(0.3, "adjust_image_sharpness", 9, 0.7, "brightness", 9, fillcolor), + SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), + SubPolicy(0.6, "contrast", 7, 0.6, "adjust_image_sharpness", 5, fillcolor), + SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), + SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), + SubPolicy(0.4, "translateY", 3, 0.2, "adjust_image_sharpness", 6, fillcolor), + SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), + SubPolicy(0.5, "apply_solarize", 2, 0.0, "invert", 3, fillcolor), + SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), + SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), + SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), + SubPolicy(0.8, "autocontrast", 4, 0.2, "apply_solarize", 8, fillcolor), + SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), + SubPolicy(0.4, "apply_solarize", 5, 0.9, "autocontrast", 3, fillcolor), + SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), + SubPolicy(0.9, "autocontrast", 2, 0.8, "apply_solarize", 3, fillcolor), + SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), + SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor), + ] + + + def __call__(self, img): + + policy = random.choice(self.policies) + return policy(img) + + + def __repr__(self): + return "AutoAugment CIFAR-10 Policy" + +class SubPolicy(object): + + def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): + + ranges = { + "shearX": np.linspace(0, 0.3, 10), + "shearY": np.linspace(0, 0.3, 10), + "translateX": np.linspace(0, 150 / 331, 10), + "translateY": np.linspace(0, 150 / 331, 10), + "rotate_image": np.linspace(0, 30, 10), + "color": np.linspace(0.0, 0.9, 10), + "apply_posterization": np.round(np.linspace(8, 4, 10), 0).astype(np.int), + "apply_solarize": np.linspace(256, 0, 10), + "contrast": np.linspace(0.0, 0.9, 10), + "adjust_image_sharpness": np.linspace(0.0, 0.9, 10), + "brightness": np.linspace(0.0, 0.9, 10), + "autocontrast": [0] * 10, + "equalize": [0] * 10, + "invert": [0] * 10 + } + + + self.fillcolor = fillcolor + + self.p1 = p1 + self.operation1 = operation1 + self.magnitude1 = ranges[operation1][magnitude_idx1] + self.p2 = p2 + self.operation2 = operation2 + self.magnitude2 = ranges[operation2][magnitude_idx2] + + + def __call__(self, img): + + if random.random() < self.p1: + img = self._perform_operation(self.operation1, img, self.magnitude1) + + if random.random() < self.p2: + img = self._perform_operation(self.operation2, img, self.magnitude2) + return img + + + def _perform_operation(self, operation, img, magnitude): + + if operation == "shearX": + img = img.apply_transformation(img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, fillcolor=self.fillcolor) + elif operation == "shearY": + img = img.apply_transformation(img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), + Image.BICUBIC, fillcolor=self.fillcolor) + elif operation == "translateX": + img = img.apply_transformation(img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, diff --git a/EdgeFLite/data_collection/augment_rand.py b/EdgeFLite/data_collection/augment_rand.py new file mode 100644 index 0000000..48b46a0 --- /dev/null +++ b/EdgeFLite/data_collection/augment_rand.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import random +import math +from PIL import Image, ImageOps, ImageEnhance, ImageChops +import PIL + +# Constants and defaults for image augmentation +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) # Get the version of the PIL library +_FILL = (128, 128, 128) # Default fill color used in some apply_transformationations (gray) +_MAX_LEVEL = 10.0 # Maximum level for augmentations +_HPARAMS_DEFAULT = { + 'translate_const': 250, # Default translation constant + 'img_mean': _FILL, # Default fill color +} +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) # Random interpolation modes + +# Function to randomly choose interpolation method +def _interpolation(kwargs): + interpolation = kwargs.pop('resample', Image.BILINEAR) + return random.choice(interpolation) if isinstance(interpolation, (list, tuple)) else interpolation + +# Check if the PIL version is compatible with fillcolor argument +def _validate_tensorflow_args(kwargs): + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): + kwargs.pop('fillcolor') # Remove fillcolor if PIL version is below 5.0 + kwargs['resample'] = _interpolation(kwargs) # Add resample method + +# Shear image along the x-axis +def apply_apply_shear_x_axis_axis(img, factor, **kwargs): + _validate_tensorflow_args(kwargs) + return img.apply_transformation(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + +# Shear image along the y-axis +def shear_y(img, factor, **kwargs): + _validate_tensorflow_args(kwargs) + return img.apply_transformation(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + +# Translate image horizontally by a percentage of the image width +def translate_image_x_relative(img, pct, **kwargs): + pixels = pct * img.size[0] # Calculate pixels to translate + _validate_tensorflow_args(kwargs) + return img.apply_transformation(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + +# Translate image vertically by a percentage of the image height +def translate_image_y_relative(img, pct, **kwargs): + pixels = pct * img.size[1] # Calculate pixels to translate + _validate_tensorflow_args(kwargs) + return img.apply_transformation(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + +# Translate image horizontally by a fixed number of pixels +def translate_image_x_absolute(img, pixels, **kwargs): + _validate_tensorflow_args(kwargs) + return img.apply_transformation(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + +# Translate image vertically by a fixed number of pixels +def translate_image_y_absolute(img, pixels, **kwargs): + _validate_tensorflow_args(kwargs) + return img.apply_transformation(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + +# rotate_image image by a specified number of degrees +def rotate_image(img, degrees, **kwargs): + _validate_tensorflow_args(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate_image(degrees, **kwargs) # Use rotate_image if PIL version is >= 5.2 + elif _PIL_VER >= (5, 0): + # Manually rotate_image the image for older versions of PIL + w, h = img.size + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), round(math.sin(angle), 15), 0.0, + round(-math.sin(angle), 15), round(math.cos(angle), 15), 0.0, + ] + + def apply_transformation(x, y, matrix): + return matrix[0] * x + matrix[1] * y + matrix[2], matrix[3] * x + matrix[4] * y + matrix[5] + + matrix[2], matrix[5] = apply_transformation(-rotn_center[0], -rotn_center[1], matrix) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.apply_transformation(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate_image(degrees, resample=kwargs['resample']) + +# Auto contrast image +def apply_auto_contrast(img, **kwargs): + return ImageOps.autocontrast(img) + +# Invert image colors +def invert(img, **kwargs): + return ImageOps.invert(img) + +# Equalize image histogram +def equalize(img, **kwargs): + return ImageOps.equalize(img) + +# Apply solarization effect +def apply_solarize(img, thresh, **kwargs): + return ImageOps.apply_solarize(img, thresh) + +# Apply solarization effect with an additional value +def apply_apply_solarize_addition(img, add, thresh=128, **kwargs): + lut = [min(255, i + add) if i < thresh else i for i in range(256)] + if img.mode in ("L", "RGB"): + lut = lut + lut + lut if img.mode == "RGB" else lut + return img.point(lut) + else: + return img + +# apply_posterization image (reduce color depth) +def apply_posterization(img, bits_to_keep, **kwargs): + return img if bits_to_keep >= 8 else ImageOps.apply_posterization(img, bits_to_keep) + +# Adjust image contrast +def contrast(img, factor, **kwargs): + return ImageEnhance.Contrast(img).enhance(factor) + +# Adjust image color +def color(img, factor, **kwargs): + return ImageEnhance.Color(img).enhance(factor) + +# Adjust image brightness +def brightness(img, factor, **kwargs): + return ImageEnhance.Brightness(img).enhance(factor) + +# Adjust image adjust_image_sharpness +def adjust_image_sharpness(img, factor, **kwargs): + return ImageEnhance.adjust_image_sharpness(img).enhance(factor) + +# Randomly negate a value with a 50% probability +def _apply_random_negation(v): + """With 50% probability, negate the value.""" + return -v if random.random() > 0.5 else v + +# Convert augmentation level to argument value +def _map_level_to_argument(level, max_value, hparams): + level = (level / _MAX_LEVEL) * max_value + return _apply_random_negation(level), + +# Convert translation level to argument value +def _map_absolute_map_level_to_argument(level, hparams): + translate_const = hparams['translate_const'] + level = (level / _MAX_LEVEL) * float(translate_const) + return _apply_random_negation(level), + +# Convert enhancement level to argument value +def _enhance_map_level_to_argument(level, _hparams): + return (level / _MAX_LEVEL) * 1.8 + 0.1, + +# Mapping of augmentation levels to argument converters +map_level_to_argument = { + 'AutoContrast': None, + 'Equalize': None, + 'Invert': None, + 'rotate_image': lambda level, _: _map_level_to_argument(level, 30, None), + 'apply_posterization': lambda level, _: int((level / _MAX_LEVEL) * 4), + 'apply_solarize': lambda level, _: int((level / _MAX_LEVEL) * 256), + 'Color': _enhance_map_level_to_argument, + 'Contrast': _enhance_map_level_to_argument, + 'Brightness': _enhance_map_level_to_argument, + 'adjust_image_sharpness': _enhance_map_level_to_argument, + 'ShearX': lambda level, _: _map_level_to_argument(level, 0.3, None), + 'ShearY': lambda level, _: _map_level_to_argument(level, 0.3, None), + 'TranslateX': _map_absolute_map_level_to_argument, + 'TranslateY': _map_absolute_map_level_to_argument, +} + +# Mapping of augmentation names to functions +NAME_TO_OP = { + 'AutoContrast': apply_auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'rotate_image': rotate_image, + 'apply_posterization': apply_posterization, + 'apply_solarize': apply_solarize, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'adjust_image_sharpness': adjust_image_sharpness, + 'ShearX': apply_apply_shear_x_axis_axis, + 'ShearY': shear_y, + 'TranslateX': translate_image_x_absolute, + 'TranslateY': translate_image_y_absolute, +} + +# Class for applying augmentations to an image +class AugmentOp: + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] # Get the augmentation function + self.level_fn = map_level_to_argument[name] # Get the level function + self.prob = prob # Probability of applying the augmentation + self.magnitude = magnitude # Magnitude of the augmentation + self.hparams = hparams.copy() + self.kwargs = { + 'fillcolor': hparams.get('img_mean', _FILL), # Set the fill color + ' diff --git a/EdgeFLite/data_collection/cifar100_noniid.py b/EdgeFLite/data_collection/cifar100_noniid.py new file mode 100644 index 0000000..5a64d68 --- /dev/null +++ b/EdgeFLite/data_collection/cifar100_noniid.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +#### Get CIFAR-100 dataset in X and Y form +import torchvision +import numpy as np +import random +import torch +from torchvision import apply_transformations +from torch.utils.data import DataLoader, Dataset +from .cifar10_non_iid import * + +# Set random seeds for reproducibility +np.random.seed(68) +random.seed(68) + +def get_cifar100(data_dir): + ''' + Load and return CIFAR-100 train/test data and labels as numpy arrays. + + Parameters: + data_dir (str): Directory where the CIFAR-100 dataset will be downloaded/saved. + + Returns: + x_train (ndarray): Training data. + y_train (ndarray): Training labels. + x_test (ndarray): Test data. + y_test (ndarray): Test labels. + ''' + # Download CIFAR-100 training and test datasets + data_train = torchvision.datasets.CIFAR100(data_dir, train=True, download=True) + data_test = torchvision.datasets.CIFAR100(data_dir, train=False, download=True) + + # Transpose data for proper channel order and convert labels to numpy arrays + x_train, y_train = data_train.data.transpose((0, 3, 1, 2)), np.array(data_train.targets) + x_test, y_test = data_test.data.transpose((0, 3, 1, 2)), np.array(data_test.targets) + + return x_train, y_train, x_test, y_test + +def split_cf100_real_world_images(data, labels, n_clients=100, verbose=True): + ''' + Splits data and labels among n_clients to simulate a non-IID distribution. + + Parameters: + data (ndarray): Dataset images [n_data x shape]. + labels (ndarray): Dataset labels [n_data]. + n_clients (int): Number of clients to split the data among. + verbose (bool): Print detailed information if True. + + Returns: + clients_split (ndarray): Split data and labels for each client. + ''' + n_labels = np.max(labels) + 1 # Number of unique labels/classes + + def divide_into_sections(n, m): + '''Return m random integers that sum up to n.''' + result = [1] * m + for _ in range(n - m): + result[random.randint(0, m - 1)] += 1 + return result + + # Shuffle and partition classes + n_classes = len(set(labels)) # Number of unique classes + classes = list(range(n_classes)) + np.random.shuffle(classes) # Shuffle class indices + label_indices = [list(np.where(labels == class_)[0]) for class_ in classes] # Indices of each class in labels + + # Define number of classes for each client (randomized) + tmp = [np.random.randint(1, 100) for _ in range(n_clients)] + total_partition = sum(tmp) + class_partition = divide_into_sections(total_partition, len(classes)) # Partition classes randomly + + # Split class indices among clients + class_partition = sorted(class_partition, reverse=True) + class_partition_split = {} + + for idx, class_ in enumerate(classes): + # Split each class' indices according to the partition + class_partition_split[class_] = [list(i) for i in np.array_split(label_indices[idx], class_partition[idx])] + + clients_split = [] + for i in range(n_clients): + n = tmp[i] # Number of classes for this client + indices = [] + j = 0 + + # Assign class data to the client + while n > 0: + class_ = classes[j] + if class_partition_split[class_]: + indices.extend(class_partition_split[class_].pop()) # Add indices of the class to the client + n -= 1 + j += 1 + + clients_split.append([data[indices], labels[indices]]) # Add client's data split + + # Re-sort classes based on available data to balance further splits + classes = sorted(classes, key=lambda x: len(class_partition_split[x]), reverse=True) + + # Raise error if client partition criteria cannot be met + if n > 0: + raise ValueError("Unable to fulfill the client partition criteria.") + + # Verbose option to print split information + if verbose: + display_data_split(clients_split) + + return np.array(clients_split) + +def display_data_split(clients_split): + '''Print the split information of the dataset for each client.''' + print("Data split:") + for i, client in enumerate(clients_split): + split = np.sum(client[1].reshape(1, -1) == np.arange(np.max(client[1]) + 1).reshape(-1, 1), axis=1) + print(f" - Client {i}: {split}") + print() + +def get_default_data_apply_transformations_cf100(train=True, verbose=True): + ''' + Return default data apply_transformationations for CIFAR-100. + + Parameters: + train (bool): Whether to apply apply_transformationations for training data. + verbose (bool): Print apply_transformationation details if True. + + Returns: + apply_transformations_train (Compose): Training apply_transformationations. + apply_transformations_eval (Compose): Evaluation (test) apply_transformationations. + ''' + # Define apply_transformationations for training data + apply_transformations_train = { + 'cifar100': apply_transformations.Compose([ + apply_transformations.ToPILImage(), + apply_transformations.RandomCrop(32, padding=4), + apply_transformations.RandomHorizontalFlip(), + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ]) + } + + # Define apply_transformationations for test data + apply_transformations_eval = { + 'cifar100': apply_transformations.Compose([ + apply_transformations.ToPILImage(), + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ]) + } + + # Verbose option to print apply_transformationation steps + if verbose: + print("\nData preprocessing:") + for apply_transformationation in apply_transformations_train['cifar100'].apply_transformations: + print(f' - {apply_transformationation}') + print() + + return apply_transformations_train['cifar100'], apply_transformations_eval['cifar100'] + +def obtain_data_loaders_train_cf100(data_dir, n_clients, batch_size, classes_per_client=10, verbose=True, + apply_transformations_train=None, apply_transformations_eval=None, non_iid=None, split_factor=1): + ''' + Return data loaders for training on CIFAR-100. + + Parameters: + data_dir (str): Directory where the CIFAR-100 dataset will be saved. + n_clients (int): Number of clients for splitting the dataset. + batch_size (int): Batch size for each data loader. + classes_per_client (int): Number of classes per client. + verbose (bool): Print detailed information if True. + apply_transformations_train (Compose): apply_transformationations for training data. + apply_transformations_eval (Compose): apply_transformationations for evaluation data. + non_iid (str): Strategy to create a non-IID dataset split. + split_factor (float): Factor to control the degree of splitting. + + Returns: + client_loaders (list): Data loaders for each client. + ''' + x_train, y_train, _, _ = get_cifar100(data_dir) + + # Verbose option to print dataset statistics + if verbose: + print_image_data_stats_train(x_train, y_train) + + # Split data according to non-IID strategy (e.g., quantity_skew) + split = None + if non_iid == 'quantity_skew': + split = split_cf100_real_world_images(x_train, y_train, n_clients=n_clients, verbose=verbose) + + split_tmp = shuffle_list(split) + + # Create DataLoaders for each client + client_loaders = [DataLoader(CustomImageDataset(x, y, apply_transformations_train, split_factor=split_factor), + batch_size=batch_size, shuffle=True) for x, y in split_tmp] + + return client_loaders + +def obtain_data_loaders_test_cf100(data_dir, batch_size, verbose=True, apply_transformations_eval=None): + ''' + Return data loaders for testing on CIFAR-100. + + Parameters: + data_dir (str): Directory where the CIFAR-100 dataset will be saved. + batch_size (int): Batch size for the test data loader. + verbose (bool): Print detailed information if True. + apply_transformations_eval (Compose): apply_transformationations for evaluation data. + + Returns: + test_loader (DataLoader): Test data loader. + ''' + _, _, x_test, y_test = get_cifar100(data_dir) + + # Verbose option to print dataset statistics + if verbose: + print_image_data_stats_test(x_test, y_test) + + # Create DataLoader for the test dataset + test_loader = DataLoader(CustomImageDataset(x_test, y_test, apply_transformations_eval, split_factor=1), + batch_size=100, shuffle=False) + + return test_loader diff --git a/EdgeFLite/data_collection/cifar10_noniid.py b/EdgeFLite/data_collection/cifar10_noniid.py new file mode 100644 index 0000000..28da974 --- /dev/null +++ b/EdgeFLite/data_collection/cifar10_noniid.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +#### Load CIFAR-10 dataset and preprocess it +import torchvision +import numpy as np +import random +import torch +from torchvision import apply_transformations +from torch.utils.data import DataLoader, Dataset + +# Set random seed for reproducibility +np.random.seed(68) # Ensures that the random operations have consistent outputs +random.seed(68) + +def get_cifar10(data_dir): + """Return CIFAR-10 train/test data and labels as numpy arrays""" + # Download CIFAR-10 dataset + data_train = torchvision.datasets.CIFAR10(data_dir, train=True, download=True) + data_test = torchvision.datasets.CIFAR10(data_dir, train=False, download=True) + + # Preprocess the train and test data to the correct format (channels first) + x_train, y_train = data_train.data.transpose((0, 3, 1, 2)), np.array(data_train.targets) + x_test, y_test = data_test.data.transpose((0, 3, 1, 2)), np.array(data_test.targets) + + return x_train, y_train, x_test, y_test + +def display_data_statistics(data, labels, dataset_type): + """Print statistics of the dataset""" + print(f"\n{dataset_type} Set: ({data.shape}, {labels.shape}), Range: [{np.min(data):.3f}, {np.max(data):.3f}], " + f"Labels: {np.min(labels)},..,{np.max(labels)}") + +def randomize_client_distributiony(train_len, n_clients): + """ + Distribute data among clients with a random distribution + Returns a list with the number of samples for each client + """ + # Randomly assign a number of samples to each client, ensuring the total matches the train_len + client_sizes = [random.randint(10, 100) for _ in range(n_clients - 1)] + total = sum(client_sizes) + client_sizes = np.array(client_sizes) + client_distributions = ((client_sizes / total) * train_len).astype(int) # Normalize to match the train_len + client_distributions = list(client_distributions) + client_distributions.append(train_len - sum(client_distributions)) # Ensure all data is allocated + return client_distributions + +def divide_into_sections(n, m): + """Return 'm' random integers that sum to 'n'""" + # Break the number 'n' into 'm' random parts that sum to 'n' + partitions = [1] * m + for _ in range(n - m): + partitions[random.randint(0, m - 1)] += 1 + return partitions + +def split_data_real_world_scenario(data, labels, n_clients=100): + """Split data among clients simulating real-world non-IID distribution""" + n_classes = len(set(labels)) # Determine number of unique classes + class_indices = [np.where(labels == class_)[0] for class_ in range(n_classes)] # Indices for each class + + client_classes = [np.random.randint(1, 10) for _ in range(n_clients)] # Random number of classes per client + total_partitions = sum(client_classes) + + class_partition = divide_into_sections(total_partitions, len(class_indices)) # Partition classes to distribute + class_partition_split = {cls: np.array_split(class_indices[cls], n) for cls, n in enumerate(class_partition)} + + clients_split = [] + for client in client_classes: + selected_indices = [] + for class_ in range(n_classes): + if class_partition_split[class_]: + selected_indices.extend(class_partition_split[class_].pop()) + client -= 1 + if client <= 0: + break + clients_split.append([data[selected_indices], labels[selected_indices]]) + + return np.array(clients_split) + +def split_data_iid(data, labels, n_clients=100, classes_per_client=10, shuffle=True): + """Split data among clients with IID (Independent and Identically Distributed) distribution""" + data_per_client = randomize_client_distributiony(len(data), n_clients) + label_indices = [np.where(labels == label)[0] for label in range(np.max(labels) + 1)] + + if shuffle: + for indices in label_indices: + np.random.shuffle(indices) + + clients_split = [] + for client_data in data_per_client: + client_indices = [] + class_ = np.random.randint(len(label_indices)) + while client_data > 0: + take = min(client_data, len(label_indices[class_])) + client_indices.extend(label_indices[class_][:take]) + label_indices[class_] = label_indices[class_][take:] + client_data -= take + class_ = (class_ + 1) % len(label_indices) + + clients_split.append([data[client_indices], labels[client_indices]]) + + return np.array(clients_split) + +def randomize_data_order(data): + """Shuffle data while maintaining the mapping between inputs and labels""" + for i in range(len(data)): + index = np.arange(len(data[i][0])) + np.random.shuffle(index) + data[i][0], data[i][1] = data[i][0][index], data[i][1][index] + return data + +class CustomImageDataset(Dataset): + """Custom Dataset class for image data""" + def __init__(self, inputs, labels, apply_transformations=None, split_factor=1): + # Convert input data to torch tensors and apply apply_transformationations if provided + self.inputs = torch.Tensor(inputs) + self.labels = labels + self.apply_transformations = apply_transformations + self.split_factor = split_factor + + def __getitem__(self, index): + img, label = self.inputs[index], self.labels[index] + # Apply apply_transformationations to the image multiple times if split_factor > 1 + imgs = [self.apply_transformations(img) for _ in range(self.split_factor)] if self.apply_transformations else [img] + return torch.cat(imgs, dim=0), label + + def __len__(self): + return len(self.inputs) + +def get_default_apply_transformations(verbose=True): + """Return default apply_transformationations for training and evaluation""" + apply_transformations_train = apply_transformations.Compose([ + apply_transformations.ToPILImage(), # Convert numpy array to PIL image + apply_transformations.RandomCrop(32, padding=4), # Randomly crop to 32x32 with padding + apply_transformations.RandomHorizontalFlip(), # Randomly flip images horizontally + apply_transformations.ToTensor(), # Convert image to tensor + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Normalize with CIFAR-10 mean and std + ]) + + apply_transformations_eval = apply_transformations.Compose([ + apply_transformations.ToPILImage(), + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Same normalization for evaluation + ]) + + if verbose: + print("\nData preprocessing steps:") + for apply_transformationation in apply_transformations_train.apply_transformations: + print(f" - {apply_transformationation}") + + return apply_transformations_train, apply_transformations_eval + +def obtain_data_loaders(data_dir, n_clients, batch_size, classes_per_client=10, non_iid=None, split_factor=1): + """Return DataLoader objects for clients with either IID or non-IID data split""" + x_train, y_train, _, _ = get_cifar10(data_dir) + display_data_statistics(x_train, y_train, "Train") + + # Split data based on non-IID method specified (either 'quantity_skew' or 'label_skew') + if non_iid == 'quantity_skew': + clients_data = split_data_real_world_scenario(x_train, y_train, n_clients) + elif non_iid == 'label_skew': + clients_data = split_data_iid(x_train, y_train, n_clients, classes_per_client) + + shuffled_clients_data = randomize_data_order(clients_data) + + apply_transformations_train, apply_transformations_eval = get_default_apply_transformations(verbose=False) + client_loaders = [DataLoader(CustomImageDataset(x, y, apply_transformations_train, split_factor=split_factor), + batch_size=batch_size, shuffle=True) for x, y in shuffled_clients_data] + + return client_loaders + +def get_test_data_loader(data_dir, batch_size): + """Return DataLoader for test data""" + _, _, x_test, y_test = get_cifar10(data_dir) + display_data_statistics(x_test, y_test, "Test") + + _, apply_transformations_eval = get_default_apply_transformations(verbose=False) + test_loader = DataLoader(CustomImageDataset(x_test, y_test, apply_transformations_eval), batch_size=batch_size, shuffle=False) + + return test_loader diff --git a/EdgeFLite/data_collection/data_cutout.py b/EdgeFLite/data_collection/data_cutout.py new file mode 100644 index 0000000..3a81219 --- /dev/null +++ b/EdgeFLite/data_collection/data_cutout.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import numpy as np + + +class Cutout: + """Applies random cutout augmentation by masking patches in an image. + + This technique randomly cuts out square patches from the image to + augment the dataset, helping the model become invariant to occlusions. + + Args: + n_holes (int): Number of patches to remove from the image. + length (int): Side length (in pixels) of each square patch. + """ + + def __init__(self, n_holes, length): + """ + Initializes the Cutout class with the number of patches to be removed + and the size of each patch. + + Args: + n_holes (int): Number of patches (holes) to cut out from the image. + length (int): Size of each square patch. + """ + self.n_holes = n_holes # Number of holes (patches) to remove. + self.length = length # Side length of each square patch. + + def __call__(self, img): + """ + Applies the cutout augmentation on the input image. + + Args: + img (Tensor): The input image tensor with shape (C, H, W), + where C is the number of channels, H is the height, + and W is the width of the image. + + Returns: + Tensor: The augmented image tensor with `n_holes` patches of size + `length x length` cut out, filled with zeros. + """ + # Get the height and width of the image (ignoring the channel dimension) + height, width = img.size(1), img.size(2) + + # Create a mask initialized with ones, same height and width as the image + # (each pixel is set to 1, representing no masking initially) + mask = np.ones((height, width), dtype=np.float32) + + # Randomly remove `n_holes` patches from the image + for _ in range(self.n_holes): + # Randomly choose the center of a patch (x_center, y_center) + y_center = np.random.randint(height) + x_center = np.random.randint(width) + + # Define the coordinates of the patch based on the center + # and ensure the patch stays within the image boundaries. + y1 = np.clip(y_center - self.length // 2, 0, height) + y2 = np.clip(y_center + self.length // 2, 0, height) + x1 = np.clip(x_center - self.length // 2, 0, width) + x2 = np.clip(x_center + self.length // 2, 0, width) + + # Set the mask to 0 for the patch (mark the patch as cut out) + mask[y1:y2, x1:x2] = 0.0 + + # Convert the mask from numpy array to a PyTorch tensor + mask_tensor = torch.from_numpy(mask).expand_as(img) + + # Multiply the input image by the mask (cut out the selected patches) + return img * mask_tensor diff --git a/EdgeFLite/data_collection/dataset_cifar.py b/EdgeFLite/data_collection/dataset_cifar.py new file mode 100644 index 0000000..154f2d6 --- /dev/null +++ b/EdgeFLite/data_collection/dataset_cifar.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Import necessary libraries +from PIL import Image # For image handling +import os # For file path operations +import numpy as np # For numerical operations +import pickle # For loading serialized data +import torch # For PyTorch operations + +# Import custom classes and functions from the current package +from .vision import VisionDataset +from .utils import validate_integrity, fetch_and_extract_archive + +# CIFAR10 dataset class +class CIFAR10(VisionDataset): + """ + CIFAR10 Dataset class that handles the CIFAR-10 dataset loading, processing, and apply_transformationations. + + Args: + root (str): Directory where the dataset is stored or will be downloaded to. + train (bool, optional): If True, load the training set. Otherwise, load the test set. + apply_transformation (callable, optional): A function/apply_transformation that takes a PIL image and returns a apply_transformationed version. + target_apply_transformation (callable, optional): A function/apply_transformation that takes the target and apply_transformations it. + download (bool, optional): If True, download the dataset if it's not found locally. + split_factor (int, optional): Number of apply_transformationations applied to each image. Default is 1. + """ + # Directory and URL details for downloading the CIFAR-10 dataset + base_folder = 'cifar-10-batches-py' + url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + filename = "cifar-10-python.tar.gz" + tgz_md5 = 'c58f30108f718f92721af3b95e74349a' # MD5 checksum to verify the file's integrity + + # List of training batches with their corresponding MD5 checksums + train_list = [ + ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], + ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], + ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], + ['data_batch_4', '634d18415352ddfa80567beed471001a'], + ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'] + ] + + # List of test batches with their corresponding MD5 checksums + test_list = [ + ['test_batch', '40351d587109b95175f43aff81a1287e'] + ] + + # Info map to hold label names and their checksum + info_map = { + 'filename': 'batches.info_map', + 'key': 'label_names', + 'md5': '5ff9c542aee3614f3951f8cda6e48888' + } + + # Initialization method + def __init__(self, root, train=True, apply_transformation=None, target_apply_transformation=None, download=False, split_factor=1): + super(CIFAR10, self).__init__(root, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation) + self.train = train # Whether to load the training set or test set + self.split_factor = split_factor # Number of apply_transformationations to apply + + # Download dataset if necessary + if download: + self.download() + + # Check if the dataset is already downloaded and valid + if not self._validate_integrity(): + raise RuntimeError('Dataset not found or corrupted. Use download=True to download it.') + + # Load the dataset + self.data, self.targets = self._load_data() + + # Load the label info map (to get class names) + self._load_info_map() + + # Load dataset from the files + def _load_data(self): + data, targets = [], [] # Initialize lists to hold data and labels + files = self.train_list if self.train else self.test_list # Choose train or test files + + # Load each file, deserialize with pickle, and append data and labels + for file_name, _ in files: + file_path = os.path.join(self.root, self.base_folder, file_name) + with open(file_path, 'rb') as f: + entry = pickle.load(f, encoding='latin1') # Load file + data.append(entry['data']) # Append image data + targets.extend(entry.get('labels', entry.get('fine_labels', []))) # Append labels + + # Reshape and format the data to (num_samples, height, width, channels) + data = np.vstack(data).reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) # Reshape to HWC format + return data, targets + + # Load label names (info map) + def _load_info_map(self): + info_map_path = os.path.join(self.root, self.base_folder, self.info_map['filename']) # Path to info map + if not validate_integrity(info_map_path, self.info_map['md5']): # Check integrity of info map + raise RuntimeError('info_mapdata file not found or corrupted. Use download=True to download it.') + + # Load the label names + with open(info_map_path, 'rb') as info_map_file: + info_map_data = pickle.load(info_map_file, encoding='latin1') # Load label names + self.classes = info_map_data[self.info_map['key']] # Extract class labels + self.class_to_idx = {label: idx for idx, label in enumerate(self.classes)} # Map class names to indices + + # Get item (image and target) by index + def __getitem__(self, index): + """ + Get the item (image, target) at the specified index. + Args: + index (int): Index of the data. + + Returns: + tuple: apply_transformationed image and the target class. + """ + img, target = self.data[index], self.targets[index] # Get image and target label + img = Image.fromarray(img) # Convert numpy array to PIL image + + # Apply the apply_transformation multiple times based on split_factor + imgs = [self.apply_transformation(img) for _ in range(self.split_factor)] if self.apply_transformation else None + if imgs is None: + raise NotImplementedError('apply_transformation must be provided.') + + # Apply target apply_transformationation if available + if self.target_apply_transformation: + target = self.target_apply_transformation(target) + + return torch.cat(imgs, dim=0), target # Return concatenated apply_transformationed images and the target + + # Return the number of items in the dataset + def __len__(self): + return len(self.data) + + # Check if the dataset files are valid and downloaded + def _validate_integrity(self): + files = self.train_list + self.test_list # All files to check + for file_name, md5 in files: + file_path = os.path.join(self.root, self.base_folder, file_name) + if not validate_integrity(file_path, md5): # Verify integrity using MD5 + return False + return True + + # Download the dataset if it's not available + def download(self): + if self._validate_integrity(): + print('Files already downloaded and verified') + else: + fetch_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) + + # Representation string to include the split type (Train/Test) + def extra_repr(self): + return f"Split: {'Train' if self.train else 'Test'}" + + +# CIFAR100 is a subclass of CIFAR10, with minor modifications +class CIFAR100(CIFAR10): + """ + CIFAR100 Dataset, a subclass of CIFAR10. + """ + # Directory and URL details for downloading CIFAR-100 dataset + base_folder = 'cifar-100-vision' + url = "https://www.cs.toronto.edu/~kriz/cifar-100-vision.tar.gz" + filename = "cifar-100-vision.tar.gz" + tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' # MD5 checksum + + # Training and test lists with their corresponding MD5 checksums for CIFAR-100 + train_list = [ + ['train', '16019d7e3df5f24257cddd939b257f8d'] + ] + + test_list = [ + ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'] + ] + + # Info map to hold fine label names and their checksum + info_map = { + 'filename': 'info_map', + 'key': 'fine_label_names', + 'md5': '7973b15100ade9c7d40fb424638fde48' + } diff --git a/EdgeFLite/data_collection/dataset_factory.py b/EdgeFLite/data_collection/dataset_factory.py new file mode 100644 index 0000000..c23a396 --- /dev/null +++ b/EdgeFLite/data_collection/dataset_factory.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +from torchvision import apply_transformations +from .cifar import CIFAR10, CIFAR100 # Import CIFAR10 and CIFAR100 datasets +from .autoaugment import CIFAR10Policy # Import CIFAR10 augmentation policy + +__all__ = ['obtain_data_loader'] # Define the public API of this module + +def obtain_data_loader( + data_dir, # Directory where the data is stored + split_factor=1, # Used for data partitioning, especially in federated learning + batch_size=128, # Batch size for loading data + crop_size=32, # Size to crop the input images + dataset='cifar10', # Dataset to use (CIFAR-10 by default) + split="train", # The split type: 'train', 'val', or 'test' + is_decentralized=False, # Whether to use decentralized training + is_autoaugment=1, # Use AutoAugment or not + randaa=None, # Placeholder for randomized augmentations + is_cutout=True, # Whether to apply cutout (random erasing) + erase_p=0.5, # Probability of applying random erasing + num_workers=8, # Number of workers to load data + pin_memory=True, # Use pinned memory for better GPU transfer + is_fed=False, # Whether to use federated learning + num_clusters=20, # Number of clients in federated learning + cifar10_non_iid=False, # Non-IID option for CIFAR-10 dataset + cifar100_non_iid=False # Non-IID option for CIFAR-100 dataset +): + """Get the dataset loader""" + assert not (is_autoaugment and randaa is not None) # Autoaugment and randaa cannot be used together + + # Loader settings based on multiprocessing + kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory} + assert split in ['train', 'val', 'test'] # Ensure valid split + + # For CIFAR-10 dataset + if dataset == 'cifar10': + # Handle non-IID 'quantity skew' case for CIFAR-10 + if cifar10_non_iid == 'quantity_skew': + non_iid = 'quantity_skew' + # If in training split + if 'train' in split: + print(f"INFO:PyTorch: Using quantity_skew CIFAR10 dataset, batch size {batch_size} and crop size is {crop_size}.") + traindir = data_dir # Set data directory + # Define data apply_transformationations for training + train_apply_transformation = apply_transformations.Compose([ + apply_transformations.ToPILImage(), + apply_transformations.RandomCrop(32, padding=4), + apply_transformations.RandomHorizontalFlip(), + CIFAR10Policy(), # AutoAugment policy + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization + apply_transformations.RandomErasing(p=erase_p, scale=(0.125, 0.2), ratio=(0.99, 1.0), value=0, inplace=False), + ]) + train_sampler = None + print('INFO:PyTorch: creating quantity_skew CIFAR10 train dataloader...') + + # For federated learning, create loaders for each client + if is_fed: + train_loader = obtain_data_loaders_train( + traindir, + nclients=num_clusters * split_factor, # Number of clients in federated learning + batch_size=batch_size, + verbose=True, + apply_transformations_train=train_apply_transformation, + non_iid=non_iid, # Specify non-IID type + split_factor=split_factor + ) + else: + assert is_fed # Ensure that is_fed is True + return train_loader, train_sampler + else: + # If in validation or test split + valdir = data_dir # Set validation data directory + # Define data apply_transformationations for validation/testing + val_apply_transformation = apply_transformations.Compose([ + apply_transformations.ToPILImage(), + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization + ]) + # Create the test loader + val_loader = obtain_data_loaders_test( + valdir, + nclients=num_clusters * split_factor, # Number of clients in federated learning + batch_size=batch_size, + verbose=True, + apply_transformations_eval=val_apply_transformation, + non_iid=non_iid, + split_factor=1 + ) + return val_loader + else: + # For standard IID CIFAR-10 case + if 'train' in split: + print(f"INFO:PyTorch: Using CIFAR10 dataset, batch size {batch_size} and crop size is {crop_size}.") + traindir = data_dir # Set training data directory + # Define data apply_transformationations for training + train_apply_transformation = apply_transformations.Compose([ + apply_transformations.RandomCrop(32, padding=4), + apply_transformations.RandomHorizontalFlip(), + CIFAR10Policy(), # AutoAugment policy + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization + apply_transformations.RandomErasing(p=erase_p, scale=(0.125, 0.2), ratio=(0.99, 1.0), value=0, inplace=False), + ]) + # Create the CIFAR-10 dataset object + train_dataset = CIFAR10( + traindir, train=True, apply_transformation=train_apply_transformation, target_apply_transformation=None, download=True, split_factor=split_factor + ) + train_sampler = None # No sampler by default + + # Decentralized training setup + if is_decentralized: + train_sampler = torch.utils.data.decentralized.decentralizedSampler(train_dataset, shuffle=True) + + print('INFO:PyTorch: creating CIFAR10 train dataloader...') + if is_fed: + # Federated learning setup + images_per_client = int(train_dataset.data.shape[0] / (num_clusters * split_factor)) + print(f"Images per client: {images_per_client}") + data_split = [images_per_client for _ in range(num_clusters * split_factor - 1)] + data_split.append(len(train_dataset) - images_per_client * (num_clusters * split_factor - 1)) + # Split dataset for each client + traindata_split = torch.utils.data.random_split(train_dataset, data_split, generator=torch.Generator().manual_seed(68)) + # Create data loaders for each client + train_loader = [torch.utils.data.DataLoader( + x, batch_size=batch_size, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler, **kwargs + ) for x in traindata_split] + else: + # Standard data loader + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler, **kwargs + ) + return train_loader, train_sampler + else: + # For validation or test split + valdir = data_dir # Set validation data directory + # Define data apply_transformationations for validation/testing + val_apply_transformation = apply_transformations.Compose([ + apply_transformations.ToTensor(), + apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization + ]) + # Create CIFAR-10 dataset object for validation + val_dataset = CIFAR10(valdir, train=False, apply_transformation=val_apply_transformation, target_apply_transformation=None, download=True, split_factor=1) + print('INFO:PyTorch: creating CIFAR10 validation dataloader...') + # Create data loader for validation + val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, **kwargs) + return val_loader + # Additional dataset logic for CIFAR-100, decentralized setups, or other datasets can be added similarly. + else: + raise NotImplementedError(f"The DataLoader for {dataset} is not implemented.") diff --git a/EdgeFLite/data_collection/dataset_imagenet.py b/EdgeFLite/data_collection/dataset_imagenet.py new file mode 100644 index 0000000..f26a4f5 --- /dev/null +++ b/EdgeFLite/data_collection/dataset_imagenet.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import warnings +from contextlib import contextmanager +import os +import shutil +import tempfile +import torch +from .folder import ImageFolder +from .utils import validate_integrity, extract_archive, verify_str_arg + +# Dictionary that maps the dataset split (train/val/devkit) to its corresponding archive filename and checksum (md5 hash) +ARCHIVE_info_map = { + 'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'), + 'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'), + 'devkit': ('ILSVRC2012_devkit_t12.tar', 'fa75699e90414af021442c21a62c3abf') +} + +# File name where the information map (class info, wnid, etc.) is stored +info_map_FILE = "info_map.bin" + +class ImageNet(ImageFolder): + """`ImageNet `_ 2012 Classification Dataset. + + Args: + root (str): Root directory of the ImageNet Dataset. + split (str, optional): Dataset split, either ``train`` or ``val``. + apply_transformation (callable, optional): A function/apply_transformation to apply to the PIL image. + target_apply_transformation (callable, optional): A function/apply_transformation to apply to the target. + loader (callable, optional): Function to load an image from its path. + + Attributes: + classes (list): List of class name tuples. + class_to_idx (dict): Mapping of class names to indices. + wnids (list): List of WordNet IDs. + wnid_to_idx (dict): Mapping of WordNet IDs to class indices. + imgs (list): List of image path and class index tuples. + targets (list): Class index values for each image in the dataset. + """ + + def __init__(self, root, split='train', download=None, **kwargs): + # Check if download flag is used, raise warnings since dataset is no longer publicly accessible + if download is True: + raise RuntimeError("The dataset is no longer publicly accessible. Please download archives externally and place them in the root directory.") + elif download is False: + warnings.warn("The download flag is deprecated, as the dataset is no longer publicly accessible.", RuntimeWarning) + + # Expand the root directory path + root = self.root = os.path.expanduser(root) + + # Validate the dataset split (should be either 'train' or 'val') + self.split = verify_str_arg(split, "split", ("train", "val")) + + # Parse dataset archives (train/val/devkit) and prepare the dataset + self.extract_archives() + + # Load WordNet ID to class mappings from the info_map file + wnid_to_classes = load_information_map_file(self.root)[0] + + # Initialize the ImageFolder with the split folder (train/val directory) + super().__init__(self.divide_folder_contents, **kwargs) + + # Set class-related attributes + self.root = root + self.wnids = self.classes + self.wnid_to_idx = self.class_to_idx + + # Update classes to human-readable names and adjust the class_to_idx mapping + self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] + self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss} + + def extract_archives(self): + # Check if the info_map file exists and is valid, otherwise parse the devkit archive + if not validate_integrity(os.path.join(self.root, info_map_FILE)): + extract_devkit_archive(self.root) + + # If the dataset folder (train/val) does not exist, extract the respective archive + if not os.path.isdir(self.divide_folder_contents): + if self.split == 'train': + process_train_archive(self.root) + elif self.split == 'val': + process_validation_archive(self.root) + + @property + def divide_folder_contents(self): + # Return the path of the folder containing the images (train/val) + return os.path.join(self.root, self.split) + + def extra_repr(self): + # Additional representation for the dataset object (showing the split) + return f"Split: {self.split}" + +def load_information_map_file(root, file=None): + # Load the info_map file from the root directory + file = os.path.join(root, file or info_map_FILE) + if validate_integrity(file): + return torch.load(file) + else: + raise RuntimeError(f"The info_map file {file} is either missing or corrupted. Please ensure it exists in the root directory.") + +def _validate_archive_file(root, file, md5): + # Verify if the archive file is present and its checksum matches + if not validate_integrity(os.path.join(root, file), md5): + raise RuntimeError(f"The archive {file} is either missing or corrupted. Please download it and place it in {root}.") + +def extract_devkit_archive(root, file=None): + """Extract and process the ImageNet 2012 devkit archive to generate info_map information. + + Args: + root (str): Root directory with the devkit archive. + file (str, optional): Archive filename. Defaults to 'ILSVRC2012_devkit_t12.tar'. + """ + import scipy.io as sio + + # Parse info_map.mat from the devkit, containing class and WordNet ID information + def read_info_map_mat_file(devkit_root): + info_map_path = os.path.join(devkit_root, "data", "info_map.mat") + info_map = sio.loadmat(info_map_path, squeeze_me=True)['synsets'] + info_map = [info_map[idx] for idx, num_children in enumerate(info_map[4]) if num_children == 0] + idcs, wnids, classes = zip(*info_map)[:3] + classes = [tuple(clss.split(', ')) for clss in classes] + return {idx: wnid for idx, wnid in zip(idcs, wnids)}, {wnid: clss for wnid, clss in zip(wnids, classes)} + + # Parse the validation ground truth file for image class labels + def process_val_groundtruth_txt(devkit_root): + file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt") + with open(file) as f: + return [int(line.strip()) for line in f] + + # Context manager to handle temporary directories for archive extraction + @contextmanager + def get_tmp_dir(): + tmp_dir = tempfile.mkdtemp() + try: + yield tmp_dir + finally: + shutil.rmtree(tmp_dir) + + # Extract and process the devkit archive + file, md5 = ARCHIVE_info_map["devkit"] + _validate_archive_file(root, file, md5) + + with get_tmp_dir() as tmp_dir: + extract_archive(os.path.join(root, file), tmp_dir) + devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12") + idx_to_wnid, wnid_to_classes = read_info_map_mat_file(devkit_root) + val_idcs = process_val_groundtruth_txt(devkit_root) + val_wnids = [idx_to_wnid[idx] for idx in val_idcs] + + # Save the mappings to the info_map file + torch.save((wnid_to_classes, val_wnids), os.path.join(root, info_map_FILE)) + +def process_train_archive(root, file=None, folder="train"): + """Extract and organize the ImageNet 2012 train dataset. + + Args: + root (str): Root directory containing the train dataset archive. + file (str, optional): Archive filename. Defaults to 'ILSVRC2012_img_train.tar'. + folder (str, optional): Destination folder. Defaults to 'train'. + """ + file, md5 = ARCHIVE_info_map["train"] + _validate_archive_file(root, file, md5) + + train_root = os.path.join(root, folder) + extract_archive(os.path.join(root, file), train_root) + + # Extract each class-specific archive in the train dataset + for archive in os.listdir(train_root): + extract_archive(os.path.join(train_root, archive), os.path.splitext(archive)[0], remove_finished=True) + +def process_validation_archive(root, file=None, wnids=None, folder="val"): + """Extract and organize the ImageNet 2012 validation dataset. + + Args: + root (str): Root directory containing the validation dataset archive. + file (str, optional): Archive filename. Defaults to 'ILSVRC2012_img_val.tar'. + wnids (list, optional): WordNet IDs for validation images. Defaults to None (loaded from info_map file). + folder (str, optional): Destination folder. Defaults to 'val'. + """ + file, md5 = ARCHIVE_info_map["val"] + if wnids is None: + wnids = load_information_map_file(root)[1] + + _validate_archive_file(root, file, md5) + + val_root = os.path.join(root, folder) + extract_archive(os.path.join(root, file), val_root) + + # Create directories for each WordNet ID (class) and move validation images into their respective folders + for wnid in set(wnids): + os.mkdir(os.path.join(val_root, wnid)) + + for wnid, img in zip(wnids, sorted(os diff --git a/EdgeFLite/data_collection/directory_utils.py b/EdgeFLite/data_collection/directory_utils.py new file mode 100644 index 0000000..e9d45cf --- /dev/null +++ b/EdgeFLite/data_collection/directory_utils.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Import necessary modules +from .vision import VisionDataset # Import the base VisionDataset class +from PIL import Image # Import PIL for image loading and processing +import os # For interacting with the file system +import torch # PyTorch for tensor operations + +# Function to check if a file has an allowed extension +def validate_file_extension(filename, extensions): + """ + Check if a file has an allowed extension. + + Args: + filename (str): Path to the file. + extensions (tuple of str): Extensions to consider (in lowercase). + + Returns: + bool: True if the filename ends with one of the given extensions. + """ + return filename.lower().endswith(extensions) + +# Function to check if a file is an image +def is_image_file(filename): + """ + Check if a file is an image based on its extension. + + Args: + filename (str): Path to the file. + + Returns: + bool: True if the filename is a known image format. + """ + return validate_file_extension(filename, IMG_EXTENSIONS) + +# Function to create a dataset of file paths and their corresponding class indices +def generate_dataset(directory, class_to_idx, extensions=None, is_valid_file=None): + """ + Creates a list of file paths and their corresponding class indices. + + Args: + directory (str): Root directory. + class_to_idx (dict): Mapping of class names to class indices. + extensions (tuple, optional): Allowed file extensions. + is_valid_file (callable, optional): Function to validate files. + + Returns: + list: A list of (file_path, class_index) tuples. + """ + instances = [] + directory = os.path.expanduser(directory) # Expand user directory path if needed + + # Ensure only one of extensions or is_valid_file is specified + if (extensions is None and is_valid_file is None) or (extensions is not None and is_valid_file is not None): + raise ValueError("Specify either 'extensions' or 'is_valid_file', but not both.") + + # Define the validation function if extensions are provided + if extensions is not None: + def is_valid_file(x): + return validate_file_extension(x, extensions) + + # Iterate through the directory, searching for valid image files + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] # Get the class index + target_dir = os.path.join(directory, target_class) # Define the target class folder + if not os.path.isdir(target_dir): # Skip if it's not a directory + continue + # Walk through the directory and subdirectories + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) # Full file path + if is_valid_file(path): # Check if it's a valid file + instances.append((path, class_index)) # Append file path and class index to the list + + return instances # Return the dataset + +# DatasetFolder class: Generic data loader for samples arranged in subdirectories by class +class DatasetFolder(VisionDataset): + """ + A generic data loader where samples are arranged in subdirectories by class. + + Args: + root (str): Root directory path. + loader (callable): Function to load a sample from its file path. + extensions (tuple[str]): Allowed file extensions. + apply_transformation (callable, optional): apply_transformation applied to each sample. + target_apply_transformation (callable, optional): apply_transformation applied to each target. + is_valid_file (callable, optional): Function to validate files. + split_factor (int, optional): Number of times to apply the apply_transformation. + + Attributes: + classes (list): Sorted list of class names. + class_to_idx (dict): Mapping of class names to class indices. + samples (list): List of (sample_path, class_index) tuples. + targets (list): List of class indices corresponding to each sample. + """ + + def __init__(self, root, loader, extensions=None, apply_transformation=None, + target_apply_transformation=None, is_valid_file=None, split_factor=1): + super().__init__(root, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation) + self.classes, self.class_to_idx = self._discover_classes(self.root) # Discover classes in the root directory + self.samples = generate_dataset(self.root, self.class_to_idx, extensions, is_valid_file) # Create dataset from files + + # Raise an error if no valid files are found + if len(self.samples) == 0: + raise RuntimeError(f"Found 0 files in subfolders of: {self.root}. " + f"Supported extensions are: {','.join(extensions)}") + + self.loader = loader # Function to load a sample + self.extensions = extensions # Allowed file extensions + self.targets = [s[1] for s in self.samples] # List of target class indices + self.split_factor = split_factor # Number of apply_transformationations to apply + + # Function to find class subdirectories in the root directory + def _discover_classes(self, dir): + """ + Discover class subdirectories in the root directory. + + Args: + dir (str): Root directory. + + Returns: + tuple: (classes, class_to_idx) where classes are subdirectories of 'dir', + and class_to_idx is a mapping of class names to indices. + """ + classes = sorted([d.name for d in os.scandir(dir) if d.is_dir()]) # List of subdirectory names (classes) + class_to_idx = {classes[i]: i for i in range(len(classes))} # Map class names to indices + return classes, class_to_idx + + # Function to get a sample and its target by index + def __getitem__(self, index): + """ + Retrieve a sample and its target by index. + + Args: + index (int): Index of the sample. + + Returns: + tuple: (sample, target), where the sample is the apply_transformationed image and + the target is the class index. + """ + path, target = self.samples[index] # Get the file path and target class index + sample = self.loader(path) # Load the sample (image) + + # Apply apply_transformationation to the sample 'split_factor' times + imgs = [self.apply_transformation(sample) for _ in range(self.split_factor)] if self.apply_transformation else NotImplementedError + + # Apply target apply_transformationation if specified + if self.target_apply_transformation: + target = self.target_apply_transformation(target) + + return torch.cat(imgs, dim=0), target # Return concatenated apply_transformationed images and the target + + # Return the number of samples in the dataset + def __len__(self): + return len(self.samples) + +# List of supported image file extensions +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') + +# Function to load an image using PIL +def load_image_pil(path): + """ + Load an image from the given path using PIL. + + Args: + path (str): Path to the image. + + Returns: + Image: RGB image. + """ + with open(path, 'rb') as f: + img = Image.open(f) # Open the image file + return img.convert('RGB') # Convert the image to RGB format + +# Function to load an image using accimage library with fallback to PIL +def load_accimage(path): + """ + Load an image using the accimage library, falling back to PIL on failure. + + Args: + path (str): Path to the image. + + Returns: + Image: Image loaded with accimage or PIL. + """ + import accimage # accimage is a faster image loading library + try: + return accimage.Image(path) # Try loading with accimage + except IOError: + return load_image_pil(path) # Fall back to PIL on error + +# Function to load an image using the default backend (accimage or PIL) +def basic_loader(path): + """ + Load an image using the default image backend (accimage or PIL). + + Args: + path (str): Path to the image. + + Returns: + Image: Loaded image. + """ + from torchvision import get_image_backend # Get the default image backend + return load_accimage(path) if get_image_backend() == 'accimage' else load_image_pil(path) # Load using the appropriate backend + +# ImageFolder class: A dataset loader for images arranged in subdirectories by class +class ImageFolder(DatasetFolder): + """ + A dataset loader for images arranged in subdirectories by class. + + Args: + root (str): Root directory path. + apply_transformation (callable, optional): apply_transformation applied to each image. + target_apply_transformation (callable, optional): apply_transformation applied to each target. + loader (callable, optional): Function to load an image from its path. + is_valid_file (callable, optional): Function to validate files. + + Attributes: + classes (list): Sorted list of class names. + class_to_idx (dict): Mapping of class names to class indices. + imgs (list): List of (image_path, class_index) tuples. + """ + + def __init__(self, root, apply_transformation=None, target_apply_transformation=None, loader=basic_loader, is_valid_file=None, split_factor=1): + super().__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, + apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation, + is_valid_file=is diff --git a/EdgeFLite/data_collection/helper_utils.py b/EdgeFLite/data_collection/helper_utils.py new file mode 100644 index 0000000..169a823 --- /dev/null +++ b/EdgeFLite/data_collection/helper_utils.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import hashlib +import gzip +import tarfile +import zipfile +import urllib.request +from torch.utils.model_zoo import tqdm + +def generate_update_progress_barr(): + """Generates a progress bar for tracking download progress.""" + pbar = tqdm(total=None) + + def update_progress_bar(count, block_size, total_size): + """Updates the progress bar based on the downloaded data size.""" + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + + return update_progress_bar + +def compute_md5_checksum(fpath, chunk_size=1024 * 1024): + """Calculates the MD5 checksum for a given file.""" + md5 = hashlib.md5() + with open(fpath, 'rb') as f: + for chunk in iter(lambda: f.read(chunk_size), b''): + md5.update(chunk) + return md5.hexdigest() + +def verify_md5_checksum(fpath, md5): + """Checks if the MD5 of a file matches the given checksum.""" + return md5 == compute_md5_checksum(fpath) + +def validate_integrity(fpath, md5=None): + """Checks the integrity of a file by verifying its existence and MD5 checksum.""" + if not os.path.isfile(fpath): + return False + return md5 is None or verify_md5_checksum(fpath, md5) + +def download_url(url, root, filename=None, md5=None): + """Download a file from a URL and save it in the specified directory.""" + root = os.path.expanduser(root) + filename = filename or os.path.basename(url) + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + if validate_integrity(fpath, md5): + print('Using downloaded and verified file: ' + fpath) + return + + try: + print('Downloading ' + url + ' to ' + fpath) + urllib.request.urlretrieve(url, fpath, reporthook=generate_update_progress_barr()) + except (urllib.error.URLError, IOError) as e: + if url.startswith('https'): + url = url.replace('https:', 'http:') + print('Failed download. Retrying with http.') + urllib.request.urlretrieve(url, fpath, reporthook=generate_update_progress_barr()) + else: + raise e + + if not validate_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + +def list_dir(root, prefix=False): + """List all directories at the specified root.""" + root = os.path.expanduser(root) + directories = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))] + + return [os.path.join(root, d) for d in directories] if prefix else directories + +def list_files(root, suffix, prefix=False): + """List all files with a specific suffix in the specified root.""" + root = os.path.expanduser(root) + files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f)) and f.endswith(suffix)] + + return [os.path.join(root, f) for f in files] if prefix else files + +def fetch_file_google_drive(file_id, root, filename=None, md5=None): + """Download a file from Google Drive and save it in the specified directory.""" + url = "https://docs.google.com/uc?export=download" + root = os.path.expanduser(root) + filename = filename or file_id + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + if os.path.isfile(fpath) and validate_integrity(fpath, md5): + print('Using downloaded and verified file: ' + fpath) + return + + session = requests.Session() + response = session.get(url, params={'id': file_id}, stream=True) + token = _get_confirm_token(response) + + if token: + params = {'id': file_id, 'confirm': token} + response = session.get(url, params=params, stream=True) + + _store_response_content(response, fpath) + +def _get_confirm_token(response): + """Extract the download token from Google Drive cookies.""" + return next((value for key, value in response.cookies.items() if key.startswith('download_warning')), None) + +def _store_response_content(response, destination, chunk_size=32768): + """Save the response content to a file in chunks.""" + with open(destination, "wb") as f: + pbar = tqdm(total=None) + progress = 0 + for chunk in response.iter_content(chunk_size): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + progress += len(chunk) + pbar.update(progress - pbar.n) + pbar.close() + +def extract_archive(from_path, to_path=None, remove_finished=False): + """Extract an archive file (tar, zip, gz) to the specified path.""" + if to_path is None: + to_path = os.path.dirname(from_path) + + if from_path.endswith((".tar", ".tar.gz", ".tgz", ".tar.xz")): + mode = 'r' + ('.gz' if from_path.endswith(('.tar.gz', '.tgz')) else + '.xz' if from_path.endswith('.tar.xz') else '') + with tarfile.open(from_path, mode) as tar: + tar.extractall(path=to_path) + elif from_path.endswith(".gz"): + to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) + with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: + out_f.write(zip_f.read()) + elif from_path.endswith(".zip"): + with zipfile.ZipFile(from_path, 'r') as z: + z.extractall(to_path) + else: + raise ValueError("Extraction of {} not supported".format(from_path)) + + if remove_finished: + os.remove(from_path) + +def fetch_and_extract_archive(url, download_root, extract_root=None, filename=None, md5=None, remove_finished=False): + """Download and extract an archive file from a URL.""" + download_root = os.path.expanduser(download_root) + extract_root = extract_root or download_root + filename = filename or os.path.basename(url) + + download_url(url, download_root, filename, md5) + archive = os.path.join(download_root, filename) + print("Extracting {} to {}".format(archive, extract_root)) + extract_archive(archive, extract_root, remove_finished) + +def iterable_to_str(iterable): + """Convert an iterable to a string representation.""" + return "'" + "', '".join(map(str, iterable)) + "'" + +def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None): + """Verify that a string argument is valid and raise an error if not.""" + if not isinstance(value, str): + msg = f"Expected type str" + (f" for argument {arg}" if arg else "") + f", but got type {type(value)}." + raise ValueError(msg) + + if valid_values is None: + return value + + if value not in valid_values: + msg = custom_msg or f"Unknown value '{value}' for argument {arg}. Valid values are {{{iterable_to_str(valid_values)}}}." + raise ValueError(msg) + + return value diff --git a/EdgeFLite/data_collection/pill_data_base.py b/EdgeFLite/data_collection/pill_data_base.py new file mode 100644 index 0000000..46308c1 --- /dev/null +++ b/EdgeFLite/data_collection/pill_data_base.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +from PIL import Image +from torch.utils.data import DataLoader, Dataset +import torch +import os + +# Importing the HOME configuration +from config import HOME + +class PillDataBase(Dataset): + def __init__(self, data_dir=HOME + '/dataset_hub/pill_base', train=True, apply_transformation=None, split_factor=1): + """ + Initialize the dataset. + + Args: + data_dir (str): Directory where the dataset is stored. + train (bool): Flag to indicate if it's a training or testing dataset. + apply_transformation (callable): Optional apply_transformationation applied to images (e.g., resizing, normalization). + split_factor (int): Number of times each image is split into parts for augmentation purposes. + """ + self.train = train + self.apply_transformation = apply_transformation + self.split_factor = split_factor + self.data_dir = data_dir + '/pill_base' + self.dataset = self._load_data() + + def __len__(self): + """Return the number of samples in the dataset.""" + return len(self.dataset) + + def _load_data(self): + """ + Load the dataset by reading the corresponding text file (train.txt or test.txt). + + The dataset text file contains the image file paths and corresponding labels. + + Returns: + dataset (list): List of image file paths and their respective labels. + """ + dataset = [] + txt_path = os.path.join(self.data_dir, 'train.txt' if self.train else 'test.txt') + + with open(txt_path, 'r') as file: + lines = file.readlines() + for line in lines: + # Each line contains an image path and a label separated by space + filename, label = line.strip().split(' ') + # Adjust the image path to the correct directory structure + filename = filename.replace('/home/tung/Tung/research/Open-Pill/FACIL/data/Pill_Base_X', self.data_dir) + # Append the image file path and label as an integer + dataset.append([filename, int(label)]) + + return dataset + + def __getitem__(self, index): + """ + Retrieve a specific sample from the dataset at the given index. + + Args: + index (int): Index of the image and label to retrieve. + + Returns: + tuple: A tensor of concatenated apply_transformationed images and the corresponding label. + """ + images = [] + image_path = self.dataset[index][0] + label = torch.tensor(int(self.dataset[index][1])) + + # Open the image file + image = Image.open(image_path) + + # Apply apply_transformationations to the image if provided and split into parts as specified by split_factor + if self.apply_transformation: + for _ in range(self.split_factor): + images.append(self.apply_transformation(image)) + + # Concatenate all apply_transformationed image splits into a single tensor + return torch.cat(images, dim=0), label + +if __name__ == "__main__": + # Example of how to instantiate and use the dataset + dataset = PillDataBase() diff --git a/EdgeFLite/data_collection/pill_data_large.py b/EdgeFLite/data_collection/pill_data_large.py new file mode 100644 index 0000000..f5fd1db --- /dev/null +++ b/EdgeFLite/data_collection/pill_data_large.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import glob +from PIL import Image +import torch +from torch.utils.data import Dataset + +# Define the folder paths for training and testing datasets +FOLDER_PATHS = [ + '/media/skydata/alpha0012/workspace/EdgeFLite/coremodel/dataset_hub/medical_images/train_images', + '/media/skydata/alpha0012/workspace/EdgeFLite/coremodel/dataset_hub/medical_images/test_images' +] + +# Custom dataset class inheriting from PyTorch's Dataset class +class PillDataLarge(Dataset): + def __init__(self, train=True, apply_transformation=None, split_factor=1): + """ + Initializes the dataset object. + + Args: + - train (bool): If True, load the training dataset, otherwise load the test dataset. + - apply_transformation (callable, optional): Optional apply_transformationations to be applied on an image sample. + - split_factor (int): Number of times to apply the apply_transformationations to the image. + """ + self.train = train # Flag to determine if the dataset is for training or testing + self.apply_transformation = apply_transformation # apply_transformationation to apply to the images + self.split_factor = split_factor # Number of times to apply the apply_transformationation + self.dataset = self._load_data() # Load the dataset + + def __len__(self): + """ + Returns the total number of samples in the dataset. + """ + return len(self.dataset) + + def _load_data(self): + """ + Loads the data from the dataset folders. + + Returns: + - dataset (list): A list containing image file paths and their corresponding class IDs. + """ + folder_path = FOLDER_PATHS[0] if self.train else FOLDER_PATHS[1] # Use train or test folder path + class_names = sorted(os.listdir(folder_path)) # Get class names from folder + class_map = {name: idx for idx, name in enumerate(class_names)} # Map class names to IDs + + dataset = [] + for class_name, class_id in class_map.items(): + folder_class = os.path.join(folder_path, class_name) # Path to class folder + files_jpg = glob.glob(os.path.join(folder_class, '**', '*.jpg'), recursive=True) # Get all jpg files + for file_path in files_jpg: + dataset.append([file_path, class_id]) # Append file path and class ID to the dataset + + return dataset + + def __getitem__(self, index): + """ + Returns a sample and its corresponding label from the dataset. + + Args: + - index (int): Index of the sample. + + Returns: + - tuple: A tuple of the image tensor and the label tensor. + """ + Xs = [] # List to store apply_transformationed images + image_path = self.dataset[index][0] # Get image path from dataset + label = torch.tensor(int(self.dataset[index][1])) # Get class label as tensor + + X = Image.open(image_path) # Open the image using PIL + + if self.apply_transformation: + for _ in range(self.split_factor): + Xs.append(self.apply_transformation(X)) # Apply apply_transformationation multiple times + + return torch.cat(Xs, dim=0), label # Concatenate all apply_transformationed images and return with the label + +if __name__ == "__main__": + dataset = PillDataLarge() # Create an instance of the dataset + print(len(dataset)) # Print the size of the dataset + print(dataset[0]) # Print the first sample of the dataset diff --git a/EdgeFLite/data_collection/skin_dataset.py b/EdgeFLite/data_collection/skin_dataset.py new file mode 100644 index 0000000..3e0107c --- /dev/null +++ b/EdgeFLite/data_collection/skin_dataset.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Import necessary libraries for image processing and handling datasets. +from PIL import Image # Used for opening and manipulating images. +from cv2 import split # A function from OpenCV, though it's not used here. It may have been intended for something else. +from torch.utils.data import DataLoader, Dataset # These are PyTorch utilities for managing datasets and data loading. +import torch # PyTorch library for tensor operations and deep learning. + +# Define a custom dataset class named 'SkinData' which inherits from PyTorch's Dataset class. +class SkinData(Dataset): + # Initialize the dataset with a DataFrame (df), an optional apply_transformationation (apply_transformation), and a split factor (split_factor). + def __init__(self, df, apply_transformation=None, split_factor=1): + self.df = df # Store the DataFrame containing image paths and target labels. + self.apply_transformation = apply_transformation # Optional image apply_transformationations to apply (e.g., resizing, normalizing). + self.split_factor = split_factor # A factor determining how many times to split or augment the image. + self.test_same_view = False # A flag indicating whether to return multiple augmentations of the same image. + + # Return the number of samples in the dataset, which corresponds to the number of rows in the DataFrame. + def __len__(self): + return len(self.df) + + # Retrieve the image and corresponding label at a specific index. + def __getitem__(self, index): + Xs = [] # Create an empty list to store apply_transformationed versions of the image. + + # Open the image located at the 'path' specified by the index in the DataFrame, then resize it to 64x64. + X = Image.open(self.df['path'][index]).resize((64, 64)) + + # Retrieve the target label (as a tensor) from the 'target' column of the DataFrame and convert it to a PyTorch tensor. + y = torch.tensor(int(self.df['target'][index])) + + # If 'test_same_view' is set to True, apply the same apply_transformationation multiple times and store the augmented images. + if self.test_same_view: + if self.apply_transformation: + aug = self.apply_transformation(X) # Apply the apply_transformationation once to the image. + # Store the same augmented image multiple times in the list 'Xs' (repeated 'split_factor' times). + Xs = [aug for _ in range(self.split_factor)] + else: + # If 'test_same_view' is False, apply the apply_transformationation independently to create different augmentations. + if self.apply_transformation: + # Store different augmentations of the image in the list 'Xs', each apply_transformationed independently. + Xs = [self.apply_transformation(X) for _ in range(self.split_factor)] + + # Concatenate the list of images into a single tensor along the first dimension (batch) and return it along with the label. + return torch.cat(Xs, dim=0), y diff --git a/EdgeFLite/data_collection/vision_utils.py b/EdgeFLite/data_collection/vision_utils.py new file mode 100644 index 0000000..84b32ee --- /dev/null +++ b/EdgeFLite/data_collection/vision_utils.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import torch +import torch.utils.data as data + +# VisionDataset is a custom dataset class inheriting from PyTorch's Dataset class. +# It handles the initialization and representation of a vision-related dataset, +# including optional apply_transformationation of input data and targets. +class VisionDataset(data.Dataset): + _repr_indent = 4 # Defines the indentation level for dataset representation + + def __init__(self, root, apply_transformations=None, apply_transformation=None, target_apply_transformation=None): + # Initializes the dataset by setting root directory and optional apply_transformationations + # If root is a string, expand any user directory shortcuts like "~" + self.root = os.path.expanduser(root) if isinstance(root, str) else root + + # Check if either 'apply_transformations' or 'apply_transformation/target_apply_transformation' is provided (but not both) + has_apply_transformations = apply_transformations is not None + has_separate_apply_transformation = apply_transformation is not None or target_apply_transformation is not None + + if has_apply_transformations and has_separate_apply_transformation: + raise ValueError("Only one of 'apply_transformations' or 'apply_transformation/target_apply_transformation' can be provided.") + + # Set apply_transformationations + self.apply_transformation = apply_transformation + self.target_apply_transformation = target_apply_transformation + + # If separate apply_transformations are provided, wrap them in a StandardTransform + if has_separate_apply_transformation: + apply_transformations = StandardTransform(apply_transformation, target_apply_transformation) + self.apply_transformations = apply_transformations + + # Placeholder for the method to retrieve an item by index + def __getitem__(self, index): + raise NotImplementedError + + # Placeholder for the method to return dataset length + def __len__(self): + raise NotImplementedError + + # Representation of the dataset including number of datapoints, root directory, and apply_transformations + def __repr__(self): + head = f"Dataset {self.__class__.__name__}" + body = [f"Number of datapoints: {self.__len__()}"] + if self.root is not None: + body.append(f"Root location: {self.root}") + body += self.extra_repr().splitlines() # Include any additional representation details + if hasattr(self, "apply_transformations") and self.apply_transformations is not None: + body.append(repr(self.apply_transformations)) # Include apply_transformationation details if applicable + lines = [head] + [" " * self._repr_indent + line for line in body] + return '\n'.join(lines) + + # Utility to format the representation of the apply_transformation and target_apply_transformation attributes + def _format_apply_transformation_repr(self, apply_transformation, head): + lines = apply_transformation.__repr__().splitlines() + return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]] + + # Hook for adding extra dataset-specific information in the representation + def extra_repr(self): + return "" + + +# StandardTransform class handles the application of the apply_transformation and target_apply_transformation +# during dataset iteration or data loading. +class StandardTransform: + def __init__(self, apply_transformation=None, target_apply_transformation=None): + # Initialize with optional input and target apply_transformationations + self.apply_transformation = apply_transformation + self.target_apply_transformation = target_apply_transformation + + # Calls the appropriate apply_transformations on the input and target when invoked + def __call__(self, input, target): + if self.apply_transformation is not None: + input = self.apply_transformation(input) + if self.target_apply_transformation is not None: + target = self.target_apply_transformation(target) + return input, target + + # Utility to format the apply_transformationation representation + def _format_apply_transformation_repr(self, apply_transformation, head): + lines = apply_transformation.__repr__().splitlines() + return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]] + + # Representation of the StandardTransform including both input and target apply_transformationations + def __repr__(self): + body = [self.__class__.__name__] + if self.apply_transformation is not None: + body += self._format_apply_transformation_repr(self.apply_transformation, "apply_transformation: ") + if self.target_apply_transformation is not None: + body += self._format_apply_transformation_repr(self.target_apply_transformation, "Target apply_transformation: ") + + return '\n'.join(body) diff --git a/EdgeFLite/debug_tool.py b/EdgeFLite/debug_tool.py new file mode 100644 index 0000000..e925f1d --- /dev/null +++ b/EdgeFLite/debug_tool.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Import necessary libraries +import torch # PyTorch for tensor computations and neural networks +from torch import nn # Neural network module +# "decentralized" is not a valid import in PyTorch, possibly a typo. Removed for now. + +# Check for available device (CPU or GPU) +# If a GPU is available (CUDA), the code will use it; otherwise, it falls back to CPU. +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# Define normalization layer and the number of initial input channels for the convolutional layers +batch_norm_layer = nn.BatchNorm2d # 2D Batch Normalization to stabilize training +initial_channels = 32 # Number of channels for the first convolutional layer + +# Define the convolutional neural network (CNN) architecture using nn.Sequential +network = nn.Sequential( + # 1st convolutional layer: takes 3 input channels (RGB image), outputs 'initial_channels' feature maps + # Uses kernel size 3, stride 2 for downsampling, and padding 1 to maintain spatial dimensions + nn.Conv2d(in_channels=3, out_channels=initial_channels, kernel_size=3, stride=2, padding=1, bias=False), + batch_norm_layer(initial_channels), # Apply Batch Normalization to the output + nn.ReLU(inplace=True), # ReLU activation function to introduce non-linearity + + # 2nd convolutional layer: takes 'initial_channels' input, outputs the same number of feature maps + # No downsampling here (stride 1) + nn.Conv2d(in_channels=initial_channels, out_channels=initial_channels, kernel_size=3, stride=1, padding=1, bias=False), + batch_norm_layer(initial_channels), # Batch normalization for better convergence + nn.ReLU(inplace=True), # ReLU activation + + # 3rd convolutional layer: doubles the number of output channels (for deeper features) + # Again, no downsampling (stride 1) + nn.Conv2d(in_channels=initial_channels, out_channels=initial_channels * 2, kernel_size=3, stride=1, padding=1, bias=False), + batch_norm_layer(initial_channels * 2), # Batch normalization for the increased feature maps + nn.ReLU(inplace=True), # ReLU activation + + # Max pooling layer to further downsample the feature maps (reduces spatial dimensions) + nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Pooling with kernel size 3 and stride 2 +) + +# Create a dummy input tensor simulating a batch of 128 images with 3 channels (RGB), each of size 64x64 +sample_input = torch.randn(128, 3, 64, 64) + +# Print the defined network architecture and the shape of the output after a forward pass +print(network) +# Perform a forward pass with the sample input and print the resulting output shape +print(network(sample_input).shape) diff --git a/EdgeFLite/fedml_service/.DS_Store b/EdgeFLite/fedml_service/.DS_Store new file mode 100644 index 0000000..ade4daf Binary files /dev/null and b/EdgeFLite/fedml_service/.DS_Store differ diff --git a/EdgeFLite/fedml_service/architecture/.DS_Store b/EdgeFLite/fedml_service/architecture/.DS_Store new file mode 100644 index 0000000..81a948a Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/.DS_Store differ diff --git a/EdgeFLite/fedml_service/architecture/cv/.DS_Store b/EdgeFLite/fedml_service/architecture/cv/.DS_Store new file mode 100644 index 0000000..852bec4 Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/.DS_Store differ diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/.DS_Store b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/.DS_Store new file mode 100644 index 0000000..35918f6 Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/.DS_Store differ diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/.DS_Store b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/.DS_Store new file mode 100644 index 0000000..efe3d1d Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/.DS_Store differ diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/resnet56/test_metrics b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/resnet56/test_metrics new file mode 100644 index 0000000..face2ed Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/resnet56/test_metrics differ diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/resnet56/train_metrics b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/resnet56/train_metrics new file mode 100644 index 0000000..3e3004c Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/resnet56/train_metrics differ diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/.DS_Store b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/.DS_Store new file mode 100644 index 0000000..efe3d1d Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/.DS_Store differ diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/resnet56/test_metrics b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/resnet56/test_metrics new file mode 100644 index 0000000..93595ae Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/resnet56/test_metrics differ diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/resnet56/train_metrics b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/resnet56/train_metrics new file mode 100644 index 0000000..bf59a31 Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/resnet56/train_metrics differ diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/.DS_Store b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/.DS_Store new file mode 100644 index 0000000..efe3d1d Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/.DS_Store differ diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/resnet56/test_metrics b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/resnet56/test_metrics new file mode 100644 index 0000000..98f322b Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/resnet56/test_metrics differ diff --git a/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/resnet56/train_metrics b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/resnet56/train_metrics new file mode 100644 index 0000000..b80014d Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/resnet56/train_metrics differ diff --git a/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/.DS_Store b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/.DS_Store new file mode 100644 index 0000000..e21dabd Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/.DS_Store differ diff --git a/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/net_server.py b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/net_server.py new file mode 100644 index 0000000..8ee2a5a --- /dev/null +++ b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/net_server.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import logging +import torch +import torch.nn as nn + +__all__ = ['ResNet', 'resnet110'] + +def apply_3x3_convolution(in_channels, out_channels, stride=1, groups=1, dilation=1): + """3x3 Convolution with padding.""" + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +def apply_1x1_convolution(in_channels, out_channels, stride=1): + """1x1 Convolution.""" + return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + """Basic Block used in ResNet. Consists of two 3x3 convolutions.""" + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + + self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride) + self.bn1 = norm_layer(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = apply_3x3_convolution(out_channels, out_channels) + self.bn2 = norm_layer(out_channels) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + """Defines the forward pass through the block.""" + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + """Bottleneck block used in ResNet. Has three layers: 1x1, 3x3, and 1x1 convolutions.""" + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(out_channels * (base_width / 64.)) * groups + + self.conv1 = apply_1x1_convolution(in_channels, width) + self.bn1 = norm_layer(width) + self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion) + self.bn3 = norm_layer(out_channels * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + """Defines the forward pass through the bottleneck block.""" + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + +class ResNet(nn.Module): + def __init__(self, block, layers, num_classes=10, zero_init_residual=False, groups=1, + width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, KD=False): + """Defines the ResNet architecture.""" + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 16 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None or a 3-element tuple.") + + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._create_model_layer(block, 16, layers[0]) + self.layer2 = self._create_model_layer(block, 32, layers[1], stride=2) + self.layer3 = self._create_model_layer(block, 64, layers[2], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(64 * block.expansion, num_classes) + self.KD = KD + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False): + """Creates a layer in ResNet using the specified block type.""" + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + apply_1x1_convolution(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + """Defines the forward pass of the ResNet model.""" + x = self.layer1(x) # Output: B x 16 x 32 x 32 + x = self.layer2(x) # Output: B x 32 x 16 x 16 + x = self.layer3(x) # Output: B x 64 x 8 x 8 + + x = self.avgpool(x) # Output: B x 64 x 1 x 1 + x_f = x.view(x.size(0), -1) # Flatten: B x 64 + x = self.fc(x_f) # Output: B x num_classes + return x + +def resnet56_server(num_classes, models_pretrained=False, path=None, **kwargs): + """ + Constructs a ResNet-110 model. + + Args: + num_classes (int): Number of output classes. + models_pretrained (bool): If True, returns a model pre-trained on ImageNet. + path (str): Path to the pre-trained model. + """ + logging.info("Loading model with path: " + str(path)) + model = ResNet(Bottleneck, [6, 6, 6], num_classes=num_classes, **kwargs) + + if models_pretrained: + checkpoint = torch.load(path) + state_dict = checkpoint['state_dict'] + new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + model.load_state_dict(new_state_dict) + + return model diff --git a/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/pretrained_weights.py b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/pretrained_weights.py new file mode 100644 index 0000000..8055cb0 --- /dev/null +++ b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/pretrained_weights.py @@ -0,0 +1,326 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import logging +import torch +import torch.nn as nn + +def apply_3x3_convolution(in_channels, out_channels, stride=1, groups=1, dilation=1): + """ + Creates a 3x3 convolutional layer with padding. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int, optional): Stride of the convolution. Default is 1. + groups (int, optional): Number of blocked connections from input to output. Default is 1. + dilation (int, optional): Spacing between kernel elements. Default is 1. + + Returns: + nn.Conv2d: A 3x3 convolutional layer. + """ + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +def apply_1x1_convolution(in_channels, out_channels, stride=1): + """ + Creates a 1x1 convolutional layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int, optional): Stride of the convolution. Default is 1. + + Returns: + nn.Conv2d: A 1x1 convolutional layer. + """ + return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + +class BasicBlock(nn.Module): + """ + A basic block for ResNet. + + This block consists of two convolutional layers with batch normalization and ReLU activation. + + Attributes: + expansion (int): The expansion factor of the block. + conv1 (nn.Conv2d): First convolutional layer. + bn1 (nn.BatchNorm2d): First batch normalization layer. + conv2 (nn.Conv2d): Second convolutional layer. + bn2 (nn.BatchNorm2d): Second batch normalization layer. + downsample (nn.Module): Downsample layer if input and output dimensions differ. + """ + expansion = 1 + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm_layer=None): + """ + Initializes the BasicBlock. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int, optional): Stride for the convolutional layers. Default is 1. + downsample (nn.Module, optional): Downsample layer if input dimensions differ. Default is None. + norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d. + """ + super(BasicBlock, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride) + self.bn1 = norm_layer(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = apply_3x3_convolution(out_channels, out_channels) + self.bn2 = norm_layer(out_channels) + self.downsample = downsample + + def forward(self, x): + """ + Defines the forward pass for the block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the block. + """ + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + +class Bottleneck(nn.Module): + """ + A bottleneck block for ResNet. + + This block reduces the number of input channels before performing convolution and then expands it back. + + Attributes: + expansion (int): The expansion factor of the block. + conv1 (nn.Conv2d): First 1x1 convolutional layer. + conv2 (nn.Conv2d): 3x3 convolutional layer. + conv3 (nn.Conv2d): Second 1x1 convolutional layer. + downsample (nn.Module): Downsample layer if input and output dimensions differ. + """ + expansion = 4 + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm_layer=None): + """ + Initializes the Bottleneck block. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int, optional): Stride for the convolutional layers. Default is 1. + downsample (nn.Module, optional): Downsample layer if input dimensions differ. Default is None. + norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d. + """ + super(Bottleneck, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + width = int(out_channels * (64 / 64)) # Base width + self.conv1 = apply_1x1_convolution(in_channels, width) + self.bn1 = norm_layer(width) + self.conv2 = apply_3x3_convolution(width, width, stride) + self.bn2 = norm_layer(width) + self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion) + self.bn3 = norm_layer(out_channels * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + """ + Defines the forward pass for the bottleneck block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the block. + """ + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + +class ResNet(nn.Module): + """ + ResNet architecture. + + This class constructs a ResNet model with a specified block type and layer configuration. + + Attributes: + conv1 (nn.Conv2d): Initial convolutional layer. + bn1 (nn.BatchNorm2d): Initial batch normalization layer. + layer1 (nn.Sequential): First residual layer. + layer2 (nn.Sequential): Second residual layer. + layer3 (nn.Sequential): Third residual layer. + fc (nn.Linear): Fully connected output layer. + """ + def __init__(self, block, layers, num_classes=10, zero_init_residual=False, norm_layer=None): + """ + Initializes the ResNet architecture. + + Args: + block (nn.Module): The block type (BasicBlock or Bottleneck). + layers (list of int): Number of blocks per layer. + num_classes (int, optional): Number of output classes. Default is 10. + zero_init_residual (bool, optional): Whether to zero-initialize residual layers. Default is False. + norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d. + """ + super(ResNet, self).__init__() + norm_layer = norm_layer or nn.BatchNorm2d + self.in_channels = 16 + + self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.in_channels) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._create_model_layer(block, 16, layers[0]) + self.layer2 = self._create_model_layer(block, 32, layers[1], stride=2) + self.layer3 = self._create_model_layer(block, 64, layers[2], stride=2) + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(64 * block.expansion, num_classes) + + self._init_model_weights(zero_init_residual) + + def _create_model_layer(self, block, out_channels, blocks, stride=1): + """ + Creates a residual layer. + + Args: + block (nn.Module): The block type. + out_channels (int): Number of output channels. + blocks (int): Number of blocks in the layer. + stride (int, optional): Stride for the first block. Default is 1. + + Returns: + nn.Sequential: A sequence of residual blocks. + """ + downsample = None + if stride != 1 or self.in_channels != out_channels * block.expansion: + downsample = nn.Sequential( + apply_1x1_convolution(self.in_channels, out_channels * block.expansion, stride), + nn.BatchNorm2d(out_channels * block.expansion), + ) + + layers = [block(self.in_channels, out_channels, stride, downsample)] + self.in_channels = out_channels * block.expansion + layers.extend(block(self.in_channels, out_channels) for _ in range(1, blocks)) + return nn.Sequential(*layers) + + def _init_model_weights(self, zero_init_residual): + """ + Initializes the weights of the model. + + Args: + zero_init_residual (bool): If True, initializes residual layers to zero. + """ + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if zero_init_residual and isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif zero_init_residual and isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def forward(self, x): + """ + Defines the forward pass of the ResNet. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + tuple: Logits and extracted features. + """ + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + extracted_features = x + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.avgpool(x) + x_f = x.view(x.size(0), -1) + logits = self.fc(x_f) + return logits, extracted_features + +def resnet32_models_pretrained(num_classes, models_pretrained=False, path=None, **kwargs): + """ + Constructs a ResNet-32 model. + + Args: + num_classes (int): Number of output classes. + models_pretrained (bool, optional): If True, loads pretrained weights. Default is False. + path (str, optional): Path to the pretrained weights. Default is None. + + Returns: + ResNet: A ResNet-32 model. + """ + model = ResNet(BasicBlock, [5, 5, 5], num_classes=num_classes, **kwargs) + if models_pretrained: + model.load_state_dict(_load_models_pretrained_weights(path)) + return model + +def resnet56_models_pretrained(num_classes, models_pretrained=False, path=None, **kwargs): + """ + Constructs a ResNet-56 model. + + Args: + num_classes (int): Number of output classes. + models_pretrained (bool, optional): If True, loads pretrained weights. Default is False. + path (str, optional): Path to the pretrained weights. Default is None. + + Returns: + ResNet: A ResNet-56 model. + """ + logging.info("Loading pretrained model from: " + str(path)) + model = ResNet(Bottleneck, [6, 6, 6], num_classes=num_classes, **kwargs) + if models_pretrained: + model.load_state_dict(_load_models_pretrained_weights(path)) + return model + +def _load_models_pretrained_weights(path): + """ + Loads pretrained weights from a checkpoint. + + Args: + path (str): Path to the checkpoint file. + + Returns: + dict: State dictionary with the loaded weights. + """ + checkpoint = torch.load(path, map_location=torch.device('cpu')) + state_dict = checkpoint['state_dict'] + from collections import OrderedDict + new_state_dict = OrderedDict() + + for k, v in state_dict.items(): + new_state_dict[k.replace("module.", "")] = v + + return new_state_dict diff --git a/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/resnet_client.py b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/resnet_client.py new file mode 100644 index 0000000..cc340b8 --- /dev/null +++ b/EdgeFLite/fedml_service/architecture/cv/resnet56_federated/resnet_client.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn + +__all__ = ['ResNet'] + +# Function to define a 3x3 convolution layer with padding +def apply_3x3_convolution(in_channels, out_channels, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +# Function to define a 1x1 convolution layer +def apply_1x1_convolution(in_channels, out_channels, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) + +# BasicBlock class for ResNet architecture +class BasicBlock(nn.Module): + expansion = 1 # Expansion factor + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d # Default normalization layer is BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + + # First convolution and batch normalization layer + self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride) + self.bn1 = norm_layer(out_channels) + self.relu = nn.ReLU(inplace=True) # ReLU activation + # Second convolution and batch normalization layer + self.conv2 = apply_3x3_convolution(out_channels, out_channels) + self.bn2 = norm_layer(out_channels) + self.downsample = downsample # If downsample is provided, use it + + def forward(self, x): + identity = x # Keep original input as identity for residual connection + + # Forward pass through first convolution, batch norm, and ReLU + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # Forward pass through second convolution and batch norm + out = self.conv2(out) + out = self.bn2(out) + + # Downsample the identity if downsample is provided + if self.downsample is not None: + identity = self.downsample(x) + + # Add residual connection (identity) + out += identity + out = self.relu(out) # Apply ReLU activation after addition + + return out + +# Bottleneck class for deeper ResNet architectures +class Bottleneck(nn.Module): + expansion = 4 # Expansion factor + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d # Default normalization layer is BatchNorm2d + width = int(out_channels * (base_width / 64.)) * groups # Calculate width based on group size + + # First 1x1 convolution + self.conv1 = apply_1x1_convolution(in_channels, width) + self.bn1 = norm_layer(width) + # Second 3x3 convolution + self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + # Third 1x1 convolution to match output channels + self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion) + self.bn3 = norm_layer(out_channels * self.expansion) + self.relu = nn.ReLU(inplace=True) # ReLU activation + self.downsample = downsample # Downsample if provided + + def forward(self, x): + identity = x # Keep original input as identity for residual connection + + # First 1x1 convolution and ReLU + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # Second 3x3 convolution and ReLU + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + # Third 1x1 convolution + out = self.conv3(out) + out = self.bn3(out) + + # Add downsampled identity if necessary + if self.downsample is not None: + identity = self.downsample(x) + + # Add residual connection (identity) + out += identity + out = self.relu(out) # Apply ReLU activation after addition + + return out + +# ResNet class to build the entire ResNet model +class ResNet(nn.Module): + def __init__(self, block, layers, num_classes=10, zero_init_residual=False, groups=1, + width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, KD=False): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d # Default normalization layer + self._norm_layer = norm_layer + + self.inplanes = 16 # Initial number of channels + self.dilation = 1 # Dilation factor + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] # Default stride behavior + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + + self.groups = groups # Number of groups for convolutions + self.base_width = width_per_group # Base width for groups + + # Initial convolutional layer with 3 input channels (RGB image) + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes) # Batch normalization + self.relu = nn.ReLU(inplace=True) # ReLU activation + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Max pooling layer + self.layer1 = self._create_model_layer(block, 16, layers[0]) # First block layer + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive average pooling + self.fc = nn.Linear(16 * block.expansion, num_classes) # Fully connected layer + + self.KD = KD # Knowledge Distillation flag + for m in self.modules(): + # Initialize convolutional weights using He initialization + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + # Initialize batch normalization weights + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last batch norm layer if zero_init_residual is True + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + # Helper function to create layers of blocks + def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + apply_1x1_convolution(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + # Forward pass of the ResNet model + def forward(self, x): + x = self.conv1(x) # Initial convolution + x = self.bn1(x) # Batch normalization + x = self.relu(x) # ReLU activation + extracted_features = x # Feature extraction point + x = self.layer1(x) # Pass through the first layer + x = self.avgpool(x) # Adaptive average pooling + x_f = x.view(x.size(0), -1) # Flatten the features + logits = self.fc(x_f) # Fully connected layer for classification + return logits, extracted_features # Return logits and extracted features + +# Function to create ResNet-5 model +def resnet5_56(num_classes, models_pretrained=False, path=None, **kwargs): + """Constructs a ResNet-5 model.""" + model = ResNet(BasicBlock, [1, 2, 2], num_classes=num_classes, **kwargs) + if models_pretrained: + checkpoint = torch.load(path) + state_dict = checkpoint['state_dict'] + + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k.replace("module.", "") + new_state_dict[name] = v + + model.load_state_dict(new_state_dict) + return model + +# Function to create ResNet-8 model +def resnet8_56(num_classes, models_pretrained=False, path=None, **kwargs): + """Constructs a ResNet-8 model.""" + model = ResNet(Bottleneck, [2, 2, 2], num_classes=num_classes, **kwargs) + if models_pretrained: + checkpoint = torch.load(path) + state_dict = checkpoint['state_dict'] + + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k.replace("module.", "") + new_state_dict[name] = v + + model.load_state_dict(new_state_dict) + return model diff --git a/EdgeFLite/fedml_service/architecture/cv/resnet_federated/.DS_Store b/EdgeFLite/fedml_service/architecture/cv/resnet_federated/.DS_Store new file mode 100644 index 0000000..e21dabd Binary files /dev/null and b/EdgeFLite/fedml_service/architecture/cv/resnet_federated/.DS_Store differ diff --git a/EdgeFLite/fedml_service/architecture/cv/resnet_federated/net.py b/EdgeFLite/fedml_service/architecture/cv/resnet_federated/net.py new file mode 100644 index 0000000..cc6b939 --- /dev/null +++ b/EdgeFLite/fedml_service/architecture/cv/resnet_federated/net.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn + +# Try to import load_state_dict_from_url from torch.hub. +# If it fails (due to older versions), fall back to load_url from torch.utils.model_zoo. +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# List of all exportable models +__all__ = ['resnet110_sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8'] + + +def apply_3x3_convolution(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding.""" + return nn.Conv2d( + in_planes, # Number of input channels + out_planes, # Number of output channels + kernel_size=3, # Size of the filter + stride=stride, # Stride of the convolution + padding=dilation, # Padding for the convolution + groups=groups, # Group convolution + bias=False, # No bias in convolution + dilation=dilation # Dilation rate for dilated convolutions + ) + + +def apply_1x1_convolution(in_planes, out_planes, stride=1): + """1x1 convolution.""" + return nn.Conv2d( + in_planes, # Number of input channels + out_planes, # Number of output channels + kernel_size=1, # Filter size is 1x1 + stride=stride, # Stride of the convolution + bias=False # No bias in convolution + ) + + +class BasicBlock(nn.Module): + """Basic block for ResNet.""" + + expansion = 1 # No expansion in BasicBlock + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + + self.conv1 = apply_3x3_convolution(inplanes, planes, stride) # First 3x3 convolution + self.bn1 = norm_layer(planes) # First batch normalization + self.relu = nn.ReLU(inplace=True) # ReLU activation + self.conv2 = apply_3x3_convolution(planes, planes) # Second 3x3 convolution + self.bn2 = norm_layer(planes) # Second batch normalization + self.downsample = downsample # If there's downsampling (e.g., stride mismatch) + + def forward(self, x): + identity = x # Preserve the input as identity for skip connection + out = self.conv1(x) # Apply the first convolution + out = self.bn1(out) # Apply first batch normalization + out = self.relu(out) # Apply ReLU activation + out = self.conv2(out) # Apply the second convolution + out = self.bn2(out) # Apply second batch normalization + + # If downsample exists, apply it to the identity + if self.downsample is not None: + identity = self.downsample(x) + + out += identity # Add skip connection + out = self.relu(out) # Final ReLU activation + + return out # Return the result + + +class Bottleneck(nn.Module): + """Bottleneck block for ResNet.""" + + expansion = 4 # Bottleneck expands the channels by a factor of 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups # Width of the block + + # 1x1 convolution (bottleneck) + self.conv1 = apply_1x1_convolution(inplanes, width) + self.bn1 = norm_layer(width) # Batch normalization after 1x1 convolution + # 3x3 convolution (main block) + self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) # Batch normalization after 3x3 convolution + # 1x1 convolution (bottleneck exit) + self.conv3 = apply_1x1_convolution(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) # Batch normalization after 1x1 exit + self.relu = nn.ReLU(inplace=True) # ReLU activation + self.downsample = downsample # Downsampling for skip connection, if needed + + def forward(self, x): + identity = x # Store input as identity for the skip connection + out = self.conv1(x) # Apply first 1x1 convolution + out = self.bn1(out) # Apply batch normalization + out = self.relu(out) # Apply ReLU + out = self.conv2(out) # Apply 3x3 convolution + out = self.bn2(out) # Apply batch normalization + out = self.relu(out) # Apply ReLU + out = self.conv3(out) # Apply 1x1 convolution + out = self.bn3(out) # Apply batch normalization + + # If downsample exists, apply it to the identity + if self.downsample is not None: + identity = self.downsample(x) + + out += identity # Add skip connection + out = self.relu(out) # Final ReLU activation + + return out # Return the result + + +class PrimaryResNetClient(nn.Module): + """Main ResNet model for client.""" + + def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None): + super(PrimaryResNetClient, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling before fully connected layer + + # Dictionary to store input channel size based on dataset and split factor + inplanes_dict = { + 'cifar10': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3}, + 'cifar100': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4}, + 'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24}, + 'pill_base': {1: 64, 2: 44, 4: 32, 8: 24}, + 'medical_images': {1: 64, 2: 44, 4: 32, 8: 24}, + } + self.inplanes = inplanes_dict[dataset][split_factor] # Set initial input channels + + self.fc = nn.Linear(self.inplanes * 4 * block.expansion, num_classes) # Fully connected layer for classification + + # Initialize all layers + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=1e-3) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + # Optionally initialize the last batch normalization layer to zero + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False): + """Create a residual layer consisting of several blocks.""" + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + apply_1x1_convolution(self.inplanes, planes * block.expansion, stride), # Adjust input size for downsampling + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) # Add the first block with downsample + self.inplanes = planes * block.expansion # Update inplanes for the next block + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) # Add the remaining blocks + + return nn.Sequential(*layers) # Return the stacked blocks + + def _forward_impl(self, x): + """Implementation of the forward pass.""" + x = self.layer0(x) # Initial layer + extracted_features = x # Save features after the initial layer + x = self.layer1(x) # First layer + x = self.avgpool(x) # Global average pooling + x = torch.flatten(x, 1) # Flatten the features into a 1D tensor + logits = self.fc(x) # Pass through the fully connected layer + return logits, extracted_features # Return logits and extracted features + + def forward(self, x): + """Standard forward method.""" + return self._forward_impl(x) diff --git a/EdgeFLite/fedml_service/data_cleaning/.DS_Store b/EdgeFLite/fedml_service/data_cleaning/.DS_Store new file mode 100644 index 0000000..2dea013 Binary files /dev/null and b/EdgeFLite/fedml_service/data_cleaning/.DS_Store differ diff --git a/EdgeFLite/fedml_service/data_cleaning/cifar10/.DS_Store b/EdgeFLite/fedml_service/data_cleaning/cifar10/.DS_Store new file mode 100644 index 0000000..e21dabd Binary files /dev/null and b/EdgeFLite/fedml_service/data_cleaning/cifar10/.DS_Store differ diff --git a/EdgeFLite/fedml_service/data_cleaning/cifar10/bulk_data_import.py b/EdgeFLite/fedml_service/data_cleaning/cifar10/bulk_data_import.py new file mode 100644 index 0000000..8b49d99 --- /dev/null +++ b/EdgeFLite/fedml_service/data_cleaning/cifar10/bulk_data_import.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import logging +import numpy as np +import torch +import torch.utils.data as data +import torchvision.apply_transformations as apply_transformations + +from .datasets import CIFAR10_truncated + +# Configure logging +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +# Function to load non-IID data distribution +def load_data_distribution(file_path='./data_cleaning/non-iid-distribution/CIFAR10/data_map.txt'): + """ + Load data distribution for non-IID data. + Reads from a text file that maps data classes to the clients in a decentralized manner. + """ + distribution = {} + with open(file_path, 'r') as file: + for line in file.readlines(): + if '{' != line[0] and '}' != line[0]: + key, value = line.split(':') + if '{' == value.strip(): + distribution[int(key)] = {} + else: + sub_key = int(key) + distribution[int(key)][sub_key] = int(value.strip().replace(',', '')) + return distribution + +# Function to load network data index map +def load_net_dataidx_map(file_path='./data_cleaning/non-iid-distribution/CIFAR10/index_map.txt'): + """ + Load index mapping between data samples and clients. + Reads from a text file that assigns data indices to different clients. + """ + net_dataidx_map = {} + with open(file_path, 'r') as file: + for line in file.readlines(): + if '{' != line[0] and '}' != line[0] and ']' != line[0]: + key, value = line.split(':') + if '[' == value.strip(): + net_dataidx_map[int(key)] = [] + else: + indices = [int(i.strip()) for i in line.split(',')] + net_dataidx_map[int(key)] = indices + return net_dataidx_map + +# Function to record and log data statistics for each client +def log_net_data_stats(y_train, net_dataidx_map): + """ + Log the data statistics for each client by calculating class distribution. + """ + net_cls_counts = {} + for net_id, dataidx in net_dataidx_map.items(): + unique, counts = np.unique(y_train[dataidx], return_counts=True) + net_cls_counts[net_id] = dict(zip(unique, counts)) + logging.debug('Data statistics: %s', net_cls_counts) + return net_cls_counts + +# Cutout augmentation class for image data +class Cutout: + """ + Apply the Cutout augmentation technique to images. + Randomly masks out a square region in the image. + """ + def __init__(self, length): + self.length = length + + def __call__(self, img): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + y, x = np.random.randint(h), np.random.randint(w) + + y1, y2 = np.clip([y - self.length // 2, y + self.length // 2], 0, h) + x1, x2 = np.clip([x - self.length // 2, x + self.length // 2], 0, w) + + mask[y1:y2, x1:x2] = 0. + mask = torch.from_numpy(mask).expand_as(img) + img *= mask + return img + +# Function to define CIFAR-10 data apply_transformationations +def cifar10_data_apply_transformations(): + """ + Define data apply_transformationations for CIFAR-10 dataset. + Includes random cropping, horizontal flipping, normalization, and Cutout for training. + """ + CIFAR_MEAN = [0.4914, 0.4822, 0.4465] + CIFAR_STD = [0.2470, 0.2435, 0.2616] + + train_apply_transformation = apply_transformations.Compose([ + apply_transformations.ToPILImage(), + apply_transformations.RandomCrop(32, padding=4), + apply_transformations.RandomHorizontalFlip(), + apply_transformations.ToTensor(), + apply_transformations.Normalize(CIFAR_MEAN, CIFAR_STD), + Cutout(16) + ]) + + valid_apply_transformation = apply_transformations.Compose([ + apply_transformations.ToTensor(), + apply_transformations.Normalize(CIFAR_MEAN, CIFAR_STD), + ]) + + return train_apply_transformation, valid_apply_transformation + +# Function to load CIFAR-10 data +def load_cifar10(datadir): + """ + Load the CIFAR-10 dataset with apply_transformationations for training and testing. + """ + train_apply_transformation, test_apply_transformation = cifar10_data_apply_transformations() + + cifar10_train = CIFAR10_truncated(datadir, train=True, download=True, apply_transformation=train_apply_transformation) + cifar10_test = CIFAR10_truncated(datadir, train=False, download=True, apply_transformation=test_apply_transformation) + + X_train, y_train = cifar10_train.data, cifar10_train.target + X_test, y_test = cifar10_test.data, cifar10_test.target + + return X_train, y_train, X_test, y_test + +# Function to partition CIFAR-10 data across clients +def partition_cifar10_data(dataset, datadir, partition_type, n_nets, alpha): + """ + Partition the CIFAR-10 dataset across clients for federated learning. + Supports homogeneous and heterogeneous partitions. + """ + logging.info("Partitioning CIFAR-10 data...") + X_train, y_train, X_test, y_test = load_cifar10(datadir) + n_train = X_train.shape[0] + + if partition_type == "homo": + # Homogeneous partitioning (equal distribution across clients) + idxs = np.random.permutation(n_train) + net_dataidx_map = {i: batch for i, batch in enumerate(np.array_split(idxs, n_nets))} + elif partition_type == "hetero": + # Heterogeneous partitioning (non-IID distribution) + K, N = 10, y_train.shape[0] + net_dataidx_map = {} + min_size = 0 + while min_size < 10: + idx_batch = [[] for _ in range(n_nets)] + for k in range(K): + idx_k = np.where(y_train == k)[0] + np.random.shuffle(idx_k) + proportions = np.random.dirichlet(np.repeat(alpha, n_nets)) + proportions = np.array([p * (len(idx_j) < N / n_nets) for p, idx_j in zip(proportions, idx_batch)]) + proportions = np.cumsum(proportions / proportions.sum()) * len(idx_k) + split_idx = proportions.astype(int)[:-1] + idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, split_idx))] + min_size = min([len(idx_j) for idx_j in idx_batch]) + net_dataidx_map = {i: np.random.permutation(batch) for i, batch in enumerate(idx_batch)} + elif partition_type == "hetero-fix": + # Fixed heterogeneous partitioning (predefined distribution) + net_dataidx_map = load_net_dataidx_map() + + # Load data distribution for 'hetero-fix' partition, otherwise calculate it + if partition_type == "hetero-fix": + traindata_cls_counts = load_data_distribution() + else: + traindata_cls_counts = log_net_data_stats(y_train, net_dataidx_map) + + return X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts + +# Function to create data loaders +def get_cifar10_dataloader(datadir, train_bs, test_bs, dataidxs=None): + """ + Create data loaders for CIFAR-10 with the option to load only specific data indices. + """ + train_apply_transformation, test_apply_transformation = cifar10_data_apply_transformations() + + train_ds = CIFAR10_truncated(datadir, dataidxs=dataidxs, train=True, apply_transformation=train_apply_transformation, download=True) + test_ds = CIFAR10_truncated(datadir, train=False, apply_transformation=test_apply_transformation, download=True) + + train_loader = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=True) + test_loader = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=True) + + return train_loader, test_loader + +# Function to load decentralized CIFAR-10 data for a specific client +def load_decentralized_cifar10(process_id, dataset, datadir, partition_method, partition_alpha, client_num, batch_size): + """ + Load decentralized CIFAR-10 data based on the partitioning method and client number. + Returns either global data loaders or local data loaders depending on the process ID. + """ + X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_cifar10_data( + dataset, datadir, partition_method, client_num, partition_alpha) + class_num = len(np.unique(y_train)) + logging.info("Class distribution: %s", traindata_cls_counts) + + if process_id == 0: + # Global data loaders + train_global, test_global = get_cifar10_dataloader(datadir, batch_size, batch_size) + return sum(len(net_dataidx_map[r]) for r in range(client_num)), train_global, test_global, 0, None, None, class_num + else: + # Local data loaders for the specific client + dataidxs = net_dataidx_map[process_id - 1] + train_local, test_local = get_cifar10_dataloader(datadir, batch_size, batch_size, dataidxs) + return len(dataidxs), None, None, len(dataidxs), train_local, test_local, class_num + +# Function to load and partition CIFAR-10 dataset +def load_cifar10_partitioned(dataset, datadir, partition_method, partition_alpha, client_num, batch_size): + """ + Load and partition the CIFAR-10 dataset and prepare data loaders for all clients. + """ + X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_cifar10_data( + dataset, datadir, partition_method, client_num, partition_alpha) + class_num = len(np.unique(y_train)) + logging.info("Global data statistics: %s", traindata_cls_counts) + + # Global data loaders + train_global, test_global = get_cifar10_dataloader(datadir, batch_size, batch_size) + + # Local data loaders for each client + data_local_num_dict, train_local_dict, test_local_dict = {}, {}, {} + for client_idx in range(client_num): + dataidxs = net_dataidx_map[client_idx] + local_data_num = len(dataidxs) + data_local_num_dict[client_idx] = local_data_num + logging.info("Client %d: Local sample count = %d", client_idx, local_data_num) + + train_local, test_local = get_cifar10_dataloader(datadir, batch_size, batch_size, dataidxs) + train_local_dict[client_idx], test_local_dict[client_idx] = train_local, test_local + + return sum(len(net_dataidx_map[r]) for r in range(client_num)), len(test_global), train_global, test_global, data_local_num_dict, train_local_dict, test_local_dict, class_num diff --git a/EdgeFLite/fedml_service/data_cleaning/cifar10/dataset_hub.py b/EdgeFLite/fedml_service/data_cleaning/cifar10/dataset_hub.py new file mode 100644 index 0000000..d179087 --- /dev/null +++ b/EdgeFLite/fedml_service/data_cleaning/cifar10/dataset_hub.py @@ -0,0 +1,109 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import logging +import numpy as np +import torch.utils.data as data +from PIL import Image +from torchvision.datasets import CIFAR10 + +# Set up logging +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +# Supported image extensions +# These are the file extensions that the loaders will support for image formats +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') + +# Loader using accimage, a faster image loading library than PIL +def load_accimage(path): + import accimage + try: + # Try to load the image with accimage + return accimage.Image(path) + except IOError: + # If there's an error, fallback to PIL for image loading + return load_image_pil(path) + +# Loader using PIL (Python Imaging Library) +def load_image_pil(path): + # Open the file in binary mode to avoid resource warnings + with open(path, 'rb') as f: + img = Image.open(f) + # Convert the image to RGB mode (3 channels) + return img.convert('RGB') + +# Default image loader that chooses accimage if available, otherwise PIL +def basic_loader(path): + from torchvision import get_image_backend + # Check if the image backend is accimage + if get_image_backend() == 'accimage': + return load_accimage(path) + # Otherwise, fallback to PIL + return load_image_pil(path) + +# Custom CIFAR10 dataset with truncation capabilities +# This class extends the torch.utils.data.Dataset to support CIFAR10 with truncation of data +class CIFAR10Truncated(data.Dataset): + + def __init__(self, root, dataidxs=None, train=True, apply_transformation=None, target_apply_transformation=None, download=False): + self.root = root # Root directory for the dataset + self.dataidxs = dataidxs # Subset of data indices (optional) + self.train = train # Boolean flag indicating if the dataset is for training + self.apply_transformation = apply_transformation # apply_transformationations to apply to the images (optional) + self.target_apply_transformation = target_apply_transformation # apply_transformationations to apply to the labels (optional) + self.download = download # Boolean flag to download the dataset if not available + + # Build the truncated dataset based on the provided indices + self.data, self.target = self._build_truncated_dataset() + + def _build_truncated_dataset(self): + # Log whether the dataset is being downloaded + logger.info(f"Download: {self.download}") + + # Load the CIFAR10 dataset from torchvision + cifar_data = CIFAR10(self.root, self.train, apply_transformation=self.apply_transformation, + target_apply_transformation=self.target_apply_transformation, download=self.download) + + # Extract data (images) and targets (labels) from the CIFAR10 dataset + data = cifar_data.data + target = np.array(cifar_data.targets) + + # If data indices are provided, filter the data and targets accordingly + if self.dataidxs is not None: + data = data[self.dataidxs] + target = target[self.dataidxs] + + # Return the truncated data and targets + return data, target + + def truncate_channel(self, indices): + # Zero out the second and third channels (green and blue) for selected images + for idx in indices: + self.data[idx, :, :, 1] = 0.0 # Zero out the green channel + self.data[idx, :, :, 2] = 0.0 # Zero out the blue channel + + def __getitem__(self, index): + """ + Args: + index (int): Index of the image + + Returns: + tuple: (image, target) where target is the class label. + """ + img, target = self.data[index], self.target[index] + + # Apply image apply_transformationations if any are specified + if self.apply_transformation is not None: + img = self.apply_transformation(img) + # Apply target apply_transformationations if any are specified + if self.target_apply_transformation is not None: + target = self.target_apply_transformation(target) + + # Return the apply_transformationed image and its corresponding target + return img, target + + def __len__(self): + # Return the total number of images in the dataset + return len(self.data) diff --git a/EdgeFLite/fedml_service/data_cleaning/cifar100/.DS_Store b/EdgeFLite/fedml_service/data_cleaning/cifar100/.DS_Store new file mode 100644 index 0000000..e21dabd Binary files /dev/null and b/EdgeFLite/fedml_service/data_cleaning/cifar100/.DS_Store differ diff --git a/EdgeFLite/fedml_service/data_cleaning/cifar100/bulk_data_import.py b/EdgeFLite/fedml_service/data_cleaning/cifar100/bulk_data_import.py new file mode 100644 index 0000000..ba11e96 --- /dev/null +++ b/EdgeFLite/fedml_service/data_cleaning/cifar100/bulk_data_import.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import logging +import numpy as np +import torch +import torch.utils.data as data +import torchvision.apply_transformations as apply_transformations + +from .datasets import CIFAR100_truncated + +# Set up logging configuration to log information level events +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +# Function to read non-IID distribution data from a file +def read_data_distribution(filename='./data_cleaning/non-iid-distribution/CIFAR10/data_map.txt'): + distribution = {} + # Open the file and read the distribution map + with open(filename, 'r') as file: + for line in file.readlines(): + # Skip lines that do not contain distribution data + if '{' != line[0] and '}' != line[0]: + key, value = line.split(':') + if '{' == value.strip(): + distribution[int(key)] = {} + current_key = int(key) + else: + sub_key, sub_value = key, value.strip().replace(',', '') + distribution[current_key][int(sub_key)] = int(sub_value) + return distribution + +# Function to read net data index map from a file +def read_net_dataidx_map(filename='./data_cleaning/non-iid-distribution/CIFAR10/index_map.txt'): + net_dataidx_map = {} + # Open the file and read the index map for the dataset + with open(filename, 'r') as file: + for line in file.readlines(): + # Skip lines that do not contain index map data + if '{' != line[0] and '}' != line[0] and ']' != line[0]: + key, value = line.split(':') + if '[' == value.strip(): + net_dataidx_map[int(key)] = [] + else: + net_dataidx_map[int(key)] = [int(i.strip()) for i in value.split(',')] + return net_dataidx_map + +# Function to calculate and record statistics of the net's data +def record_net_data_stats(y_train, net_dataidx_map): + net_cls_counts = {} + # For each net, count the unique classes and their frequencies in the training data + for net_id, dataidx in net_dataidx_map.items(): + unique, counts = np.unique(y_train[dataidx], return_counts=True) + net_cls_counts[net_id] = {unique[i]: counts[i] for i in range(len(unique))} + logging.debug(f'Data statistics: {net_cls_counts}') + return net_cls_counts + +# Custom Cutout data augmentation class to apply a random mask to an image +class Cutout: + def __init__(self, length): + self.length = length + + def __call__(self, img): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + y, x = np.random.randint(h), np.random.randint(w) + + # Define the region to apply the mask + y1, y2 = np.clip([y - self.length // 2, y + self.length // 2], 0, h) + x1, x2 = np.clip([x - self.length // 2, x + self.length // 2], 0, w) + + # Apply the mask and return the augmented image + mask[y1:y2, x1:x2] = 0 + mask = torch.from_numpy(mask).expand_as(img) + img *= mask + return img + +# Function to define CIFAR-100 data apply_transformationation pipelines for training and validation +def _data_apply_transformations_cifar100(): + # Define normalization constants for CIFAR-100 + CIFAR_MEAN = [0.5071, 0.4865, 0.4409] + CIFAR_STD = [0.2673, 0.2564, 0.2762] + + # Data augmentation and apply_transformationation pipeline for training data + train_apply_transformation = apply_transformations.Compose([ + apply_transformations.ToPILImage(), + apply_transformations.RandomCrop(32, padding=4), + apply_transformations.RandomHorizontalFlip(), + apply_transformations.ToTensor(), + apply_transformations.Normalize(CIFAR_MEAN, CIFAR_STD), + Cutout(16) # Apply the Cutout augmentation + ]) + + # apply_transformationation pipeline for validation data + valid_apply_transformation = apply_transformations.Compose([ + apply_transformations.ToTensor(), + apply_transformations.Normalize(CIFAR_MEAN, CIFAR_STD) + ]) + + return train_apply_transformation, valid_apply_transformation + +# Function to load CIFAR-100 dataset with the specified apply_transformationations +def load_cifar100_data(datadir): + train_apply_transformation, test_apply_transformation = _data_apply_transformations_cifar100() + + # Load training and testing datasets + cifar_train = CIFAR100_truncated(datadir, train=True, download=True, apply_transformation=train_apply_transformation) + cifar_test = CIFAR100_truncated(datadir, train=False, download=True, apply_transformation=test_apply_transformation) + + return cifar_train.data, cifar_train.target, cifar_test.data, cifar_test.target + +# Function to partition data based on IID (Independent and Identically Distributed) or non-IID methods +def partition_data(dataset, datadir, partition, n_nets, alpha): + logging.info("********* Partitioning Data ***************") + X_train, y_train, X_test, y_test = load_cifar100_data(datadir) + n_train = X_train.shape[0] + + # IID partitioning: randomly split the data across the clients + if partition == "homo": + idxs = np.random.permutation(n_train) + net_dataidx_map = {i: idxs_split for i, idxs_split in enumerate(np.array_split(idxs, n_nets))} + + # Non-IID partitioning using Dirichlet distribution + elif partition == "hetero": + min_size, K, N = 0, 100, y_train.shape[0] + net_dataidx_map = {} + + # Ensure each client has at least 10 samples + while min_size < 10: + idx_batch = [[] for _ in range(n_nets)] + for k in range(K): + idx_k = np.where(y_train == k)[0] + np.random.shuffle(idx_k) + proportions = np.random.dirichlet(np.repeat(alpha, n_nets)) + proportions = np.array([p * (len(batch) < N / n_nets) for p, batch in zip(proportions, idx_batch)]) + proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] + idx_batch = [batch + idx.tolist() for batch, idx in zip(idx_batch, np.split(idx_k, proportions))] + min_size = min([len(batch) for batch in idx_batch]) + + # Randomly shuffle the data batches for each client + net_dataidx_map = {i: np.random.permutation(batch) for i, batch in enumerate(idx_batch)} + + # Non-IID fixed partition: read the distribution from a predefined file + elif partition == "hetero-fix": + net_dataidx_map = read_net_dataidx_map('./data_cleaning/non-iid-distribution/CIFAR100/index_map.txt') + + # Record class counts for the partitioned training data + traindata_cls_counts = read_data_distribution('./data_cleaning/non-iid-distribution/CIFAR100/data_map.txt') \ + if partition == "hetero-fix" else record_net_data_stats(y_train, net_dataidx_map) + + return X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts + +# Function to get data loaders for centralized and local training +def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): + return get_dataloader_CIFAR100(datadir, train_bs, test_bs, dataidxs) + +# Function to get data loaders for test data during decentralized training +def get_dataloader_test(dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test): + return get_dataloader_test_CIFAR100(datadir, train_bs, test_bs, dataidxs_train, dataidxs_test) + +# Function to load CIFAR-100 data into PyTorch data loaders for training and testing +def get_dataloader_CIFAR100(datadir, train_bs, test_bs, dataidxs=None): + apply_transformation_train, apply_transformation_test = _data_apply_transformations_cifar100() + + train_ds = CIFAR100_truncated(datadir, dataidxs=dataidxs, train=True, apply_transformation=apply_transformation_train, download=True) + test_ds = CIFAR100_truncated(datadir, train=False, apply_transformation=apply_transformation_test, download=True) + + train_dl = data.DataLoader(train_ds, batch_size=train_bs, shuffle=True, drop_last=True) + test_dl = data.DataLoader(test_ds, batch_size=test_bs, shuffle=False, drop_last=True) + + return train_dl, test_dl + +# Function to get data loaders for test data during decentralized training (same as above but with test data indexes) +def get_dataloader_test_CIFAR100(datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None): + apply_transformation_train, apply_transformation_test = _data_apply_transformations_cifar100() + + train_ds = CIFAR100_truncated(datadir, dataidxs=dataidxs_train, train=True, apply_transformation=apply_transformation_train, download=True) + test_ds = CIFAR100_truncated(datadir, dataidxs=dataidxs_test, train=False, apply_transformation=apply_transformation_test, download=True) + + train_dl = data.DataLoader(train_ds, batch_size=train_bs, shuffle=True, drop_last=True) + test diff --git a/EdgeFLite/fedml_service/data_cleaning/cifar100/dataset_hub.py b/EdgeFLite/fedml_service/data_cleaning/cifar100/dataset_hub.py new file mode 100644 index 0000000..2dcb71c --- /dev/null +++ b/EdgeFLite/fedml_service/data_cleaning/cifar100/dataset_hub.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import logging +import numpy as np +import torch.utils.data as data +from PIL import Image +from torchvision.datasets import CIFAR100 + +# Configure logging +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +# Supported image extensions for loading images +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') + +def load_accimage(path): + """ + Attempts to load an image using the accimage backend. + If accimage fails, it falls back to using the PIL image loader. + + Args: + path (str): Path to the image file. + + Returns: + accimage.Image: The loaded image if successful, otherwise a PIL image. + """ + import accimage + try: + return accimage.Image(path) + except IOError: + # If accimage fails, use PIL to load the image + return load_image_pil(path) + +def load_image_pil(path): + """ + Loads an image using PIL, ensuring that file handles are properly closed to prevent warnings. + + Args: + path (str): Path to the image file. + + Returns: + Image: The image loaded using PIL, converted to RGB format. + """ + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + +def basic_loader(path): + """ + Selects the appropriate image loader based on the backend configured by torchvision. + If the backend is 'accimage', it uses load_accimage; otherwise, it uses PIL. + + Args: + path (str): Path to the image file. + + Returns: + Image: The loaded image. + """ + from torchvision import get_image_backend + if get_image_backend() == 'accimage': + return load_accimage(path) + else: + return load_image_pil(path) + +class CIFAR100_truncated(data.Dataset): + """ + Custom dataset class for CIFAR100 with optional data truncation. + It allows selecting a subset of the data by index and also enables modification of image channels. + """ + + def __init__(self, root, dataidxs=None, train=True, apply_transformation=None, target_apply_transformation=None, download=False): + """ + Initializes the CIFAR100_truncated dataset. + + Args: + root (str): The root directory where the dataset is stored. + dataidxs (list or None): List of indices for truncating the dataset, if applicable. + train (bool): Whether to load the training set (True) or the test set (False). + apply_transformation (callable, optional): apply_transformationation function applied to images. + target_apply_transformation (callable, optional): apply_transformationation function applied to targets (labels). + download (bool): Whether to download the dataset if it is not found in the root directory. + """ + self.root = root # Root directory where dataset is stored + self.dataidxs = dataidxs # List of indices for truncating the dataset + self.train = train # Specifies whether to load the training set + self.apply_transformation = apply_transformation # Optional apply_transformationations on images + self.target_apply_transformation = target_apply_transformation # Optional apply_transformationations on labels + self.download = download # Specifies whether to download the dataset if missing + + # Build the truncated dataset based on the provided indices + self.data, self.target = self.__build_truncated_dataset__() + + def __build_truncated_dataset__(self): + """ + Constructs the truncated dataset based on the provided data indices. + + Returns: + tuple: The truncated data and corresponding target labels. + """ + cifar_dataobj = CIFAR100(self.root, self.train, self.apply_transformation, self.target_apply_transformation, self.download) + + # Load all data and targets + data = cifar_dataobj.data + target = np.array(cifar_dataobj.targets) + + # If specific indices are provided, truncate the dataset accordingly + if self.dataidxs is not None: + data = data[self.dataidxs] + target = target[self.dataidxs] + + return data, target + + def truncate_channel(self, index): + """ + Modifies the selected images by zeroing out the green and blue channels, + effectively converting them to grayscale-like images. + + Args: + index (np.array): The indices of images to modify. + """ + for i in range(index.shape[0]): + gs_index = index[i] + self.data[gs_index, :, :, 1] = 0.0 # Set the green channel to 0 + self.data[gs_index, :, :, 2] = 0.0 # Set the blue channel to 0 + + def __getitem__(self, index): + """ + Retrieves an image and its corresponding target (label) at the given index. + + Args: + index (int): Index of the data point to retrieve. + + Returns: + tuple: (image, target) where the image is apply_transformationed (if specified), and the target is the label. + """ + img, target = self.data[index], self.target[index] + + # Apply any specified apply_transformationations to the image + if self.apply_transformation is not None: + img = self.apply_transformation(img) + + # Apply any specified apply_transformationations to the target label + if self.target_apply_transformation is not None: + target = self.target_apply_transformation(target) + + return img, target + + def __len__(self): + """ + Returns the total number of data points in the dataset. + + Returns: + int: The number of samples in the dataset. + """ + return len(self.data) diff --git a/EdgeFLite/fedml_service/data_cleaning/pillbase/.DS_Store b/EdgeFLite/fedml_service/data_cleaning/pillbase/.DS_Store new file mode 100644 index 0000000..e21dabd Binary files /dev/null and b/EdgeFLite/fedml_service/data_cleaning/pillbase/.DS_Store differ diff --git a/EdgeFLite/fedml_service/data_cleaning/pillbase/bulk_data_import.py b/EdgeFLite/fedml_service/data_cleaning/pillbase/bulk_data_import.py new file mode 100644 index 0000000..c70743f --- /dev/null +++ b/EdgeFLite/fedml_service/data_cleaning/pillbase/bulk_data_import.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import logging +import random +import pickle +from PIL import Image + +import numpy as np +import torch +import torch.utils.data as data +import torchvision.apply_transformations as apply_transformations + +from dataset.pill_dataset_base import PillDataBase # Custom dataset class for handling Pill data +from config import HOME # Configuration file for defining the home directory + +# Configure logging to capture information during the execution of the script +logging.basicConfig() +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +# Function to load and partition pill base data +def load_partition_data_pillbase(dataset, data_dir, partition_method, partition_alpha, client_number, batch_size): + # Define the number of samples in the training and testing datasets + train_data_num = 8161 + test_data_num = 1619 + + # Normalization parameters (mean and standard deviation) for each channel (RGB) based on the dataset's characteristics + mean, std = [0.4550, 0.5239, 0.5653], [0.2460, 0.2446, 0.2252] + + # Define apply_transformationations for training data, including: + # 1. Randomly resized crops to augment the data. + # 2. Horizontal flipping for data augmentation. + # 3. Conversion to tensor and normalization with the provided mean and std. + # 4. Random erasing to simulate occlusion as part of augmentation. + train_apply_transformation = apply_transformations.Compose([ + apply_transformations.RandomResizedCrop(224, scale=(0.1, 1.0), interpolation=Image.BILINEAR), + apply_transformations.RandomHorizontalFlip(), + apply_transformations.ToTensor(), + apply_transformations.Normalize(mean, std), + apply_transformations.RandomErasing(p=0.5, scale=(0.05, 0.12), ratio=(0.5, 1.5), value=0) + ]) + + # Create a training dataset using the PillDataBase class and apply the apply_transformationation + train_dataset = PillDataBase(data_dir, train=True, apply_transformation=train_apply_transformation, split_factor=1) + + # Create a DataLoader for the global training dataset, with shuffling enabled and dropping the last incomplete batch + train_data_global = data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + + # Define apply_transformationations for validation data, including: + # 1. Resizing to a larger scale for testing. + # 2. Center cropping to ensure the image size is 224x224. + # 3. Conversion to tensor and normalization with the same mean and std as the training data. + val_apply_transformation = apply_transformations.Compose([ + apply_transformations.Resize(int(224 * 1.15), interpolation=Image.BILINEAR), + apply_transformations.CenterCrop(224), + apply_transformations.ToTensor(), + apply_transformations.Normalize(mean, std) + ]) + + # Create a validation dataset using the PillDataBase class with the validation apply_transformationations + val_dataset = PillDataBase(data_dir, train=False, apply_transformation=val_apply_transformation, split_factor=1) + + # Calculate how many images each client will receive for the validation dataset + images_per_client = len(val_dataset) // client_number + logger.info(f"Images per client: {images_per_client}") # Log the number of images assigned to each client + + # Split the validation data among the clients evenly, ensuring the last client gets any remaining images + data_split = [images_per_client] * (client_number - 1) + [len(val_dataset) - images_per_client * (client_number - 1)] + + # Perform the actual data splitting using torch's random_split function and a fixed random seed for reproducibility + testdata_split = torch.utils.data.random_split(val_dataset, data_split, generator=torch.Generator().manual_seed(68)) + + # Create a DataLoader for each client from their respective validation dataset splits + test_data_local_dict = [ + torch.utils.data.DataLoader(x, batch_size=16, shuffle=True, drop_last=True) + for x in testdata_split + ] + + # Return all necessary data structures, including the number of classes and the training/test data loaders + class_num = 98 # Total number of classes in the dataset + return train_data_num, test_data_num, train_data_global, None, None, None, test_data_local_dict, class_num diff --git a/EdgeFLite/fedml_service/data_cleaning/skin_dataset/.DS_Store b/EdgeFLite/fedml_service/data_cleaning/skin_dataset/.DS_Store new file mode 100644 index 0000000..e21dabd Binary files /dev/null and b/EdgeFLite/fedml_service/data_cleaning/skin_dataset/.DS_Store differ diff --git a/EdgeFLite/fedml_service/data_cleaning/skin_dataset/bulk_data_import.py b/EdgeFLite/fedml_service/data_cleaning/skin_dataset/bulk_data_import.py new file mode 100644 index 0000000..b9a56ab --- /dev/null +++ b/EdgeFLite/fedml_service/data_cleaning/skin_dataset/bulk_data_import.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import logging +import numpy as np +import torch +import torch.utils.data as data +import torchvision.apply_transformations as apply_transformations +import random +from dataset.skin_dataset import SkinData +from config import HOME +import pickle + +# Set up logging +logging.basicConfig() # Configures the basic logging setup +logger = logging.getLogger() # Gets the root logger +logger.setLevel(logging.INFO) # Sets the logging level to INFO + +def load_partition_data_skin_dataset(dataset, data_dir, partition_method, partition_alpha, client_number, batch_size): + # Predefined dataset sizes for training and testing + train_data_num = 8012 # Number of training samples + test_data_num = 2003 # Number of testing samples + + # Normalization parameters used for preprocessing + mean = [0.485, 0.456, 0.406] # Mean values for normalization (standard for ImageNet) + std = [0.229, 0.224, 0.225] # Standard deviation for normalization + + # Load the training data from the pre-saved pickle file + with open(HOME + '/dataset_hub/skin_dataset/skin_dataset_train.pickle', 'rb') as train_file: + train_data = pickle.load(train_file) # Loading training data from pickle file + + # Data augmentation and preprocessing apply_transformationations for training data + train_apply_transformations = apply_transformations.Compose([ + apply_transformations.RandomHorizontalFlip(), # Randomly flip the image horizontally + apply_transformations.RandomVerticalFlip(), # Randomly flip the image vertically + apply_transformations.RandomHorizontalFlip(), # Repeated horizontal flip (may be intentional) + apply_transformations.RandomAdjustadjust_image_sharpness(random.uniform(0, 4.0)), # Adjust image adjust_image_sharpness + apply_transformations.RandomAutocontrast(), # Automatically adjust image contrast + apply_transformations.Pad(3), # Pad image by 3 pixels + apply_transformations.RandomRotation(10), # Random rotation by 10 degrees + apply_transformations.CenterCrop(64), # Crop the center to a size of 64x64 + apply_transformations.ToTensor(), # Convert the image to a tensor + apply_transformations.Normalize(mean=mean, std=std) # Normalize using the predefined mean and std + ]) + + # Create the training dataset with augmentation and apply_transformationation + train_dataset = SkinData(train_data, apply_transformation=train_apply_transformations, split_factor=1) + # Create a DataLoader for the training dataset + train_data_global = data.DataLoader( + dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True # Shuffle data and drop incomplete batches + ) + + # Load the test data from the pre-saved pickle file + with open(HOME + "/dataset_hub/skin_dataset/skin_dataset_test.pickle", 'rb') as test_file: + test_data = pickle.load(test_file) # Loading test data from pickle file + + # Preprocessing apply_transformationations for validation/testing data (without augmentation) + val_apply_transformations = apply_transformations.Compose([ + apply_transformations.Pad(3), # Pad the image by 3 pixels + apply_transformations.CenterCrop(64), # Crop the center to a size of 64x64 + apply_transformations.ToTensor(), # Convert the image to a tensor + apply_transformations.Normalize(mean=mean, std=std) # Normalize using the predefined mean and std + ]) + + # Create the validation/test dataset with the preprocessing apply_transformationations + val_dataset = SkinData(test_data, apply_transformation=val_apply_transformations, split_factor=1) + + # Split test data across clients. Each client gets approximately equal data. + images_per_client = len(val_dataset) // client_number # Number of images each client will get + logger.info(f"Images per client: {images_per_client}") # Log the number of images per client + + # Create a list that determines the size of the data splits for each client + data_split = [images_per_client] * (client_number - 1) # Distribute data equally to all but the last client + data_split.append(len(val_dataset) - images_per_client * (client_number - 1)) # The last client gets the remaining data + logger.info(f"Data split: {data_split}") # Log the data split + + # Randomly split test data for each client using the data_split list + testdata_split = torch.utils.data.random_split( + val_dataset, data_split, generator=torch.Generator().manual_seed(68) # Set the random seed for reproducibility + ) + + # Create a DataLoader for each client's test data + test_data_local_dict = [ + torch.utils.data.DataLoader( + x, batch_size=32, shuffle=(True if train_sampler is None else False), drop_last=True # Create DataLoader for each client's split + ) for x in testdata_split + ] + + # Other variables that are currently unused (placeholders for future implementation) + class_num = 7 # Number of classes in the dataset + test_data_global = None # Placeholder for global test data (currently not used) + data_local_num_dict = None # Placeholder for storing the number of samples per client (currently not used) + train_data_local_dict = None # Placeholder for storing local training data for each client (currently not used) + + # Return key values including the number of samples, data loaders, and class number + return train_data_num, test_data_num, train_data_global, test_data_global, \ + data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num diff --git a/EdgeFLite/fedml_service/decentralized/.DS_Store b/EdgeFLite/fedml_service/decentralized/.DS_Store new file mode 100644 index 0000000..afbd066 Binary files /dev/null and b/EdgeFLite/fedml_service/decentralized/.DS_Store differ diff --git a/EdgeFLite/fedml_service/decentralized/federated_gkt/.DS_Store b/EdgeFLite/fedml_service/decentralized/federated_gkt/.DS_Store new file mode 100644 index 0000000..e21dabd Binary files /dev/null and b/EdgeFLite/fedml_service/decentralized/federated_gkt/.DS_Store differ diff --git a/EdgeFLite/fedml_service/decentralized/federated_gkt/client_coach.py b/EdgeFLite/fedml_service/decentralized/federated_gkt/client_coach.py new file mode 100644 index 0000000..05eb1a5 --- /dev/null +++ b/EdgeFLite/fedml_service/decentralized/federated_gkt/client_coach.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import logging +import torch +from torch import nn, optim +from fedml_service.decentralized.federated_gkt import utils + +# Class for training a GKT client in a federated learning setup +class GKTTrainer: + def __init__(self, client_index, local_training_data, local_test_data, device, client_model, args): + # Initialize the client trainer with various parameters + self.client_index = client_index # Index for the current client + self.local_training_data = local_training_data[client_index] # Local training dataset specific to the client + self.local_test_data = local_test_data[client_index] # Local test dataset specific to the client + self.device = device # Device (CPU/GPU) where the computation will take place + self.client_model = client_model.to(self.device) # Model assigned to the client + self.args = args # Arguments passed for configuring the training process + + logging.info(f"Client device = {self.device}") + + # Model parameters used for optimization + self.model_params = self.master_params = self.client_model.parameters() + optim_params = self.master_params + + # Configure optimizer based on the provided arguments + if self.args.optimizer == "SGD": + # Using SGD optimizer with learning rate, momentum, and weight decay + self.optimizer = optim.SGD(optim_params, lr=self.args.lr, momentum=0.9, nesterov=True, weight_decay=self.args.wd) + elif self.args.optimizer == "Adam": + # Using Adam optimizer with learning rate, weight decay, and AMSGrad variant + self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True) + + # Define loss functions: CrossEntropy for true label prediction, KL divergence for knowledge distillation + self.criterion_CE = nn.CrossEntropyLoss() + self.criterion_KL = utils.KL_Loss(self.args.temperature) + + # Dictionary to hold logits received from the server (used for knowledge distillation) + self.server_logits_dict = {} + + logging.info(f"Client device = {self.device} - Initialization Complete") + + # Update server logits for knowledge distillation + def update_large_model_logits(self, logits): + self.server_logits_dict = logits + + # Main training function for the client + def train(self): + # Dictionaries to store extracted features, logits, and labels during training and testing + extracted_feature_dict, logits_dict, labels_dict = {}, {}, {} + extracted_feature_dict_test, labels_dict_test = {}, {} + + # Only train if training on client is enabled + if self.args.whether_training_on_client: + self.client_model.train() # Set model to training mode + epoch_loss = [] # Track loss for each epoch + + # Loop over the specified number of federated epochs + for epoch in range(self.args.fed_epochs): + batch_loss = [] # Track loss for each batch + + # Loop through the local training data in batches + for batch_idx, (images, labels) in enumerate(self.local_training_data): + # Move images and labels to the specified device + images, labels = images.to(self.device), labels.to(self.device) + + # Forward pass through the client model + log_probs, _ = self.client_model(images) + + # Compute the loss with respect to the true labels + loss_true = self.criterion_CE(log_probs, labels) + + # If server logits are available, calculate the distillation loss using KL divergence + if self.server_logits_dict: + large_model_logits = torch.from_numpy(self.server_logits_dict[batch_idx]).to(self.device) + loss_kd = self.criterion_KL(log_probs, large_model_logits) + # Combine true label loss and distillation loss + loss = loss_true + self.args.alpha * loss_kd + else: + # Use only the true label loss if no server logits are available + loss = loss_true + + # Perform backpropagation and optimization step + self.optimizer.zero_grad() # Reset gradients + loss.backward() # Backpropagate the loss + self.optimizer.step() # Update model parameters + + # Logging progress for each batch + logging.info(f'Client {self.client_index} - Update Epoch: {epoch} ' + f'[{batch_idx * len(images)}/{len(self.local_training_data.dataset)} ' + f'({100. * batch_idx / len(self.local_training_data):.0f}%)]') + batch_loss.append(loss.item()) # Store the loss for the current batch + + # Calculate and store average loss for the epoch + epoch_loss.append(sum(batch_loss) / len(batch_loss)) + + # Switch to evaluation mode after training + self.client_model.eval() + + # Extract features, logits, and labels from the training data for evaluation + for batch_idx, (images, labels) in enumerate(self.local_training_data): + images, labels = images.to(self.device), labels.to(self.device) + log_probs, extracted_features = self.client_model(images) + + # Store the extracted features, logits, and labels for this batch + extracted_feature_dict[batch_idx] = extracted_features.cpu().detach().numpy() + logits_dict[batch_idx] = log_probs.cpu().detach().numpy() + labels_dict[batch_idx] = labels.cpu().detach().numpy() + + # Extract features and labels from the test data for evaluation + for batch_idx, (images, labels) in enumerate(self.local_test_data): + test_images, test_labels = images.to(self.device), labels.to(self.device) + _, extracted_features_test = self.client_model(test_images) + + # Store the extracted test features and labels for this batch + extracted_feature_dict_test[batch_idx] = extracted_features_test.cpu().detach().numpy() + labels_dict_test[batch_idx] = test_labels.cpu().detach().numpy() + + # Return the extracted features, logits, and labels from both training and test datasets + return extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test, labels_dict_test diff --git a/EdgeFLite/fedml_service/decentralized/federated_gkt/helper_utils.py b/EdgeFLite/fedml_service/decentralized/federated_gkt/helper_utils.py new file mode 100644 index 0000000..e3c7c6d --- /dev/null +++ b/EdgeFLite/fedml_service/decentralized/federated_gkt/helper_utils.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import json +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def load_state_dict(file): + """Load a state dict from a file, handling any potential location issues.""" + try: + return torch.load(file) + except AssertionError: + return torch.load(file, map_location=lambda storage, location: storage) + + +def flatten_parameters(model): + """Flatten the parameters of the model into a single tensor.""" + return torch.cat([param.data.view(-1) for param in model.parameters()]) + + +def set_flattened_parameters(model, flat_params): + """Set the model's parameters from a flattened tensor.""" + prev_ind = 0 + for param in model.parameters(): + flat_size = int(np.prod(param.size())) + param.data.copy_(flat_params[prev_ind:prev_ind + flat_size].view(param.size())) + prev_ind += flat_size + + +class RollingAverage: + """Class to maintain a running average of a quantity.""" + def __init__(self): + self.steps = 0 + self.total = 0 + + def update(self, val): + self.total += val + self.steps += 1 + + def value(self): + return self.total / float(self.steps) if self.steps > 0 else 0 + + +def compute_accuracy(output, target, topk=(1,)): + """Compute the precision@k for the specified values of k.""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, dim=1, largest=True, sorted=True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + return [correct[:k].reshape(-1).float().sum(0).mul_(100.0 / batch_size) for k in topk] + + +class KLDivergenceLoss(nn.Module): + """Kullback-Leibler Divergence Loss.""" + def __init__(self, temperature=1): + super(KLDivergenceLoss, self).__init__() + self.temperature = temperature + + def forward(self, output_batch, teacher_outputs): + output_batch = F.log_softmax(output_batch / self.temperature, dim=1) + teacher_outputs = F.softmax(teacher_outputs / self.temperature, dim=1) + 1e-7 + return self.temperature ** 2 * nn.KLDivLoss(reduction='batchmean')(output_batch, teacher_outputs) + + +class CELoss(nn.Module): + """Cross-Entropy Loss.""" + def __init__(self, temperature=1): + super(CELoss, self).__init__() + self.temperature = temperature + + def forward(self, output_batch, teacher_outputs): + output_batch = F.log_softmax(output_batch / self.temperature, dim=1) + teacher_outputs = F.softmax(teacher_outputs / self.temperature, dim=1) + return -self.temperature ** 2 * torch.sum(output_batch * teacher_outputs) / teacher_outputs.size(0) + + +def save_dict_to_json(data, json_path): + """Save a dictionary of floats to a JSON file.""" + with open(json_path, 'w') as f: + json.dump({k: float(v) for k, v in data.items()}, f, indent=4) + + +def get_optimized_params(model, model_params, master_params): + """Filter out batch norm parameters from weight decay to improve accuracy.""" + bn_params, remaining_params = split_bn_params(model, model_params, master_params) + return [{'params': bn_params, 'weight_decay': 0}, {'params': remaining_params}] + + +def split_bn_params(model, model_params, master_params): + """Split parameters into batch norm and non-batch norm.""" + def get_bn_params(module): + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + return set(module.parameters()) + return {p for child in module.children() for p in get_bn_params(child)} + + mod_bn_params = get_bn_params(model) + zipped_params = zip(model_params, master_params) + + mas_bn_params = [p_mast for p_mod, p_mast in zipped_params if p_mod in mod_bn_params] + mas_rem_params = [p_mast for p_mod, p_mast in zipped_params if p_mod not in mod_bn_params] + + return mas_bn_params, mas_rem_params diff --git a/EdgeFLite/fedml_service/decentralized/federated_gkt/server_coach.py b/EdgeFLite/fedml_service/decentralized/federated_gkt/server_coach.py new file mode 100644 index 0000000..c066bf4 --- /dev/null +++ b/EdgeFLite/fedml_service/decentralized/federated_gkt/server_coach.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import logging +import os +import shutil + +import torch +from torch import nn, optim +from torch.optim.lr_scheduler import ReduceLROnPlateau +from utils import metric +from fedml_service.decentralized.federated_gkt import utils + +# List to store filenames of saved checkpoints +saved_ckpt_filenames = [] + +class GKTServerTrainer: + def __init__(self, client_num, device, server_model, args, writer): + # Initialize the trainer with the number of clients, device (CPU/GPU), global server model, training arguments, and a writer for logging + self.client_num = client_num + self.device = device + self.args = args + self.writer = writer + + """ + Notes: Using data parallelism requires adjusting the batch size accordingly. + For example, with a single GPU (batch_size = 64), an epoch takes 1:03; + using 4 GPUs (batch_size = 256), it takes 38 seconds, and with 4 GPUs (batch_size = 64), it takes 1:00. + If batch size is not adjusted, the communication between CPU and GPU may slow down training. + """ + + # Server model setup + self.model_global = server_model + self.model_global.train() # Set model to training mode + self.model_global.to(self.device) # Move model to the specified device (CPU or GPU) + + # Model parameters for optimization + self.model_params = self.master_params = self.model_global.parameters() + optim_params = self.master_params + + # Choose optimizer based on arguments (SGD or Adam) + if self.args.optimizer == "SGD": + self.optimizer = optim.SGD(optim_params, lr=self.args.lr, momentum=0.9, nesterov=True, weight_decay=self.args.wd) + elif self.args.optimizer == "Adam": + self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True) + + # Learning rate scheduler to reduce the learning rate when the accuracy plateaus + self.scheduler = ReduceLROnPlateau(self.optimizer, 'max') + + # Loss functions: CrossEntropy for classification, KL for knowledge distillation + self.criterion_CE = nn.CrossEntropyLoss() + self.criterion_KL = utils.KL_Loss(self.args.temperature) + + # Best accuracy tracking + self.best_acc = 0.0 + + # Client data dictionaries to store features, logits, and labels + self.client_extracted_feature_dict = {} + self.client_logits_dict = {} + self.client_labels_dict = {} + self.server_logits_dict = {} + + # Testing data dictionaries + self.client_extracted_feature_dict_test = {} + self.client_labels_dict_test = {} + + # Miscellaneous dictionaries to store model info, sample numbers, training accuracy, and loss + self.model_dict = {} + self.sample_num_dict = {} + self.train_acc_dict = {} + self.train_loss_dict = {} + self.test_acc_avg = 0.0 + self.test_loss_avg = 0.0 + + # Dictionary to track if the client model has been uploaded + self.flag_client_model_uploaded_dict = {idx: False for idx in range(self.client_num)} + + # Add results from a local client model after training + def add_local_trained_result(self, index, extracted_feature_dict, logits_dict, labels_dict, + extracted_feature_dict_test, labels_dict_test): + logging.info(f"Adding model for client index = {index}") + self.client_extracted_feature_dict[index] = extracted_feature_dict + self.client_logits_dict[index] = logits_dict + self.client_labels_dict[index] = labels_dict + self.client_extracted_feature_dict_test[index] = extracted_feature_dict_test + self.client_labels_dict_test[index] = labels_dict_test + self.flag_client_model_uploaded_dict[index] = True + + # Check if all clients have uploaded their models + def check_whether_all_receive(self): + if all(self.flag_client_model_uploaded_dict.values()): + self.flag_client_model_uploaded_dict = {idx: False for idx in range(self.client_num)} + return True + return False + + # Get logits from the global model for a specific client + def get_global_logits(self, client_index): + return self.server_logits_dict.get(client_index) + + # Main training function based on the round index + def train(self, round_idx): + if self.args.sweep == 1: # Sweep mode + self.sweep(round_idx) + else: # Normal training process + if self.args.whether_training_on_client == 1: # Check if training occurs on client + self.train_and_distill_on_client(round_idx) + else: # No training on client, just evaluate + self.do_not_train_on_client(round_idx) + + # Training and knowledge distillation on client side + def train_and_distill_on_client(self, round_idx): + # Set the number of server epochs (based on testing mode) + epochs_server = 1 if not self.args.test else self.get_server_epoch_strategy_test()[0] + self.train_and_eval(round_idx, epochs_server, self.writer, self.args) # Train and evaluate + self.scheduler.step(self.best_acc, epoch=round_idx) # Update learning rate scheduler + + # Skip client-side training + def do_not_train_on_client(self, round_idx): + self.train_and_eval(round_idx, 1) + self.scheduler.step(self.best_acc, epoch=round_idx) + + # Training with sweeping strategy + def sweep(self, round_idx): + self.train_and_eval(round_idx, self.args.epochs_server) + self.scheduler.step(self.best_acc, epoch=round_idx) + + # Strategy for determining the number of epochs (used in testing) + def get_server_epoch_strategy_test(self): + return 1, True + + # Different strategies for determining the number of epochs based on training round + def get_server_epoch_strategy_reset56(self, round_idx): + epochs = 20 if round_idx < 20 else 15 if round_idx < 30 else 10 if round_idx < 40 else 5 if round_idx < 50 else 3 if round_idx < 150 else 1 + whether_distill_back = round_idx < 150 + return epochs, whether_distill_back + + # Another variant of epoch strategy + def get_server_epoch_strategy_reset56_2(self, round_idx): + return self.args.epochs_server, True + + # Main training and evaluation loop + def train_and_eval(self, round_idx, epochs, val_writer, args): + for epoch in range(epochs): + logging.info(f"Train and evaluate. Round = {round_idx}, Epoch = {epoch}") + train_metrics = self.train_large_model_on_the_server() # Training step + + if epoch == epochs - 1: + # Log metrics for the final epoch + val_writer.add_scalar('average training loss', train_metrics['train_loss'], global_step=round_idx) + test_metrics = self.eval_large_model_on_the_server() # Evaluation step + test_acc = test_metrics['test_accTop1'] + + val_writer.add_scalar('test loss', test_metrics['test_loss'], global_step=round_idx) + val_writer.add_scalar('test acc', test_metrics['test_accTop1'], global_step=round_idx) + + # Save best accuracy model + if test_acc >= self.best_acc: + logging.info("- Found better accuracy") + self.best_acc = test_acc + + val_writer.add_scalar('best_acc1', self.best_acc, global_step=round_idx) + + # Save model checkpoints + if args.save_weight: + filename = f"checkpoint_{round_idx}.pth.tar" + saved_ckpt_filenames.append(filename) + if len(saved_ckpt_filenames) > args.max_ckpt_nums: + os.remove(os.path.join(args.model_dir, saved_ckpt_filenames.pop(0))) + + ckpt_dict = { + 'round': round_idx + 1, + 'arch': args.arch, + 'state_dict': self.model_global.state_dict(), + 'best_acc1': self.best_acc, + 'optimizer': self.optimizer.state_dict(), + } + metric.save_checkpoint(ckpt_dict, test_acc >= self.best_acc, args.model_dir, filename=filename) + + # Print metrics for the current round + print(f"{round_idx}-th round | Train Loss: {train_metrics['train_loss']:.3g} | Test Loss: {test_metrics['test_loss']:.3g} | Test Acc: {test_metrics['test_accTop1']:.3f}") + + # Function to train the model on the server side + def train_large_model_on_the_server(self): + # Clear the logits dictionary and set model to training mode + self.server_logits_dict.clear() + self.model_global.train() + + # Track loss and accuracy + loss_avg = utils.RollingAverage() + accTop1_avg = utils.RollingAverage() + accTop5_avg = utils.RollingAverage() + + # Iterate over clients' extracted features + for client_index, extracted_feature_dict in self.client_extracted_feature_dict.items(): + logits_dict = self.client_logits_dict[client_index] + labels_dict = self.client_labels_dict[client_index] + + s_logits_dict = {} + self.server_logits_dict[client_index] = s_logits_dict + + # Iterate over batches of features for each client + for batch_index, batch_feature_map_x in extracted_feature_dict.items(): + batch_feature_map_x = torch.from_numpy(batch_feature_map_x).to(self.device) + batch_logits = torch.from_numpy(logits_dict[batch_index]).float().to(self.device) + batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device) + + # Forward pass + output_batch = self.model_global(batch_feature_map_x) + + # Knowledge distillation loss + if self.args.whether_distill_on_the_server == 1: + loss_kd = self.criterion_KL(output_batch, batch_logits).to(self.device) + loss_true = self.criterion_CE(output_batch, batch_labels).to(self.device) + loss = loss_kd + self.args.alpha * loss_true + else: + # Standard cross-entropy loss + loss = self.criterion_CE(output_batch, batch_labels).to(self.device) + + # Backward pass and optimization + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # Compute accuracy metrics + metrics = utils.accuracy(output_batch, batch_labels, topk=(1, 5)) + accTop1_avg.update(metrics[0].item()) + accTop5_avg.update(metrics[1].item()) + loss_avg.update(loss.item()) + + # Store logits for the batch + s_logits_dict[batch_index] = output_batch.cpu().detach().numpy() + + # Aggregate and log training metrics + train_metrics = {'train_loss': loss_avg.value(), + 'train_accTop1': accTop1_avg.value(), + 'train_accTop5': accTop5_avg.value()} + logging.info(f"- Train metrics: {' ; '.join(f'{k}: {v:.3f}' for k, v in train_metrics.items())}") + return train_metrics + + # Function to evaluate the model on the server side + def eval_large_model_on_the_server(self): + # Set model to evaluation mode + self.model_global.eval() + loss_avg = utils.RollingAverage() + accTop1_avg = utils.RollingAverage() + accTop5_avg = utils.RollingAverage() + + # Disable gradient computation for evaluation + with torch.no_grad(): + # Iterate over clients' extracted features for testing + for client_index, extracted_feature_dict in self.client_extracted_feature_dict_test.items(): + labels_dict = self.client_labels_dict_test[client_index] + + # Iterate over batches for each client + for batch_index, batch_feature_map_x in extracted_feature_dict.items(): + batch_feature_map_x = torch.from_numpy(batch_feature_map_x).to(self.device) + batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device) + + # Forward pass + output_batch = self.model_global(batch_feature_map_x) + loss = self.criterion_CE(output_batch, batch_labels) + + # Compute accuracy metrics + metrics = utils.accuracy(output_batch, batch_labels, topk=(1, 5)) + accTop1_avg.update(metrics[0].item()) + accTop5_avg.update(metrics[1].item()) + loss_avg.update(loss.item()) + + # Aggregate and log test metrics + test_metrics = {'test_loss': loss_avg.value(), + 'test_accTop1': accTop1_avg.value(), + 'test_accTop5': accTop5_avg.value()} + logging.info(f"- Test metrics: {' ; '.join(f'{k}: {v:.3f}' for k, v in test_metrics.items())}") + return test_metrics diff --git a/EdgeFLite/helpers/.DS_Store b/EdgeFLite/helpers/.DS_Store new file mode 100644 index 0000000..e8a5f76 Binary files /dev/null and b/EdgeFLite/helpers/.DS_Store differ diff --git a/EdgeFLite/helpers/evaluation_metrics.py b/EdgeFLite/helpers/evaluation_metrics.py new file mode 100644 index 0000000..51febc6 --- /dev/null +++ b/EdgeFLite/helpers/evaluation_metrics.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import shutil +import torch + +def store_model(state, best_model, directory, filename='checkpoint.pth'): + """ + Stores the model checkpoint in the specified directory. If it's the best model, + it saves another copy named 'best_model.pth'. + + Args: + state (dict): Model's state dictionary. + best_model (bool): Flag indicating if the current model is the best. + directory (str): Directory where the model is saved. + filename (str): Name of the file to save the checkpoint (default 'checkpoint.pth'). + """ + save_path = os.path.join(directory, filename) + torch.save(state, save_path) + if best_model: + # If the current model is the best, save another copy as 'best_model.pth' + shutil.copy(save_path, os.path.join(directory, 'best_model.pth')) + +def save_main_client_model(state, best_model, directory): + """ + Saves the model for the main client if it's the best one. + + Args: + state (dict): Model's state dictionary. + best_model (bool): Flag indicating if the current model is the best. + directory (str): Directory where the model is saved. + """ + if best_model: + print("Saving the best main client model") + torch.save(state, os.path.join(directory, 'main_client_best.pth')) + +def save_proxy_clients_model(state, best_model, directory): + """ + Saves the model for proxy clients if it's the best one. + + Args: + state (dict): Model's state dictionary. + best_model (bool): Flag indicating if the current model is the best. + directory (str): Directory where the model is saved. + """ + if best_model: + print("Saving the best proxy client model") + torch.save(state, os.path.join(directory, 'proxy_clients_best.pth')) + +def save_individual_client_model(state, best_model, directory): + """ + Saves the model for individual clients if it's the best one. + + Args: + state (dict): Model's state dictionary. + best_model (bool): Flag indicating if the current model is the best. + directory (str): Directory where the model is saved. + """ + if best_model: + print("Saving the best client model") + torch.save(state, os.path.join(directory, 'client_best.pth')) + +def save_server_model(state, best_model, directory): + """ + Saves the model for the server if it's the best one. + + Args: + state (dict): Model's state dictionary. + best_model (bool): Flag indicating if the current model is the best. + directory (str): Directory where the model is saved. + """ + if best_model: + print("Saving the best server model") + torch.save(state, os.path.join(directory, 'server_best.pth')) + +class MetricTracker(object): + """ + A helper class to track and compute the average of a given metric. + + Args: + metric_name (str): Name of the metric to track. + fmt (str): Format for printing metric values (default ':f'). + """ + def __init__(self, metric_name, fmt=':f'): + self.metric_name = metric_name + self.fmt = fmt + self.reset() + + def reset(self): + """Resets all metric counters.""" + self.current_value = 0 + self.total_sum = 0 + self.count = 0 + self.average = 0 + + def update(self, value, n=1): + """ + Updates the metric value. + + Args: + value (float): New value of the metric. + n (int): Weight or count for the value (default 1). + """ + self.current_value = value + self.total_sum += value * n + self.count += n + self.average = self.total_sum / self.count + + def __str__(self): + """Returns the formatted metric string showing current value and average.""" + return f'{self.metric_name} {self.current_value{self.fmt}} ({self.average{self.fmt}})' + +class ProgressLogger(object): + """ + A class to log and display the progress of training/testing over multiple batches. + + Args: + total_batches (int): Total number of batches. + *metrics (MetricTracker): Metrics to log during the process. + prefix (str): Prefix for the progress log (default "Progress:"). + """ + def __init__(self, total_batches, *metrics, prefix="Progress:"): + self.batch_format = self._get_batch_format(total_batches) + self.metrics = metrics + self.prefix = prefix + + def log(self, batch_idx): + """ + Logs the current progress of training/testing. + + Args: + batch_idx (int): The current batch index. + """ + output = [self.prefix + self.batch_format.format(batch_idx)] + output += [str(metric) for metric in self.metrics] + print(' | '.join(output)) + + def _get_batch_format(self, total_batches): + """Creates a format string to display the batch index.""" + num_digits = len(str(total_batches)) + return '[{:' + str(num_digits) + 'd}/{}]'.format(total_batches) + +def compute_accuracy(prediction, target, top_k=(1,)): + """ + Computes the accuracy for the top-k predictions. + + Args: + prediction (Tensor): Model predictions. + target (Tensor): Ground truth labels. + top_k (tuple): Tuple of top-k values to consider for accuracy (default (1,)). + + Returns: + List[Tensor]: List of accuracies for each top-k value. + """ + with torch.no_grad(): + max_k = max(top_k) + batch_size = target.size(0) + + # Get the top-k predictions + _, top_predictions = prediction.topk(max_k, 1, largest=True, sorted=True) + top_predictions = top_predictions.t() + + # Compare top-k predictions with targets + correct_predictions = top_predictions.eq(target.view(1, -1).expand_as(top_predictions)) + + accuracy_results = [] + for k in top_k: + # Count the number of correct predictions within the top-k + correct_k = correct_predictions[:k].view(-1).float().sum(0, keepdim=True) + accuracy_results.append(correct_k.mul_(100.0 / batch_size)) + return accuracy_results + +def count_model_parameters(model, trainable_only=False): + """ + Counts the total number of parameters in the model. + + Args: + model (nn.Module): The PyTorch model. + trainable_only (bool): Whether to count only trainable parameters (default False). + + Returns: + int: Total number of parameters in the model. + """ + if trainable_only: + # Count only the parameters that require gradients (trainable parameters) + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + # Count all parameters (trainable and non-trainable) + return sum(p.numel() for p in model.parameters()) diff --git a/EdgeFLite/helpers/normalization.py b/EdgeFLite/helpers/normalization.py new file mode 100644 index 0000000..c5f238f --- /dev/null +++ b/EdgeFLite/helpers/normalization.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn + + +class PassThrough(nn.Module): + """ + A placeholder module that simply returns the input tensor unchanged. + """ + def __init__(self, **kwargs): + super(PassThrough, self).__init__() + + def forward(self, input_tensor): + return input_tensor + + +class LayerNormalization2D(nn.Module): + """ + A custom layer normalization module for 2D inputs (typically used for + convolutional layers). It optionally applies learned scaling (weight) + and shifting (bias) parameters. + + Arguments: + epsilon: A small value to avoid division by zero. + use_weight: Whether to learn and apply weight parameters. + use_bias: Whether to learn and apply bias parameters. + """ + def __init__(self, epsilon=1e-05, use_weight=True, use_bias=True, **kwargs): + super(LayerNormalization2D, self).__init__() + + self.epsilon = epsilon + self.use_weight = use_weight + self.use_bias = use_bias + + def forward(self, input_tensor): + # Initialize weight and bias parameters if they are not nn.Parameter instances + if (not isinstance(self.use_weight, nn.parameter.Parameter) and + not isinstance(self.use_bias, nn.parameter.Parameter) and + (self.use_weight or self.use_bias)): + self._initialize_parameters(input_tensor) + + # Apply layer normalization + return nn.functional.layer_norm(input_tensor, input_tensor.shape[1:], + weight=self.use_weight, bias=self.use_bias, + eps=self.epsilon) + + def _initialize_parameters(self, input_tensor): + """ + Initialize weight and bias parameters for layer normalization. + Arguments: + input_tensor: The input tensor to the normalization layer. + """ + channels, height, width = input_tensor.shape[1:] + param_shape = [channels, height, width] + + # Initialize weight parameter if applicable + if self.use_weight: + self.use_weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) + else: + self.register_parameter('use_weight', None) + + # Initialize bias parameter if applicable + if self.use_bias: + self.use_bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) + else: + self.register_parameter('use_bias', None) + + +class NormalizationLayer(nn.Module): + """ + A flexible normalization layer that supports different types of normalization + (batch, group, layer, instance, or none). This class is a wrapper that selects + the appropriate normalization technique based on the norm_type argument. + + Arguments: + norm_type: The type of normalization to apply ('batch', 'group', 'layer', 'instance', or 'none'). + epsilon: A small value to avoid division by zero (Default: 1e-05). + momentum: Momentum for updating running statistics (Default: 0.1, applicable for batch norm). + use_weight: Whether to learn weight parameters (Default: True). + use_bias: Whether to learn bias parameters (Default: True). + track_stats: Whether to track running statistics (Default: True, applicable for batch norm). + group_norm_groups: Number of groups to use for group normalization (Default: 32). + """ + def __init__(self, norm_type='batch', epsilon=1e-05, momentum=0.1, + use_weight=True, use_bias=True, track_stats=True, group_norm_groups=32, **kwargs): + super(NormalizationLayer, self).__init__() + + if norm_type not in ['batch', 'group', 'layer', 'instance', 'none']: + raise ValueError('Unsupported norm_type: {}. Supported options: ' + '"batch" | "group" | "layer" | "instance" | "none".'.format(norm_type)) + + self.norm_type = norm_type + self.epsilon = epsilon + self.momentum = momentum + self.use_weight = use_weight + self.use_bias = use_bias + self.affine = self.use_weight and self.use_bias # Check if affine apply_transformationation is needed + self.track_stats = track_stats + self.group_norm_groups = group_norm_groups + + def forward(self, num_features): + """ + Select and apply the appropriate normalization technique based on the norm_type. + + Arguments: + num_features: The number of input channels or features. + Returns: + A normalization layer corresponding to the norm_type. + """ + if self.norm_type == 'batch': + # Apply Batch Normalization + normalizer = nn.BatchNorm2d(num_features=num_features, eps=self.epsilon, + momentum=self.momentum, affine=self.affine, + track_running_stats=self.track_stats) + elif self.norm_type == 'group': + # Apply Group Normalization + normalizer = nn.GroupNorm(self.group_norm_groups, num_features, + eps=self.epsilon, affine=self.affine) + elif self.norm_type == 'layer': + # Apply Layer Normalization + normalizer = LayerNormalization2D(epsilon=self.epsilon, use_weight=self.use_weight, use_bias=self.use_bias) + elif self.norm_type == 'instance': + # Apply Instance Normalization + normalizer = nn.InstanceNorm2d(num_features, eps=self.epsilon, affine=self.affine) + else: + # No normalization applied, just pass the input through + normalizer = PassThrough() + + return normalizer diff --git a/EdgeFLite/helpers/optimizer_rmsprop.py b/EdgeFLite/helpers/optimizer_rmsprop.py new file mode 100644 index 0000000..52792ca --- /dev/null +++ b/EdgeFLite/helpers/optimizer_rmsprop.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +from torch.optim import Optimizer + + +class CustomRMSprop(Optimizer): + """ + Implements a modified version of the RMSprop algorithm with TensorFlow-style epsilon handling. + + Main differences in this implementation: + 1. Epsilon is incorporated within the square root operation. + 2. The moving average of squared gradients is initialized to 1. + 3. The momentum buffer accumulates updates scaled by the learning rate. + """ + + def __init__(self, params, lr=0.01, alpha=0.99, eps=1e-8, momentum=0, weight_decay=0, centered=False, decoupled_decay=False, lr_in_momentum=True): + """ + Initializes the optimizer with the provided parameters. + + Arguments: + - params: iterable of parameters to optimize or dicts defining parameter groups + - lr: learning rate (default: 0.01) + - alpha: smoothing constant for the moving average (default: 0.99) + - eps: small value to prevent division by zero (default: 1e-8) + - momentum: momentum factor (default: 0) + - weight_decay: weight decay (L2 penalty) (default: 0) + - centered: if True, compute centered RMSprop (default: False) + - decoupled_decay: if True, decouples weight decay from gradient update (default: False) + - lr_in_momentum: if True, applies learning rate within the momentum buffer (default: True) + """ + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if eps < 0.0: + raise ValueError(f"Invalid epsilon value: {eps}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight decay: {weight_decay}") + if alpha < 0.0: + raise ValueError(f"Invalid alpha value: {alpha}") + + # Store the optimizer defaults + defaults = { + 'lr': lr, + 'alpha': alpha, + 'eps': eps, + 'momentum': momentum, + 'centered': centered, + 'weight_decay': weight_decay, + 'decoupled_decay': decoupled_decay, + 'lr_in_momentum': lr_in_momentum + } + super().__init__(params, defaults) + + def step(self, closure=None): + """ + Performs a single optimization step. + + Arguments: + - closure: A closure that reevaluates the model and returns the loss. + """ + # Get the loss value if a closure is provided + loss = closure() if closure is not None else None + + # Iterate over parameter groups + for group in self.param_groups: + lr = group['lr'] + momentum = group['momentum'] + weight_decay = group['weight_decay'] + alpha = group['alpha'] + eps = group['eps'] + + # Iterate over parameters in the group + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data # Get gradient data + if grad.is_sparse: + raise RuntimeError("RMSprop does not support sparse gradients.") + + # Get the state of the parameter + state = self.state[p] + + # Initialize state if it doesn't exist + if not state: + state['step'] = 0 + state['square_avg'] = torch.ones_like(p.data) # Initialize moving average of squared gradients to 1 + if momentum > 0: + state['momentum_buffer'] = torch.zeros_like(p.data) # Initialize momentum buffer + if group['centered']: + state['grad_avg'] = torch.zeros_like(p.data) # Initialize moving average of gradients if centered + + square_avg = state['square_avg'] + one_minus_alpha = 1 - alpha + state['step'] += 1 # Update the step count + + # Apply weight decay + if weight_decay != 0: + if group['decoupled_decay']: + p.data.mul_(1 - lr * weight_decay) # Apply decoupled weight decay + else: + grad.add_(p.data, alpha=weight_decay) # Apply traditional weight decay + + # Update the moving average of squared gradients + square_avg.add_((grad ** 2) - square_avg, alpha=one_minus_alpha) + + # Compute the denominator for gradient update + if group['centered']: + grad_avg = state['grad_avg'] + grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha) + avg = (square_avg - grad_avg ** 2).add_(eps).sqrt_() # Centered RMSprop + else: + avg = square_avg.add_(eps).sqrt_() # Standard RMSprop + + # Apply momentum if needed + if momentum > 0: + buf = state['momentum_buffer'] + if group['lr_in_momentum']: + buf.mul_(momentum).addcdiv_(grad, avg, value=lr) # Apply learning rate inside momentum buffer + p.data.add_(-buf) + else: + buf.mul_(momentum).addcdiv_(grad, avg) # Standard momentum update + p.data.add_(buf, alpha=-lr) + else: + p.data.addcdiv_(grad, avg, value=-lr) # Update parameter without momentum + + return loss # Return the loss if closure was provided diff --git a/EdgeFLite/helpers/pace_controller.py b/EdgeFLite/helpers/pace_controller.py new file mode 100644 index 0000000..1e84654 --- /dev/null +++ b/EdgeFLite/helpers/pace_controller.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import math + +class CustomScheduler: + def __init__(self, mode='cosine', + initial_lr=0.1, + num_epochs=100, + iters_per_epoch=300, + lr_milestones=None, + lr_step=100, + step_multiplier=0.1, + slow_start_epochs=0, + slow_start_lr=1e-4, + min_lr=1e-3, + multiplier=1.0, + lower_bound=-6.0, + upper_bound=3.0, + decay_factor=0.97, + decay_epochs=0.8, + staircase=True): + """ + Initialize the learning rate scheduler. + + Parameters: + mode (str): Mode for learning rate adjustment ('cosine', 'poly', 'HTD', 'step', 'exponential'). + initial_lr (float): Initial learning rate. + num_epochs (int): Total number of epochs. + iters_per_epoch (int): Number of iterations per epoch. + lr_milestones (list): Epoch milestones for learning rate decay in 'step' mode. + lr_step (int): Epoch step size for learning rate reduction in 'step' mode. + step_multiplier (float): Multiplication factor for learning rate reduction in 'step' mode. + slow_start_epochs (int): Number of slow start epochs for warm-up. + slow_start_lr (float): Learning rate during warm-up. + min_lr (float): Minimum learning rate limit. + multiplier (float): Multiplication factor for applying to different parameter groups. + lower_bound (float): Lower bound for the tanh function in 'HTD' mode. + upper_bound (float): Upper bound for the tanh function in 'HTD' mode. + decay_factor (float): Factor by which learning rate decays in 'exponential' mode. + decay_epochs (float): Number of epochs over which learning rate decays in 'exponential' mode. + staircase (bool): If True, apply step-wise learning rate decay in 'exponential' mode. + """ + # Ensure valid mode selection + assert mode in ['cosine', 'poly', 'HTD', 'step', 'exponential'], "Invalid mode." + + # Initialize learning rate settings + self.initial_lr = initial_lr + self.current_lr = initial_lr + self.min_lr = min_lr + self.mode = mode + self.num_epochs = num_epochs + self.iters_per_epoch = iters_per_epoch + self.total_iterations = (num_epochs - slow_start_epochs) * iters_per_epoch + self.slow_start_iters = slow_start_epochs * iters_per_epoch + self.slow_start_lr = slow_start_lr + self.multiplier = multiplier + self.lr_step = lr_step + self.lr_milestones = lr_milestones + self.step_multiplier = step_multiplier + self.lower_bound = lower_bound + self.upper_bound = upper_bound + self.decay_factor = decay_factor + self.decay_steps = decay_epochs * iters_per_epoch + self.staircase = staircase + + print(f"INFO: Using {self.mode} learning rate scheduler with {slow_start_epochs} warm-up epochs.") + + def update_lr(self, optimizer, iteration, epoch): + """Update the learning rate based on the current iteration and epoch.""" + current_iter = epoch * self.iters_per_epoch + iteration + + # During slow start, linearly increase the learning rate + if current_iter <= self.slow_start_iters: + lr = self.slow_start_lr + (self.initial_lr - self.slow_start_lr) * (current_iter / self.slow_start_iters) + else: + # After slow start, calculate learning rate based on the selected mode + lr = self._calculate_lr(current_iter - self.slow_start_iters) + + # Ensure learning rate does not fall below the minimum limit + self.current_lr = max(lr, self.min_lr) + self._apply_lr(optimizer, self.current_lr) + + def _calculate_lr(self, adjusted_iter): + """Calculate the learning rate based on the selected scheduling mode.""" + if self.mode == 'cosine': + # Cosine annealing schedule + return 0.5 * self.initial_lr * (1 + math.cos(math.pi * adjusted_iter / self.total_iterations)) + elif self.mode == 'poly': + # Polynomial decay schedule + return self.initial_lr * (1 - adjusted_iter / self.total_iterations) ** 0.9 + elif self.mode == 'HTD': + # Hyperbolic tangent decay schedule + ratio = adjusted_iter / self.total_iterations + return 0.5 * self.initial_lr * (1 - math.tanh(self.lower_bound + (self.upper_bound - self.lower_bound) * ratio)) + elif self.mode == 'step': + # Step decay schedule + return self._step_lr(adjusted_iter) + elif self.mode == 'exponential': + # Exponential decay schedule + power = math.floor(adjusted_iter / self.decay_steps) if self.staircase else adjusted_iter / self.decay_steps + return self.initial_lr * (self.decay_factor ** power) + else: + raise NotImplementedError("Unknown learning rate mode.") + + def _step_lr(self, adjusted_iter): + """Calculate the learning rate for the 'step' mode.""" + epoch = adjusted_iter // self.iters_per_epoch + # Count how many milestones or steps have passed + if self.lr_milestones: + num_steps = sum([1 for milestone in self.lr_milestones if epoch >= milestone]) + else: + num_steps = epoch // self.lr_step + return self.initial_lr * (self.step_multiplier ** num_steps) + + def _apply_lr(self, optimizer, lr): + """Apply the calculated learning rate to the optimizer.""" + for i, param_group in enumerate(optimizer.param_groups): + # Apply multiplier to parameter groups beyond the first one + param_group['lr'] = lr * (self.multiplier if i > 1 else 1.0) + + +def adjust_hyperparameters(args): + """Adjust the learning rate and momentum based on the batch size.""" + print(f'Adjusting LR and momentum. Original LR: {args.lr}, Original momentum: {args.momentum}') + # Set standard batch size for scaling + standard_batch_size = 128 if 'cifar' in args.dataset else NotImplementedError + # Scale momentum and learning rate + args.momentum = args.momentum ** (args.batch_size / standard_batch_size) + args.lr *= (args.batch_size / standard_batch_size) + print(f'Adjusted LR: {args.lr}, Adjusted momentum: {args.momentum}') + return args + + +def separate_parameters(model, weight_decay_for_norm=0): + """Separate the model parameters into two groups: regular parameters and norm-based parameters.""" + regular_params, norm_params = [], [] + for name, param in model.named_parameters(): + if param.requires_grad: + # Parameters related to normalization and biases are treated separately + if 'norm' in name or 'bias' in name: + norm_params.append(param) + else: + regular_params.append(param) + # Return parameter groups with corresponding weight decay for norm parameters + return [{'params': regular_params}, {'params': norm_params, 'weight_decay': weight_decay_for_norm}] diff --git a/EdgeFLite/helpers/preloader_module.py b/EdgeFLite/helpers/preloader_module.py new file mode 100644 index 0000000..e771a80 --- /dev/null +++ b/EdgeFLite/helpers/preloader_module.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch + +class DataPrefetcher: + def __init__(self, dataloader): + # Initialize with the dataloader and create an iterator + self.dataloader = iter(dataloader) + # Create a CUDA stream for asynchronous data transfer + self.cuda_stream = torch.cuda.Stream() + # Load the next batch of data + self._load_next_batch() + + def _load_next_batch(self): + try: + # Fetch the next batch from the dataloader iterator + self.batch_input, self.batch_target = next(self.dataloader) + except StopIteration: + # If no more data, set inputs and targets to None + self.batch_input, self.batch_target = None, None + return + + # Transfer data to GPU asynchronously using the created CUDA stream + with torch.cuda.stream(self.cuda_stream): + self.batch_input = self.batch_input.cuda(non_blocking=True) + self.batch_target = self.batch_target.cuda(non_blocking=True) + + def get_next_batch(self): + # Synchronize the current stream with the prefetching stream to ensure data is ready + torch.cuda.current_stream().wait_stream(self.cuda_stream) + + # Return the preloaded batch of input and target data + current_input, current_target = self.batch_input, self.batch_target + + # Preload the next batch in the background while the current batch is processed + self._load_next_batch() + + return current_input, current_target diff --git a/EdgeFLite/helpers/report_summary.py b/EdgeFLite/helpers/report_summary.py new file mode 100644 index 0000000..e186210 --- /dev/null +++ b/EdgeFLite/helpers/report_summary.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +__all__ = ['model_summary'] + +import torch +import torch.nn as nn +import numpy as np +import os +import json +from collections import OrderedDict + +# Format FLOPs value with appropriate unit (T, G, M, K) +def format_flops(flops): + units = [(1e12, 'T'), (1e9, 'G'), (1e6, 'M'), (1e3, 'K')] + for scale, suffix in units: + if flops >= scale: + return f"{flops / scale:.1f}{suffix}" + return f"{flops:.1f}" + +# Calculate the number of trainable or non-trainable parameters +def calculate_grad_params(param_count, param): + if param.requires_grad: + return param_count, 0 + else: + return 0, param_count + +# Compute FLOPs and parameters for a convolutional layer +def compute_conv_flops(layer, input, output): + oh, ow = output.shape[-2:] # Output height and width + kh, kw = layer.kernel_size # Kernel height and width + ic, oc = layer.in_channels, layer.out_channels # Input/output channels + groups = layer.groups # Number of groups for grouped convolution + + total_trainable = 0 + total_non_trainable = 0 + flops = 0 + + # Compute parameters and FLOPs for the weight + if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): + param_count = np.prod(layer.weight.shape) + trainable, non_trainable = calculate_grad_params(param_count, layer.weight) + total_trainable += trainable + total_non_trainable += non_trainable + flops += (2 * ic * kh * kw - 1) * oh * ow * (oc // groups) + + # Compute parameters and FLOPs for the bias + if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): + param_count = np.prod(layer.bias.shape) + trainable, non_trainable = calculate_grad_params(param_count, layer.bias) + total_trainable += trainable + total_non_trainable += non_trainable + flops += oh * ow * (oc // groups) + + return total_trainable, total_non_trainable, flops + +# Compute FLOPs and parameters for normalization layers (BatchNorm, GroupNorm) +def compute_norm_flops(layer, input, output): + total_trainable = 0 + total_non_trainable = 0 + + if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): + param_count = np.prod(layer.weight.shape) + trainable, non_trainable = calculate_grad_params(param_count, layer.weight) + total_trainable += trainable + total_non_trainable += non_trainable + + if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): + param_count = np.prod(layer.bias.shape) + trainable, non_trainable = calculate_grad_params(param_count, layer.bias) + total_trainable += trainable + total_non_trainable += non_trainable + + if hasattr(layer, 'running_mean'): + total_non_trainable += np.prod(layer.running_mean.shape) + + if hasattr(layer, 'running_var'): + total_non_trainable += np.prod(layer.running_var.shape) + + # FLOPs for normalization operations + flops = np.prod(input[0].shape) + if layer.affine: + flops *= 2 + + return total_trainable, total_non_trainable, flops + +# Compute FLOPs and parameters for linear (fully connected) layers +def compute_linear_flops(layer, input, output): + ic, oc = layer.in_features, layer.out_features # Input/output features + total_trainable = 0 + total_non_trainable = 0 + flops = 0 + + # Compute parameters and FLOPs for the weight + if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'): + param_count = np.prod(layer.weight.shape) + trainable, non_trainable = calculate_grad_params(param_count, layer.weight) + total_trainable += trainable + total_non_trainable += non_trainable + flops += (2 * ic - 1) * oc + + # Compute parameters and FLOPs for the bias + if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'): + param_count = np.prod(layer.bias.shape) + trainable, non_trainable = calculate_grad_params(param_count, layer.bias) + total_trainable += trainable + total_non_trainable += non_trainable + flops += oc + + return total_trainable, total_non_trainable, flops + +# Model summary function: calculates the total parameters and FLOPs for a model +@torch.no_grad() +def model_summary(model, input_data, target_data=None, is_coremodel=True, return_data=False): + model.eval() + + summary_info = OrderedDict() + hooks = [] + + # Hook function to register layer and compute its parameters/FLOPs + def register_layer_hook(layer): + def hook(layer, input, output): + layer_name = f"{layer.__class__.__name__}-{len(summary_info) + 1}" + summary_info[layer_name] = OrderedDict() + summary_info[layer_name]['input_shape'] = list(input[0].shape) + summary_info[layer_name]['output_shape'] = list(output.shape) if not isinstance(output, (list, tuple)) else [list(o.shape) for o in output] + + if isinstance(layer, nn.Conv2d): + trainable, non_trainable, flops = compute_conv_flops(layer, input, output) + elif isinstance(layer, (nn.BatchNorm2d, nn.GroupNorm)): + trainable, non_trainable, flops = compute_norm_flops(layer, input, output) + elif isinstance(layer, nn.Linear): + trainable, non_trainable, flops = compute_linear_flops(layer, input, output) + else: + trainable, non_trainable, flops = 0, 0, 0 + + summary_info[layer_name]['trainable_params'] = trainable + summary_info[layer_name]['non_trainable_params'] = non_trainable + summary_info[layer_name]['total_params'] = trainable + non_trainable + summary_info[layer_name]['flops'] = flops + + if not isinstance(layer, (nn.Sequential, nn.ModuleList, nn.Identity)): + hooks.append(layer.register_forward_hook(hook)) + + model.apply(register_layer_hook) + + if is_coremodel: + model(input_data, target=target_data, mode='summary') + else: + model(input_data) + + for hook in hooks: + hook.remove() + + total_params, trainable_params, total_flops = 0, 0, 0 + for layer_name, layer_info in summary_info.items(): + total_params += layer_info['total_params'] + trainable_params += layer_info['trainable_params'] + total_flops += layer_info['flops'] + + param_size_mb = total_params * 4 / (1024 ** 2) + print(f"Total parameters: {total_params:,} ({format_flops(total_params)})") + print(f"Trainable parameters: {trainable_params:,}") + print(f"Non-trainable parameters: {total_params - trainable_params:,}") + print(f"Total FLOPs: {total_flops:,} ({format_flops(total_flops)})") + print(f"Model size: {param_size_mb:.2f} MB") + + if return_data: + return total_params, total_flops + +# Example usage with a convolutional layer +if __name__ == '__main__': + conv_layer = nn.Conv2d(50, 10, 3, padding=1, groups=5, bias=True) + model_summary(conv_layer, torch.rand((1, 50, 10, 10)), target_data=torch.ones(1, dtype=torch.long), is_coremodel=False) + + for name, param in conv_layer.named_parameters(): + print(f"{name}: {param.size()}") + +# Save the model's summary details as a JSON file +def save_model_as_json(args, model_content): + """Save the model's details to a JSON file.""" + os.makedirs(args.model_dir, exist_ok=True) + filename = os.path.join(args.model_dir, f"model_{args.split_factor}.txt") + + with open(filename, 'w') as f: + f.write(str(model_content)) diff --git a/EdgeFLite/helpers/smoothing_labels.py b/EdgeFLite/helpers/smoothing_labels.py new file mode 100644 index 0000000..965f358 --- /dev/null +++ b/EdgeFLite/helpers/smoothing_labels.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Define the SmoothEntropyLoss class, which inherits from nn.Module +class SmoothEntropyLoss(nn.Module): + def __init__(self, smoothing=0.1, reduction='mean'): + # Initialize the parent class (nn.Module) and set the smoothing factor and reduction method + super(SmoothEntropyLoss, self).__init__() + self.smoothing = smoothing # Label smoothing factor + self.reduction_method = reduction # Reduction method to apply to the loss + + def forward(self, predictions, targets): + # Ensure that the batch sizes of predictions and targets match + if predictions.shape[0] != targets.shape[0]: + raise ValueError(f"Batch size of predictions ({predictions.shape[0]}) does not match targets ({targets.shape[0]}).") + + # Ensure that the predictions tensor has at least 2 dimensions (batch_size x num_classes) + if predictions.dim() < 2: + raise ValueError(f"Predictions should have at least 2 dimensions, got {predictions.dim()}.") + + # Get the number of classes from the last dimension of predictions (num_classes) + num_classes = predictions.size(-1) + + # Convert targets (class indices) to one-hot encoded format + target_one_hot = F.one_hot(targets, num_classes=num_classes).type_as(predictions) + + # Apply label smoothing: smooth the one-hot encoded targets by distributing some probability mass across all classes + smooth_targets = target_one_hot * (1.0 - self.smoothing) + (self.smoothing / num_classes) + + # Compute the log probabilities of predictions using softmax (log-softmax for numerical stability) + log_probabilities = F.log_softmax(predictions, dim=-1) + + # Compute the per-sample loss by multiplying log probabilities with the smoothed targets and summing across classes + loss_per_sample = -torch.sum(log_probabilities * smooth_targets, dim=-1) + + # Apply the specified reduction method to the computed loss + if self.reduction_method == 'none': + return loss_per_sample # Return the unreduced loss for each sample + elif self.reduction_method == 'sum': + return torch.sum(loss_per_sample) # Return the sum of the losses over all samples + elif self.reduction_method == 'mean': + return torch.mean(loss_per_sample) # Return the mean loss over all samples + else: + raise ValueError(f"Invalid reduction option: {self.reduction_method}. Expected 'none', 'sum', or 'mean'.") diff --git a/EdgeFLite/info_map.csv b/EdgeFLite/info_map.csv new file mode 100644 index 0000000..6ecea73 --- /dev/null +++ b/EdgeFLite/info_map.csv @@ -0,0 +1,68 @@ +import pandas as pd +import os +from glob import glob +from PIL import Image +import torch +from sklearn.model_selection import train_test_split +import pickle +from torch.utils.data import Dataset, DataLoader +from torch import nn +from torchvision import apply_transformations + +# Loading the info_mapdata for the skin_dataset dataset +info_mapdata = pd.read_csv('dataset_hub/skin_dataset/data/skin_info_map.csv') +print(info_mapdata.head()) + +# Mapping lesion abbreviations to their full names +lesion_labels = { + 'nv': 'Melanocytic nevi', + 'mel': 'Melanoma', + 'bkl': 'Benign keratosis-like lesions', + 'bcc': 'Basal cell carcinoma', + 'akiec': 'Actinic keratoses', + 'vasc': 'Vascular lesions', + 'df': 'Dermatofibroma' +} + +# Combine images from both dataset parts into one dictionary +image_paths = {os.path.splitext(os.path.basename(img))[0]: img + for img in glob(os.path.join("dataset_hub/skin_dataset/data", '*', '*.jpg'))} + +# Mapping the image paths and cell types to the DataFrame +info_mapdata['image_path'] = info_mapdata['image_id'].map(image_paths.get) +info_mapdata['cell_type'] = info_mapdata['dx'].map(lesion_labels.get) +info_mapdata['label'] = pd.Categorical(info_mapdata['cell_type']).workspaces + +# Display the count of each cell type and their enworkspaced labels +print(info_mapdata['cell_type'].value_counts()) +print(info_mapdata['label'].value_counts()) + +# Custom Dataset class for PyTorch +class SkinDataset(Dataset): + def __init__(self, dataframe, apply_transformation=None): + self.dataframe = dataframe + self.apply_transformation = apply_transformation + + def __len__(self): + return len(self.dataframe) + + def __getitem__(self, idx): + img = Image.open(self.dataframe.loc[idx, 'image_path']).resize((64, 64)) + label = torch.tensor(self.dataframe.loc[idx, 'label'], dtype=torch.long) + + if self.apply_transformation: + img = self.apply_transformation(img) + + return img, label + +# Splitting the data into train and test sets +train_data, test_data = train_test_split(info_mapdata, test_size=0.2, random_state=42) +train_data = train_data.reset_index(drop=True) +test_data = test_data.reset_index(drop=True) + +# Save the train and test data to pickle files +with open("skin_dataset_train.pkl", "wb") as train_file: + pickle.dump(train_data, train_file) + +with open("skin_dataset_test.pkl", "wb") as test_file: + pickle.dump(test_data, test_file) diff --git a/EdgeFLite/process_data.py b/EdgeFLite/process_data.py new file mode 100644 index 0000000..155e18a --- /dev/null +++ b/EdgeFLite/process_data.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import glob +import numpy as np + +# Define paths to the training and testing datasets +train_data_path = '/media/skydata/alpha0012/workspace/EdgeFLite/dataset_hub/pill/train_images' +test_data_path = '/media/skydata/alpha0012/workspace/EdgeFLite/dataset_hub/pill/test_images' + +def list_image_files_by_class(directory): + """ + Returns a list of image file paths and their corresponding class indices. + + Args: + directory (str): The path to the directory containing class folders. + + Returns: + list: A list of image file paths and their class indices. + """ + # Get the sorted list of class labels (folder names) + class_labels = sorted(os.listdir(directory)) + # Create a mapping from class names to indices + class_to_idx = {class_name: idx for idx, class_name in enumerate(class_labels)} + + image_dataset = [] # Initialize an empty list to store image data + + # Iterate through each class + for class_name in class_labels: + class_folder = os.path.join(directory, class_name) # Path to the class folder + # Find all JPG images in the class folder and its subfolders + image_files = glob.glob(os.path.join(class_folder, '**', '*.jpg'), recursive=True) + + # Append image file paths and their class indices to the dataset + for image_file in image_files: + image_dataset.append([image_file, class_to_idx[class_name]]) + + return image_dataset + +if __name__ == "__main__": + # Retrieve and print the number of files in the training and testing datasets + train_images = list_image_files_by_class(train_data_path) + test_images = list_image_files_by_class(test_data_path) + + print(f"Training dataset size: {len(train_images)}") # Output the size of the training dataset + print(f"Testing dataset size: {len(test_images)}") # Output the size of the testing dataset diff --git a/EdgeFLite/resnet_federated.py b/EdgeFLite/resnet_federated.py new file mode 100644 index 0000000..4459937 --- /dev/null +++ b/EdgeFLite/resnet_federated.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import argparse +import os +import torch +from dataset import factory +from params import train_params +from fedml_service.data_cleaning.cifar10.data_loader import load_partition_data_cifar10 +from fedml_service.data_cleaning.cifar100.data_loader import load_partition_data_cifar100 +from fedml_service.data_cleaning.skin_dataset.data_loader import load_partition_data_skin_dataset +from fedml_service.data_cleaning.pillbase.data_loader import load_partition_data_pillbase +from fedml_service.model.cv.resnet_gkt.resnet import wide_resnet16_8_gkt, wide_resnet_model_50_2_gkt, resnet110_gkt +from fedml_service.decentralized.fedgkt.GKTTrainer import GKTTrainer +from fedml_service.decentralized.fedgkt.GKTServerTrainer import GKTServerTrainer +from params.train_params import save_hp_to_json +from config import HOME +from tensorboardX import SummaryWriter + +# Set CUDA device to be used for training +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +device = torch.device("cuda:0") + +# Initialize TensorBoard writers for logging +def initialize_writers(args): + log_dir = os.path.join(args.model_dir, 'val') # Create a log directory inside the model directory + return SummaryWriter(log_dir=log_dir) # Initialize SummaryWriter for TensorBoard logging + +# Initialize dataset and data loaders +def initialize_dataset(args, data_split_factor): + # Fetch training data and sampler based on various input parameters + train_data_local_dict, train_sampler = factory.obtain_data_loader( + args.data, + split_factor=data_split_factor, + batch_size=args.batch_size, + crop_size=args.crop_size, + dataset=args.dataset, + split="train", # Split data for training + is_decentralized=args.is_decentralized, + is_autoaugment=args.is_autoaugment, + randaa=args.randaa, + is_cutout=args.is_cutout, + erase_p=args.erase_p, + num_workers=args.workers, + is_fed=args.is_fed, + num_clusters=args.num_clusters, + cifar10_non_iid=args.cifar10_non_iid, + cifar100_non_iid=args.cifar100_non_iid + ) + + # Fetch global test data + test_data_global = factory.obtain_data_loader( + args.data, + batch_size=args.eval_batch_size, + crop_size=args.crop_size, + dataset=args.dataset, + split="val", # Split data for validation + num_workers=args.workers, + cifar10_non_iid=args.cifar10_non_iid, + cifar100_non_iid=args.cifar100_non_iid + ) + return train_data_local_dict, test_data_global # Return both train and test data loaders + +# Setup models based on the dataset +def setup_models(args): + if args.dataset == "cifar10": + return load_partition_data_cifar10, wide_resnet16_8_gkt() # Model for CIFAR-10 + elif args.dataset == "cifar100": + return load_partition_data_cifar100, resnet110_gkt() # Model for CIFAR-100 + elif args.dataset == "skin_dataset": + return load_partition_data_skin_dataset, wide_resnet_model_50_2_gkt() # Model for skin dataset + elif args.dataset == "pill_base": + return load_partition_data_pillbase, wide_resnet_model_50_2_gkt() # Model for pill base dataset + else: + raise ValueError(f"Unsupported dataset: {args.dataset}") # Raise error for unsupported dataset + +# Initialize trainers for each client in the federated learning setup +def initialize_trainers(client_number, device, model_client, args, train_data_local_dict, test_data_local_dict): + client_trainers = [] + # Initialize a trainer for each client + for i in range(client_number): + trainer = GKTTrainer( + client_idx=i, + train_data_local_dict=train_data_local_dict, + test_data_local_dict=test_data_local_dict, + device=device, + model_client=model_client, + args=args + ) + client_trainers.append(trainer) # Add client trainer to the list + return client_trainers + +# Main function to initialize and run the federated learning process +def main(args): + args.model_dir = os.path.join(str(HOME), "models/coremodel", str(args.spid)) # Set model directory based on home directory and spid + + # Save hyperparameters if not in summary or evaluation mode + if not args.is_summary and not args.evaluate: + save_hp_to_json(args) + + # Initialize the TensorBoard writer for logging + val_writer = initialize_writers(args) + data_split_factor = args.loop_factor if args.is_diff_data_train else 1 # Set data split factor based on training mode + args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized # Check if decentralized learning is needed + + print(f"INFO: PyTorch: => The number of views of train data is '{data_split_factor}'") + + # Load dataset and initialize data loaders + train_data_local_dict, test_data_global = initialize_dataset(args, data_split_factor) + + # Setup models for the clients and server + data_loader, (model_client, model_server) = setup_models(args) + client_number = args.num_clusters * args.split_factor # Calculate the number of clients + + # Load data for federated learning + train_data_num, test_data_num, train_data_global, _, _, _, test_data_local_dict, class_num = data_loader( + args.dataset, args.data, 'homo', 0.5, client_number, args.batch_size + ) + + dataset_info = [train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_dict, test_data_local_dict, class_num] + + print("Server and clients initialized.") + round_idx = 0 # Initialize the training round index + + # Initialize client trainers and server trainer + client_trainers = initialize_trainers(client_number, device, model_client, args, train_data_local_dict, test_data_local_dict) + server_trainer = GKTServerTrainer(client_number, device, model_server, args, val_writer) + + # Start federated training rounds + for current_round in range(args.num_rounds): + # For each client, perform local training and send results to the server + for client_idx in range(client_number): + extracted_features, logits, labels, test_features, test_labels = client_trainers[client_idx].train() + print(f"Client {client_idx} finished training.") + server_trainer.add_local_trained_result(client_idx, extracted_features, logits, labels, test_features, test_labels) + + # Check if the server has received all clients' results + if server_trainer.check_whether_all_receive(): + print("All clients' results received by server.") + server_trainer.train(round_idx) # Server performs training using the aggregated results + round_idx += 1 + + # Send global model updates back to clients + for client_idx in range(client_number): + global_logits = server_trainer.get_global_logits(client_idx) + client_trainers[client_idx].update_large_model_logits(global_logits) + print("Server sent updated logits back to clients.") + +# Entry point of the script +if __name__ == "__main__": + parser = argparse.ArgumentParser() + args = train_params.add_parser_params(parser) + + # Ensure that federated learning mode is enabled + assert args.is_fed == 1, "Federated learning requires 'args.is_fed' to be set to 1." + + # Create the model directory if it does not exist + os.makedirs(args.model_dir, exist_ok=True) + + print(args) # Print the parsed arguments for verification + main(args) # Start the main process diff --git a/EdgeFLite/run_federated.py b/EdgeFLite/run_federated.py new file mode 100644 index 0000000..7e83319 --- /dev/null +++ b/EdgeFLite/run_federated.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn +import torch.decentralized as dist +import torch.multiprocessing as mp +import torch.cuda.amp as amp +from torch.backends import cudnn +from tensorboardX import SummaryWriter +import warnings +import argparse +import os +import numpy as np +from tqdm import tqdm +from dataset import factory +from model import coremodel +from utils import metric, label_smoothing, lr_scheduler, prefetch +from params.train_params import save_hp_to_json +from params import train_params + +# Global variable to track the best accuracy +best_accuracy = 0 + +def calculate_average(values): + """Calculate the average of a list of values""" + return sum(values) / len(values) + +def initialize_processes(rank, world_size, args): + """ + Initialize decentralized processes. + This function is used to set up distributed training across multiple GPUs. + """ + ngpus = torch.cuda.device_count() + args.ngpus = ngpus + args.is_decentralized = world_size > 1 + + if args.multiprocessing_decentralized: + # If running decentralized with multiple GPUs, spawn processes for each GPU + mp.spawn(train_single_worker, nprocs=ngpus, args=(ngpus, args)) + else: + print(f"INFO:PyTorch: Using {ngpus} GPUs") + # If single GPU, start the training worker directly + train_single_worker(args.gpu, ngpus, args) + +def client_training_step(args, current_round, model, optimizer, scheduler, dataloader, epochs=5, scaler=None): + """ + Perform training for a single client model in the federated learning setup. + This method will train the model for a given number of epochs. + """ + model.train() # Set model to training mode + for epoch in range(epochs): + # Prefetch data to improve efficiency + prefetcher = prefetch.data_prefetcher(dataloader) + images, targets = prefetcher.next() + step = 0 + + while images is not None: + # Update the learning rate using the scheduler + scheduler(optimizer, step, current_round) + optimizer.zero_grad() # Clear the gradients + + # Enable mixed precision training to optimize memory and computation speed + with amp.autocast(enabled=args.is_amp): + outputs, ce_loss, cot_loss = model(images, target=targets, mode='train') + + # Combine losses and normalize by accumulation steps + loss = (ce_loss + cot_loss) / args.accumulation_steps + loss.backward() # Backpropagate the gradients + + # Perform optimization step after enough accumulation + if step % args.accumulation_steps == 0: + optimizer.step() + optimizer.zero_grad() # Clear gradients after the step + + images, targets = prefetcher.next() # Get the next batch of images and targets + step += 1 + + return loss.item() # Return the final loss value + +def combine_model_parameters(global_model, client_models): + """ + Aggregate the weights of multiple client models to update the global model. + This is the core of the Federated Averaging (FedAvg) algorithm. + """ + global_state = global_model.state_dict() + for key in global_state.keys(): + # Average the weights of the corresponding layers from all client models + global_state[key] = torch.stack([client.state_dict()[key].float() for client in client_models], dim=0).mean(dim=0) + + # Load the averaged weights into the global model + global_model.load_state_dict(global_state) + # Update the client models with the new global model weights + for client in client_models: + client.load_state_dict(global_model.state_dict()) + +def validate_model(validation_loader, model, args): + """ + Perform model validation on the validation dataset. + Calculate and return the average accuracy across the dataset. + """ + model.eval() # Set the model to evaluation mode + accuracy_values = [] + + with torch.no_grad(): + for images, targets in validation_loader: + if args.gpu is not None: + images, targets = images.cuda(args.gpu), targets.cuda(args.gpu) + + # Use mixed precision for inference + with amp.autocast(enabled=args.is_amp): + ensemble_output, outputs, ce_loss = model(images, target=targets, mode='val') + + # Calculate the top-1 accuracy for the current batch + avg_acc1 = metric.accuracy(ensemble_output, targets, topk=(1,)) + accuracy_values.append(avg_acc1) + + return calculate_average(accuracy_values) # Return the average accuracy + +def train_single_worker(gpu, ngpus, args): + """ + Training worker function that runs on a single GPU. + This function handles the entire federated learning workflow for the assigned GPU. + """ + global best_accuracy + args.gpu = gpu + cudnn.performance_test = True # Enable performance optimization for CuDNN + + # Optionally, resume from a checkpoint if provided + if args.resume: + checkpoint = torch.load(args.resume) + args.start_round = checkpoint['round'] + best_accuracy = checkpoint['best_acc1'] + + # Initialize global and client models + model = coremodel.coremodel(args).cuda() + client_models = [coremodel.coremodel(args).cuda() for _ in range(args.num_clients)] + optimizers = [torch.optim.SGD(client.parameters(), lr=args.lr) for client in client_models] + + # Training and validation loop + for round_num in range(args.start_round, args.num_rounds): + # Perform training for each client model + for client_num in range(args.num_clients): + client_training_step(args, round_num, client_models[client_num], optimizers[client_num], lr_scheduler, args.train_loader) + + # Aggregate client models to update the global model + combine_model_parameters(model, client_models) + + # Validate the updated global model and track the best accuracy + validation_accuracy = validate_model(args.val_loader, model, args) + best_accuracy = max(best_accuracy, validation_accuracy) + print(f"Round {round_num}: Best Accuracy: {best_accuracy:.2f}") + +if __name__ == "__main__": + # Parse command-line arguments + parser = argparse.ArgumentParser(description='FedAvg decentralized Training') + args = train_params.add_parser_params(parser) + initialize_processes(0, args.world_size, args) # Initialize distributed training diff --git a/EdgeFLite/run_local.py b/EdgeFLite/run_local.py new file mode 100644 index 0000000..460173c --- /dev/null +++ b/EdgeFLite/run_local.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import argparse +import warnings +import setproctitle +from torch import nn, decentralized # Used for decentralized training +from torch.backends import cudnn # Optimizes performance for convolutional networks +from tensorboardX import SummaryWriter # For logging metrics and results to TensorBoard +import torch.cuda.amp as amp # For mixed precision training +from config import * # Custom configuration module +from params import train_params # Training parameters +from utils import label_smoothing, norm, summary, metric, lr_scheduler, prefetch # Utility functions +from model import coremodel # Core model implementation +from dataset import factory # Dataset and data loader factory +from params.train_params import save_hp_to_json # Function to save hyperparameters to JSON + +# Global variable to store the best accuracy obtained during training +best_acc1 = 0 + +def main(args): + # Warn if a specific GPU is chosen, as this will disable data parallelism + if args.gpu is not None: + warnings.warn("Selecting a specific GPU will disable data parallelism.") + + # Adjust loop factor based on specific training configurations + args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor + # Check if decentralized training is needed + args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized + + # Get the number of available GPUs on the machine + num_gpus = torch.cuda.device_count() + args.ngpus_per_node = num_gpus + print(f"INFO:PyTorch: GPUs available on this node: {num_gpus}") + + # If multiprocessing is needed for decentralized training + if args.multiprocessing_decentralized: + # Adjust world size to account for multiple GPUs + args.world_size *= num_gpus + # Spawn multiple processes for each GPU + torch.multiprocessing.spawn(execute_worker_process, nprocs=num_gpus, args=(num_gpus, args)) + else: + # If using a single GPU + print("INFO:PyTorch: Using GPU 0 for single GPU training") + args.gpu = 0 + # Call main worker for single GPU + execute_worker_process(args.gpu, num_gpus, args) + +def execute_worker_process(gpu, num_gpus, args): + global best_acc1 + args.gpu = gpu + # Set the directory where models will be saved + args.model_dir = os.path.join(HOME, "models", "coremodel", str(args.spid)) + + # Initialize the decentralized training process group if needed + if args.is_decentralized: + print("INFO:PyTorch: Initializing process group for decentralized training.") + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_decentralized: + args.rank = args.rank * num_gpus + gpu + decentralized.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) + + # Set the GPU to be used for training or evaluation + if args.gpu is not None: + print(f"INFO:PyTorch: GPU {args.gpu} in use for training (Rank: {args.rank})" if not args.evaluate else f"INFO:PyTorch: GPU {args.gpu} in use for evaluation (Rank: {args.rank})") + + # Set process title for better identification in system process monitors + setproctitle.setproctitle(f"{args.proc_name}centralized_rank{args.rank}") + + # Initialize a SummaryWriter for TensorBoard logging + val_writer = SummaryWriter(log_dir=os.path.join(args.model_dir, 'val')) + + # Use label smoothing if enabled, otherwise use standard cross-entropy loss + criterion = label_smoothing.label_smoothing_CE(reduction='mean') if args.is_label_smoothing else nn.CrossEntropyLoss() + + # Instantiate the model + model = coremodel.coremodel(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) + print(f"INFO:PyTorch: Model '{args.arch}' has {metric.get_the_number_of_params(model)} parameters") + + # If summary is requested, print model and exit + if args.is_summary: + print(model) + return + + # Save model configuration and hyperparameters + summary.save_model_to_json(args, model) + + # Convert BatchNorm layers to synchronized BatchNorm for decentralized training + if args.is_decentralized and args.world_size > 1 and args.is_syncbn: + print("INFO:PyTorch: Converting BatchNorm to SyncBatchNorm") + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + + # Set up the model for GPU-based training + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + args.batch_size = int(args.batch_size / num_gpus) # Adjust batch size for multiple GPUs + args.workers = int((args.workers + num_gpus - 1) / num_gpus) # Adjust number of workers + # Use decentralized data parallel model + model = nn.parallel.decentralizedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) + else: + # Use standard DataParallel for multi-GPU training + model = nn.DataParallel(model).cuda() + + # Create the optimizer + optimizer = create_optimizer(args, model) + # Set up the gradient scaler for mixed precision training, if enabled + scaler = amp.GradScaler() if args.is_amp else None + + # If resuming from a checkpoint, load model and optimizer state + if args.resume: + load_checkpoint(args, model, optimizer, scaler) + + cudnn.performance_test = True # Enable cuDNN performance optimizations + + # Set up data loader parameters + data_loader_params = { + 'split_factor': args.loop_factor if args.is_diff_data_train else 1, + 'batch_size': args.batch_size, + 'crop_size': args.crop_size, + 'dataset': args.dataset, + 'is_decentralized': args.is_decentralized, + 'num_workers': args.workers, + 'randaa': args.randaa, + 'is_autoaugment': args.is_autoaugment, + 'is_cutout': args.is_cutout, + 'erase_p': args.erase_p, + } + + # Get the training and validation data loaders + train_loader, train_sampler = factory.obtain_data_loader(args.data, split="train", **data_loader_params) + val_loader = factory.obtain_data_loader(args.data, split="val", batch_size=args.eval_batch_size, crop_size=args.crop_size, num_workers=args.workers) + + # Set up the learning rate scheduler + scheduler = lr_scheduler.create_scheduler(args, len(train_loader)) + + # If evaluating, run the validation function and exit + if args.evaluate: + validate(val_loader, model, args) + return + + # Begin training and evaluation + train_and_evaluate(train_loader, val_loader, train_sampler, model, optimizer, scheduler, scaler, val_writer, args, num_gpus) + +# Function to create the optimizer +def create_optimizer(args, model): + param_groups = model.parameters() if args.is_wd_all else lr_scheduler.get_parameter_groups(model) + # Select the optimizer based on input arguments + if args.optimizer == 'SGD': + return torch.optim.SGD(param_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.is_nesterov) + elif args.optimizer == 'AdamW': + return torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-4, weight_decay=args.weight_decay) + elif args.optimizer == 'RMSprop': + return torch.optim.RMSprop(param_groups, lr=args.lr, alpha=0.9, momentum=0.9, weight_decay=args.weight_decay) + else: + # Raise error if unsupported optimizer is selected + raise NotImplementedError(f"Optimizer {args.optimizer} not implemented") + +# Function to load a checkpoint and resume training +def load_checkpoint(args, model, optimizer, scaler): + if os.path.isfile(args.resume): + print(f"INFO:PyTorch: Loading checkpoint from '{args.resume}'") + loc = f'cuda:{args.gpu}' if args.gpu is not None else None + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + global best_acc1 + best_acc1 = checkpoint['best_acc1'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + if "scaler" in checkpoint: + scaler.load_state_dict(checkpoint['scaler']) + print(f"INFO:PyTorch: Checkpoint loaded, epoch {checkpoint['epoch']}") + else: + print(f"INFO:PyTorch: No checkpoint found at '{args.resume}'") + +# Function to train and evaluate the model over multiple epochs +def train_and_evaluate(train_loader, val_loader, train_sampler, model, optimizer, scheduler, scaler, val_writer, args, num_gpus): + for epoch in range(args.start_epoch, args.epochs + 1): + if args.is_decentralized: + train_sampler.set_epoch(epoch) + + train_one_epoch(train_loader, model, optimizer, scheduler, epoch, scaler, val_writer, args) + + # Evaluate the model every 'eval_per_epoch' epochs + if (epoch + 1) % args.eval_per_epoch == 0: + acc_all = validate(val_loader, model, args) + global best_acc1 + is_best = acc_all[0] > best_acc1 # Track the best accuracy + best_acc1 = max(acc_all[0], best_acc1) + # Save the model checkpoint + save_checkpoint(model, optimizer, scaler, epoch, best_acc1, args, is_best) + +# Function to perform one training epoch +def train_one_epoch(train_loader, model, optimizer, scheduler, epoch, scaler, val_writer, args): + metric_storage = create_metric_storage(args.loop_factor) + model.train() # Set the model to training mode + data_loader = prefetch.data_prefetcher(train_loader) # Use data prefetching to improve efficiency + images, target = data_loader.next() + + optimizer.zero_grad() # Reset gradients + while images is not None: + # Adjust the learning rate based on the scheduler + scheduler(optimizer, epoch) + + # Perform forward pass with mixed precision if enabled + if args.is_amp: + with amp.autocast(): + ensemble_output, outputs, ce_loss, cot_loss = model(images, target=target, mode='train', epoch=epoch) + else: + ensemble_output, outputs, ce_loss, cot_loss = model(images, target=target, mode='train', epoch=epoch) + + # Calculate total loss and normalize + total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate + val_writer.add_scalar('average_training_loss', total_loss, global_step=epoch) + + # Perform backward pass and update gradients with mixed precision if enabled + if args.is_amp: + scaler.scale(total_loss).backward() + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + optimizer.step() + + images, target = data_loader.next() # Fetch the next batch of data + +# Function to save the model checkpoint +def save_checkpoint(model, optimizer, scaler, epoch, best_acc1, args, is_best): + ckpt = { + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'best_acc1': best_acc1, + 'optimizer': optimizer.state_dict(), + } + if args.is_amp: + ckpt['scaler'] = scaler.state_dict() + metric.save_checkpoint(ckpt, is_best, args.model_dir, filename=f"checkpoint_{epoch}.pth.tar") + +# Function to validate the model on the validation dataset +def validate(val_loader, model, args): + metric_storage = create_metric_storage(args.loop_factor) + model.eval() # Set the model to evaluation mode + + with torch.no_grad(): + for i, (images, target) in enumerate(val_loader): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # Perform forward pass with mixed precision if enabled + if args.is_amp: + with amp.autocast(): + ensemble_output, outputs, ce_loss = model(images, target=target, mode='val') + else: + ensemble_output, outputs, ce_loss = model(images, target=target, mode='val') + + batch_size = images.size(0) + acc1, acc5 = metric.accuracy(ensemble_output, target, topk=(1, 5)) + + metric_storage.update(acc1, acc5, ce_loss, batch_size) + + return metric_storage.results() + +# Helper function to create a storage for metrics during training and validation +def create_metric_storage(loop_factor): + # Initialize metrics for accuracy and other performance metrics + top1_all = [metric.AverageMeter(f'Acc@1_{i}', ':6.2f') for i in range(loop_factor)] + avg_top1 = metric.AverageMeter('Avg_Acc@1', ':6.2f') + return metric.ProgressMeter(len(top1_all), top1_all, avg_top1) + +# Main entry point for the script +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Centralized Training') + args = train_params.add_parser_params(parser) # Add parameters to the argument parser + assert args.is_fed == 0, "Centralized training requires args.is_fed to be False" + os.makedirs(args.model_dir, exist_ok=True) # Create model directory if it doesn't exist + main(args) # Call the main function diff --git a/EdgeFLite/run_prox.py b/EdgeFLite/run_prox.py new file mode 100644 index 0000000..e050816 --- /dev/null +++ b/EdgeFLite/run_prox.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import os +import warnings +import torch +import torch.cuda.amp as autocast +from torch import nn +from torch.backends import cudnn +from tensorboardX import SummaryWriter +from config import * +from params import train_settings +from utils import label_smooth, metrics, scheduler, prefetch_loader +from model import net_splitter +from dataset import data_factory +import numpy as np +from tqdm import tqdm +from params.train_settings import save_hyperparams_to_json + +# Set the visible GPU to use for training +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# Variable to store the best accuracy achieved during training +best_accuracy = 0 + + +# Helper function to compute the average of a list +def compute_average(lst): + return sum(lst) / len(lst) + + +# Main function to initialize the training process +def main(args): + if args.gpu_index is not None: + # Warn if a specific GPU is selected, disabling data parallelism + warnings.warn("Specific GPU chosen, disabling data parallelism.") + + # Adjust loop factor based on training setup + args.loop_factor = 1 if args.separate_training or args.single_branch else args.split_factor + # Determine if decentralized training is required + args.decentralized_training = args.world_size > 1 or args.multiprocessing_decentralized + num_gpus = torch.cuda.device_count() + args.num_gpus = num_gpus + + # If decentralized multiprocessing is enabled, spawn multiple processes + if args.multiprocessing_decentralized: + args.world_size = num_gpus * args.world_size + torch.multiprocessing.spawn(worker_process, nprocs=num_gpus, args=(num_gpus, args)) + else: + # Otherwise, proceed with single-GPU training + print(f"INFO:PyTorch: Detected {num_gpus} GPU(s) available.") + args.gpu_index = 0 + worker_process(args.gpu_index, num_gpus, args) + + +# Client-side training function for federated learning updates +def client_train_update(args, round_num, client_model, global_model, sched, opt, train_loader, epochs=5, scaler=None): + client_model.train() + + for epoch in range(epochs): + # Prefetch data for training + loader = prefetch_loader.DataPrefetcher(train_loader) + images, targets = loader.next() + batch_idx = 0 + opt.zero_grad() + + while images is not None: + # Apply learning rate scheduling + sched(opt, batch_idx) + + # Use automatic mixed precision if enabled + if args.amp_enabled: + with autocast.autocast(): + ensemble_out, model_outputs, loss_ce, loss_cot = client_model(images, targets=targets, mode='train', + epoch=epoch) + else: + ensemble_out, model_outputs, loss_ce, loss_cot = client_model(images, targets=targets, mode='train', + epoch=epoch) + + # Compute accuracy for top-1 predictions + batch_size = images.size(0) + for j in range(args.loop_factor): + top1_acc = metrics.accuracy(model_outputs[j], targets, topk=(1,)) + + # Compute the proximal term for FedProx loss + prox_term = sum((param - global_param).norm(2) for param, global_param in + zip(client_model.parameters(), global_model.parameters())) + # Compute the total loss (cross-entropy + contrastive loss + proximal term) + total_loss = (loss_ce + loss_cot) / args.accum_steps + (args.mu / 2) * prox_term + + # Backward pass with mixed precision scaling if enabled + if args.amp_enabled: + scaler.scale(total_loss).backward() + if (batch_idx % args.accum_steps == 0) or (batch_idx == len(train_loader)): + scaler.step(opt) + scaler.update() + opt.zero_grad() + else: + total_loss.backward() + if (batch_idx % args.accum_steps == 0) or (batch_idx == len(train_loader)): + opt.step() + opt.zero_grad() + + images, targets = loader.next() + + return total_loss.item() + + +# Function to aggregate model weights from clients on the server +def server_compute_average_weights(global_model, client_models): + global_state_dict = global_model.state_dict() + # Average weights across all clients + for key in global_state_dict.keys(): + global_state_dict[key] = torch.stack( + [client_models[i].state_dict()[key].float() for i in range(len(client_models))], 0).mean(0) + global_model.load_state_dict(global_state_dict) + + # Update clients with the averaged global model + for model in client_models: + model.load_state_dict(global_model.state_dict()) + + +# Function to validate the model on the validation set +def validate_model(val_loader, model, args): + model.eval() + acc1_list, acc5_list, loss_ce_list = [], [], [] + + # Perform validation without gradient calculation + with torch.no_grad(): + for images, targets in val_loader: + if args.gpu_index is not None: + images, targets = images.cuda(args.gpu_index, non_blocking=True), targets.cuda(args.gpu_index, + non_blocking=True) + + if args.amp_enabled: + with autocast.autocast(): + ensemble_out, model_outputs, loss_ce = model(images, target=targets, mode='val') + else: + ensemble_out, model_outputs, loss_ce = model(images, target=targets, mode='val') + + for j in range(args.loop_factor): + acc1, acc5 = metrics.accuracy(model_outputs[j], targets, topk=(1, 5)) + + avg_acc1, avg_acc5 = metrics.accuracy(ensemble_out, targets, topk=(1, 5)) + acc1_list.append(avg_acc1) + acc5_list.append(avg_acc5) + loss_ce_list.append(loss_ce) + + return compute_average(loss_ce_list), compute_average(acc1_list) + + +# Function to handle the worker process for training on a specific GPU +def worker_process(gpu_index, num_gpus, args): + global best_accuracy + args.gpu_index = gpu_index + args.model_path = os.path.join(HOME, "models", "coremodel", str(args.model_id)) + + # Create summary writer for validation if not using decentralized training + if not args.decentralized_training or (args.multiprocessing_decentralized and args.rank % num_gpus == 0): + val_summary_writer = SummaryWriter(log_dir=os.path.join(args.model_path, 'validation')) + + # Set the loss function based on the label smoothing option + criterion = label_smooth.smooth_ce_loss(reduction='mean') if args.use_label_smooth else nn.CrossEntropyLoss() + # Initialize the global model and client models + global_model = net_splitter.coremodel(args, normalization=args.norm_mode, loss_function=criterion) + client_models = [net_splitter.coremodel(args, normalization=args.norm_mode, loss_function=criterion) for _ in + range(args.num_clients)] + + # Save hyperparameters to JSON if required + if args.save_summary: + save_hyperparams_to_json(args) + return + + # Move models to GPU + global_model = global_model.cuda() + for model in client_models: + model.cuda() + model.load_state_dict(global_model.state_dict()) + + # Create optimizers for each client + opt_list = [torch.optim.SGD(client.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, + nesterov=args.use_nesterov) for client in client_models] + + # Initialize gradient scaler if AMP is enabled + scaler = torch.cuda.amp.GradScaler() if args.amp_enabled else None + cudnn.performance_test = True + + # Resume training from checkpoint if specified + if args.resume_training: + if os.path.isfile(args.resume_checkpoint): + checkpoint = torch.load(args.resume_checkpoint, + map_location=f'cuda:{args.gpu_index}' if args.gpu_index else None) + args.start_round = checkpoint['round'] + best_accuracy = checkpoint['best_acc1'] + global_model.load_state_dict(checkpoint['state_dict']) + for opt in opt_list: + opt.load_state_dict(checkpoint['optimizer']) + if "scaler" in checkpoint: + scaler.load_state_dict(checkpoint['scaler']) + for client_model in client_models: + client_model.load_state_dict(global_model.state_dict()) + else: + args.start_round = 0 + else: + args.start_round = 0 + + # Load training and validation data + train_loader, _ = data_factory.load_data(args.data_dir, args.batch_size, args.split_factor, + dataset_name=args.dataset_name, split="train", + num_workers=args.num_workers, decentralized=args.decentralized_training) + val_loader = data_factory.load_data(args.data_dir, args.eval_batch_size, args.split_factor, + dataset_name=args.dataset_name, split="val", num_workers=args.num_workers) + + # Federated learning rounds + for round_num in range(args.start_round, args.num_rounds + 1): + if args.fixed_cluster: + # Select clients from fixed clusters for each round + selected_clusters = np.random.permutation(args.num_clusters)[:args.num_clients] + for i in tqdm(range(args.num_clients)): + selected_clients = np.arange(start=selected_clusters[i] * args.split_factor, + stop=(selected_clusters[i] + 1) * args.split_factor) + for client in selected_clients: + loss = client_train diff --git a/EdgeFLite/run_splitfed.py b/EdgeFLite/run_splitfed.py new file mode 100644 index 0000000..083c316 --- /dev/null +++ b/EdgeFLite/run_splitfed.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn +import numpy as np +import argparse +import warnings +from tqdm import tqdm +from tensorboardX import SummaryWriter +from dataset import factory +from config import * +from model import coremodelsl +from utils import label_smoothing, norm, metric, lr_scheduler, prefetch +from params import train_params +from params.train_params import save_hp_to_json + +# Set the visible GPU devices for the training +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# Global best accuracy to track the performance +best_acc1 = 0 # Global best accuracy + +def average(values): + """Calculate the average of a list of values.""" + return sum(values) / len(values) + +def combine_model_weights(global_model_client, global_model_server, client_models, server_models): + """ + Aggregate weights from client and server models using the mean method. + This function updates the global model weights by averaging the weights + from all client and server models. + """ + # Get the state dictionaries (weights) for both client and server models + client_state_dict = global_model_client.state_dict() + server_state_dict = global_model_server.state_dict() + + # Average the weights across all client models + for key in client_state_dict.keys(): + client_state_dict[key] = torch.stack([model.state_dict()[key].float() for model in client_models], dim=0).mean(0) + global_model_client.load_state_dict(client_state_dict) + + # Average the weights across all server models + for key in server_state_dict.keys(): + server_state_dict[key] = torch.stack([model.state_dict()[key].float() for model in server_models], dim=0).mean(0) + global_model_server.load_state_dict(server_state_dict) + + # Load the updated global model weights back into the client models + for model in client_models: + model.load_state_dict(global_model_client.state_dict()) + + # Load the updated global model weights back into the server models + for model in server_models: + model.load_state_dict(global_model_server.state_dict()) + +def client_training(args, round_num, client_model, server_model, scheduler_client, scheduler_server, optimizer_client, optimizer_server, data_loader, epochs=5, streams=None): + """ + Perform client-side model training for the given number of epochs. + The client model performs the forward pass and sends intermediate outputs + to the server model for further computation. + """ + client_model.train() + server_model.train() + + for epoch in range(epochs): + # Prefetch data to improve data loading speed + prefetcher = prefetch.data_prefetcher(data_loader) + images, target = prefetcher.next() + i = 0 + optimizer_client.zero_grad() + optimizer_server.zero_grad() + + while images is not None: + # Adjust learning rates using the schedulers + scheduler_client(optimizer_client, i, round_num) + scheduler_server(optimizer_server, i, round_num) + i += 1 + + # Forward pass on the client model + outputs_client, y_a, y_b, lam = client_model(images, target=target, mode='train', epoch=epoch, streams=streams) + client_fx = [outputs.clone().detach().requires_grad_(True) for outputs in outputs_client] + + # Forward pass on the server model and compute losses + ensemble_output, outputs_server, ce_loss, cot_loss = server_model(client_fx, y_a, y_b, lam, target=target, mode='train', epoch=epoch, streams=streams) + total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate + total_loss.backward() + + # Backpropagate the gradients to the client model + for fx, grad in zip(outputs_client, client_fx): + fx.backward(grad.grad) + + # Perform optimization step when the accumulation condition is met + if i % args.iters_to_accumulate == 0 or i == len(data_loader): + optimizer_client.step() + optimizer_server.step() + optimizer_client.zero_grad() + optimizer_server.zero_grad() + + # Fetch the next batch of data + images, target = prefetcher.next() + + return total_loss.item() + +def validate_model(val_loader, client_model, server_model, args, streams=None): + """ + Validate the performance of client and server models. + This function performs forward passes without updating the model weights + and computes validation accuracy and loss. + """ + client_model.eval() + server_model.eval() + + acc1_list, acc5_list, ce_loss_list = [], [], [] + + with torch.no_grad(): + for i, (images, target) in enumerate(val_loader): + if args.gpu is not None: + images = images.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # Forward pass on the client model + outputs_client = client_model(images, target=target, mode='val') + client_fx = [output.clone().detach().requires_grad_(True) for output in outputs_client] + + # Forward pass on the server model + ensemble_output, outputs_server, ce_loss = server_model(client_fx, target=target, mode='val') + + # Calculate accuracy and losses + acc1, acc5 = metric.accuracy(ensemble_output, target, topk=(1, 5)) + acc1_list.append(acc1) + acc5_list.append(acc5) + ce_loss_list.append(ce_loss) + + # Calculate average accuracy and loss over the validation dataset + avg_acc1 = average(acc1_list) + avg_acc5 = average(acc5_list) + avg_ce_loss = average(ce_loss_list) + + return avg_ce_loss, avg_acc1, avg_acc5 + +def main(args): + """ + The main entry point for the federated learning process. + Initializes models, handles multiprocessing setup, and starts training. + """ + if args.gpu is not None: + warnings.warn("A specific GPU has been chosen. Data parallelism is disabled.") + + # Set loop factor based on training configuration + args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor + ngpus_per_node = torch.cuda.device_count() + args.ngpus_per_node = ngpus_per_node + + if args.multiprocessing_decentralized: + # Spawn a process for each GPU in decentralized setup + args.world_size = ngpus_per_node * args.world_size + torch.multiprocessing.spawn(execute_worker_process, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Use only a single GPU in non-decentralized setup + args.gpu = 0 + execute_worker_process(args.gpu, ngpus_per_node, args) + +def execute_worker_process(gpu, ngpus_per_node, args): + """ + Worker function that handles model initialization, training, and validation. + """ + global best_acc1 + args.gpu = gpu + + if args.gpu is not None: + print(f"Using GPU {args.gpu} for training.") + + # Create tensorboard writer for logging validation metrics + if not args.multiprocessing_decentralized or (args.multiprocessing_decentralized and args.rank % ngpus_per_node == 0): + val_writer = SummaryWriter(log_dir=os.path.join(args.model_dir, 'val')) + + # Define loss criterion with label smoothing or cross-entropy + criterion = label_smoothing.label_smoothing_CE(reduction='mean') if args.is_label_smoothing else nn.CrossEntropyLoss() + + # Initialize global client and server models + global_model_client = coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) + global_model_server = coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) + + # Initialize client and server models for each selected client + client_models = [coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) for _ in range(args.num_selected)] + server_models = [coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) for _ in range(args.num_selected)] + + # Save hyperparameters to a JSON file + save_hp_to_json(args) + + # Move global models and client/server models to GPU + global_model_client = global_model_client.cuda() + global_model_server = global_model_server.cuda() + for model in client_models + server_models: + model.cuda() + + # Load global model weights into each client and server model + for model in client_models: + model.load_state_dict(global_model_client.state_dict()) + for model in server_models: + model.load_state_dict(global_model_server.state_dict()) + + # Initialize learning rate schedulers for clients and servers + schedulers_clients = [lr_scheduler.lr_scheduler(args.lr_mode, args.lr, args.num_rounds, len(factory.obtain_data_loader(args.data)), args.lr_milestones, args.lr_multiplier) for _ in range(args.num_selected)] + schedulers_servers = [lr_scheduler.lr_scheduler(args.lr_mode, args.lr, args.num_rounds, len(factory.obtain_data_loader(args.data)), args.lr_milestones, args.lr_multiplier) for _ in range(args.num_selected)] + + # Start the training and validation loop for the specified number of rounds + for r in range(args.start_round, args.num_rounds + 1): + # Randomly select client indices for training in each round + client_indices = np.random.permutation(args.num_clusters * args.loop_factor)[:args.num_selected * args.loop diff --git a/EdgeFLite/scripts/.DS_Store b/EdgeFLite/scripts/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/EdgeFLite/scripts/.DS_Store differ diff --git a/EdgeFLite/scripts/EdgeFLite_R110_100c_650r.sh b/EdgeFLite/scripts/EdgeFLite_R110_100c_650r.sh new file mode 100644 index 0000000..adcc4e8 --- /dev/null +++ b/EdgeFLite/scripts/EdgeFLite_R110_100c_650r.sh @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Load necessary modules +# This section loads essential modules required for the execution environment +source /etc/profile.d/modules.sh # Load the module environment configuration +module load gcc/11.2.0 # Load GCC (GNU Compiler Collection) version 11.2.0 +module load openmpi/4.1.3 # Load OpenMPI version 4.1.3 for parallel computing +module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2 for GPU computing +module load cudnn/8.3/8.3.3 # Load cuDNN version 8.3.3 for deep learning libraries +module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11 for multi-GPU communication +module load python/3.10/3.10.4 # Load Python version 3.10.4 for executing Python scripts + +# Activate virtual environment +# This activates the virtual environment that contains the required Python packages +source ~/venv/pytorch1.11+horovod/bin/activate + +# Configure log directory +# Sets up the directory for storing logs related to the job execution +LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" +mkdir -p ${LOG_PATH} # Create the log directory if it doesn't exist + +# Prepare dataset directory +# This section prepares the dataset directory by copying data to the local directory for the job +TEMP_DATA_PATH="${SGE_LOCALDIR}/${JOB_ID}/" # Define the temporary data path for the current job +cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${TEMP_DATA_PATH} # Copy the dataset to the temporary path + +# Change to project directory +# Navigates to the project directory where the training script is located +cd EdgeFLite + +# Execute training script +# This runs the training script with the specified configuration +python train_EdgeFLite.py \ + --is_fed=1 \ # Enable federated learning mode + --fixed_cluster=0 \ # Do not use a fixed cluster configuration + --split_factor=4 \ # Specify the data split factor for federated learning + --num_clusters=25 \ # Set the number of clusters to 25 + --num_selected=25 \ # Select all 25 clusters for training + --arch="resnet_model_110sl" \ # Use the 'resnet_model_110sl' architecture for the model + --dataset="cifar100" \ # Set the dataset to CIFAR-100 + --num_classes=100 \ # Specify the number of output classes (100 for CIFAR-100) + --is_single_branch=0 \ # Enable multi-branch mode for model training + --is_amp=0 \ # Disable automatic mixed precision (AMP) for this run + --num_rounds=650 \ # Set the total number of federated rounds to 650 + --fed_epochs=1 \ # Set the number of local epochs per round to 1 + --spid="EdgeFLite_R110_100c_650r" \ # Set the session/process ID for the current job + --data=${TEMP_DATA_PATH} # Specify the dataset location (temporary directory) diff --git a/EdgeFLite/scripts/EdgeFLite_R110_80c_650r.sh b/EdgeFLite/scripts/EdgeFLite_R110_80c_650r.sh new file mode 100644 index 0000000..3bb7508 --- /dev/null +++ b/EdgeFLite/scripts/EdgeFLite_R110_80c_650r.sh @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Load necessary environment modules +source /etc/profile.d/modules.sh # Source the module environment setup script +module load gcc/11.2.0 # Load GCC compiler version 11.2.0 +module load openmpi/4.1.3 # Load OpenMPI version 4.1.3 for distributed computing +module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2 for GPU acceleration +module load cudnn/8.3/8.3.3 # Load cuDNN version 8.3.3 for deep learning operations +module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11 for multi-GPU communication +module load python/3.10/3.10.4 # Load Python version 3.10.4 + +# Activate the Python virtual environment with PyTorch and Horovod installed +source ~/venv/pytorch1.11+horovod/bin/activate + +# Setup the log directory for the experiment +LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" # Define the log path +rm -rf ${LOG_PATH} # Remove any existing logs in the directory +mkdir -p ${LOG_PATH} # Create the log directory if it doesn't exist + +# Setup the dataset directory, copying data for local use +DATA_PATH="${SGE_LOCALDIR}/${JOB_ID}/" # Define the local directory for the dataset +cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_PATH} # Copy CIFAR-100 dataset to local storage + +# Set experiment parameters for federated learning +OUTPUT_DIR="./EdgeFLite/models/coremodel/" # Directory where model checkpoints will be saved +FED_MODE=1 # Federated learning mode enabled +CLUSTER_FIXED=0 # Cluster dynamic, not fixed +SPLIT_RATIO=4 # Split the dataset into 4 parts +TOTAL_CLUSTERS=20 # Number of clusters (e.g., number of different clients in federated learning) +SELECTED_CLIENTS=20 # Number of clients selected per round +MODEL_ARCH="resnet_model_110sl" # Model architecture to be used (ResNet-110 with some custom changes) +DATASET_NAME="cifar100" # Dataset being used (CIFAR-100) +NUM_CLASS_LABELS=100 # Number of class labels in the dataset (CIFAR-100 has 100 classes) +SINGLE_BRANCH=0 # Multi-branch model architecture (not single-branch) +AMP_MODE=0 # Disable Automatic Mixed Precision (AMP) for training +ROUNDS=650 # Total number of federated learning rounds +EPOCHS_PER_ROUND=1 # Number of local epochs per round of federated learning +EXP_ID="EdgeFLite_R110_80c_650r" # Experiment ID for tracking + +# Navigate to the project directory +cd EdgeFLite # Change to the EdgeFLite project directory + +# Execute the training process for federated learning with the defined parameters +python train_EdgeFLite.py \ + --is_fed=${FED_MODE} # Enable federated learning mode + --fixed_cluster=${CLUSTER_FIXED} # Use dynamic clusters + --split_factor=${SPLIT_RATIO} # Set the dataset split ratio + --num_clusters=${TOTAL_CLUSTERS} # Total number of clusters (clients) + --num_selected=${SELECTED_CLIENTS} # Number of clients selected per federated round + --arch=${MODEL_ARCH} # Set model architecture (ResNet-110 variant) + --dataset=${DATASET_NAME} # Dataset name (CIFAR-100) + --num_classes=${NUM_CLASS_LABELS} # Number of classes in the dataset + --is_single_branch=${SINGLE_BRANCH} # Use multi-branch model (set to 0) + --is_amp=${AMP_MODE} # Disable automatic mixed precision + --num_rounds=${ROUNDS} # Total number of rounds for federated learning + --fed_epochs=${EPOCHS_PER_ROUND} # Number of local epochs per round + --spid=${EXP_ID} # Set experiment ID for tracking + --data=${DATA_PATH} # Provide dataset path + --model_dir=${OUTPUT_DIR} # Directory where the model will be saved diff --git a/EdgeFLite/scripts/EdgeFLite_W168_96c_650r.sh b/EdgeFLite/scripts/EdgeFLite_W168_96c_650r.sh new file mode 100644 index 0000000..485c70f --- /dev/null +++ b/EdgeFLite/scripts/EdgeFLite_W168_96c_650r.sh @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Initialize environment and load necessary modules +# This sets up the environment for running the necessary libraries like GCC, OpenMPI, CUDA, cuDNN, NCCL, and Python +source /etc/profile.d/modules.sh +module load gcc/11.2.0 # Load GCC version 11.2.0 for compiling +module load openmpi/4.1.3 # Load OpenMPI version 4.1.3 for distributed computing +module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2 for GPU computing +module load cudnn/8.3/8.3.3 # Load cuDNN version 8.3.3 for deep learning frameworks +module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11.4-1 for multi-GPU communication +module load python/3.10/3.10.4 # Load Python version 3.10.4 + +# Activate the Python virtual environment +# This activates the pre-configured virtual environment where necessary Python packages (e.g., PyTorch, Horovod) are installed +source ~/venv/pytorch1.11+horovod/bin/activate + +# Prepare the log directory and clean up any old records +# Create a log directory for this job run and remove any previous log records +LOG_DIRECTORY="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" +rm -rf ${LOG_DIRECTORY} # Remove old logs if they exist +mkdir -p ${LOG_DIRECTORY} # Create a new directory for current job logs + +# Set up local data directory and copy dataset +# Define local data storage and copy the dataset for training the model +DATA_STORAGE="${SGE_LOCALDIR}/${JOB_ID}/" +cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_STORAGE} # Copy CIFAR-100 dataset to the local directory + +# Change directory to project location +# Navigate to the EdgeFLite project directory to execute the training script +cd EdgeFLite + +# Execute the training process for the federated learning model +# This runs the model training with specific hyperparameters for federated learning, including architecture, dataset, and configuration settings +python train_EdgeFLite.py \ + --is_fed=1 \ # Enable federated learning mode + --fixed_cluster=0 \ # Disable fixed clusters, allowing dynamic changes + --split_factor=16 \ # Set data split factor to 16 + --num_clusters=6 \ # Use 6 clusters for the federated learning process + --num_selected=6 \ # Select 6 clients for each training round + --arch="wide_resnetsl16_8" \ # Use a Wide ResNet architecture with depth 16 and width 8 + --dataset="cifar100" \ # Specify CIFAR-100 as the dataset + --num_classes=100 \ # CIFAR-100 has 100 output classes + --is_single_branch=0 \ # Use multi-branch (multi-head) learning + --is_amp=0 \ # Disable automatic mixed precision training + --num_rounds=650 \ # Train for 650 communication rounds + --fed_epochs=1 \ # Each client trains for 1 epoch per round + --spid="EdgeFLite_W168_96c_650r" \ # Set the unique identifier for the job + --data=${DATA_STORAGE} # Provide the location of the dataset diff --git a/EdgeFLite/scripts/EdgeFLite_W168_96c_650r2.sh b/EdgeFLite/scripts/EdgeFLite_W168_96c_650r2.sh new file mode 100644 index 0000000..5e05c7e --- /dev/null +++ b/EdgeFLite/scripts/EdgeFLite_W168_96c_650r2.sh @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Load necessary system modules for the environment +source /etc/profile.d/modules.sh +module load gcc/11.2.0 # Load GCC compiler version 11.2.0 +module load openmpi/4.1.3 # Load OpenMPI version 4.1.3 for parallel processing +module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2 for GPU acceleration +module load cudnn/8.3/8.3.3 # Load cuDNN version 8.3.3 for deep learning libraries +module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11.4-1 for multi-GPU communication +module load python/3.10/3.10.4 # Load Python version 3.10.4 + +# Activate the virtual environment for PyTorch and Horovod +source ~/venv/pytorch1.11+horovod/bin/activate + +# Set up the log directory and remove any previous log records +LOG_OUTPUT="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" +rm -rf ${LOG_OUTPUT} # Clean previous logs +mkdir -p ${LOG_OUTPUT} # Create new log directory + +# Prepare local storage for the dataset +LOCAL_DATA_DIR="${SGE_LOCALDIR}/${JOB_ID}/" # Set local storage path +cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${LOCAL_DATA_DIR} # Copy CIFAR-100 data to local storage + +# Move to the project directory +cd EdgeFLite + +# Run the federated learning experiment with the specified parameters +python run_gkt.py \ + --is_fed=1 \ # Enable federated learning + --fixed_cluster=0 \ # Use dynamic clustering + --split_factor=1 \ # Set split factor + --num_clusters=20 \ # Number of clusters in the federation + --num_selected=20 \ # Number of selected clients per round + --arch=resnet_model_110sl \ # Model architecture: ResNet-110 small layer + --dataset=cifar100 \ # Dataset: CIFAR-100 + --num_classes=100 \ # Number of classes in the dataset + --is_single_branch=0 \ # Enable multi-branch model + --is_amp=0 \ # Disable automatic mixed precision + --num_rounds=650 \ # Total number of federated learning rounds + --fed_epochs=1 \ # Number of local epochs per round + --cifar100_non_iid="quantity_skew" \ # Specify non-IID scenario: quantity skew + --spid="FGKT_R110_20c_skew" \ # Experiment identifier + --data=${LOCAL_DATA_DIR} # Path to the local dataset diff --git a/EdgeFLite/scripts/EdgeFLite_W168_96c_650r32.sh b/EdgeFLite/scripts/EdgeFLite_W168_96c_650r32.sh new file mode 100644 index 0000000..9f9135f --- /dev/null +++ b/EdgeFLite/scripts/EdgeFLite_W168_96c_650r32.sh @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Load environment modules required for execution +# This block sets up necessary modules, including compilers and deep learning libraries +source /etc/profile.d/modules.sh +module load gcc/11.2.0 # Load GCC compiler version 11.2.0 +module load openmpi/4.1.3 # Load OpenMPI for distributed computing +module load cuda/11.5/11.5.2 # Load CUDA 11.5 for GPU acceleration +module load cudnn/8.3/8.3.3 # Load cuDNN for deep neural network operations +module load nccl/2.11/2.11.4-1 # Load NCCL for multi-GPU communication +module load python/3.10/3.10.4 # Load Python version 3.10 + +# Activate the Python environment +# This line activates a Python virtual environment with required packages (e.g., PyTorch and Horovod) +source ~/venv/pytorch1.11+horovod/bin/activate + +# Create and clean the log directory for this job +# The log directory is where all training logs will be stored for this specific job +LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" +rm -rf ${LOG_PATH} # Remove any pre-existing log directory +mkdir -p ${LOG_PATH} # Create a new log directory + +# Prepare the local dataset storage +# This copies the dataset to a local directory for faster access during training +DATA_PATH="${SGE_LOCALDIR}/${JOB_ID}/" +cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_PATH} + +# Change to the working directory of the federated training scripts +# The working directory contains the necessary scripts for running the training process +cd EdgeFLite + +# Execute the federated training process with the specified configuration +# This command runs the federated learning training script with several parameters +python run_gkt.py \ + --is_fed=1 # Enables federated learning mode + --fixed_cluster=0 # Dynamic clusters during training + --split_factor=1 # Data split factor for federated learning + --num_clusters=20 # Number of clusters for federated training + --num_selected=20 # Number of selected devices per round + --arch="wide_resnetsl50_2" # Model architecture (Wide ResNet with layers) + --dataset="pill_base" # Dataset being used for training + --num_classes=98 # Number of classes in the dataset + --is_single_branch=0 # Multi-branch model + --is_amp=0 # Disable automatic mixed precision + --num_rounds=350 # Total number of communication rounds in federated learning + --fed_epochs=1 # Number of local epochs per device + --batch_size=32 # Batch size for training + --crop_size=224 # Crop size for image preprocessing + --spid="FGKT_W502_20c_350r" # Unique identifier for the specific training experiment + --data=${DATA_PATH} # Path to the dataset being used for training diff --git a/EdgeFLite/scripts/EdgeFLite_W168_96c_650r4.sh b/EdgeFLite/scripts/EdgeFLite_W168_96c_650r4.sh new file mode 100644 index 0000000..c44c219 --- /dev/null +++ b/EdgeFLite/scripts/EdgeFLite_W168_96c_650r4.sh @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Load necessary modules and dependencies +source /etc/profile.d/modules.sh +module load gcc/11.2.0 +module load openmpi/4.1.3 +module load cuda/11.5/11.5.2 +module load cudnn/8.3/8.3.3 +module load nccl/2.11/2.11.4-1 +module load python/3.10/3.10.4 + +# Activate the Python environment +source ~/venv/pytorch1.11+horovod/bin/activate + +# Configure log directory and clean up any existing records +OUTPUT_LOG_DIR="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" +rm -rf ${OUTPUT_LOG_DIR} +mkdir -p ${OUTPUT_LOG_DIR} + +# Copy dataset to local directory for processing +LOCAL_DATA_PATH="${SGE_LOCALDIR}/${JOB_ID}/" +cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${LOCAL_DATA_PATH} + +# Switch to the working directory containing the training scripts +cd EdgeFLite + +# Run the training script with specified settings for federated learning +python run_gkt.py \ + --is_fed=1 \ # Enable federated learning mode + --fixed_cluster=0 \ # Use dynamic clustering + --split_factor=1 \ # Split factor for distributed computation + --num_clusters=20 \ # Number of clusters to create + --num_selected=20 \ # Number of selected clients per round + --arch="wide_resnet16_8" \ # Architecture to use (Wide ResNet-16-8) + --dataset="cifar10" \ # Dataset to use (CIFAR-10) + --num_classes=10 \ # Number of classes in the dataset + --is_single_branch=0 \ # Disable single branch training mode + --is_amp=0 \ # Disable automatic mixed precision + --num_rounds=300 \ # Number of communication rounds + --fed_epochs=1 \ # Number of local epochs for each client per round + --spid="fedgkt_wrn168_split1_cifar10_20clients_20choose_300rounds" \ # Unique ID for the experiment + --data=${LOCAL_DATA_PATH} # Local path to the dataset diff --git a/EdgeFLite/scripts/EdgeFLite_W168_96c_650r8.sh b/EdgeFLite/scripts/EdgeFLite_W168_96c_650r8.sh new file mode 100644 index 0000000..692f2ae --- /dev/null +++ b/EdgeFLite/scripts/EdgeFLite_W168_96c_650r8.sh @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Load environment modules and required dependencies +source /etc/profile.d/modules.sh +module load gcc/11.2.0 # Load GCC version 11.2.0 +module load openmpi/4.1.3 # Load OpenMPI version 4.1.3 +module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2 +module load cudnn/8.3/8.3.3 # Load cuDNN version 8.3.3 +module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11.4-1 +module load python/3.10/3.10.4 # Load Python version 3.10.4 + +# Activate the virtual Python environment +source ~/venv/pytorch1.11+horovod/bin/activate # Activate a virtual environment for PyTorch and Horovod + +# Define the log directory, clean up old records if any, and recreate the directory +LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" +rm -rf ${LOG_PATH} # Remove any existing log directory +mkdir -p ${LOG_PATH} # Create a new log directory + +# Set up the local data directory and copy the dataset into it +DATA_STORAGE="${SGE_LOCALDIR}/${JOB_ID}/" # Define a local data directory for the job +cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_STORAGE} # Copy CIFAR-100 dataset to the local directory + +# Navigate to the working directory where training scripts are located +cd EdgeFLite # Change directory to the EdgeFLite project + +# Execute the training script with federated learning parameters +python run_gkt.py \ + --is_fed=1 \ # Enable federated learning + --fixed_cluster=0 \ # Allow dynamic cluster formation + --split_factor=1 \ # Data split factor + --num_clusters=20 \ # Number of clusters + --num_selected=20 \ # Number of selected clients per round + --arch="wide_resnet16_8" \ # Network architecture: Wide ResNet 16-8 + --dataset="cifar10" \ # Use CIFAR-10 dataset + --num_classes=10 \ # Number of classes in CIFAR-10 + --is_single_branch=0 \ # Multi-branch network + --is_amp=0 \ # Disable Automatic Mixed Precision (AMP) + --num_rounds=300 \ # Number of federated learning rounds + --fed_epochs=1 \ # Number of local training epochs per round + --cifar10_non_iid="quantity_skew" \ # Non-IID data distribution: quantity skew + --spid="FGKT_W168_20c_skew" \ # Set a specific job identifier + --data=${DATA_STORAGE} # Path to the dataset diff --git a/EdgeFLite/scripts/FGKT_R110_20c_650r.sh b/EdgeFLite/scripts/FGKT_R110_20c_650r.sh new file mode 100644 index 0000000..879fbd5 --- /dev/null +++ b/EdgeFLite/scripts/FGKT_R110_20c_650r.sh @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Load necessary system modules for the job +source /etc/profile.d/modules.sh +module load gcc/11.2.0 # Load GCC compiler +module load openmpi/4.1.3 # Load OpenMPI for distributed computing +module load cuda/11.5/11.5.2 # Load CUDA for GPU acceleration +module load cudnn/8.3/8.3.3 # Load cuDNN for deep learning frameworks +module load nccl/2.11/2.11.4-1 # Load NCCL for multi-GPU communication +module load python/3.10/3.10.4 # Load Python 3.10 environment + +# Activate the required Python virtual environment +source ~/venv/pytorch1.11+horovod/bin/activate # Activate PyTorch 1.11 + Horovod environment + +# Define log directory and clean up any existing records before starting +LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" # Set log path +rm -rf ${LOG_PATH} # Remove any existing log directory +mkdir -p ${LOG_PATH} # Create new log directory + +# Copy the dataset to the local temporary directory +DATA_DIR="${SGE_LOCALDIR}/${JOB_ID}/" # Set the local directory for dataset +cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_DIR} # Copy CIFAR-100 dataset to the local directory + +# Move to the directory containing the training scripts +cd EdgeFLite # Change to EdgeFLite project directory + +# Start the federated learning training process with the specified parameters +python run_gkt.py \ + --is_fed=1 \ # Enable federated learning + --fixed_cluster=0 \ # Use dynamic clustering + --split_factor=1 \ # Set data split factor + --num_clusters=20 \ # Set the number of clusters + --num_selected=20 \ # Number of selected clients per round + --arch="resnet_model_110sl" \ # Model architecture (ResNet 110 with single-layer output) + --dataset="cifar100" \ # Dataset used (CIFAR-100) + --num_classes=100 \ # Number of classes in the dataset + --is_single_branch=0 \ # Enable multi-branch model + --is_amp=0 \ # Disable automatic mixed precision + --num_rounds=650 \ # Number of federated learning rounds + --fed_epochs=1 \ # Number of local epochs per federated round + --spid="FGKT_R110_20c_650r" \ # Experiment ID for logging and tracking + --data=${DATA_DIR} # Specify the path to the dataset diff --git a/EdgeFLite/scripts/FGKT_R110_20c_skew.sh b/EdgeFLite/scripts/FGKT_R110_20c_skew.sh new file mode 100644 index 0000000..c3361a2 --- /dev/null +++ b/EdgeFLite/scripts/FGKT_R110_20c_skew.sh @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Load necessary system modules +source /etc/profile.d/modules.sh + +# Load the GCC module version 11.2.0 +module load gcc/11.2.0 + +# Load the OpenMPI module version 4.1.3 +module load openmpi/4.1.3 + +# Load the CUDA module version 11.5.2 +module load cuda/11.5/11.5.2 + +# Load the cuDNN module version 8.3.3 +module load cudnn/8.3/8.3.3 + +# Load the NCCL module version 2.11.4-1 +module load nccl/2.11/2.11.4-1 + +# Load the Python module version 3.10.4 +module load python/3.10/3.10.4 + +# Activate the virtual environment for PyTorch and Horovod +source ~/venv/pytorch1.11+horovod/bin/activate + +# Set up the log directory and clean previous records if they exist +LOG_OUTPUT="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" +rm -rf ${LOG_OUTPUT} # Remove previous log files +mkdir -p ${LOG_OUTPUT} # Create a new directory for logs + +# Prepare local storage for the dataset by copying it to a local directory +LOCAL_DATA_DIR="${SGE_LOCALDIR}/${JOB_ID}/" +cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${LOCAL_DATA_DIR} + +# Navigate to the EdgeFLite project directory +cd EdgeFLite + +# Run the federated learning experiment with the specified parameters +python run_gkt.py \ + --is_fed=1 \ # Enable federated learning + --fixed_cluster=0 \ # Disable fixed cluster settings + --split_factor=1 \ # Use split factor of 1 + --num_clusters=20 \ # Set the number of clusters to 20 + --num_selected=20 \ # Select 20 clients for each round + --arch=resnet_model_110sl \ # Use ResNet110 single branch architecture + --dataset=cifar100 \ # Use CIFAR-100 dataset + --num_classes=100 \ # Set the number of classes to 100 + --is_single_branch=0 \ # Use multiple branches in the model + --is_amp=0 \ # Disable automatic mixed precision + --num_rounds=650 \ # Set the number of communication rounds to 650 + --fed_epochs=1 \ # Set the number of federated epochs to 1 + --cifar100_non_iid="quantity_skew" \ # Apply non-IID data partitioning (quantity skew) + --spid="FGKT_R110_20c_skew" \ # Set the experiment ID + --data=${LOCAL_DATA_DIR} # Set the path to the dataset in local storage diff --git a/EdgeFLite/scripts/FGKT_W168_20c_300r.sh b/EdgeFLite/scripts/FGKT_W168_20c_300r.sh new file mode 100644 index 0000000..870b273 --- /dev/null +++ b/EdgeFLite/scripts/FGKT_W168_20c_300r.sh @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Load necessary modules and dependencies +source /etc/profile.d/modules.sh + +# Load GCC version 11.2.0 +module load gcc/11.2.0 + +# Load OpenMPI version 4.1.3 for distributed computing +module load openmpi/4.1.3 + +# Load CUDA version 11.5 (subversion 11.5.2) for GPU acceleration +module load cuda/11.5/11.5.2 + +# Load cuDNN version 8.3 (subversion 8.3.3) for deep learning operations +module load cudnn/8.3/8.3.3 + +# Load NCCL version 2.11 (subversion 2.11.4-1) for multi-GPU communication +module load nccl/2.11/2.11.4-1 + +# Load Python version 3.10 (subversion 3.10.4) +module load python/3.10/3.10.4 + +# Activate the Python virtual environment for PyTorch 1.11 + Horovod +source ~/venv/pytorch1.11+horovod/bin/activate + +# Configure the output log directory and clean up any existing records +OUTPUT_LOG_DIR="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" + +# Remove any previous log files from the directory +rm -rf ${OUTPUT_LOG_DIR} + +# Create a fresh directory for storing logs +mkdir -p ${OUTPUT_LOG_DIR} + +# Copy the dataset to a local directory for processing during training +LOCAL_DATA_PATH="${SGE_LOCALDIR}/${JOB_ID}/" + +# Copy the dataset files from the performance test directory to the local directory +cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${LOCAL_DATA_PATH} + +# Switch to the working directory containing the EdgeFLite training scripts +cd EdgeFLite + +# Run the federated learning training script with the specified settings +python run_gkt.py \ + --is_fed=1 \ # Enable federated learning + --fixed_cluster=0 \ # Disable fixed clusters + --split_factor=1 \ # Set data split factor + --num_clusters=20 \ # Specify number of clusters + --num_selected=20 \ # Specify number of selected clients + --arch="wide_resnet16_8" \ # Use Wide ResNet 16-8 architecture + --dataset="cifar10" \ # Set dataset to CIFAR-10 + --num_classes=10 \ # Set number of classes + --is_single_branch=0 \ # Use multi-branch training + --is_amp=0 \ # Disable automatic mixed precision (AMP) + --num_rounds=300 \ # Set number of training rounds + --fed_epochs=1 \ # Set number of federated learning epochs per round + --spid="fedgkt_wrn168_split1_cifar10_20clients_20choose_300rounds" \ # Set session ID + --data=${LOCAL_DATA_PATH} # Set path to the local dataset diff --git a/EdgeFLite/scripts/FGKT_W168_20c_skew.sh b/EdgeFLite/scripts/FGKT_W168_20c_skew.sh new file mode 100644 index 0000000..89f7e1c --- /dev/null +++ b/EdgeFLite/scripts/FGKT_W168_20c_skew.sh @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Load environment modules and required dependencies +source /etc/profile.d/modules.sh +module load gcc/11.2.0 # Load GCC compiler version 11.2.0 +module load openmpi/4.1.3 # Load OpenMPI version 4.1.3 for distributed computing +module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2 for GPU acceleration +module load cudnn/8.3/8.3.3 # Load cuDNN version 8.3.3 for deep learning libraries +module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11.4 for multi-GPU communication +module load python/3.10/3.10.4 # Load Python version 3.10.4 + +# Activate the virtual Python environment +source ~/venv/pytorch1.11+horovod/bin/activate # Activate the virtual environment with PyTorch 1.11 and Horovod + +# Define the log directory, clean up old records if any, and recreate the directory +LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" +rm -rf ${LOG_PATH} # Remove the existing log directory if it exists +mkdir -p ${LOG_PATH} # Create the log directory + +# Set up the local data directory and copy the dataset into it +DATA_STORAGE="${SGE_LOCALDIR}/${JOB_ID}/" +cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_STORAGE} # Copy CIFAR-100 dataset into the local storage directory + +# Navigate to the working directory where training scripts are located +cd EdgeFLite # Change directory to the project EdgeFLite + +# Execute the training script with federated learning parameters +python run_gkt.py \ + --is_fed=1 # Enable federated learning mode + --fixed_cluster=0 # Allow dynamic cluster selection + --split_factor=1 # Set the split factor for cluster selection + --num_clusters=20 # Specify the number of clusters for federated learning + --num_selected=20 # Specify the number of selected clusters for each round + --arch="wide_resnet16_8" # Use the Wide ResNet16_8 architecture + --dataset="cifar10" # Specify the dataset as CIFAR-10 + --num_classes=10 # Set the number of classes for classification + --is_single_branch=0 # Use multiple branches (not single branch) + --is_amp=0 # Disable automatic mixed precision (AMP) + --num_rounds=300 # Specify the number of federated learning rounds + --fed_epochs=1 # Set the number of epochs per round for federated learning + --cifar10_non_iid="quantity_skew" # Use non-iid data distribution with quantity skew for CIFAR-10 + --spid="FGKT_W168_20c_skew" # Set the specific process ID for tracking + --data=${DATA_STORAGE} # Specify the local data storage path diff --git a/EdgeFLite/scripts/FGKT_W502_20c_350r.sh b/EdgeFLite/scripts/FGKT_W502_20c_350r.sh new file mode 100644 index 0000000..91255b7 --- /dev/null +++ b/EdgeFLite/scripts/FGKT_W502_20c_350r.sh @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Load environment modules required for execution +source /etc/profile.d/modules.sh + +# Load the GCC compiler version 11.2.0 +module load gcc/11.2.0 + +# Load the OpenMPI version 4.1.3 for distributed computing +module load openmpi/4.1.3 + +# Load CUDA version 11.5 (subversion 11.5.2) for GPU acceleration +module load cuda/11.5/11.5.2 + +# Load cuDNN version 8.3 (subversion 8.3.3) for deep learning libraries +module load cudnn/8.3/8.3.3 + +# Load NCCL version 2.11 (subversion 2.11.4-1) for multi-GPU communication +module load nccl/2.11/2.11.4-1 + +# Load Python version 3.10 (subversion 3.10.4) as the programming language +module load python/3.10/3.10.4 + +# Activate the Python virtual environment for PyTorch and Horovod +source ~/venv/pytorch1.11+horovod/bin/activate + +# Create and clean the log directory for this job +LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" +# Remove any existing log directory to avoid conflicts +rm -rf ${LOG_PATH} +# Create a fresh log directory for the current job +mkdir -p ${LOG_PATH} + +# Prepare the local dataset storage +DATA_PATH="${SGE_LOCALDIR}/${JOB_ID}/" +# Copy the dataset for local processing to improve performance +cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_PATH} + +# Change to the working directory of the federated training scripts +cd EdgeFLite + +# Execute the federated training process with the specified configuration +python run_gkt.py \ + --is_fed=1 \ # Enable federated training mode + --fixed_cluster=0 \ # Do not fix clusters + --split_factor=1 \ # Set the split factor to 1 + --num_clusters=20 \ # Number of clusters to use in federated training + --num_selected=20 \ # Number of selected clusters per round + --arch="wide_resnetsl50_2" \ # Use the wide ResNet-50_2 architecture + --dataset="pill_base" \ # Specify the dataset to use (Pill Base) + --num_classes=98 \ # Number of classes in the dataset + --is_single_branch=0 \ # Enable multi-branch training + --is_amp=0 \ # Disable automatic mixed precision training + --num_rounds=350 \ # Number of federated training rounds + --fed_epochs=1 \ # Number of epochs per federated round + --batch_size=32 \ # Batch size for training + --crop_size=224 \ # Image crop size + --spid="FGKT_W502_20c_350r" \ # Specify the unique session ID for logging + --data=${DATA_PATH} # Path to the dataset diff --git a/EdgeFLite/settings.py b/EdgeFLite/settings.py new file mode 100644 index 0000000..c8a5959 --- /dev/null +++ b/EdgeFLite/settings.py @@ -0,0 +1,7 @@ +import os # Import the 'os' module, which provides functions for interacting with the operating system + +current_directory = os.getcwd() + +# Define a variable 'data_directory' to store the path to the directory where data will be stored. +# In this case, it's being set to the same path as 'current_directory', meaning the data will be stored in the same location as the current working directory +data_directory = current_directory diff --git a/EdgeFLite/thop/.DS_Store b/EdgeFLite/thop/.DS_Store new file mode 100644 index 0000000..e8a5f76 Binary files /dev/null and b/EdgeFLite/thop/.DS_Store differ diff --git a/EdgeFLite/thop/helper_utils.py b/EdgeFLite/thop/helper_utils.py new file mode 100644 index 0000000..b0346b0 --- /dev/null +++ b/EdgeFLite/thop/helper_utils.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan +from collections.abc import Iterable + +# Define a function named 'clever_format' that takes two arguments: +# 1. 'nums' - either a single number or a list of numbers to format. +# 2. 'fmt' - an optional string argument specifying the format for the numbers (default is "%.2f", meaning two decimal places). +def clever_format(nums, fmt="%.2f"): + + # Check if the input 'nums' is not an instance of an iterable (like a list or tuple). + # If it is not iterable, convert the single number into a list for uniform processing later. + if not isinstance(nums, Iterable): + nums = [nums] + + # Create an empty list to store the formatted numbers. + formatted_nums = [] + + # Loop through each number in the 'nums' list. + for num in nums: + # Check if the number is greater than 1 trillion (1e12). If so, format it by dividing it by 1 trillion and appending 'T' (for trillions). + if num > 1e12: + formatted_nums.append(fmt % (num / 1e12) + "T") + # If the number is greater than 1 billion (1e9), format it by dividing by 1 billion and appending 'G' (for billions). + elif num > 1e9: + formatted_nums.append(fmt % (num / 1e9) + "G") + # If the number is greater than 1 million (1e6), format it by dividing by 1 million and appending 'M' (for millions). + elif num > 1e6: + formatted_nums.append(fmt % (num / 1e6) + "M") + # If the number is greater than 1 thousand (1e3), format it by dividing by 1 thousand and appending 'K' (for thousands). + elif num > 1e3: + formatted_nums.append(fmt % (num / 1e3) + "K") + # If the number is less than 1 thousand, simply format it using the provided format and append 'B' (for base or basic). + else: + formatted_nums.append(fmt % num + "B") + + # If only one number was passed, return just the formatted string for that number. + # If multiple numbers were passed, return a tuple containing all formatted numbers. + return formatted_nums[0] if len(formatted_nums) == 1 else tuple(formatted_nums) diff --git a/EdgeFLite/thop/hooks_basic.py b/EdgeFLite/thop/hooks_basic.py new file mode 100644 index 0000000..f3acdc4 --- /dev/null +++ b/EdgeFLite/thop/hooks_basic.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import argparse +import logging + +import torch +import torch.nn as nn +from torch.nn.modules.conv import _ConvNd + +multiply_adds = 1 + +def count_parameters(m, x, y): + """Counts the number of parameters in a model.""" + total_params = sum(p.numel() for p in m.parameters()) + m.total_params[0] = torch.DoubleTensor([total_params]) + +def zero_ops(m, x, y): + """Sets total operations to zero.""" + m.total_ops += torch.DoubleTensor([0]) + +def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): + """Counts operations for convolutional layers.""" + x = x[0] + kernel_ops = m.weight[0][0].numel() # Kw x Kh + bias_ops = 1 if m.bias is not None else 0 + total_ops = y.nelement() * (m.in_channels // m.groups * kernel_ops + bias_ops) + m.total_ops += torch.DoubleTensor([total_ops]) + +def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): + """Alternative method for counting operations for convolutional layers.""" + x = x[0] + output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel() + kernel_ops = m.weight.numel() + (m.bias.numel() if m.bias is not None else 0) + m.total_ops += torch.DoubleTensor([output_size * kernel_ops]) + +def count_bn(m, x, y): + """Counts operations for batch normalization layers.""" + x = x[0] + nelements = x.numel() + if not m.training: + total_ops = 2 * nelements + m.total_ops += torch.DoubleTensor([total_ops]) + +def count_relu(m, x, y): + """Counts operations for ReLU activation.""" + x = x[0] + nelements = x.numel() + m.total_ops += torch.DoubleTensor([nelements]) + +def count_softmax(m, x, y): + """Counts operations for softmax.""" + x = x[0] + batch_size, nfeatures = x.size() + total_ops = batch_size * (2 * nfeatures - 1) + m.total_ops += torch.DoubleTensor([total_ops]) + +def count_avgpool(m, x, y): + """Counts operations for average pooling layers.""" + num_elements = y.numel() + m.total_ops += torch.DoubleTensor([num_elements]) + +def count_adap_avgpool(m, x, y): + """Counts operations for adaptive average pooling layers.""" + kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor(list((m.output_size,))).squeeze() + kernel_ops = torch.prod(kernel) + 1 + num_elements = y.numel() + m.total_ops += torch.DoubleTensor([kernel_ops * num_elements]) + +def count_upsample(m, x, y): + """Counts operations for upsample layers.""" + if m.mode not in ("nearest", "linear", "bilinear", "bicubic"): + logging.warning(f"Mode {m.mode} is not implemented yet, assuming zero ops") + return zero_ops(m, x, y) + + if m.mode == "nearest": + return zero_ops(m, x, y) + + total_ops = { + "linear": 5, + "bilinear": 11, + "bicubic": 259, # 224 muls + 35 adds + "trilinear": 31 # 2 * bilinear + 1 * linear + }.get(m.mode, 0) * y.nelement() + + m.total_ops += torch.DoubleTensor([total_ops]) + +def count_linear(m, x, y): + """Counts operations for linear layers.""" + total_ops = m.in_features * y.numel() + m.total_ops += torch.DoubleTensor([total_ops]) diff --git a/EdgeFLite/thop/hooks_rnn.py b/EdgeFLite/thop/hooks_rnn.py new file mode 100644 index 0000000..700041d --- /dev/null +++ b/EdgeFLite/thop/hooks_rnn.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import torch.nn as nn + +def _count_rnn_cell(input_size, hidden_size, bias=True): + """Calculate the total operations for a single RNN cell. + + Args: + input_size (int): Size of the input. + hidden_size (int): Size of the hidden state. + bias (bool, optional): Whether the RNN cell uses bias. Defaults to True. + + Returns: + int: Total number of operations for the RNN cell. + """ + ops = hidden_size * (input_size + hidden_size) + hidden_size + if bias: + ops += hidden_size * 2 + return ops + +def count_rnn_cell(cell: nn.RNNCell, x: torch.Tensor): + """Count operations for the RNNCell over a batch of input. + + Args: + cell (nn.RNNCell): The RNNCell to count operations for. + x (torch.Tensor): Input tensor. + """ + ops = _count_rnn_cell(cell.input_size, cell.hidden_size, cell.bias) + batch_size = x[0].size(0) + total_ops = ops * batch_size + cell.total_ops += torch.DoubleTensor([int(total_ops)]) + +def _count_gru_cell(input_size, hidden_size, bias=True): + """Calculate the total operations for a single GRU cell. + + Args: + input_size (int): Size of the input. + hidden_size (int): Size of the hidden state. + bias (bool, optional): Whether the GRU cell uses bias. Defaults to True. + + Returns: + int: Total number of operations for the GRU cell. + """ + ops = (hidden_size + input_size) * hidden_size + hidden_size + if bias: + ops += hidden_size * 2 + ops *= 2 # For reset and update gates + + ops += (hidden_size + input_size) * hidden_size + hidden_size # Calculate new gate + if bias: + ops += hidden_size * 2 + ops += hidden_size # Hadamard product + ops += hidden_size * 3 # Final output + + return ops + +def count_gru_cell(cell: nn.GRUCell, x: torch.Tensor): + """Count operations for the GRUCell over a batch of input. + + Args: + cell (nn.GRUCell): The GRUCell to count operations for. + x (torch.Tensor): Input tensor. + """ + ops = _count_gru_cell(cell.input_size, cell.hidden_size, cell.bias) + batch_size = x[0].size(0) + total_ops = ops * batch_size + cell.total_ops += torch.DoubleTensor([int(total_ops)]) + +def _count_lstm_cell(input_size, hidden_size, bias=True): + """Calculate the total operations for a single LSTM cell. + + Args: + input_size (int): Size of the input. + hidden_size (int): Size of the hidden state. + bias (bool, optional): Whether the LSTM cell uses bias. Defaults to True. + + Returns: + int: Total number of operations for the LSTM cell. + """ + ops = (input_size + hidden_size) * hidden_size + hidden_size + if bias: + ops += hidden_size * 2 + ops *= 4 # For input, forget, output, and cell gates + + ops += hidden_size * 3 # Cell state update + ops += hidden_size # Final output + + return ops + +def count_lstm_cell(cell: nn.LSTMCell, x: torch.Tensor): + """Count operations for the LSTMCell over a batch of input. + + Args: + cell (nn.LSTMCell): The LSTMCell to count operations for. + x (torch.Tensor): Input tensor. + """ + ops = _count_lstm_cell(cell.input_size, cell.hidden_size, cell.bias) + batch_size = x[0].size(0) + total_ops = ops * batch_size + cell.total_ops += torch.DoubleTensor([int(total_ops)]) + +def _count_rnn_layers(model: nn.RNN, num_layers, input_size, hidden_size): + """Calculate the total operations for RNN layers. + + Args: + model (nn.RNN): The RNN model. + num_layers (int): Number of layers in the RNN. + input_size (int): Size of the input. + hidden_size (int): Size of the hidden state. + + Returns: + int: Total number of operations for the RNN layers. + """ + ops = _count_rnn_cell(input_size, hidden_size, model.bias) + for _ in range(num_layers - 1): + ops += _count_rnn_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias) + return ops + +def count_rnn(model: nn.RNN, x: torch.Tensor): + """Count operations for the entire RNN over a batch of input. + + Args: + model (nn.RNN): The RNN model. + x (torch.Tensor): Input tensor. + """ + batch_size = x[0].size(0) if model.batch_first else x[0].size(1) + num_steps = x[0].size(1) if model.batch_first else x[0].size(0) + + ops = _count_rnn_layers(model, model.num_layers, model.input_size, model.hidden_size) + total_ops = ops * num_steps * batch_size + model.total_ops += torch.DoubleTensor([int(total_ops)]) + +def _count_gru_layers(model: nn.GRU, num_layers, input_size, hidden_size): + """Calculate the total operations for GRU layers. + + Args: + model (nn.GRU): The GRU model. + num_layers (int): Number of layers in the GRU. + input_size (int): Size of the input. + hidden_size (int): Size of the hidden state. + + Returns: + int: Total number of operations for the GRU layers. + """ + ops = _count_gru_cell(input_size, hidden_size, model.bias) + for _ in range(num_layers - 1): + ops += _count_gru_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias) + return ops + +def count_gru(model: nn.GRU, x: torch.Tensor): + """Count operations for the entire GRU over a batch of input. + + Args: + model (nn.GRU): The GRU model. + x (torch.Tensor): Input tensor. + """ + batch_size = x[0].size(0) if model.batch_first else x[0].size(1) + num_steps = x[0].size(1) if model.batch_first else x[0].size(0) + + ops = _count_gru_layers(model, model.num_layers, model.input_size, model.hidden_size) + total_ops = ops * num_steps * batch_size + model.total_ops += torch.DoubleTensor([int(total_ops)]) + +def _count_lstm_layers(model: nn.LSTM, num_layers, input_size, hidden_size): + """Calculate the total operations for LSTM layers. + + Args: + model (nn.LSTM): The LSTM model. + num_layers (int): Number of layers in the LSTM. + input_size (int): Size of the input. + hidden_size (int): Size of the hidden state. + + Returns: + int: Total number of operations for the LSTM layers. + """ + ops = _count_lstm_cell(input_size, hidden_size, model.bias) + for _ in range(num_layers - 1): + ops += _count_lstm_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias) + return ops + +def count_lstm(model: nn.LSTM, x: torch.Tensor): + """Count operations for the entire LSTM over a batch of input. + + Args: + model (nn.LSTM): The LSTM model. + x (torch.Tensor): Input tensor. + """ + batch_size = x[0].size(0) if model.batch_first else x[0].size(1) + num_steps = x[0].size(1) if model.batch_first else x[0].size(0) + + ops = _count_lstm_layers(model, model.num_layers, model.input_size, model.hidden_size) + total_ops = ops * num_steps * batch_size + model.total_ops += torch.DoubleTensor([int(total_ops)]) diff --git a/EdgeFLite/thop/profiling.py b/EdgeFLite/thop/profiling.py new file mode 100644 index 0000000..158ec8d --- /dev/null +++ b/EdgeFLite/thop/profiling.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +# Importing necessary modules +from distutils.version import LooseVersion # Used for version comparisons +from .basic_hooks import * # Importing basic hooks (functions for profiling operations) +from .rnn_hooks import * # Importing hooks specific to RNN operations + +# Uncomment the following for logging purposes +# import logging +# logger = logging.getLogger(__name__) # Creating a logger instance +# logger.setLevel(logging.INFO) # Setting the log level to INFO + +# Functions to print text in different colors +# Useful for visually differentiating output in terminal +def prRed(skk): + print("\033[91m{}\033[00m".format(skk)) # Print red text +def prGreen(skk): + print("\033[92m{}\033[00m".format(skk)) # Print green text +def prYellow(skk): + print("\033[93m{}\033[00m".format(skk)) # Print yellow text + +# Checking if the installed version of PyTorch is outdated +if LooseVersion(torch.__version__) < LooseVersion("1.0.0"): + # If the version is below 1.0.0, print a warning + logging.warning( + f"You are using an old version of PyTorch {torch.__version__}, which THOP may not support in the future." + ) + +# Setting the default data type for tensors +default_dtype = torch.float64 # Using 64-bit float as the default precision + +# Register hooks for different layers in PyTorch +# Each layer type is mapped to its respective counting function +register_hooks = { + nn.ZeroPad2d: zero_ops, + nn.Conv1d: count_convNd, nn.Conv2d: count_convNd, nn.Conv3d: count_convNd, + nn.ConvTranspose1d: count_convNd, nn.ConvTranspose2d: count_convNd, nn.ConvTranspose3d: count_convNd, + nn.BatchNorm1d: count_bn, nn.BatchNorm2d: count_bn, nn.BatchNorm3d: count_bn, nn.SyncBatchNorm: count_bn, + nn.ReLU: zero_ops, nn.ReLU6: zero_ops, nn.LeakyReLU: count_relu, + nn.MaxPool1d: zero_ops, nn.MaxPool2d: zero_ops, nn.MaxPool3d: zero_ops, + nn.AdaptiveMaxPool1d: zero_ops, nn.AdaptiveMaxPool2d: zero_ops, nn.AdaptiveMaxPool3d: zero_ops, + nn.AvgPool1d: count_avgpool, nn.AvgPool2d: count_avgpool, nn.AvgPool3d: count_avgpool, + nn.AdaptiveAvgPool1d: count_adap_avgpool, nn.AdaptiveAvgPool2d: count_adap_avgpool, nn.AdaptiveAvgPool3d: count_adap_avgpool, + nn.Linear: count_linear, nn.Dropout: zero_ops, + nn.Upsample: count_upsample, nn.UpsamplingBilinear2d: count_upsample, nn.UpsamplingNearest2d: count_upsample, + nn.RNNCell: count_rnn_cell, nn.GRUCell: count_gru_cell, nn.LSTMCell: count_lstm_cell, + nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm, +} + +# Function for profiling model operations and parameters +# This tracks how many operations (ops) and parameters (params) a model uses +def profile_origin(model, inputs, custom_ops=None, verbose=True): + handler_collection = [] # Collection of hooks + types_collection = set() # Keep track of registered layer types + custom_ops = custom_ops or {} # Custom operation handling + + def add_hooks(m): + # Ignore compound modules (those that contain other modules) + if len(list(m.children())) > 0: + return + + # Check if the module already has the required attributes + if hasattr(m, "total_ops") or hasattr(m, "total_params"): + logging.warning(f"Either .total_ops or .total_params is already defined in {str(m)}. Be cautious.") + + # Add buffers to store the total number of operations and parameters + m.register_buffer('total_ops', torch.zeros(1, dtype=default_dtype)) + m.register_buffer('total_params', torch.zeros(1, dtype=default_dtype)) + + # Count the number of parameters for this module + for p in m.parameters(): + m.total_params += torch.DoubleTensor([p.numel()]) + + # Determine which function to use for counting operations + m_type = type(m) + fn = custom_ops.get(m_type, register_hooks.get(m_type, None)) + + if fn: + # If the function exists, register the forward hook + if m_type not in types_collection and verbose: + print(f"[INFO] {'Customize rule' if m_type in custom_ops else 'Register'} {fn.__qualname__} for {m_type}.") + handler = m.register_forward_hook(fn) + handler_collection.append(handler) + else: + # Warn if no counting rule is found + if m_type not in types_collection and verbose: + prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero MACs and zero Params.") + + types_collection.add(m_type) + + # Set the model to evaluation mode (no gradients) + model.eval() + model.apply(add_hooks) + + # Run a forward pass with no gradients + with torch.no_grad(): + model(*inputs) + + # Sum up the total operations and parameters across all layers + total_ops = sum(m.total_ops.item() for m in model.modules() if hasattr(m, 'total_ops')) + total_params = sum(m.total_params.item() for m in model.modules() if hasattr(m, 'total_params')) + + # Restore the model to training mode and remove hooks + model.train() + for handler in handler_collection: + handler.remove() + for m in model.modules(): + if hasattr(m, "total_ops"): del m._buffers['total_ops'] + if hasattr(m, "total_params"): del m._buffers['total_params'] + + return total_ops, total_params # Return the total number of ops and params + +# Updated profiling function with a different approach for hierarchical modules +def profile(model: nn.Module, inputs, custom_ops=None, verbose=True): + handler_collection = {} # Dictionary to store handlers + types_collection = set() # Store layer types that have been processed + custom_ops = custom_ops or {} # Custom operation handling + + def add_hooks(m: nn.Module): + # Add buffers for storing total ops and params + m.register_buffer('total_ops', torch.zeros(1, dtype=default_dtype)) + m.register_buffer('total_params', torch.zeros(1, dtype=default_dtype)) + + # Find the appropriate counting function for this layer + fn = custom_ops.get(type(m), register_hooks.get(type(m), None)) + if fn: + # Register hooks for both operations and parameters + handler_collection[m] = (m.register_forward_hook(fn), m.register_forward_hook(count_parameters)) + if type(m) not in types_collection and verbose: + print(f"[INFO] {'Customize rule' if type(m) in custom_ops else 'Register'} {fn.__qualname__} for {type(m)}.") + else: + # Warn if no rule is found for this layer + if type(m) not in types_collection and verbose: + prRed(f"[WARN] Cannot find rule for {type(m)}. Treat it as zero MACs and zero Params.") + types_collection.add(type(m)) + + # Set the model to evaluation mode + model.eval() + model.apply(add_hooks) + + # Run a forward pass with no gradients + with torch.no_grad(): + model(*inputs) + + # Recursive function to count ops and params for hierarchical models + def dfs_count(module: nn.Module) -> (int, int): + total_ops, total_params = 0, 0 + for m in module.children(): + if m in handler_collection: + m_ops, m_params = m.total_ops.item(), m.total_params.item() + else: + m_ops, m_params = dfs_count(m) + total_ops += m_ops + total_params += m_params + return total_ops, total_params + + total_ops, total_params = dfs_count(model) # Perform the depth-first count + + # Restore the model to training mode and remove hooks + model.train() + for m, (op_handler, params_handler) in handler_collection.items(): + op_handler.remove() + params_handler.remove() + del m._buffers['total_ops'] + del m._buffers['total_params'] + + return total_ops, total_params # Return the total ops and params diff --git a/EdgeFLite/train_EdgeFLite.py b/EdgeFLite/train_EdgeFLite.py new file mode 100644 index 0000000..c8cb6e8 --- /dev/null +++ b/EdgeFLite/train_EdgeFLite.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +# @Author: Weisen Pan + +import torch +import argparse +import torch.nn as nn +from config import * # Import configuration +from params import train_params # Import training parameters +from model import coremodel, coremodelsl # Import models +from utils import ( # Import utility functions + label_smoothing, norm, metric, lr_scheduler, prefetch, + save_hp_to_json, profile, clever_format +) +from dataset import factory # Import dataset factory + +# Specify the GPU to be used +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# Global variable for tracking the best accuracy +best_acc1 = 0 + +# Function to calculate the average of a list of values +def average(values): + """Calculate average of a list.""" + return sum(values) / len(values) + +# Function to aggregate the models from multiple clients into a global model +def merge_models(global_model_main, global_model_proxy, client_main_models, client_proxy_models): + """Aggregates weights of the models using simple mean.""" + # Get the state dictionaries for the global models + global_main_state = global_model_main.state_dict() + global_proxy_state = global_model_proxy.state_dict() + + # Aggregate the main client models by averaging the weights + for key in global_main_state.keys(): + global_main_state[key] = torch.stack([client.state_dict()[key].float() for client in client_main_models], 0).mean(0) + global_model_main.load_state_dict(global_main_state) + + # Aggregate the proxy client models similarly + for key in global_proxy_state.keys(): + global_proxy_state[key] = torch.stack([client.state_dict()[key].float() for client in client_proxy_models], 0).mean(0) + global_model_proxy.load_state_dict(global_proxy_state) + + # Synchronize the client models with the updated global model + for client in client_main_models: + client.load_state_dict(global_model_main.state_dict()) + for client in client_proxy_models: + client.load_state_dict(global_model_proxy.state_dict()) + +# Function to perform client-side training updates +def client_update(args, round_idx, main_model, proxy_models, schedulers_main, schedulers_proxy, optimizers_main, optimizers_proxy, train_loader, epochs=5, streams=None): + """Client-side training update.""" + main_model.train() + proxy_models.train() + + # Train for a given number of epochs + for epoch in range(epochs): + # Prefetch data for faster loading + prefetcher = prefetch.data_prefetcher(train_loader) + images, targets = prefetcher.next() + batch_idx = 0 + + # Zero the gradients + optimizers_main.zero_grad() + optimizers_proxy.zero_grad() + + # Process each batch of data + while images is not None: + # Adjust learning rates using the scheduler + schedulers_main(optimizers_main, batch_idx, round_idx) + schedulers_proxy(optimizers_proxy, batch_idx, round_idx) + + # Forward pass for the main model + outputs, y_a, y_b, lam = main_model(images, target=targets, mode='train', epoch=epoch, streams=streams) + main_fx = [output.clone().detach().requires_grad_(True) for output in outputs] + + # Forward pass for the proxy model with outputs from the main model + ensemble_output, proxy_outputs, ce_loss, cot_loss = proxy_models(main_fx, y_a, y_b, lam, target=targets, mode='train', epoch=epoch, streams=streams) + + # Calculate total loss and perform backpropagation + total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate + total_loss.backward() + + # Backpropagate gradients for the main model + for j in range(len(main_fx)): + outputs[j].backward(main_fx[j].grad) + + # Update the model weights periodically + if batch_idx % args.iters_to_accumulate == 0 or batch_idx == len(train_loader): + optimizers_main.step() + optimizers_main.zero_grad() + optimizers_proxy.step() + optimizers_proxy.zero_grad() + + # Fetch the next batch of images + images, targets = prefetcher.next() + batch_idx += 1 + + return total_loss.item() + +# Function to validate the models on a validation set +def validate(val_loader, main_model, proxy_models, args, streams=None): + """Validation function to evaluate models.""" + main_model.eval() + proxy_models.eval() + + # Initialize metrics for accuracy tracking + top1_metrics = [metric.AverageMeter(f"Acc@1_{i}", ":6.2f") for i in range(args.loop_factor)] + acc1_list, acc5_list, ce_loss_list = [], [], [] + + # Disable gradient computation for validation + with torch.no_grad(): + for images, targets in val_loader: + images, targets = images.cuda(), targets.cuda() + + # Forward pass for main model + outputs = main_model(images, target=targets, mode='val') + main_fx = [output.clone().detach().requires_grad_(True) for output in outputs] + + # Forward pass for proxy model + ensemble_output, proxy_outputs, ce_loss = proxy_models(main_fx, target=targets, mode='val') + + # Calculate accuracy + acc1, acc5 = metric.accuracy(ensemble_output, targets, topk=(1, 5)) + acc1_list.append(acc1) + acc5_list.append(acc5) + ce_loss_list.append(ce_loss) + + # Calculate average metrics over the validation set + avg_acc1 = average(acc1_list) + avg_acc5 = average(acc5_list) + avg_ce_loss = average(ce_loss_list) + + return avg_ce_loss, avg_acc1, top1_metrics + +# Main function to set up and start decentralized training +def main(args): + """Main function to set up decentralized training.""" + # Set loop factor based on training configuration + args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor + # Determine if decentralized training is needed + args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized + + # Get the number of GPUs available + ngpus_per_node = torch.cuda.device_count() + args.ngpus_per_node = ngpus_per_node + + # If using decentralized training with multiprocessing + if args.multiprocessing_decentralized: + args.world_size *= ngpus_per_node + torch.multiprocessing.spawn(execute_worker_process, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # If not using multiprocessing, proceed with a single GPU + args.gpu = 0 + execute_worker_process(args.gpu, ngpus_per_node, args) + +# Main worker function to handle training with multiple GPUs or single GPU +def execute_worker_process(gpu, ngpus_per_node, args): + """Main worker function for multi-GPU or single-GPU training.""" + global best_acc1 + args.gpu = gpu + + # Set process title + setproctitle.setproctitle(f"{args.proc_name}_EdgeFLite_rank{args.rank}") + + # Set the criterion for loss calculation + if args.is_label_smoothing: + criterion = label_smoothing.label_smoothing_CE(reduction='mean') + else: + criterion = nn.CrossEntropyLoss() + + # Create the main and proxy models for training + main_model = coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() + proxy_model = coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() + + # Initialize client models for federated learning + client_main_models = [coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() for _ in range(args.num_selected)] + client_proxy_models = [coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() for _ in range(args.num_selected)] + + # Synchronize client models with the global models + for client in client_main_models: + client.load_state_dict(main_model.state_dict()) + for client in client_proxy_models: + client.load_state_dict(proxy_model.state_dict()) + + # Load training and validation data + train_loader = factory.obtain_data_loader(args.data, batch_size=args.batch_size, dataset=args.dataset, split="train", num_workers=args.workers) + val_loader = factory.obtain_data_loader(args.data, batch_size=args.eval_batch_size, dataset=args.dataset, split="val", num_workers=args.workers) + + # Loop over training rounds + for r in range(args.start_round, args.num_rounds + 1): + # Update client models with new training data + client_update(args, r, client_main_models, client_proxy_models, lr_scheduler.lr_scheduler, lr_scheduler.lr_scheduler, torch.optim.SGD, torch.optim.SGD, train_loader) + + # Validate the models + test_loss, acc, top1 = validate(val_loader, main_model, proxy_model, args) + + # Track the best accuracy achieved + best_acc1 = max(acc, best_acc1) + +# Entry point for the script +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Training EdgeFLite") + args = train_params.add_parser_params(parser) + main(args)