Edge Federated Learning for Improved Training Efficiency
Change-Id: Ic4e43992e1674946cb69e0221659b0261259196c
This commit is contained in:
parent
bc2236740c
commit
4ec0a23e73
BIN
EdgeFLite/.DS_Store
vendored
Normal file
BIN
EdgeFLite/.DS_Store
vendored
Normal file
Binary file not shown.
41
EdgeFLite/README.md
Normal file
41
EdgeFLite/README.md
Normal file
@ -0,0 +1,41 @@
|
||||
# EdgeFLite:Edge Federated Learning for Improved Training Efficiency
|
||||
|
||||
|
||||
- EdgeFLite is a cutting-edge framework developed to tackle the memory limitations of federated learning (FL) on edge devices with restricted resources. By partitioning large convolutional neural networks (CNNs) into smaller sub-models and distributing the training across local clients, EdgeFLite ensures efficient learning while maintaining data privacy. Clients in clusters collaborate by sharing learned representations, which are then aggregated by a central server to refine the global model. Experimental results on medical imaging and natural datasets demonstrate that EdgeFLite consistently outperforms other FL frameworks, setting new benchmarks for performance.
|
||||
|
||||
- Within 6G-enabled mobile edge computing (MEC) networks, EdgeFLite addresses the challenges posed by client diversity and resource constraints. It optimizes local models and resource allocation to improve overall efficiency. Through a detailed convergence analysis, this research establishes a clear relationship between training loss and resource usage. The innovative Intelligent Frequency Band Allocation (IFBA) algorithm minimizes latency and enhances training efficiency by 5-10%, making EdgeFLite a robust solution for improving federated learning across a wide range of edge environments.
|
||||
|
||||
## Preparation
|
||||
### Dataset Setup
|
||||
- The CIFAR-10 and CIFAR-100 datasets, both derived from the Tiny Images dataset, will be automatically downloaded. CIFAR-10 includes 60,000 32x32 color images across 10 categories: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 6,000 images per category, split into 5,000 for training and 1,000 for testing.
|
||||
|
||||
- CIFAR-100 is a more complex dataset, featuring 100 categories with fewer images per class compared to CIFAR-10. These datasets serve as standard benchmarks for image classification tasks and provide a robust evaluation environment for machine learning models.
|
||||
|
||||
### Dependency Installation
|
||||
|
||||
```bash
|
||||
Pytorch 1.10.2
|
||||
OpenCV 4.5.5
|
||||
```
|
||||
|
||||
## Running Experiments
|
||||
*Top-1 accuracy (%) of FedDCT compared to state-of-the-art FL methods on the CIFAR-10 and CIFAR-100 test datasets.*
|
||||
|
||||
1. **Specify Experiment Name:**
|
||||
Add `--spid` to specify the experiment name in each training script, like this:
|
||||
```bash
|
||||
python run_gkt.py --is_fed=1 --fixed_cluster=0 --split_factor=1 --num_clusters=20 --num_selected=20 --dataset=cifar10 --num_classes=10 --is_single_branch=0 --is_amp=0 --num_rounds=300 --fed_epochs=1
|
||||
```
|
||||
|
||||
2. **Training Scripts for CIFAR-10:**
|
||||
|
||||
- **Centralized Training:**
|
||||
```bash
|
||||
python run_local.py --is_fed=0 --split_factor=1 --dataset=cifar10 --num_classes=10 --is_single_branch=0 --is_amp=0 --epochs=300
|
||||
```
|
||||
|
||||
- **FedDCT:**
|
||||
```bash
|
||||
python train_EdgeFLite.py --is_fed=1 --fixed_cluster=0 --split_factor=4 --num_clusters=5 --num_selected=5 --dataset=cifar10 --num_classes=10 --is_single_branch=0 --is_amp=0 --num_rounds=300 --fed_epochs=1
|
||||
```
|
||||
---
|
BIN
EdgeFLite/architecture/.DS_Store
vendored
Normal file
BIN
EdgeFLite/architecture/.DS_Store
vendored
Normal file
Binary file not shown.
245
EdgeFLite/architecture/coremodel.py
Normal file
245
EdgeFLite/architecture/coremodel.py
Normal file
@ -0,0 +1,245 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import os
|
||||
from sklearn import ensemble
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .mixup import mixup_loss_criterion, combine_mixup_data
|
||||
from . import resnet, resnet_sl
|
||||
|
||||
# Exported members of the module
|
||||
__all__ = ['coremodelSL']
|
||||
|
||||
def _retrieve_networkwork(arch='resnet_model_110sl'):
|
||||
"""Retrieve the specific network architecture based on the provided name."""
|
||||
available_networks = {
|
||||
'resnet_model_110sl': resnet_sl.resnet_model_110sl,
|
||||
'wide_resnetsl50_2': resnet_sl.wide_resnetsl50_2,
|
||||
'wide_resnetsl16_8': resnet_sl.wide_resnetsl16_8,
|
||||
}
|
||||
# Ensure the architecture requested exists in the available networks
|
||||
assert arch in available_networks, f"Architecture '{arch}' is not supported."
|
||||
return available_networks[arch]
|
||||
|
||||
class CoreModelClient(nn.Module):
|
||||
"""Main client model for training and inference, managing multiple sub-networks."""
|
||||
|
||||
def __init__(self, args, norm_layer=None, criterion=None, progress=True):
|
||||
super(CoreModelClient, self).__init__()
|
||||
|
||||
# Parameters and configurations for the client model
|
||||
self.split_factor = args.split_factor
|
||||
self.arch = args.arch
|
||||
self.loop_factor = args.loop_factor
|
||||
self.is_train_sep = args.is_train_sep
|
||||
self.epochs = args.epochs
|
||||
self.num_classes = args.num_classes
|
||||
self.is_diff_data_train = args.is_diff_data_train
|
||||
self.is_mixup = args.is_mixup
|
||||
self.mix_alpha = args.mix_alpha
|
||||
|
||||
# Model arguments
|
||||
model_kwargs = {
|
||||
'num_classes': self.num_classes,
|
||||
'norm_layer': norm_layer,
|
||||
'dataset': args.dataset,
|
||||
'split_factor': self.split_factor,
|
||||
'output_stride': args.output_stride
|
||||
}
|
||||
|
||||
# Initialize multiple instances of the network architecture for the main client
|
||||
if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']:
|
||||
self.main_client_models = nn.ModuleList(
|
||||
[_retrieve_networkwork(self.arch)(models_pretrained=args.models_pretrained, **model_kwargs)[0]
|
||||
for _ in range(self.loop_factor)]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Architecture '{self.arch}' not implemented.")
|
||||
|
||||
# Identical initialization of the model if specified
|
||||
if args.is_identical_init:
|
||||
print("INFO:PyTorch: Using identical initialization.")
|
||||
self._identical_init()
|
||||
|
||||
def forward(self, x, target=None, mode='train', epoch=0, streams=None):
|
||||
"""Forward pass for the main client. Handles both training and evaluation modes."""
|
||||
main_client_outputs = []
|
||||
|
||||
if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']:
|
||||
if mode == 'train':
|
||||
# Apply mixup augmentation if enabled
|
||||
if self.is_mixup:
|
||||
x, y_a, y_b, lam = combine_mixup_data(x, target, alpha=self.mix_alpha)
|
||||
|
||||
# Split input data across multiple sub-networks during training
|
||||
all_x = torch.chunk(x, chunks=self.loop_factor, dim=1) if self.is_diff_data_train else [x] * self.loop_factor
|
||||
|
||||
for i in range(self.loop_factor):
|
||||
fx = self.main_client_models[i](all_x[i])
|
||||
main_client_outputs.append(fx.clone().detach().requires_grad_(True))
|
||||
|
||||
return main_client_outputs, y_a, y_b, lam
|
||||
elif mode in ['val', 'test']:
|
||||
# Forward pass during evaluation or testing
|
||||
for i in range(self.loop_factor):
|
||||
fx = self.main_client_models[i](x)
|
||||
main_client_outputs.append(fx.clone().detach().requires_grad_(True))
|
||||
|
||||
return main_client_outputs
|
||||
else:
|
||||
# Return a dummy tensor if the mode is unsupported
|
||||
return torch.ones(1)
|
||||
else:
|
||||
raise NotImplementedError(f"Mode '{mode}' not supported for architecture '{self.arch}'.")
|
||||
|
||||
def _identical_init(self):
|
||||
"""Ensure identical initialization of weights for sub-networks."""
|
||||
with torch.no_grad():
|
||||
# Copy weights from the first model to all subsequent models
|
||||
for i in range(1, self.split_factor):
|
||||
for (name1, param1), (name2, param2) in zip(self.main_client_models[i].named_parameters(),
|
||||
self.main_client_models[0].named_parameters()):
|
||||
if 'weight' in name1:
|
||||
param1.data.copy_(param2.data)
|
||||
|
||||
class coremodelProxyClient(nn.Module):
|
||||
"""Proxy client model to handle downstream processing and training logic."""
|
||||
|
||||
def __init__(self, args, norm_layer=None, criterion=None, progress=True):
|
||||
super(coremodelProxyClient, self).__init__()
|
||||
|
||||
# Parameters and configurations for the proxy client model
|
||||
self.split_factor = args.split_factor
|
||||
self.arch = args.arch
|
||||
self.loop_factor = args.loop_factor
|
||||
self.epochs = args.epochs
|
||||
self.num_classes = args.num_classes
|
||||
self.criterion = criterion
|
||||
self.is_mixup = args.is_mixup
|
||||
self.is_ensembled_loss = args.is_ensembled_loss if self.split_factor > 1 else False
|
||||
self.ensembled_loss_weight = args.ensembled_loss_weight
|
||||
self.is_ensembled_after_softmax = args.is_ensembled_after_softmax if self.split_factor > 1 else False
|
||||
self.is_max_ensemble = args.is_max_ensemble if self.split_factor > 1 else False
|
||||
self.is_cot_loss = args.is_cot_loss if self.split_factor > 1 else False
|
||||
self.cot_weight = args.cot_weight
|
||||
self.is_cot_weight_warm_up = args.is_cot_weight_warm_up
|
||||
self.cot_weight_warm_up_epochs = args.cot_weight_warm_up_epochs
|
||||
self.cot_loss_choose = args.cot_loss_choose
|
||||
|
||||
# Model arguments for the proxy client
|
||||
model_kwargs = {
|
||||
'num_classes': self.num_classes,
|
||||
'norm_layer': norm_layer,
|
||||
'dataset': args.dataset,
|
||||
'split_factor': self.split_factor,
|
||||
'output_stride': args.output_stride
|
||||
}
|
||||
|
||||
# Initialize multiple instances of the network architecture for the proxy client
|
||||
if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']:
|
||||
self.proxy_clients_models = nn.ModuleList(
|
||||
[_retrieve_networkwork(self.arch)(models_pretrained=args.models_pretrained, **model_kwargs)[1]
|
||||
for _ in range(self.loop_factor)]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Architecture '{self.arch}' not implemented.")
|
||||
|
||||
# Identical initialization of the model if specified
|
||||
if args.is_identical_init:
|
||||
print("INFO:PyTorch: Using identical initialization.")
|
||||
self._identical_init()
|
||||
|
||||
def forward(self, main_client_outputs, y_a=None, y_b=None, lam=None, target=None, mode='train', epoch=0, streams=None):
|
||||
"""Forward pass for the proxy client. Manages multiple sub-networks and ensemble outputs."""
|
||||
outputs = []
|
||||
ce_losses = []
|
||||
|
||||
if self.arch in ['resnet_model_110sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']:
|
||||
if mode == 'train':
|
||||
# Calculate loss and forward pass during training
|
||||
for i in range(self.loop_factor):
|
||||
output = self.proxy_clients_models[i](main_client_outputs[i])
|
||||
loss = mixup_loss_criterion(self.criterion, output, y_a, y_b, lam) if self.is_mixup else self.criterion(output, target)
|
||||
outputs.append(output)
|
||||
ce_losses.append(loss)
|
||||
|
||||
ensemble_output = self._collect_ensemble_output(outputs)
|
||||
ce_loss = torch.sum(torch.stack(ce_losses, dim=0))
|
||||
|
||||
# Calculate co-training loss if enabled
|
||||
if self.is_cot_loss:
|
||||
cot_loss = self._calculate_co_training_loss(outputs, epoch)
|
||||
else:
|
||||
cot_loss = torch.zeros_like(ce_loss)
|
||||
|
||||
return ensemble_output, torch.stack(outputs, dim=0), ce_loss, cot_loss
|
||||
|
||||
elif mode in ['val', 'test']:
|
||||
# Forward pass during evaluation or testing
|
||||
for i in range(self.loop_factor):
|
||||
output = self.proxy_clients_models[i](main_client_outputs[i])
|
||||
loss = self.criterion(output, target) if self.criterion else torch.zeros(1)
|
||||
outputs.append(output)
|
||||
ce_losses.append(loss)
|
||||
|
||||
ensemble_output = self._collect_ensemble_output(outputs)
|
||||
ce_loss = torch.sum(torch.stack(ce_losses, dim=0))
|
||||
return ensemble_output, torch.stack(outputs, dim=0), ce_loss
|
||||
else:
|
||||
# Return a dummy tensor if the mode is unsupported
|
||||
return torch.ones(1)
|
||||
else:
|
||||
raise NotImplementedError(f"Mode '{mode}' not supported for architecture '{self.arch}'.")
|
||||
|
||||
def _collect_ensemble_output(self, outputs):
|
||||
"""Calculate the ensemble output from multiple sub-networks."""
|
||||
stacked_outputs = torch.stack(outputs, dim=0)
|
||||
|
||||
# Apply softmax to the outputs before ensembling if specified
|
||||
if self.is_ensembled_after_softmax:
|
||||
if self.is_max_ensemble:
|
||||
ensemble_output, _ = torch.max(F.softmax(stacked_outputs, dim=-1), dim=0)
|
||||
else:
|
||||
ensemble_output = torch.mean(F.softmax(stacked_outputs, dim=-1), dim=0)
|
||||
else:
|
||||
if self.is_max_ensemble:
|
||||
ensemble_output, _ = torch.max(stacked_outputs, dim=0)
|
||||
else:
|
||||
ensemble_output = torch.mean(stacked_outputs, dim=0)
|
||||
|
||||
return ensemble_output
|
||||
|
||||
def _calculate_co_training_loss(self, outputs, epoch):
|
||||
"""Calculate the co-training loss between outputs of different sub-networks."""
|
||||
# Adjust the weight of the co-training loss during warm-up epochs
|
||||
weight_now = self.cot_weight if not self.is_cot_weight_warm_up or epoch >= self.cot_weight_warm_up_epochs else max(self.cot_weight * epoch / self.cot_weight_warm_up_epochs, 0.005)
|
||||
|
||||
# Different methods of calculating co-training loss
|
||||
if self.cot_loss_choose == 'js_divergence':
|
||||
outputs_all = torch.stack(outputs, dim=0)
|
||||
p_all = F.softmax(outputs_all, dim=-1)
|
||||
p_mean = torch.mean(p_all, dim=0)
|
||||
H_mean = (-p_mean * torch.log(p_mean)).sum(-1).mean()
|
||||
H_sep = (-p_all * F.log_softmax(outputs_all, dim=-1)).sum(-1).mean()
|
||||
return weight_now * (H_mean - H_sep)
|
||||
elif self.cot_loss_choose == 'kl_separate':
|
||||
outputs_all = torch.stack(outputs, dim=0)
|
||||
outputs_r1 = torch.repeat_interleave(outputs_all, self.split_factor - 1, dim=0)
|
||||
index_list = [j for i in range(self.split_factor) for j in range(self.split_factor) if j != i]
|
||||
outputs_r2 = torch.index_select(outputs_all, dim=0, index=torch.tensor(index_list, dtype=torch.long).cuda())
|
||||
kl_loss = F.kl_div(F.log_softmax(outputs_r1, dim=-1), F.softmax(outputs_r2, dim=-1).detach(), reduction='none')
|
||||
return weight_now * kl_loss.sum(-1).mean(-1).sum() / (self.split_factor - 1)
|
||||
else:
|
||||
raise NotImplementedError(f"Co-training loss '{self.cot_loss_choose}' not implemented.")
|
||||
|
||||
def _identical_init(self):
|
||||
"""Ensure identical initialization of weights for sub-networks."""
|
||||
with torch.no_grad():
|
||||
# Copy weights from the first model to all subsequent models
|
||||
for i in range(1, self.split_factor):
|
||||
for (name1, param1), (name2, param2) in zip(self.proxy_clients_models[i].named_parameters(),
|
||||
self.proxy_clients_models[0].named_parameters()):
|
||||
if 'weight' in name1:
|
||||
param1.data.copy_(param2.data)
|
60
EdgeFLite/architecture/mixup.py
Normal file
60
EdgeFLite/architecture/mixup.py
Normal file
@ -0,0 +1,60 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
@torch.no_grad()
|
||||
def combine_mixup_data(x, y, alpha=1.0, use_cuda=True):
|
||||
"""
|
||||
Perform the mixup operation on input data.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input features, typically from the dataset.
|
||||
y (Tensor): Input labels corresponding to the features.
|
||||
alpha (float): Mixup interpolation coefficient. The default value is 1.0.
|
||||
A higher value results in more mixing between samples.
|
||||
use_cuda (bool): Boolean flag to indicate whether CUDA should be used if available.
|
||||
|
||||
Returns:
|
||||
mixed_x (Tensor): Mixed input features, a linear combination of x and a permuted version of x.
|
||||
y_a (Tensor): Original input labels corresponding to x.
|
||||
y_b (Tensor): Permuted input labels corresponding to the mixed samples.
|
||||
lam (float): The lambda value used for interpolation between samples.
|
||||
"""
|
||||
# Draw lambda value from the Beta distribution if alpha > 0, otherwise set lam to 1 (no mixup)
|
||||
lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
|
||||
|
||||
# Get the batch size from the input tensor
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Generate a random permutation of indices for mixing
|
||||
# Use CUDA if available, otherwise stick with CPU
|
||||
index = torch.randperm(batch_size).cuda() if use_cuda else torch.randperm(batch_size)
|
||||
|
||||
# Mix the features of the original and permuted samples using the lambda value
|
||||
mixed_x = lam * x + (1 - lam) * x[index, :]
|
||||
|
||||
# Assign original and permuted labels to y_a and y_b, respectively
|
||||
y_a, y_b = y, y[index]
|
||||
|
||||
# Return mixed features, original and permuted labels, and the lambda value
|
||||
return mixed_x, y_a, y_b, lam
|
||||
|
||||
|
||||
def mixup_loss_criterion(criterion, pred, y_a, y_b, lam):
|
||||
"""
|
||||
Compute the mixup loss using the provided criterion.
|
||||
|
||||
Args:
|
||||
criterion (function): The loss function used to compute the error (e.g., CrossEntropyLoss).
|
||||
pred (Tensor): The model predictions, typically the output of a neural network.
|
||||
y_a (Tensor): The original labels corresponding to the original input features.
|
||||
y_b (Tensor): The permuted labels corresponding to the mixed input features.
|
||||
lam (float): The lambda value for mixup, used to interpolate between the two losses.
|
||||
|
||||
Returns:
|
||||
loss (Tensor): The final mixup loss, computed as a weighted sum of the two losses.
|
||||
"""
|
||||
# Compute the mixup loss by combining the loss from the original and permuted labels
|
||||
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
|
237
EdgeFLite/architecture/resnet.py
Normal file
237
EdgeFLite/architecture/resnet.py
Normal file
@ -0,0 +1,237 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Try to import the method to load model weights from a URL, with a fallback in case of ImportError
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
|
||||
# List of available ResNet architectures
|
||||
__all__ = ['resnet_model_18', 'resnet_model_34', 'resnet_model_50',
|
||||
'resnet_model_101', 'resnet_model_152', 'resnet_model_200',
|
||||
'resnet110', 'resnet164',
|
||||
'resnext29_8x64d', 'resnext29_16x64d',
|
||||
'resnext50_32x4d', 'resnext101_32x4d',
|
||||
'resnext101_32x8d', 'resnext101_64x4d',
|
||||
'wide_resnet_model_50_2', 'wide_resnet_model_50_3', 'wide_resnet_model_101_2',
|
||||
'wide_resnet16_8', 'wide_resnet52_8', 'wide_resnet16_12',
|
||||
'wide_resnet28_10', 'wide_resnet40_10']
|
||||
|
||||
# Pre-trained model URLs for various ResNet variants
|
||||
model_urls = {
|
||||
'resnet_model_18': 'https://download.pytorch.org/models/resnet_model_18-5c106cde.pth',
|
||||
'resnet_model_34': 'https://download.pytorch.org/models/resnet_model_34-333f7ec4.pth',
|
||||
'resnet_model_50': 'https://download.pytorch.org/models/resnet_model_50-19c8e357.pth',
|
||||
'resnet_model_101': 'https://download.pytorch.org/models/resnet_model_101-5d3b4d8f.pth',
|
||||
'resnet_model_152': 'https://download.pytorch.org/models/resnet_model_152-b121ed2d.pth',
|
||||
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet_model_50_2': 'https://download.pytorch.org/models/wide_resnet_model_50_2-95faca4d.pth',
|
||||
'wide_resnet_model_101_2': 'https://download.pytorch.org/models/wide_resnet_model_101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
# Function for a 3x3 convolution with padding
|
||||
def apply_3x3_convolution(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
# Function for a 1x1 convolution
|
||||
def apply_1x1_convolution(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
# BasicBlock class for the ResNet architecture
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1 # Expansion factor for the output channels
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
# If norm_layer is not provided, use BatchNorm2d as the default
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
# Ensure BasicBlock is restricted to specific parameters
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock is restricted to groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("BasicBlock does not support dilation greater than 1")
|
||||
|
||||
# Define the layers for the BasicBlock
|
||||
self.conv1 = apply_3x3_convolution(inplanes, planes, stride) # First 3x3 convolution
|
||||
self.bn1 = norm_layer(planes) # First BatchNorm layer
|
||||
self.relu = nn.ReLU(inplace=True) # ReLU activation
|
||||
self.conv2 = apply_3x3_convolution(planes, planes) # Second 3x3 convolution
|
||||
self.bn2 = norm_layer(planes) # Second BatchNorm layer
|
||||
self.downsample = downsample # Optional downsample layer
|
||||
self.stride = stride
|
||||
|
||||
# Define the forward pass for BasicBlock
|
||||
def forward(self, x):
|
||||
identity = x # Save the input for the skip connection
|
||||
|
||||
out = self.conv1(x) # First convolution
|
||||
out = self.bn1(out) # BatchNorm after first convolution
|
||||
out = self.relu(out) # ReLU activation
|
||||
|
||||
out = self.conv2(out) # Second convolution
|
||||
out = self.bn2(out) # BatchNorm after second convolution
|
||||
|
||||
# Apply downsample if defined
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity # Add the skip connection
|
||||
out = self.relu(out) # Apply ReLU activation again
|
||||
|
||||
return out
|
||||
|
||||
# Bottleneck class for the ResNet architecture, a more complex block used in deeper ResNet models
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4 # Expansion factor for the output channels
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups # Calculate width based on base width and groups
|
||||
|
||||
# Define the layers for the Bottleneck block
|
||||
self.conv1 = apply_1x1_convolution(inplanes, width) # 1x1 convolution to reduce the dimensions
|
||||
self.bn1 = norm_layer(width) # BatchNorm after 1x1 convolution
|
||||
self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) # 3x3 convolution
|
||||
self.bn2 = norm_layer(width) # BatchNorm after 3x3 convolution
|
||||
self.conv3 = apply_1x1_convolution(width, planes * self.expansion) # 1x1 convolution to expand the dimensions
|
||||
self.bn3 = norm_layer(planes * self.expansion) # BatchNorm after final 1x1 convolution
|
||||
self.relu = nn.ReLU(inplace=True) # ReLU activation
|
||||
self.downsample = downsample # Optional downsample layer
|
||||
self.stride = stride
|
||||
|
||||
# Define the forward pass for Bottleneck
|
||||
def forward(self, x):
|
||||
identity = x # Save the input for the skip connection
|
||||
|
||||
out = self.conv1(x) # First convolution
|
||||
out = self.bn1(out) # BatchNorm after first convolution
|
||||
out = self.relu(out) # ReLU activation
|
||||
|
||||
out = self.conv2(out) # Second convolution
|
||||
out = self.bn2(out) # BatchNorm after second convolution
|
||||
out = self.relu(out) # ReLU activation
|
||||
|
||||
out = self.conv3(out) # Third convolution
|
||||
out = self.bn3(out) # BatchNorm after third convolution
|
||||
|
||||
# Apply downsample if defined
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity # Add the skip connection
|
||||
out = self.relu(out) # Apply ReLU activation again
|
||||
|
||||
return out
|
||||
|
||||
# Main ResNet class, a customizable deep learning model architecture
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d # Default normalization layer
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.groups = groups # Number of groups in convolutions
|
||||
self.inplanes = 16 if dataset in ['cifar10', 'cifar100'] else 64 # Adjust initial planes for CIFAR
|
||||
|
||||
# First layer: a combination of convolution, normalization, and ReLU
|
||||
self.layer0 = nn.Sequential(
|
||||
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
norm_layer(self.inplanes),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
# Subsequent ResNet layers using the _create_model_layer method
|
||||
self.layer1 = self._create_model_layer(block, 16, layers[0])
|
||||
self.layer2 = self._create_model_layer(block, 32, layers[1], stride=2)
|
||||
self.layer3 = self._create_model_layer(block, 64, layers[2], stride=2)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling
|
||||
self.fc = nn.Linear(64 * block.expansion, num_classes) # Fully connected layer for classification
|
||||
|
||||
# Initialization for model weights
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 1e-3)
|
||||
|
||||
# Zero-initialize the last BatchNorm in residual connections if required
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
# Helper function to create layers in ResNet
|
||||
def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer # Set normalization layer
|
||||
downsample = None
|
||||
# If the stride is not 1 or input/output planes do not match, create a downsample layer
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
apply_1x1_convolution(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
layers = [block(self.inplanes, planes, stride, downsample)] # Create the first block with downsampling
|
||||
self.inplanes = planes * block.expansion # Update inplanes for next blocks
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes)) # Add subsequent blocks without downsampling
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
# Forward pass through the ResNet architecture
|
||||
def forward(self, x):
|
||||
x = self.layer0(x) # Pass input through the first layer
|
||||
x = self.layer1(x) # First ResNet layer
|
||||
x = self.layer2(x) # Second ResNet layer
|
||||
x = self.layer3(x) # Third ResNet layer
|
||||
x = self.avgpool(x) # Global average pooling
|
||||
x = torch.flatten(x, 1) # Flatten the output for the fully connected layer
|
||||
x = self.fc(x) # Pass through the fully connected layer
|
||||
return x
|
||||
|
||||
# Helper function to instantiate ResNet with pretrained weights if available
|
||||
def _resnet(arch, block, layers, models_pretrained, progress, **kwargs):
|
||||
model = ResNet(arch, block, layers, **kwargs) # Create a ResNet model
|
||||
if models_pretrained: # Load pretrained weights if requested
|
||||
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
# Functions to create specific ResNet variants
|
||||
def resnet_model_18(models_pretrained=False, progress=True, **kwargs):
|
||||
return _resnet('resnet_model_18', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs)
|
||||
|
||||
def resnet_model_34(models_pretrained=False, progress=True, **kwargs):
|
||||
return _resnet('resnet_model_34', BasicBlock, [3, 4, 6, 3], models_pretrained, progress, **kwargs)
|
||||
|
||||
def resnet_model_50(models_pretrained=False, progress=True, **kwargs):
|
||||
return _resnet('resnet_model_50', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs)
|
||||
|
||||
def resnet_model_101(models_pretrained=False, progress=True, **kwargs):
|
||||
return _resnet('resnet_model_101', Bottleneck, [3, 4, 23, 3], models_pretrained, progress, **kwargs)
|
||||
|
||||
def resnet_model_152(models_pretrained=False, progress=True, **kwargs):
|
||||
return _resnet('resnet_model_152', Bottleneck, [3, 8, 36, 3], models_pretrained, progress, **kwargs)
|
||||
|
||||
def resnet_model_200(models_pretrained=False, progress=True, **kwargs):
|
||||
return _resnet('resnet_model_200', Bottleneck, [3, 24, 36, 3], models_pretrained, progress, **kwargs)
|
312
EdgeFLite/architecture/resnet_sl.py
Normal file
312
EdgeFLite/architecture/resnet_sl.py
Normal file
@ -0,0 +1,312 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Importing necessary PyTorch libraries
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Attempt to import model loading utilities from torch.hub; fall back to torch.utils.model_zoo if unavailable
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
|
||||
# Specify all the modules and functions to export
|
||||
__all__ = ['resnet110_sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']
|
||||
|
||||
# Function for 3x3 convolution with padding
|
||||
def apply_3x3_convolution(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
# Function for 1x1 convolution, typically used to change the number of channels
|
||||
def apply_1x1_convolution(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
# Basic Block class for ResNet (used in smaller networks like resnet_model_18/resnet_model_34)
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1 # Expansion factor for output channels
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
# BasicBlock only supports groups=1 and base_width=64
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("BasicBlock does not support dilation greater than 1")
|
||||
|
||||
# Define two 3x3 convolution layers with batch normalization and ReLU activation
|
||||
self.conv1 = apply_3x3_convolution(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = apply_3x3_convolution(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
# Optional downsample layer for changing the dimensions
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
# Forward function defining the data flow through the block
|
||||
def forward(self, x):
|
||||
identity = x # Save the input for residual connection
|
||||
|
||||
# First convolution, batch norm, and ReLU
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
# Second convolution, batch norm
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
# Apply downsample if needed to match dimensions for residual addition
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
# Residual connection (add identity to output)
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
# Bottleneck block class for deeper ResNet architectures (e.g., resnet_model_50/resnet_model_101)
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4 # Expansion factor for output channels (output = input * 4)
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
# Width of the block based on base_width and groups
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
|
||||
# Define 1x1, 3x3, and 1x1 convolutions with batch norm and ReLU activation
|
||||
self.conv1 = apply_1x1_convolution(inplanes, width) # First 1x1 convolution
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation) # Main 3x3 convolution
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = apply_1x1_convolution(width, planes * self.expansion) # Final 1x1 convolution
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample # Downsample layer for dimension adjustment
|
||||
self.stride = stride
|
||||
|
||||
# Forward function defining the data flow through the bottleneck block
|
||||
def forward(self, x):
|
||||
identity = x # Save the input for residual connection
|
||||
|
||||
# First 1x1 convolution, batch norm, and ReLU
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
# Second 3x3 convolution, batch norm, and ReLU
|
||||
out = self.conv2(x)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
# Third 1x1 convolution, batch norm
|
||||
out = self.conv3(x)
|
||||
out = self.bn3(out)
|
||||
|
||||
# Apply downsample if needed for residual connection
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
# Residual connection (add identity to output)
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
# ResNet model for the main client (usually the primary model)
|
||||
class PrimaryResNetClient(nn.Module):
|
||||
|
||||
def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None):
|
||||
super(PrimaryResNetClient, self).__init__()
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
# Initialize the number of input channels based on the dataset and split factor
|
||||
inplanes_dict = {
|
||||
'cifar10': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3},
|
||||
'cifar100': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3},
|
||||
'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24},
|
||||
'pill_base': {1: 64, 2: 44, 4: 32, 8: 24},
|
||||
'medical_images': {1: 64, 2: 44, 4: 32, 8: 24},
|
||||
}
|
||||
self.inplanes = inplanes_dict[dataset][split_factor]
|
||||
|
||||
# Adjust input planes if using a wide ResNet
|
||||
if 'wide_resnet' in arch:
|
||||
widen_factor = int(arch.split('_')[-1])
|
||||
self.inplanes *= int(max(widen_factor / (split_factor ** 0.5) + 0.4, 1.0))
|
||||
|
||||
self.base_width = width_per_group
|
||||
self.dilation = 1
|
||||
replace_stride_with_dilation = replace_stride_with_dilation or [False, False, False]
|
||||
|
||||
# Check if replace_stride_with_dilation is properly defined
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation must either be None or a tuple with three elements")
|
||||
|
||||
# Initialize input layer depending on the dataset (small or large)
|
||||
if dataset in ['skin_dataset', 'pill_base', 'medical_images']:
|
||||
self.layer0 = self._initialize_primary_layer_large()
|
||||
else:
|
||||
self.layer0 = self._init_layer0_small()
|
||||
|
||||
# Initialize model weights
|
||||
self._init_model_weights(zero_init_residual)
|
||||
|
||||
# Define the large initial convolution layer for large datasets
|
||||
def _initialize_primary_layer_large(self):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
self._norm_layer(self.inplanes),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
)
|
||||
|
||||
# Define the small initial convolution layer for smaller datasets like CIFAR
|
||||
def _init_layer0_small(self):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
self._norm_layer(self.inplanes),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
# Function to initialize weights in the network
|
||||
def _init_model_weights(self, zero_init_residual):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, std=1e-3)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Initialize residual weights for Bottleneck and BasicBlock if specified
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
# Define forward pass for the model
|
||||
def forward(self, x):
|
||||
x = self.layer0(x)
|
||||
return x
|
||||
|
||||
# ResNet model for proxy clients (usually assisting the main model)
|
||||
class ResNetProxies(nn.Module):
|
||||
|
||||
def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None):
|
||||
super(ResNetProxies, self).__init__()
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
# Set input channels based on architecture, dataset, and split factor
|
||||
self.inplanes = self._set_input_planes(arch, dataset, split_factor, width_per_group)
|
||||
self.base_width = width_per_group
|
||||
|
||||
# Define layers of the network (layer1, layer2, layer3)
|
||||
self.layer1 = self._create_model_layer(block, self.inplanes, layers[0], stride=1)
|
||||
self.layer2 = self._create_model_layer(block, self.inplanes * 2, layers[1], stride=2)
|
||||
self.layer3 = self._create_model_layer(block, self.inplanes * 4, layers[2], stride=2)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive average pooling layer
|
||||
self.fc = nn.Linear(self.inplanes * 4 * block.expansion, num_classes)
|
||||
|
||||
# Initialize model weights
|
||||
self._init_model_weights(zero_init_residual)
|
||||
|
||||
# Set input channels based on dataset and split factor
|
||||
def _set_input_planes(self, arch, dataset, split_factor, width_per_group):
|
||||
inplanes_dict = {
|
||||
'cifar10': {1: 16, 2: 12, 4: 8, 8: 6},
|
||||
'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24},
|
||||
}
|
||||
inplanes = inplanes_dict[dataset][split_factor]
|
||||
|
||||
# Adjust input planes for wide ResNet
|
||||
if 'wide_resnet' in arch:
|
||||
widen_factor = float(arch.split('_')[-1])
|
||||
inplanes *= int(max(widen_factor / (split_factor ** 0.5) + 0.4, 1.0))
|
||||
|
||||
return inplanes
|
||||
|
||||
# Function to create layers of the network (consisting of blocks)
|
||||
def _create_model_layer(self, block, planes, blocks, stride=1):
|
||||
layers = [block(self.inplanes, planes, stride)] # First block
|
||||
self.inplanes = planes * block.expansion # Update input planes
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes)) # Additional blocks
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
# Initialize weights in the network
|
||||
def _init_model_weights(self, zero_init_residual):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, std=1e-3)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Initialize residual weights for Bottleneck and BasicBlock if specified
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
# Define forward pass for the model
|
||||
def forward(self, x):
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
# Helper function to create the main ResNet client
|
||||
def _resnetsl_primary_client_(arch, block, layers, models_pretrained, progress, **kwargs):
|
||||
return PrimaryResNetClient(arch, block, layers, **kwargs)
|
||||
|
||||
# Helper function to create the proxy ResNet client
|
||||
def _resnetsl_secondary_client_(arch, block, layers, models_pretrained, progress, **kwargs):
|
||||
return ResNetProxies(arch, block, layers, **kwargs)
|
||||
|
||||
# Function to define a ResNet-110 model for main and proxy clients
|
||||
def resnet_model_110sl(models_pretrained=False, progress=True, **kwargs):
|
||||
assert 'cifar' in kwargs['dataset'] # Ensure that CIFAR dataset is used
|
||||
return _resnetsl_primary_client_('resnet110_sl', Bottleneck, [12, 12, 12, 12], models_pretrained, progress, **kwargs), \
|
||||
_resnetsl_secondary_client_('resnet110_sl', Bottleneck, [12, 12, 12, 12], models_pretrained, progress, **kwargs)
|
||||
|
||||
# Function to define a Wide ResNet-50-2 model for main and proxy clients
|
||||
def wide_resnetsl50_2(models_pretrained=False, progress=True, **kwargs):
|
||||
kwargs['width_per_group'] = 64 * 2 # Adjust width for Wide ResNet
|
||||
return _resnetsl_primary_client_('wide_resnetsl50_2', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs), \
|
||||
_resnetsl_secondary_client_('wide_resnetsl50_2', Bottleneck, [3, 4, 6, 3], models_pretrained, progress, **kwargs)
|
||||
|
||||
# Function to define a Wide ResNet-16-8 model for main and proxy clients
|
||||
def wide_resnetsl16_8(models_pretrained=False, progress=True, **kwargs):
|
||||
kwargs['width_per_group'] = 64 # Adjust width for Wide ResNet
|
||||
return _resnetsl_primary_client_('wide_resnetsl16_8', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs), \
|
||||
_resnetsl_secondary_client_('wide_resnetsl16_8', BasicBlock, [2, 2, 2, 2], models_pretrained, progress, **kwargs)
|
212
EdgeFLite/architecture/splitnet.py
Normal file
212
EdgeFLite/architecture/splitnet.py
Normal file
@ -0,0 +1,212 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from sklearn import ensemble
|
||||
from .mixup import mixup_loss_criterion, combine_mixup_data
|
||||
from . import resnet, resnet_sl
|
||||
|
||||
__all__ = ['coremodel']
|
||||
|
||||
def _retrieve_network(arch='wide_resnet28_10'):
|
||||
"""
|
||||
Get the network architecture based on the provided name.
|
||||
|
||||
Args:
|
||||
arch (str): Name of the architecture.
|
||||
|
||||
Returns:
|
||||
Callable: The network class or function corresponding to the given architecture.
|
||||
"""
|
||||
networks = {
|
||||
'wide_resnet28_10': resnet.wide_resnet28_10,
|
||||
'wide_resnet16_8': resnet.wide_resnet16_8,
|
||||
'resnet110': resnet.resnet110,
|
||||
'wide_resnet_model_50_2': resnet.wide_resnet_model_50_2
|
||||
}
|
||||
if arch not in networks:
|
||||
raise ValueError(f"Architecture {arch} is not supported.")
|
||||
return networks[arch]
|
||||
|
||||
class coremodel(nn.Module):
|
||||
def __init__(self, args, norm_layer=None, criterion=None, progress=True):
|
||||
"""
|
||||
Initialize the coremodel model with multiple sub-networks.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): Configuration arguments.
|
||||
norm_layer (callable, optional): Normalization layer.
|
||||
criterion (callable, optional): Loss function.
|
||||
progress (bool): Whether to show progress.
|
||||
"""
|
||||
super(coremodel, self).__init__()
|
||||
|
||||
# Configuration parameters
|
||||
self.split_factor = args.split_factor
|
||||
self.arch = args.arch
|
||||
self.loop_factor = args.loop_factor
|
||||
self.is_train_sep = args.is_train_sep
|
||||
self.epochs = args.epochs
|
||||
self.criterion = criterion
|
||||
self.is_diff_data_train = args.is_diff_data_train
|
||||
self.is_mixup = args.is_mixup
|
||||
self.mix_alpha = args.mix_alpha
|
||||
|
||||
# Define model architectures
|
||||
valid_archs = [
|
||||
'resnet_model_50', 'resnet_model_101', 'resnet_model_152', 'resnet_model_200',
|
||||
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d',
|
||||
'resnext29_8x64d', 'resnext29_16x64d', 'resnet110', 'resnet164',
|
||||
'wide_resnet16_8', 'wide_resnet16_12', 'wide_resnet28_10', 'wide_resnet40_10',
|
||||
'wide_resnet52_8', 'wide_resnet_model_50_2', 'wide_resnet_model_50_3', 'wide_resnet_model_101_2'
|
||||
]
|
||||
|
||||
if self.arch not in valid_archs:
|
||||
raise NotImplementedError(f"Architecture {self.arch} is not implemented.")
|
||||
|
||||
model_args = {
|
||||
'num_classes': args.num_classes,
|
||||
'norm_layer': norm_layer,
|
||||
'dataset': args.dataset,
|
||||
'split_factor': self.split_factor,
|
||||
'output_stride': args.output_stride
|
||||
}
|
||||
# Initialize multiple sub-models based on the loop factor
|
||||
self.models = nn.ModuleList([_retrieve_network(self.arch)(models_models_pretrained=args.models_models_pretrained, **model_args) for _ in range(self.loop_factor)])
|
||||
|
||||
if args.is_identical_init:
|
||||
print("INFO: Using identical initialization.")
|
||||
self._identical_init()
|
||||
|
||||
# Ensemble settings
|
||||
self.is_ensembled_loss = args.is_ensembled_loss if self.split_factor > 1 else False
|
||||
self.ensembled_loss_weight = args.ensembled_loss_weight
|
||||
self.is_ensembled_after_softmax = args.is_ensembled_after_softmax if self.split_factor > 1 else False
|
||||
self.is_max_ensemble = args.is_max_ensemble if self.split_factor > 1 else False
|
||||
|
||||
# Co-training settings
|
||||
self.is_cot_loss = args.is_cot_loss if self.split_factor > 1 else False
|
||||
self.cot_weight = args.cot_weight
|
||||
self.is_cot_weight_warm_up = args.is_cot_weight_warm_up
|
||||
self.cot_weight_warm_up_epochs = args.cot_weight_warm_up_epochs
|
||||
self.cot_loss_choose = args.cot_loss_choose
|
||||
print(f"INFO: The co-training loss is {self.cot_loss_choose}.")
|
||||
self.num_classes = args.num_classes
|
||||
|
||||
def forward(self, x, target=None, mode='train', epoch=0, streams=None):
|
||||
"""
|
||||
Forward pass through the model with optional mixup and co-training loss.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor.
|
||||
target (Tensor, optional): Target tensor for loss computation.
|
||||
mode (str): Mode of operation ('train', 'val', or 'test').
|
||||
epoch (int): Current epoch.
|
||||
streams (optional): Additional data streams.
|
||||
|
||||
Returns:
|
||||
Tuple:
|
||||
- ensemble_output (Tensor): The ensemble output of shape [batch_size, num_classes].
|
||||
- outputs (Tensor): Stack of individual outputs of shape [split_factor, batch_size, num_classes].
|
||||
- ce_loss (Tensor): Sum of cross-entropy losses for each model.
|
||||
- cot_loss (Tensor): Co-training loss if applicable.
|
||||
"""
|
||||
outputs, ce_losses = [], []
|
||||
|
||||
if 'train' in mode:
|
||||
if self.is_mixup:
|
||||
x, y_a, y_b, lam = combine_mixup_data(x, target, alpha=self.mix_alpha)
|
||||
|
||||
# Split input data based on the loop factor
|
||||
all_x = torch.chunk(x, chunks=self.loop_factor, dim=1) if self.is_diff_data_train else [x]
|
||||
|
||||
for i in range(self.loop_factor:
|
||||
x_input = all_x[i]
|
||||
output = self.models[i](x_input)
|
||||
loss = mixup_loss_criterion(self.criterion, output, y_a, y_b, lam) if self.is_mixup else self.criterion(output, target)
|
||||
outputs.append(output)
|
||||
ce_losses.append(loss)
|
||||
|
||||
elif mode in ['val', 'test']:
|
||||
for i in range(self.loop_factor:
|
||||
output = self.models[i](x)
|
||||
loss = self.criterion(output, target) if self.criterion else torch.zeros(1)
|
||||
outputs.append(output)
|
||||
ce_losses.append(loss)
|
||||
|
||||
else:
|
||||
return torch.ones(1), None, None, None
|
||||
|
||||
# Calculate ensemble output and losses
|
||||
ensemble_output = self._collect_ensemble_output(outputs)
|
||||
ce_loss = torch.sum(torch.stack(ce_losses))
|
||||
|
||||
if mode in ['val', 'test']:
|
||||
return ensemble_output, torch.stack(outputs, dim=0), ce_loss
|
||||
|
||||
if self.is_cot_loss:
|
||||
cot_loss = self._calculate_co_training_loss(outputs, self.cot_loss_choose, epoch)
|
||||
else:
|
||||
cot_loss = torch.zeros_like(ce_loss)
|
||||
|
||||
return ensemble_output, torch.stack(outputs, dim=0), ce_loss, cot_loss
|
||||
|
||||
def _collect_ensemble_output(self, outputs):
|
||||
"""
|
||||
Calculate the ensemble output from a list of tensors.
|
||||
|
||||
Args:
|
||||
outputs (list of tensors): A list where each tensor has shape [batch_size, num_classes].
|
||||
|
||||
Returns:
|
||||
Tensor: The ensemble output with shape [batch_size, num_classes].
|
||||
"""
|
||||
stacked_outputs = torch.stack(outputs, dim=0)
|
||||
|
||||
if self.is_ensembled_after_softmax:
|
||||
softmax_outputs = F.softmax(stacked_outputs, dim=-1)
|
||||
if self.is_max_ensemble:
|
||||
ensemble_output, _ = torch.max(softmax_outputs, dim=0)
|
||||
else:
|
||||
ensemble_output = torch.mean(softmax_outputs, dim=0)
|
||||
else:
|
||||
if self.is_max_ensemble:
|
||||
ensemble_output, _ = torch.max(stacked_outputs, dim=0)
|
||||
else:
|
||||
ensemble_output = torch.mean(stacked_outputs, dim=0)
|
||||
|
||||
return ensemble_output
|
||||
|
||||
def _calculate_co_training_loss(self, outputs, loss_choose, epoch=0):
|
||||
"""
|
||||
Calculate the co-training loss between outputs of different networks.
|
||||
|
||||
Args:
|
||||
outputs (list of tensors): A list where each tensor has shape [batch_size, num_classes].
|
||||
loss_choose (str): Type of co-training loss to compute ('js_divergence' or 'kl_seperate').
|
||||
epoch (int): Current epoch.
|
||||
|
||||
Returns:
|
||||
Tensor: The computed co-training loss.
|
||||
"""
|
||||
weight_now = self.cot_weight
|
||||
if self.is_cot_weight_warm_up and epoch < self.cot_weight_warm_up_epochs:
|
||||
weight_now = max(self.cot_weight * epoch / self.cot_weight_warm_up_epochs, 0.005)
|
||||
|
||||
stacked_outputs = torch.stack(outputs, dim=0)
|
||||
|
||||
if loss_choose == 'js_divergence':
|
||||
p_all = F.softmax(stacked_outputs, dim=-1)
|
||||
p_mean = torch.mean(p_all, dim=0)
|
||||
H_mean = (-p_mean * torch.log(p_mean + 1e-8)).sum(-1).mean()
|
||||
H_sep = (-p_all * F.log_softmax(stacked_outputs, dim=-1)).sum(-1).mean()
|
||||
cot_loss = weight_now * (H_mean - H_sep)
|
||||
|
||||
elif loss_choose == 'kl_seperate':
|
||||
outputs_r1 = torch.repeat_interleave(stacked_outputs, self.split_factor - 1, dim=0)
|
||||
index_list = [j for i in range(self.split_factor) for j in range(self.split_factor) if j != i]
|
||||
outputs_r2 = torch.index_select(stacked_outputs, dim=0, index=torch.tensor(index_list, dtype=torch.long, device=stacked_outputs.device))
|
||||
kl_loss = F.kl_div(F.log_softmax(outputs_r1, dim=-1), F.softmax(outputs_r2,”
|
BIN
EdgeFLite/configurations/.DS_Store
vendored
Normal file
BIN
EdgeFLite/configurations/.DS_Store
vendored
Normal file
Binary file not shown.
81
EdgeFLite/configurations/training_config.py
Normal file
81
EdgeFLite/configurations/training_config.py
Normal file
@ -0,0 +1,81 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import json
|
||||
import torch
|
||||
from config import *
|
||||
|
||||
# Function to save hyperparameters into a JSON file
|
||||
def store_hyperparameters_json(args):
|
||||
"""Save hyperparameters to a JSON file."""
|
||||
# Create the model directory if it does not exist
|
||||
os.makedirs(args.model_dir, exist_ok=True)
|
||||
# Determine the filename based on whether it's evaluation or training mode
|
||||
filename = os.path.join(args.model_dir, 'hparams_eval.json' if args.evaluate else 'hparams_train.json')
|
||||
# Convert the arguments to a dictionary
|
||||
hparams = vars(args)
|
||||
# Write the hyperparameters to a JSON file with indentation and sorted keys
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(hparams, f, indent=4, sort_keys=True)
|
||||
|
||||
# Function to add parser arguments for command-line interface
|
||||
def add_parser_arguments(parser):
|
||||
# Dataset and model settings
|
||||
parser.add_argument('--data', type=str, default=f"{data_dir}/dataset_hub/", help='Path to dataset') # Path to the dataset
|
||||
parser.add_argument('--model_dir', type=str, default="EdgeFLite", help='Directory to save the model') # Directory where the model is saved
|
||||
parser.add_argument('--arch', type=str, default='wide_resnet16_8', choices=[
|
||||
'resnet110', 'resnet_model_110sl', 'wide_resnet16_8', 'wide_resnetsl16_8',
|
||||
'wide_resnet_model_50_2', 'wide_resnetsl50_2'], help='Neural architecture name') # Neural architecture options
|
||||
|
||||
# Normalization and training settings
|
||||
parser.add_argument('--norm_mode', type=str, default='batch', choices=['batch', 'group', 'layer', 'instance', 'none'], help='Batch normalization style') # Type of normalization used
|
||||
parser.add_argument('--is_syncbn', default=0, type=int, help='Use nn.SyncBatchNorm or not') # Whether to use synchronized batch normalization
|
||||
parser.add_argument('--workers', default=16, type=int, help='Number of data loading workers') # Number of workers for data loading
|
||||
parser.add_argument('--epochs', default=650, type=int, help='Total epochs to run') # Total number of training epochs
|
||||
parser.add_argument('--start_epoch', default=0, type=int, help='Manual epoch number for restarts') # Starting epoch number for restarting training
|
||||
parser.add_argument('--eval_per_epoch', default=1, type=int, help='Evaluation frequency per epoch') # Frequency of evaluation during training
|
||||
parser.add_argument('--spid', default="EdgeFLite", type=str, help='Experiment name') # Name of the experiment
|
||||
parser.add_argument('--save_weight', default=False, type=bool, help='Save model weights') # Whether to save model weights
|
||||
|
||||
# Data augmentation settings
|
||||
parser.add_argument('--batch_size', default=128, type=int, help='Mini-batch size for training') # Batch size for training
|
||||
parser.add_argument('--eval_batch_size', default=100, type=int, help='Mini-batch size for evaluation') # Batch size for evaluation
|
||||
parser.add_argument('--crop_size', default=32, type=int, help='Crop size for images') # Size of the image crops
|
||||
parser.add_argument('--output_stride', default=8, type=int, help='Output stride for model') # Output stride for the model
|
||||
parser.add_argument('--padding', default=4, type=int, help='Padding size for images') # Padding size for image processing
|
||||
|
||||
# Learning rate settings
|
||||
parser.add_argument('--lr_mode', type=str, default='cos', choices=['cos', 'step', 'poly', 'HTD', 'exponential'], help='Learning rate strategy') # Strategy for adjusting learning rate
|
||||
parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, help='Initial learning rate') # Initial learning rate value
|
||||
parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'AdamW', 'RMSprop', 'RMSpropTF'], help='Optimizer choice') # Choice of optimizer
|
||||
parser.add_argument('--lr_milestones', nargs='+', type=int, default=[100, 200], help='Epochs for learning rate steps') # Epochs where learning rate adjustments occur
|
||||
parser.add_argument('--lr_step_multiplier', default=0.1, type=float, help='Multiplier at learning rate milestones') # Multiplier applied at learning rate steps
|
||||
parser.add_argument('--end_lr', type=float, default=1e-4, help='Ending learning rate') # Final learning rate value
|
||||
|
||||
# Additional hyperparameters
|
||||
parser.add_argument('--weight_decay', default=1e-4, type=float, help='Weight decay for regularization') # Weight decay for L2 regularization
|
||||
parser.add_argument('--momentum', default=0.9, type=float, help='Optimizer momentum') # Momentum for optimizers like SGD
|
||||
parser.add_argument('--print_freq', default=20, type=int, help='Print frequency for logging') # Frequency for printing logs during training
|
||||
|
||||
# Federated learning settings
|
||||
parser.add_argument('--is_fed', default=1, type=int, help='Enable federated learning') # Enable or disable federated learning
|
||||
parser.add_argument('--num_clusters', default=20, type=int, help='Number of clusters for federated learning') # Number of clusters in federated learning
|
||||
parser.add_argument('--num_selected', default=20, type=int, help='Number of clients selected for training per round') # Number of clients selected each round
|
||||
parser.add_argument('--num_rounds', default=300, type=int, help='Total number of training rounds') # Total number of federated learning rounds
|
||||
|
||||
# Processing and decentralized training settings
|
||||
parser.add_argument('--gpu', default=None, type=int, help='GPU ID to use') # GPU ID to be used for training
|
||||
parser.add_argument('--no_cuda', action='store_true', default=False, help='Disable CUDA training') # Whether to disable CUDA
|
||||
parser.add_argument('--gpu_ids', type=str, default='0', help='Comma-separated list of GPU IDs for training') # Comma-separated GPU IDs for multi-GPU training
|
||||
|
||||
# Parse command-line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Additional configurations
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available() # Enable CUDA if not disabled and available
|
||||
if args.cuda:
|
||||
args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] # Parse GPU IDs from comma-separated string
|
||||
args.num_gpus = len(args.gpu_ids) # Count number of GPUs being used
|
||||
|
||||
return args
|
BIN
EdgeFLite/data_collection/.DS_Store
vendored
Normal file
BIN
EdgeFLite/data_collection/.DS_Store
vendored
Normal file
Binary file not shown.
103
EdgeFLite/data_collection/augment_auto.py
Normal file
103
EdgeFLite/data_collection/augment_auto.py
Normal file
@ -0,0 +1,103 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
class CIFAR10Policy(object):
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
|
||||
self.policies = [
|
||||
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
|
||||
SubPolicy(0.7, "rotate_image", 2, 0.3, "translateX", 9, fillcolor),
|
||||
SubPolicy(0.8, "adjust_image_sharpness", 1, 0.9, "adjust_image_sharpness", 3, fillcolor),
|
||||
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.2, "shearY", 7, 0.3, "apply_posterization", 7, fillcolor),
|
||||
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
|
||||
SubPolicy(0.3, "adjust_image_sharpness", 9, 0.7, "brightness", 9, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
|
||||
SubPolicy(0.6, "contrast", 7, 0.6, "adjust_image_sharpness", 5, fillcolor),
|
||||
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
|
||||
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
|
||||
SubPolicy(0.4, "translateY", 3, 0.2, "adjust_image_sharpness", 6, fillcolor),
|
||||
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
|
||||
SubPolicy(0.5, "apply_solarize", 2, 0.0, "invert", 3, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
|
||||
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
|
||||
SubPolicy(0.8, "autocontrast", 4, 0.2, "apply_solarize", 8, fillcolor),
|
||||
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
|
||||
SubPolicy(0.4, "apply_solarize", 5, 0.9, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.9, "autocontrast", 2, 0.8, "apply_solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
|
||||
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor),
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
|
||||
policy = random.choice(self.policies)
|
||||
return policy(img)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment CIFAR-10 Policy"
|
||||
|
||||
class SubPolicy(object):
|
||||
|
||||
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
|
||||
|
||||
ranges = {
|
||||
"shearX": np.linspace(0, 0.3, 10),
|
||||
"shearY": np.linspace(0, 0.3, 10),
|
||||
"translateX": np.linspace(0, 150 / 331, 10),
|
||||
"translateY": np.linspace(0, 150 / 331, 10),
|
||||
"rotate_image": np.linspace(0, 30, 10),
|
||||
"color": np.linspace(0.0, 0.9, 10),
|
||||
"apply_posterization": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
|
||||
"apply_solarize": np.linspace(256, 0, 10),
|
||||
"contrast": np.linspace(0.0, 0.9, 10),
|
||||
"adjust_image_sharpness": np.linspace(0.0, 0.9, 10),
|
||||
"brightness": np.linspace(0.0, 0.9, 10),
|
||||
"autocontrast": [0] * 10,
|
||||
"equalize": [0] * 10,
|
||||
"invert": [0] * 10
|
||||
}
|
||||
|
||||
|
||||
self.fillcolor = fillcolor
|
||||
|
||||
self.p1 = p1
|
||||
self.operation1 = operation1
|
||||
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
||||
self.p2 = p2
|
||||
self.operation2 = operation2
|
||||
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
|
||||
if random.random() < self.p1:
|
||||
img = self._perform_operation(self.operation1, img, self.magnitude1)
|
||||
|
||||
if random.random() < self.p2:
|
||||
img = self._perform_operation(self.operation2, img, self.magnitude2)
|
||||
return img
|
||||
|
||||
|
||||
def _perform_operation(self, operation, img, magnitude):
|
||||
|
||||
if operation == "shearX":
|
||||
img = img.apply_transformation(img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC, fillcolor=self.fillcolor)
|
||||
elif operation == "shearY":
|
||||
img = img.apply_transformation(img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
||||
Image.BICUBIC, fillcolor=self.fillcolor)
|
||||
elif operation == "translateX":
|
||||
img = img.apply_transformation(img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1,
|
199
EdgeFLite/data_collection/augment_rand.py
Normal file
199
EdgeFLite/data_collection/augment_rand.py
Normal file
@ -0,0 +1,199 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import random
|
||||
import math
|
||||
from PIL import Image, ImageOps, ImageEnhance, ImageChops
|
||||
import PIL
|
||||
|
||||
# Constants and defaults for image augmentation
|
||||
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) # Get the version of the PIL library
|
||||
_FILL = (128, 128, 128) # Default fill color used in some apply_transformationations (gray)
|
||||
_MAX_LEVEL = 10.0 # Maximum level for augmentations
|
||||
_HPARAMS_DEFAULT = {
|
||||
'translate_const': 250, # Default translation constant
|
||||
'img_mean': _FILL, # Default fill color
|
||||
}
|
||||
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) # Random interpolation modes
|
||||
|
||||
# Function to randomly choose interpolation method
|
||||
def _interpolation(kwargs):
|
||||
interpolation = kwargs.pop('resample', Image.BILINEAR)
|
||||
return random.choice(interpolation) if isinstance(interpolation, (list, tuple)) else interpolation
|
||||
|
||||
# Check if the PIL version is compatible with fillcolor argument
|
||||
def _validate_tensorflow_args(kwargs):
|
||||
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
|
||||
kwargs.pop('fillcolor') # Remove fillcolor if PIL version is below 5.0
|
||||
kwargs['resample'] = _interpolation(kwargs) # Add resample method
|
||||
|
||||
# Shear image along the x-axis
|
||||
def apply_apply_shear_x_axis_axis(img, factor, **kwargs):
|
||||
_validate_tensorflow_args(kwargs)
|
||||
return img.apply_transformation(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
|
||||
|
||||
# Shear image along the y-axis
|
||||
def shear_y(img, factor, **kwargs):
|
||||
_validate_tensorflow_args(kwargs)
|
||||
return img.apply_transformation(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
|
||||
|
||||
# Translate image horizontally by a percentage of the image width
|
||||
def translate_image_x_relative(img, pct, **kwargs):
|
||||
pixels = pct * img.size[0] # Calculate pixels to translate
|
||||
_validate_tensorflow_args(kwargs)
|
||||
return img.apply_transformation(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
||||
|
||||
# Translate image vertically by a percentage of the image height
|
||||
def translate_image_y_relative(img, pct, **kwargs):
|
||||
pixels = pct * img.size[1] # Calculate pixels to translate
|
||||
_validate_tensorflow_args(kwargs)
|
||||
return img.apply_transformation(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
||||
|
||||
# Translate image horizontally by a fixed number of pixels
|
||||
def translate_image_x_absolute(img, pixels, **kwargs):
|
||||
_validate_tensorflow_args(kwargs)
|
||||
return img.apply_transformation(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
||||
|
||||
# Translate image vertically by a fixed number of pixels
|
||||
def translate_image_y_absolute(img, pixels, **kwargs):
|
||||
_validate_tensorflow_args(kwargs)
|
||||
return img.apply_transformation(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
||||
|
||||
# rotate_image image by a specified number of degrees
|
||||
def rotate_image(img, degrees, **kwargs):
|
||||
_validate_tensorflow_args(kwargs)
|
||||
if _PIL_VER >= (5, 2):
|
||||
return img.rotate_image(degrees, **kwargs) # Use rotate_image if PIL version is >= 5.2
|
||||
elif _PIL_VER >= (5, 0):
|
||||
# Manually rotate_image the image for older versions of PIL
|
||||
w, h = img.size
|
||||
rotn_center = (w / 2.0, h / 2.0)
|
||||
angle = -math.radians(degrees)
|
||||
matrix = [
|
||||
round(math.cos(angle), 15), round(math.sin(angle), 15), 0.0,
|
||||
round(-math.sin(angle), 15), round(math.cos(angle), 15), 0.0,
|
||||
]
|
||||
|
||||
def apply_transformation(x, y, matrix):
|
||||
return matrix[0] * x + matrix[1] * y + matrix[2], matrix[3] * x + matrix[4] * y + matrix[5]
|
||||
|
||||
matrix[2], matrix[5] = apply_transformation(-rotn_center[0], -rotn_center[1], matrix)
|
||||
matrix[2] += rotn_center[0]
|
||||
matrix[5] += rotn_center[1]
|
||||
return img.apply_transformation(img.size, Image.AFFINE, matrix, **kwargs)
|
||||
else:
|
||||
return img.rotate_image(degrees, resample=kwargs['resample'])
|
||||
|
||||
# Auto contrast image
|
||||
def apply_auto_contrast(img, **kwargs):
|
||||
return ImageOps.autocontrast(img)
|
||||
|
||||
# Invert image colors
|
||||
def invert(img, **kwargs):
|
||||
return ImageOps.invert(img)
|
||||
|
||||
# Equalize image histogram
|
||||
def equalize(img, **kwargs):
|
||||
return ImageOps.equalize(img)
|
||||
|
||||
# Apply solarization effect
|
||||
def apply_solarize(img, thresh, **kwargs):
|
||||
return ImageOps.apply_solarize(img, thresh)
|
||||
|
||||
# Apply solarization effect with an additional value
|
||||
def apply_apply_solarize_addition(img, add, thresh=128, **kwargs):
|
||||
lut = [min(255, i + add) if i < thresh else i for i in range(256)]
|
||||
if img.mode in ("L", "RGB"):
|
||||
lut = lut + lut + lut if img.mode == "RGB" else lut
|
||||
return img.point(lut)
|
||||
else:
|
||||
return img
|
||||
|
||||
# apply_posterization image (reduce color depth)
|
||||
def apply_posterization(img, bits_to_keep, **kwargs):
|
||||
return img if bits_to_keep >= 8 else ImageOps.apply_posterization(img, bits_to_keep)
|
||||
|
||||
# Adjust image contrast
|
||||
def contrast(img, factor, **kwargs):
|
||||
return ImageEnhance.Contrast(img).enhance(factor)
|
||||
|
||||
# Adjust image color
|
||||
def color(img, factor, **kwargs):
|
||||
return ImageEnhance.Color(img).enhance(factor)
|
||||
|
||||
# Adjust image brightness
|
||||
def brightness(img, factor, **kwargs):
|
||||
return ImageEnhance.Brightness(img).enhance(factor)
|
||||
|
||||
# Adjust image adjust_image_sharpness
|
||||
def adjust_image_sharpness(img, factor, **kwargs):
|
||||
return ImageEnhance.adjust_image_sharpness(img).enhance(factor)
|
||||
|
||||
# Randomly negate a value with a 50% probability
|
||||
def _apply_random_negation(v):
|
||||
"""With 50% probability, negate the value."""
|
||||
return -v if random.random() > 0.5 else v
|
||||
|
||||
# Convert augmentation level to argument value
|
||||
def _map_level_to_argument(level, max_value, hparams):
|
||||
level = (level / _MAX_LEVEL) * max_value
|
||||
return _apply_random_negation(level),
|
||||
|
||||
# Convert translation level to argument value
|
||||
def _map_absolute_map_level_to_argument(level, hparams):
|
||||
translate_const = hparams['translate_const']
|
||||
level = (level / _MAX_LEVEL) * float(translate_const)
|
||||
return _apply_random_negation(level),
|
||||
|
||||
# Convert enhancement level to argument value
|
||||
def _enhance_map_level_to_argument(level, _hparams):
|
||||
return (level / _MAX_LEVEL) * 1.8 + 0.1,
|
||||
|
||||
# Mapping of augmentation levels to argument converters
|
||||
map_level_to_argument = {
|
||||
'AutoContrast': None,
|
||||
'Equalize': None,
|
||||
'Invert': None,
|
||||
'rotate_image': lambda level, _: _map_level_to_argument(level, 30, None),
|
||||
'apply_posterization': lambda level, _: int((level / _MAX_LEVEL) * 4),
|
||||
'apply_solarize': lambda level, _: int((level / _MAX_LEVEL) * 256),
|
||||
'Color': _enhance_map_level_to_argument,
|
||||
'Contrast': _enhance_map_level_to_argument,
|
||||
'Brightness': _enhance_map_level_to_argument,
|
||||
'adjust_image_sharpness': _enhance_map_level_to_argument,
|
||||
'ShearX': lambda level, _: _map_level_to_argument(level, 0.3, None),
|
||||
'ShearY': lambda level, _: _map_level_to_argument(level, 0.3, None),
|
||||
'TranslateX': _map_absolute_map_level_to_argument,
|
||||
'TranslateY': _map_absolute_map_level_to_argument,
|
||||
}
|
||||
|
||||
# Mapping of augmentation names to functions
|
||||
NAME_TO_OP = {
|
||||
'AutoContrast': apply_auto_contrast,
|
||||
'Equalize': equalize,
|
||||
'Invert': invert,
|
||||
'rotate_image': rotate_image,
|
||||
'apply_posterization': apply_posterization,
|
||||
'apply_solarize': apply_solarize,
|
||||
'Color': color,
|
||||
'Contrast': contrast,
|
||||
'Brightness': brightness,
|
||||
'adjust_image_sharpness': adjust_image_sharpness,
|
||||
'ShearX': apply_apply_shear_x_axis_axis,
|
||||
'ShearY': shear_y,
|
||||
'TranslateX': translate_image_x_absolute,
|
||||
'TranslateY': translate_image_y_absolute,
|
||||
}
|
||||
|
||||
# Class for applying augmentations to an image
|
||||
class AugmentOp:
|
||||
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
self.aug_fn = NAME_TO_OP[name] # Get the augmentation function
|
||||
self.level_fn = map_level_to_argument[name] # Get the level function
|
||||
self.prob = prob # Probability of applying the augmentation
|
||||
self.magnitude = magnitude # Magnitude of the augmentation
|
||||
self.hparams = hparams.copy()
|
||||
self.kwargs = {
|
||||
'fillcolor': hparams.get('img_mean', _FILL), # Set the fill color
|
||||
'
|
220
EdgeFLite/data_collection/cifar100_noniid.py
Normal file
220
EdgeFLite/data_collection/cifar100_noniid.py
Normal file
@ -0,0 +1,220 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
#### Get CIFAR-100 dataset in X and Y form
|
||||
import torchvision
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from torchvision import apply_transformations
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from .cifar10_non_iid import *
|
||||
|
||||
# Set random seeds for reproducibility
|
||||
np.random.seed(68)
|
||||
random.seed(68)
|
||||
|
||||
def get_cifar100(data_dir):
|
||||
'''
|
||||
Load and return CIFAR-100 train/test data and labels as numpy arrays.
|
||||
|
||||
Parameters:
|
||||
data_dir (str): Directory where the CIFAR-100 dataset will be downloaded/saved.
|
||||
|
||||
Returns:
|
||||
x_train (ndarray): Training data.
|
||||
y_train (ndarray): Training labels.
|
||||
x_test (ndarray): Test data.
|
||||
y_test (ndarray): Test labels.
|
||||
'''
|
||||
# Download CIFAR-100 training and test datasets
|
||||
data_train = torchvision.datasets.CIFAR100(data_dir, train=True, download=True)
|
||||
data_test = torchvision.datasets.CIFAR100(data_dir, train=False, download=True)
|
||||
|
||||
# Transpose data for proper channel order and convert labels to numpy arrays
|
||||
x_train, y_train = data_train.data.transpose((0, 3, 1, 2)), np.array(data_train.targets)
|
||||
x_test, y_test = data_test.data.transpose((0, 3, 1, 2)), np.array(data_test.targets)
|
||||
|
||||
return x_train, y_train, x_test, y_test
|
||||
|
||||
def split_cf100_real_world_images(data, labels, n_clients=100, verbose=True):
|
||||
'''
|
||||
Splits data and labels among n_clients to simulate a non-IID distribution.
|
||||
|
||||
Parameters:
|
||||
data (ndarray): Dataset images [n_data x shape].
|
||||
labels (ndarray): Dataset labels [n_data].
|
||||
n_clients (int): Number of clients to split the data among.
|
||||
verbose (bool): Print detailed information if True.
|
||||
|
||||
Returns:
|
||||
clients_split (ndarray): Split data and labels for each client.
|
||||
'''
|
||||
n_labels = np.max(labels) + 1 # Number of unique labels/classes
|
||||
|
||||
def divide_into_sections(n, m):
|
||||
'''Return m random integers that sum up to n.'''
|
||||
result = [1] * m
|
||||
for _ in range(n - m):
|
||||
result[random.randint(0, m - 1)] += 1
|
||||
return result
|
||||
|
||||
# Shuffle and partition classes
|
||||
n_classes = len(set(labels)) # Number of unique classes
|
||||
classes = list(range(n_classes))
|
||||
np.random.shuffle(classes) # Shuffle class indices
|
||||
label_indices = [list(np.where(labels == class_)[0]) for class_ in classes] # Indices of each class in labels
|
||||
|
||||
# Define number of classes for each client (randomized)
|
||||
tmp = [np.random.randint(1, 100) for _ in range(n_clients)]
|
||||
total_partition = sum(tmp)
|
||||
class_partition = divide_into_sections(total_partition, len(classes)) # Partition classes randomly
|
||||
|
||||
# Split class indices among clients
|
||||
class_partition = sorted(class_partition, reverse=True)
|
||||
class_partition_split = {}
|
||||
|
||||
for idx, class_ in enumerate(classes):
|
||||
# Split each class' indices according to the partition
|
||||
class_partition_split[class_] = [list(i) for i in np.array_split(label_indices[idx], class_partition[idx])]
|
||||
|
||||
clients_split = []
|
||||
for i in range(n_clients):
|
||||
n = tmp[i] # Number of classes for this client
|
||||
indices = []
|
||||
j = 0
|
||||
|
||||
# Assign class data to the client
|
||||
while n > 0:
|
||||
class_ = classes[j]
|
||||
if class_partition_split[class_]:
|
||||
indices.extend(class_partition_split[class_].pop()) # Add indices of the class to the client
|
||||
n -= 1
|
||||
j += 1
|
||||
|
||||
clients_split.append([data[indices], labels[indices]]) # Add client's data split
|
||||
|
||||
# Re-sort classes based on available data to balance further splits
|
||||
classes = sorted(classes, key=lambda x: len(class_partition_split[x]), reverse=True)
|
||||
|
||||
# Raise error if client partition criteria cannot be met
|
||||
if n > 0:
|
||||
raise ValueError("Unable to fulfill the client partition criteria.")
|
||||
|
||||
# Verbose option to print split information
|
||||
if verbose:
|
||||
display_data_split(clients_split)
|
||||
|
||||
return np.array(clients_split)
|
||||
|
||||
def display_data_split(clients_split):
|
||||
'''Print the split information of the dataset for each client.'''
|
||||
print("Data split:")
|
||||
for i, client in enumerate(clients_split):
|
||||
split = np.sum(client[1].reshape(1, -1) == np.arange(np.max(client[1]) + 1).reshape(-1, 1), axis=1)
|
||||
print(f" - Client {i}: {split}")
|
||||
print()
|
||||
|
||||
def get_default_data_apply_transformations_cf100(train=True, verbose=True):
|
||||
'''
|
||||
Return default data apply_transformationations for CIFAR-100.
|
||||
|
||||
Parameters:
|
||||
train (bool): Whether to apply apply_transformationations for training data.
|
||||
verbose (bool): Print apply_transformationation details if True.
|
||||
|
||||
Returns:
|
||||
apply_transformations_train (Compose): Training apply_transformationations.
|
||||
apply_transformations_eval (Compose): Evaluation (test) apply_transformationations.
|
||||
'''
|
||||
# Define apply_transformationations for training data
|
||||
apply_transformations_train = {
|
||||
'cifar100': apply_transformations.Compose([
|
||||
apply_transformations.ToPILImage(),
|
||||
apply_transformations.RandomCrop(32, padding=4),
|
||||
apply_transformations.RandomHorizontalFlip(),
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
])
|
||||
}
|
||||
|
||||
# Define apply_transformationations for test data
|
||||
apply_transformations_eval = {
|
||||
'cifar100': apply_transformations.Compose([
|
||||
apply_transformations.ToPILImage(),
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
])
|
||||
}
|
||||
|
||||
# Verbose option to print apply_transformationation steps
|
||||
if verbose:
|
||||
print("\nData preprocessing:")
|
||||
for apply_transformationation in apply_transformations_train['cifar100'].apply_transformations:
|
||||
print(f' - {apply_transformationation}')
|
||||
print()
|
||||
|
||||
return apply_transformations_train['cifar100'], apply_transformations_eval['cifar100']
|
||||
|
||||
def obtain_data_loaders_train_cf100(data_dir, n_clients, batch_size, classes_per_client=10, verbose=True,
|
||||
apply_transformations_train=None, apply_transformations_eval=None, non_iid=None, split_factor=1):
|
||||
'''
|
||||
Return data loaders for training on CIFAR-100.
|
||||
|
||||
Parameters:
|
||||
data_dir (str): Directory where the CIFAR-100 dataset will be saved.
|
||||
n_clients (int): Number of clients for splitting the dataset.
|
||||
batch_size (int): Batch size for each data loader.
|
||||
classes_per_client (int): Number of classes per client.
|
||||
verbose (bool): Print detailed information if True.
|
||||
apply_transformations_train (Compose): apply_transformationations for training data.
|
||||
apply_transformations_eval (Compose): apply_transformationations for evaluation data.
|
||||
non_iid (str): Strategy to create a non-IID dataset split.
|
||||
split_factor (float): Factor to control the degree of splitting.
|
||||
|
||||
Returns:
|
||||
client_loaders (list): Data loaders for each client.
|
||||
'''
|
||||
x_train, y_train, _, _ = get_cifar100(data_dir)
|
||||
|
||||
# Verbose option to print dataset statistics
|
||||
if verbose:
|
||||
print_image_data_stats_train(x_train, y_train)
|
||||
|
||||
# Split data according to non-IID strategy (e.g., quantity_skew)
|
||||
split = None
|
||||
if non_iid == 'quantity_skew':
|
||||
split = split_cf100_real_world_images(x_train, y_train, n_clients=n_clients, verbose=verbose)
|
||||
|
||||
split_tmp = shuffle_list(split)
|
||||
|
||||
# Create DataLoaders for each client
|
||||
client_loaders = [DataLoader(CustomImageDataset(x, y, apply_transformations_train, split_factor=split_factor),
|
||||
batch_size=batch_size, shuffle=True) for x, y in split_tmp]
|
||||
|
||||
return client_loaders
|
||||
|
||||
def obtain_data_loaders_test_cf100(data_dir, batch_size, verbose=True, apply_transformations_eval=None):
|
||||
'''
|
||||
Return data loaders for testing on CIFAR-100.
|
||||
|
||||
Parameters:
|
||||
data_dir (str): Directory where the CIFAR-100 dataset will be saved.
|
||||
batch_size (int): Batch size for the test data loader.
|
||||
verbose (bool): Print detailed information if True.
|
||||
apply_transformations_eval (Compose): apply_transformationations for evaluation data.
|
||||
|
||||
Returns:
|
||||
test_loader (DataLoader): Test data loader.
|
||||
'''
|
||||
_, _, x_test, y_test = get_cifar100(data_dir)
|
||||
|
||||
# Verbose option to print dataset statistics
|
||||
if verbose:
|
||||
print_image_data_stats_test(x_test, y_test)
|
||||
|
||||
# Create DataLoader for the test dataset
|
||||
test_loader = DataLoader(CustomImageDataset(x_test, y_test, apply_transformations_eval, split_factor=1),
|
||||
batch_size=100, shuffle=False)
|
||||
|
||||
return test_loader
|
179
EdgeFLite/data_collection/cifar10_noniid.py
Normal file
179
EdgeFLite/data_collection/cifar10_noniid.py
Normal file
@ -0,0 +1,179 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
#### Load CIFAR-10 dataset and preprocess it
|
||||
import torchvision
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from torchvision import apply_transformations
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
# Set random seed for reproducibility
|
||||
np.random.seed(68) # Ensures that the random operations have consistent outputs
|
||||
random.seed(68)
|
||||
|
||||
def get_cifar10(data_dir):
|
||||
"""Return CIFAR-10 train/test data and labels as numpy arrays"""
|
||||
# Download CIFAR-10 dataset
|
||||
data_train = torchvision.datasets.CIFAR10(data_dir, train=True, download=True)
|
||||
data_test = torchvision.datasets.CIFAR10(data_dir, train=False, download=True)
|
||||
|
||||
# Preprocess the train and test data to the correct format (channels first)
|
||||
x_train, y_train = data_train.data.transpose((0, 3, 1, 2)), np.array(data_train.targets)
|
||||
x_test, y_test = data_test.data.transpose((0, 3, 1, 2)), np.array(data_test.targets)
|
||||
|
||||
return x_train, y_train, x_test, y_test
|
||||
|
||||
def display_data_statistics(data, labels, dataset_type):
|
||||
"""Print statistics of the dataset"""
|
||||
print(f"\n{dataset_type} Set: ({data.shape}, {labels.shape}), Range: [{np.min(data):.3f}, {np.max(data):.3f}], "
|
||||
f"Labels: {np.min(labels)},..,{np.max(labels)}")
|
||||
|
||||
def randomize_client_distributiony(train_len, n_clients):
|
||||
"""
|
||||
Distribute data among clients with a random distribution
|
||||
Returns a list with the number of samples for each client
|
||||
"""
|
||||
# Randomly assign a number of samples to each client, ensuring the total matches the train_len
|
||||
client_sizes = [random.randint(10, 100) for _ in range(n_clients - 1)]
|
||||
total = sum(client_sizes)
|
||||
client_sizes = np.array(client_sizes)
|
||||
client_distributions = ((client_sizes / total) * train_len).astype(int) # Normalize to match the train_len
|
||||
client_distributions = list(client_distributions)
|
||||
client_distributions.append(train_len - sum(client_distributions)) # Ensure all data is allocated
|
||||
return client_distributions
|
||||
|
||||
def divide_into_sections(n, m):
|
||||
"""Return 'm' random integers that sum to 'n'"""
|
||||
# Break the number 'n' into 'm' random parts that sum to 'n'
|
||||
partitions = [1] * m
|
||||
for _ in range(n - m):
|
||||
partitions[random.randint(0, m - 1)] += 1
|
||||
return partitions
|
||||
|
||||
def split_data_real_world_scenario(data, labels, n_clients=100):
|
||||
"""Split data among clients simulating real-world non-IID distribution"""
|
||||
n_classes = len(set(labels)) # Determine number of unique classes
|
||||
class_indices = [np.where(labels == class_)[0] for class_ in range(n_classes)] # Indices for each class
|
||||
|
||||
client_classes = [np.random.randint(1, 10) for _ in range(n_clients)] # Random number of classes per client
|
||||
total_partitions = sum(client_classes)
|
||||
|
||||
class_partition = divide_into_sections(total_partitions, len(class_indices)) # Partition classes to distribute
|
||||
class_partition_split = {cls: np.array_split(class_indices[cls], n) for cls, n in enumerate(class_partition)}
|
||||
|
||||
clients_split = []
|
||||
for client in client_classes:
|
||||
selected_indices = []
|
||||
for class_ in range(n_classes):
|
||||
if class_partition_split[class_]:
|
||||
selected_indices.extend(class_partition_split[class_].pop())
|
||||
client -= 1
|
||||
if client <= 0:
|
||||
break
|
||||
clients_split.append([data[selected_indices], labels[selected_indices]])
|
||||
|
||||
return np.array(clients_split)
|
||||
|
||||
def split_data_iid(data, labels, n_clients=100, classes_per_client=10, shuffle=True):
|
||||
"""Split data among clients with IID (Independent and Identically Distributed) distribution"""
|
||||
data_per_client = randomize_client_distributiony(len(data), n_clients)
|
||||
label_indices = [np.where(labels == label)[0] for label in range(np.max(labels) + 1)]
|
||||
|
||||
if shuffle:
|
||||
for indices in label_indices:
|
||||
np.random.shuffle(indices)
|
||||
|
||||
clients_split = []
|
||||
for client_data in data_per_client:
|
||||
client_indices = []
|
||||
class_ = np.random.randint(len(label_indices))
|
||||
while client_data > 0:
|
||||
take = min(client_data, len(label_indices[class_]))
|
||||
client_indices.extend(label_indices[class_][:take])
|
||||
label_indices[class_] = label_indices[class_][take:]
|
||||
client_data -= take
|
||||
class_ = (class_ + 1) % len(label_indices)
|
||||
|
||||
clients_split.append([data[client_indices], labels[client_indices]])
|
||||
|
||||
return np.array(clients_split)
|
||||
|
||||
def randomize_data_order(data):
|
||||
"""Shuffle data while maintaining the mapping between inputs and labels"""
|
||||
for i in range(len(data)):
|
||||
index = np.arange(len(data[i][0]))
|
||||
np.random.shuffle(index)
|
||||
data[i][0], data[i][1] = data[i][0][index], data[i][1][index]
|
||||
return data
|
||||
|
||||
class CustomImageDataset(Dataset):
|
||||
"""Custom Dataset class for image data"""
|
||||
def __init__(self, inputs, labels, apply_transformations=None, split_factor=1):
|
||||
# Convert input data to torch tensors and apply apply_transformationations if provided
|
||||
self.inputs = torch.Tensor(inputs)
|
||||
self.labels = labels
|
||||
self.apply_transformations = apply_transformations
|
||||
self.split_factor = split_factor
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, label = self.inputs[index], self.labels[index]
|
||||
# Apply apply_transformationations to the image multiple times if split_factor > 1
|
||||
imgs = [self.apply_transformations(img) for _ in range(self.split_factor)] if self.apply_transformations else [img]
|
||||
return torch.cat(imgs, dim=0), label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.inputs)
|
||||
|
||||
def get_default_apply_transformations(verbose=True):
|
||||
"""Return default apply_transformationations for training and evaluation"""
|
||||
apply_transformations_train = apply_transformations.Compose([
|
||||
apply_transformations.ToPILImage(), # Convert numpy array to PIL image
|
||||
apply_transformations.RandomCrop(32, padding=4), # Randomly crop to 32x32 with padding
|
||||
apply_transformations.RandomHorizontalFlip(), # Randomly flip images horizontally
|
||||
apply_transformations.ToTensor(), # Convert image to tensor
|
||||
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Normalize with CIFAR-10 mean and std
|
||||
])
|
||||
|
||||
apply_transformations_eval = apply_transformations.Compose([
|
||||
apply_transformations.ToPILImage(),
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # Same normalization for evaluation
|
||||
])
|
||||
|
||||
if verbose:
|
||||
print("\nData preprocessing steps:")
|
||||
for apply_transformationation in apply_transformations_train.apply_transformations:
|
||||
print(f" - {apply_transformationation}")
|
||||
|
||||
return apply_transformations_train, apply_transformations_eval
|
||||
|
||||
def obtain_data_loaders(data_dir, n_clients, batch_size, classes_per_client=10, non_iid=None, split_factor=1):
|
||||
"""Return DataLoader objects for clients with either IID or non-IID data split"""
|
||||
x_train, y_train, _, _ = get_cifar10(data_dir)
|
||||
display_data_statistics(x_train, y_train, "Train")
|
||||
|
||||
# Split data based on non-IID method specified (either 'quantity_skew' or 'label_skew')
|
||||
if non_iid == 'quantity_skew':
|
||||
clients_data = split_data_real_world_scenario(x_train, y_train, n_clients)
|
||||
elif non_iid == 'label_skew':
|
||||
clients_data = split_data_iid(x_train, y_train, n_clients, classes_per_client)
|
||||
|
||||
shuffled_clients_data = randomize_data_order(clients_data)
|
||||
|
||||
apply_transformations_train, apply_transformations_eval = get_default_apply_transformations(verbose=False)
|
||||
client_loaders = [DataLoader(CustomImageDataset(x, y, apply_transformations_train, split_factor=split_factor),
|
||||
batch_size=batch_size, shuffle=True) for x, y in shuffled_clients_data]
|
||||
|
||||
return client_loaders
|
||||
|
||||
def get_test_data_loader(data_dir, batch_size):
|
||||
"""Return DataLoader for test data"""
|
||||
_, _, x_test, y_test = get_cifar10(data_dir)
|
||||
display_data_statistics(x_test, y_test, "Test")
|
||||
|
||||
_, apply_transformations_eval = get_default_apply_transformations(verbose=False)
|
||||
test_loader = DataLoader(CustomImageDataset(x_test, y_test, apply_transformations_eval), batch_size=batch_size, shuffle=False)
|
||||
|
||||
return test_loader
|
71
EdgeFLite/data_collection/data_cutout.py
Normal file
71
EdgeFLite/data_collection/data_cutout.py
Normal file
@ -0,0 +1,71 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Cutout:
|
||||
"""Applies random cutout augmentation by masking patches in an image.
|
||||
|
||||
This technique randomly cuts out square patches from the image to
|
||||
augment the dataset, helping the model become invariant to occlusions.
|
||||
|
||||
Args:
|
||||
n_holes (int): Number of patches to remove from the image.
|
||||
length (int): Side length (in pixels) of each square patch.
|
||||
"""
|
||||
|
||||
def __init__(self, n_holes, length):
|
||||
"""
|
||||
Initializes the Cutout class with the number of patches to be removed
|
||||
and the size of each patch.
|
||||
|
||||
Args:
|
||||
n_holes (int): Number of patches (holes) to cut out from the image.
|
||||
length (int): Size of each square patch.
|
||||
"""
|
||||
self.n_holes = n_holes # Number of holes (patches) to remove.
|
||||
self.length = length # Side length of each square patch.
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Applies the cutout augmentation on the input image.
|
||||
|
||||
Args:
|
||||
img (Tensor): The input image tensor with shape (C, H, W),
|
||||
where C is the number of channels, H is the height,
|
||||
and W is the width of the image.
|
||||
|
||||
Returns:
|
||||
Tensor: The augmented image tensor with `n_holes` patches of size
|
||||
`length x length` cut out, filled with zeros.
|
||||
"""
|
||||
# Get the height and width of the image (ignoring the channel dimension)
|
||||
height, width = img.size(1), img.size(2)
|
||||
|
||||
# Create a mask initialized with ones, same height and width as the image
|
||||
# (each pixel is set to 1, representing no masking initially)
|
||||
mask = np.ones((height, width), dtype=np.float32)
|
||||
|
||||
# Randomly remove `n_holes` patches from the image
|
||||
for _ in range(self.n_holes):
|
||||
# Randomly choose the center of a patch (x_center, y_center)
|
||||
y_center = np.random.randint(height)
|
||||
x_center = np.random.randint(width)
|
||||
|
||||
# Define the coordinates of the patch based on the center
|
||||
# and ensure the patch stays within the image boundaries.
|
||||
y1 = np.clip(y_center - self.length // 2, 0, height)
|
||||
y2 = np.clip(y_center + self.length // 2, 0, height)
|
||||
x1 = np.clip(x_center - self.length // 2, 0, width)
|
||||
x2 = np.clip(x_center + self.length // 2, 0, width)
|
||||
|
||||
# Set the mask to 0 for the patch (mark the patch as cut out)
|
||||
mask[y1:y2, x1:x2] = 0.0
|
||||
|
||||
# Convert the mask from numpy array to a PyTorch tensor
|
||||
mask_tensor = torch.from_numpy(mask).expand_as(img)
|
||||
|
||||
# Multiply the input image by the mask (cut out the selected patches)
|
||||
return img * mask_tensor
|
178
EdgeFLite/data_collection/dataset_cifar.py
Normal file
178
EdgeFLite/data_collection/dataset_cifar.py
Normal file
@ -0,0 +1,178 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Import necessary libraries
|
||||
from PIL import Image # For image handling
|
||||
import os # For file path operations
|
||||
import numpy as np # For numerical operations
|
||||
import pickle # For loading serialized data
|
||||
import torch # For PyTorch operations
|
||||
|
||||
# Import custom classes and functions from the current package
|
||||
from .vision import VisionDataset
|
||||
from .utils import validate_integrity, fetch_and_extract_archive
|
||||
|
||||
# CIFAR10 dataset class
|
||||
class CIFAR10(VisionDataset):
|
||||
"""
|
||||
CIFAR10 Dataset class that handles the CIFAR-10 dataset loading, processing, and apply_transformationations.
|
||||
|
||||
Args:
|
||||
root (str): Directory where the dataset is stored or will be downloaded to.
|
||||
train (bool, optional): If True, load the training set. Otherwise, load the test set.
|
||||
apply_transformation (callable, optional): A function/apply_transformation that takes a PIL image and returns a apply_transformationed version.
|
||||
target_apply_transformation (callable, optional): A function/apply_transformation that takes the target and apply_transformations it.
|
||||
download (bool, optional): If True, download the dataset if it's not found locally.
|
||||
split_factor (int, optional): Number of apply_transformationations applied to each image. Default is 1.
|
||||
"""
|
||||
# Directory and URL details for downloading the CIFAR-10 dataset
|
||||
base_folder = 'cifar-10-batches-py'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||
filename = "cifar-10-python.tar.gz"
|
||||
tgz_md5 = 'c58f30108f718f92721af3b95e74349a' # MD5 checksum to verify the file's integrity
|
||||
|
||||
# List of training batches with their corresponding MD5 checksums
|
||||
train_list = [
|
||||
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
||||
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
|
||||
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
|
||||
['data_batch_4', '634d18415352ddfa80567beed471001a'],
|
||||
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb']
|
||||
]
|
||||
|
||||
# List of test batches with their corresponding MD5 checksums
|
||||
test_list = [
|
||||
['test_batch', '40351d587109b95175f43aff81a1287e']
|
||||
]
|
||||
|
||||
# Info map to hold label names and their checksum
|
||||
info_map = {
|
||||
'filename': 'batches.info_map',
|
||||
'key': 'label_names',
|
||||
'md5': '5ff9c542aee3614f3951f8cda6e48888'
|
||||
}
|
||||
|
||||
# Initialization method
|
||||
def __init__(self, root, train=True, apply_transformation=None, target_apply_transformation=None, download=False, split_factor=1):
|
||||
super(CIFAR10, self).__init__(root, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation)
|
||||
self.train = train # Whether to load the training set or test set
|
||||
self.split_factor = split_factor # Number of apply_transformationations to apply
|
||||
|
||||
# Download dataset if necessary
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
# Check if the dataset is already downloaded and valid
|
||||
if not self._validate_integrity():
|
||||
raise RuntimeError('Dataset not found or corrupted. Use download=True to download it.')
|
||||
|
||||
# Load the dataset
|
||||
self.data, self.targets = self._load_data()
|
||||
|
||||
# Load the label info map (to get class names)
|
||||
self._load_info_map()
|
||||
|
||||
# Load dataset from the files
|
||||
def _load_data(self):
|
||||
data, targets = [], [] # Initialize lists to hold data and labels
|
||||
files = self.train_list if self.train else self.test_list # Choose train or test files
|
||||
|
||||
# Load each file, deserialize with pickle, and append data and labels
|
||||
for file_name, _ in files:
|
||||
file_path = os.path.join(self.root, self.base_folder, file_name)
|
||||
with open(file_path, 'rb') as f:
|
||||
entry = pickle.load(f, encoding='latin1') # Load file
|
||||
data.append(entry['data']) # Append image data
|
||||
targets.extend(entry.get('labels', entry.get('fine_labels', []))) # Append labels
|
||||
|
||||
# Reshape and format the data to (num_samples, height, width, channels)
|
||||
data = np.vstack(data).reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1)) # Reshape to HWC format
|
||||
return data, targets
|
||||
|
||||
# Load label names (info map)
|
||||
def _load_info_map(self):
|
||||
info_map_path = os.path.join(self.root, self.base_folder, self.info_map['filename']) # Path to info map
|
||||
if not validate_integrity(info_map_path, self.info_map['md5']): # Check integrity of info map
|
||||
raise RuntimeError('info_mapdata file not found or corrupted. Use download=True to download it.')
|
||||
|
||||
# Load the label names
|
||||
with open(info_map_path, 'rb') as info_map_file:
|
||||
info_map_data = pickle.load(info_map_file, encoding='latin1') # Load label names
|
||||
self.classes = info_map_data[self.info_map['key']] # Extract class labels
|
||||
self.class_to_idx = {label: idx for idx, label in enumerate(self.classes)} # Map class names to indices
|
||||
|
||||
# Get item (image and target) by index
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Get the item (image, target) at the specified index.
|
||||
Args:
|
||||
index (int): Index of the data.
|
||||
|
||||
Returns:
|
||||
tuple: apply_transformationed image and the target class.
|
||||
"""
|
||||
img, target = self.data[index], self.targets[index] # Get image and target label
|
||||
img = Image.fromarray(img) # Convert numpy array to PIL image
|
||||
|
||||
# Apply the apply_transformation multiple times based on split_factor
|
||||
imgs = [self.apply_transformation(img) for _ in range(self.split_factor)] if self.apply_transformation else None
|
||||
if imgs is None:
|
||||
raise NotImplementedError('apply_transformation must be provided.')
|
||||
|
||||
# Apply target apply_transformationation if available
|
||||
if self.target_apply_transformation:
|
||||
target = self.target_apply_transformation(target)
|
||||
|
||||
return torch.cat(imgs, dim=0), target # Return concatenated apply_transformationed images and the target
|
||||
|
||||
# Return the number of items in the dataset
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
# Check if the dataset files are valid and downloaded
|
||||
def _validate_integrity(self):
|
||||
files = self.train_list + self.test_list # All files to check
|
||||
for file_name, md5 in files:
|
||||
file_path = os.path.join(self.root, self.base_folder, file_name)
|
||||
if not validate_integrity(file_path, md5): # Verify integrity using MD5
|
||||
return False
|
||||
return True
|
||||
|
||||
# Download the dataset if it's not available
|
||||
def download(self):
|
||||
if self._validate_integrity():
|
||||
print('Files already downloaded and verified')
|
||||
else:
|
||||
fetch_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
|
||||
|
||||
# Representation string to include the split type (Train/Test)
|
||||
def extra_repr(self):
|
||||
return f"Split: {'Train' if self.train else 'Test'}"
|
||||
|
||||
|
||||
# CIFAR100 is a subclass of CIFAR10, with minor modifications
|
||||
class CIFAR100(CIFAR10):
|
||||
"""
|
||||
CIFAR100 Dataset, a subclass of CIFAR10.
|
||||
"""
|
||||
# Directory and URL details for downloading CIFAR-100 dataset
|
||||
base_folder = 'cifar-100-vision'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-100-vision.tar.gz"
|
||||
filename = "cifar-100-vision.tar.gz"
|
||||
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' # MD5 checksum
|
||||
|
||||
# Training and test lists with their corresponding MD5 checksums for CIFAR-100
|
||||
train_list = [
|
||||
['train', '16019d7e3df5f24257cddd939b257f8d']
|
||||
]
|
||||
|
||||
test_list = [
|
||||
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc']
|
||||
]
|
||||
|
||||
# Info map to hold fine label names and their checksum
|
||||
info_map = {
|
||||
'filename': 'info_map',
|
||||
'key': 'fine_label_names',
|
||||
'md5': '7973b15100ade9c7d40fb424638fde48'
|
||||
}
|
152
EdgeFLite/data_collection/dataset_factory.py
Normal file
152
EdgeFLite/data_collection/dataset_factory.py
Normal file
@ -0,0 +1,152 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
from torchvision import apply_transformations
|
||||
from .cifar import CIFAR10, CIFAR100 # Import CIFAR10 and CIFAR100 datasets
|
||||
from .autoaugment import CIFAR10Policy # Import CIFAR10 augmentation policy
|
||||
|
||||
__all__ = ['obtain_data_loader'] # Define the public API of this module
|
||||
|
||||
def obtain_data_loader(
|
||||
data_dir, # Directory where the data is stored
|
||||
split_factor=1, # Used for data partitioning, especially in federated learning
|
||||
batch_size=128, # Batch size for loading data
|
||||
crop_size=32, # Size to crop the input images
|
||||
dataset='cifar10', # Dataset to use (CIFAR-10 by default)
|
||||
split="train", # The split type: 'train', 'val', or 'test'
|
||||
is_decentralized=False, # Whether to use decentralized training
|
||||
is_autoaugment=1, # Use AutoAugment or not
|
||||
randaa=None, # Placeholder for randomized augmentations
|
||||
is_cutout=True, # Whether to apply cutout (random erasing)
|
||||
erase_p=0.5, # Probability of applying random erasing
|
||||
num_workers=8, # Number of workers to load data
|
||||
pin_memory=True, # Use pinned memory for better GPU transfer
|
||||
is_fed=False, # Whether to use federated learning
|
||||
num_clusters=20, # Number of clients in federated learning
|
||||
cifar10_non_iid=False, # Non-IID option for CIFAR-10 dataset
|
||||
cifar100_non_iid=False # Non-IID option for CIFAR-100 dataset
|
||||
):
|
||||
"""Get the dataset loader"""
|
||||
assert not (is_autoaugment and randaa is not None) # Autoaugment and randaa cannot be used together
|
||||
|
||||
# Loader settings based on multiprocessing
|
||||
kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory}
|
||||
assert split in ['train', 'val', 'test'] # Ensure valid split
|
||||
|
||||
# For CIFAR-10 dataset
|
||||
if dataset == 'cifar10':
|
||||
# Handle non-IID 'quantity skew' case for CIFAR-10
|
||||
if cifar10_non_iid == 'quantity_skew':
|
||||
non_iid = 'quantity_skew'
|
||||
# If in training split
|
||||
if 'train' in split:
|
||||
print(f"INFO:PyTorch: Using quantity_skew CIFAR10 dataset, batch size {batch_size} and crop size is {crop_size}.")
|
||||
traindir = data_dir # Set data directory
|
||||
# Define data apply_transformationations for training
|
||||
train_apply_transformation = apply_transformations.Compose([
|
||||
apply_transformations.ToPILImage(),
|
||||
apply_transformations.RandomCrop(32, padding=4),
|
||||
apply_transformations.RandomHorizontalFlip(),
|
||||
CIFAR10Policy(), # AutoAugment policy
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization
|
||||
apply_transformations.RandomErasing(p=erase_p, scale=(0.125, 0.2), ratio=(0.99, 1.0), value=0, inplace=False),
|
||||
])
|
||||
train_sampler = None
|
||||
print('INFO:PyTorch: creating quantity_skew CIFAR10 train dataloader...')
|
||||
|
||||
# For federated learning, create loaders for each client
|
||||
if is_fed:
|
||||
train_loader = obtain_data_loaders_train(
|
||||
traindir,
|
||||
nclients=num_clusters * split_factor, # Number of clients in federated learning
|
||||
batch_size=batch_size,
|
||||
verbose=True,
|
||||
apply_transformations_train=train_apply_transformation,
|
||||
non_iid=non_iid, # Specify non-IID type
|
||||
split_factor=split_factor
|
||||
)
|
||||
else:
|
||||
assert is_fed # Ensure that is_fed is True
|
||||
return train_loader, train_sampler
|
||||
else:
|
||||
# If in validation or test split
|
||||
valdir = data_dir # Set validation data directory
|
||||
# Define data apply_transformationations for validation/testing
|
||||
val_apply_transformation = apply_transformations.Compose([
|
||||
apply_transformations.ToPILImage(),
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization
|
||||
])
|
||||
# Create the test loader
|
||||
val_loader = obtain_data_loaders_test(
|
||||
valdir,
|
||||
nclients=num_clusters * split_factor, # Number of clients in federated learning
|
||||
batch_size=batch_size,
|
||||
verbose=True,
|
||||
apply_transformations_eval=val_apply_transformation,
|
||||
non_iid=non_iid,
|
||||
split_factor=1
|
||||
)
|
||||
return val_loader
|
||||
else:
|
||||
# For standard IID CIFAR-10 case
|
||||
if 'train' in split:
|
||||
print(f"INFO:PyTorch: Using CIFAR10 dataset, batch size {batch_size} and crop size is {crop_size}.")
|
||||
traindir = data_dir # Set training data directory
|
||||
# Define data apply_transformationations for training
|
||||
train_apply_transformation = apply_transformations.Compose([
|
||||
apply_transformations.RandomCrop(32, padding=4),
|
||||
apply_transformations.RandomHorizontalFlip(),
|
||||
CIFAR10Policy(), # AutoAugment policy
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization
|
||||
apply_transformations.RandomErasing(p=erase_p, scale=(0.125, 0.2), ratio=(0.99, 1.0), value=0, inplace=False),
|
||||
])
|
||||
# Create the CIFAR-10 dataset object
|
||||
train_dataset = CIFAR10(
|
||||
traindir, train=True, apply_transformation=train_apply_transformation, target_apply_transformation=None, download=True, split_factor=split_factor
|
||||
)
|
||||
train_sampler = None # No sampler by default
|
||||
|
||||
# Decentralized training setup
|
||||
if is_decentralized:
|
||||
train_sampler = torch.utils.data.decentralized.decentralizedSampler(train_dataset, shuffle=True)
|
||||
|
||||
print('INFO:PyTorch: creating CIFAR10 train dataloader...')
|
||||
if is_fed:
|
||||
# Federated learning setup
|
||||
images_per_client = int(train_dataset.data.shape[0] / (num_clusters * split_factor))
|
||||
print(f"Images per client: {images_per_client}")
|
||||
data_split = [images_per_client for _ in range(num_clusters * split_factor - 1)]
|
||||
data_split.append(len(train_dataset) - images_per_client * (num_clusters * split_factor - 1))
|
||||
# Split dataset for each client
|
||||
traindata_split = torch.utils.data.random_split(train_dataset, data_split, generator=torch.Generator().manual_seed(68))
|
||||
# Create data loaders for each client
|
||||
train_loader = [torch.utils.data.DataLoader(
|
||||
x, batch_size=batch_size, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler, **kwargs
|
||||
) for x in traindata_split]
|
||||
else:
|
||||
# Standard data loader
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), drop_last=True, sampler=train_sampler, **kwargs
|
||||
)
|
||||
return train_loader, train_sampler
|
||||
else:
|
||||
# For validation or test split
|
||||
valdir = data_dir # Set validation data directory
|
||||
# Define data apply_transformationations for validation/testing
|
||||
val_apply_transformation = apply_transformations.Compose([
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 normalization
|
||||
])
|
||||
# Create CIFAR-10 dataset object for validation
|
||||
val_dataset = CIFAR10(valdir, train=False, apply_transformation=val_apply_transformation, target_apply_transformation=None, download=True, split_factor=1)
|
||||
print('INFO:PyTorch: creating CIFAR10 validation dataloader...')
|
||||
# Create data loader for validation
|
||||
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, **kwargs)
|
||||
return val_loader
|
||||
# Additional dataset logic for CIFAR-100, decentralized setups, or other datasets can be added similarly.
|
||||
else:
|
||||
raise NotImplementedError(f"The DataLoader for {dataset} is not implemented.")
|
194
EdgeFLite/data_collection/dataset_imagenet.py
Normal file
194
EdgeFLite/data_collection/dataset_imagenet.py
Normal file
@ -0,0 +1,194 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import torch
|
||||
from .folder import ImageFolder
|
||||
from .utils import validate_integrity, extract_archive, verify_str_arg
|
||||
|
||||
# Dictionary that maps the dataset split (train/val/devkit) to its corresponding archive filename and checksum (md5 hash)
|
||||
ARCHIVE_info_map = {
|
||||
'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'),
|
||||
'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'),
|
||||
'devkit': ('ILSVRC2012_devkit_t12.tar', 'fa75699e90414af021442c21a62c3abf')
|
||||
}
|
||||
|
||||
# File name where the information map (class info, wnid, etc.) is stored
|
||||
info_map_FILE = "info_map.bin"
|
||||
|
||||
class ImageNet(ImageFolder):
|
||||
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
|
||||
|
||||
Args:
|
||||
root (str): Root directory of the ImageNet Dataset.
|
||||
split (str, optional): Dataset split, either ``train`` or ``val``.
|
||||
apply_transformation (callable, optional): A function/apply_transformation to apply to the PIL image.
|
||||
target_apply_transformation (callable, optional): A function/apply_transformation to apply to the target.
|
||||
loader (callable, optional): Function to load an image from its path.
|
||||
|
||||
Attributes:
|
||||
classes (list): List of class name tuples.
|
||||
class_to_idx (dict): Mapping of class names to indices.
|
||||
wnids (list): List of WordNet IDs.
|
||||
wnid_to_idx (dict): Mapping of WordNet IDs to class indices.
|
||||
imgs (list): List of image path and class index tuples.
|
||||
targets (list): Class index values for each image in the dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, root, split='train', download=None, **kwargs):
|
||||
# Check if download flag is used, raise warnings since dataset is no longer publicly accessible
|
||||
if download is True:
|
||||
raise RuntimeError("The dataset is no longer publicly accessible. Please download archives externally and place them in the root directory.")
|
||||
elif download is False:
|
||||
warnings.warn("The download flag is deprecated, as the dataset is no longer publicly accessible.", RuntimeWarning)
|
||||
|
||||
# Expand the root directory path
|
||||
root = self.root = os.path.expanduser(root)
|
||||
|
||||
# Validate the dataset split (should be either 'train' or 'val')
|
||||
self.split = verify_str_arg(split, "split", ("train", "val"))
|
||||
|
||||
# Parse dataset archives (train/val/devkit) and prepare the dataset
|
||||
self.extract_archives()
|
||||
|
||||
# Load WordNet ID to class mappings from the info_map file
|
||||
wnid_to_classes = load_information_map_file(self.root)[0]
|
||||
|
||||
# Initialize the ImageFolder with the split folder (train/val directory)
|
||||
super().__init__(self.divide_folder_contents, **kwargs)
|
||||
|
||||
# Set class-related attributes
|
||||
self.root = root
|
||||
self.wnids = self.classes
|
||||
self.wnid_to_idx = self.class_to_idx
|
||||
|
||||
# Update classes to human-readable names and adjust the class_to_idx mapping
|
||||
self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
|
||||
self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss}
|
||||
|
||||
def extract_archives(self):
|
||||
# Check if the info_map file exists and is valid, otherwise parse the devkit archive
|
||||
if not validate_integrity(os.path.join(self.root, info_map_FILE)):
|
||||
extract_devkit_archive(self.root)
|
||||
|
||||
# If the dataset folder (train/val) does not exist, extract the respective archive
|
||||
if not os.path.isdir(self.divide_folder_contents):
|
||||
if self.split == 'train':
|
||||
process_train_archive(self.root)
|
||||
elif self.split == 'val':
|
||||
process_validation_archive(self.root)
|
||||
|
||||
@property
|
||||
def divide_folder_contents(self):
|
||||
# Return the path of the folder containing the images (train/val)
|
||||
return os.path.join(self.root, self.split)
|
||||
|
||||
def extra_repr(self):
|
||||
# Additional representation for the dataset object (showing the split)
|
||||
return f"Split: {self.split}"
|
||||
|
||||
def load_information_map_file(root, file=None):
|
||||
# Load the info_map file from the root directory
|
||||
file = os.path.join(root, file or info_map_FILE)
|
||||
if validate_integrity(file):
|
||||
return torch.load(file)
|
||||
else:
|
||||
raise RuntimeError(f"The info_map file {file} is either missing or corrupted. Please ensure it exists in the root directory.")
|
||||
|
||||
def _validate_archive_file(root, file, md5):
|
||||
# Verify if the archive file is present and its checksum matches
|
||||
if not validate_integrity(os.path.join(root, file), md5):
|
||||
raise RuntimeError(f"The archive {file} is either missing or corrupted. Please download it and place it in {root}.")
|
||||
|
||||
def extract_devkit_archive(root, file=None):
|
||||
"""Extract and process the ImageNet 2012 devkit archive to generate info_map information.
|
||||
|
||||
Args:
|
||||
root (str): Root directory with the devkit archive.
|
||||
file (str, optional): Archive filename. Defaults to 'ILSVRC2012_devkit_t12.tar'.
|
||||
"""
|
||||
import scipy.io as sio
|
||||
|
||||
# Parse info_map.mat from the devkit, containing class and WordNet ID information
|
||||
def read_info_map_mat_file(devkit_root):
|
||||
info_map_path = os.path.join(devkit_root, "data", "info_map.mat")
|
||||
info_map = sio.loadmat(info_map_path, squeeze_me=True)['synsets']
|
||||
info_map = [info_map[idx] for idx, num_children in enumerate(info_map[4]) if num_children == 0]
|
||||
idcs, wnids, classes = zip(*info_map)[:3]
|
||||
classes = [tuple(clss.split(', ')) for clss in classes]
|
||||
return {idx: wnid for idx, wnid in zip(idcs, wnids)}, {wnid: clss for wnid, clss in zip(wnids, classes)}
|
||||
|
||||
# Parse the validation ground truth file for image class labels
|
||||
def process_val_groundtruth_txt(devkit_root):
|
||||
file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt")
|
||||
with open(file) as f:
|
||||
return [int(line.strip()) for line in f]
|
||||
|
||||
# Context manager to handle temporary directories for archive extraction
|
||||
@contextmanager
|
||||
def get_tmp_dir():
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
yield tmp_dir
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
# Extract and process the devkit archive
|
||||
file, md5 = ARCHIVE_info_map["devkit"]
|
||||
_validate_archive_file(root, file, md5)
|
||||
|
||||
with get_tmp_dir() as tmp_dir:
|
||||
extract_archive(os.path.join(root, file), tmp_dir)
|
||||
devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12")
|
||||
idx_to_wnid, wnid_to_classes = read_info_map_mat_file(devkit_root)
|
||||
val_idcs = process_val_groundtruth_txt(devkit_root)
|
||||
val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
|
||||
|
||||
# Save the mappings to the info_map file
|
||||
torch.save((wnid_to_classes, val_wnids), os.path.join(root, info_map_FILE))
|
||||
|
||||
def process_train_archive(root, file=None, folder="train"):
|
||||
"""Extract and organize the ImageNet 2012 train dataset.
|
||||
|
||||
Args:
|
||||
root (str): Root directory containing the train dataset archive.
|
||||
file (str, optional): Archive filename. Defaults to 'ILSVRC2012_img_train.tar'.
|
||||
folder (str, optional): Destination folder. Defaults to 'train'.
|
||||
"""
|
||||
file, md5 = ARCHIVE_info_map["train"]
|
||||
_validate_archive_file(root, file, md5)
|
||||
|
||||
train_root = os.path.join(root, folder)
|
||||
extract_archive(os.path.join(root, file), train_root)
|
||||
|
||||
# Extract each class-specific archive in the train dataset
|
||||
for archive in os.listdir(train_root):
|
||||
extract_archive(os.path.join(train_root, archive), os.path.splitext(archive)[0], remove_finished=True)
|
||||
|
||||
def process_validation_archive(root, file=None, wnids=None, folder="val"):
|
||||
"""Extract and organize the ImageNet 2012 validation dataset.
|
||||
|
||||
Args:
|
||||
root (str): Root directory containing the validation dataset archive.
|
||||
file (str, optional): Archive filename. Defaults to 'ILSVRC2012_img_val.tar'.
|
||||
wnids (list, optional): WordNet IDs for validation images. Defaults to None (loaded from info_map file).
|
||||
folder (str, optional): Destination folder. Defaults to 'val'.
|
||||
"""
|
||||
file, md5 = ARCHIVE_info_map["val"]
|
||||
if wnids is None:
|
||||
wnids = load_information_map_file(root)[1]
|
||||
|
||||
_validate_archive_file(root, file, md5)
|
||||
|
||||
val_root = os.path.join(root, folder)
|
||||
extract_archive(os.path.join(root, file), val_root)
|
||||
|
||||
# Create directories for each WordNet ID (class) and move validation images into their respective folders
|
||||
for wnid in set(wnids):
|
||||
os.mkdir(os.path.join(val_root, wnid))
|
||||
|
||||
for wnid, img in zip(wnids, sorted(os
|
229
EdgeFLite/data_collection/directory_utils.py
Normal file
229
EdgeFLite/data_collection/directory_utils.py
Normal file
@ -0,0 +1,229 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Import necessary modules
|
||||
from .vision import VisionDataset # Import the base VisionDataset class
|
||||
from PIL import Image # Import PIL for image loading and processing
|
||||
import os # For interacting with the file system
|
||||
import torch # PyTorch for tensor operations
|
||||
|
||||
# Function to check if a file has an allowed extension
|
||||
def validate_file_extension(filename, extensions):
|
||||
"""
|
||||
Check if a file has an allowed extension.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the file.
|
||||
extensions (tuple of str): Extensions to consider (in lowercase).
|
||||
|
||||
Returns:
|
||||
bool: True if the filename ends with one of the given extensions.
|
||||
"""
|
||||
return filename.lower().endswith(extensions)
|
||||
|
||||
# Function to check if a file is an image
|
||||
def is_image_file(filename):
|
||||
"""
|
||||
Check if a file is an image based on its extension.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the file.
|
||||
|
||||
Returns:
|
||||
bool: True if the filename is a known image format.
|
||||
"""
|
||||
return validate_file_extension(filename, IMG_EXTENSIONS)
|
||||
|
||||
# Function to create a dataset of file paths and their corresponding class indices
|
||||
def generate_dataset(directory, class_to_idx, extensions=None, is_valid_file=None):
|
||||
"""
|
||||
Creates a list of file paths and their corresponding class indices.
|
||||
|
||||
Args:
|
||||
directory (str): Root directory.
|
||||
class_to_idx (dict): Mapping of class names to class indices.
|
||||
extensions (tuple, optional): Allowed file extensions.
|
||||
is_valid_file (callable, optional): Function to validate files.
|
||||
|
||||
Returns:
|
||||
list: A list of (file_path, class_index) tuples.
|
||||
"""
|
||||
instances = []
|
||||
directory = os.path.expanduser(directory) # Expand user directory path if needed
|
||||
|
||||
# Ensure only one of extensions or is_valid_file is specified
|
||||
if (extensions is None and is_valid_file is None) or (extensions is not None and is_valid_file is not None):
|
||||
raise ValueError("Specify either 'extensions' or 'is_valid_file', but not both.")
|
||||
|
||||
# Define the validation function if extensions are provided
|
||||
if extensions is not None:
|
||||
def is_valid_file(x):
|
||||
return validate_file_extension(x, extensions)
|
||||
|
||||
# Iterate through the directory, searching for valid image files
|
||||
for target_class in sorted(class_to_idx.keys()):
|
||||
class_index = class_to_idx[target_class] # Get the class index
|
||||
target_dir = os.path.join(directory, target_class) # Define the target class folder
|
||||
if not os.path.isdir(target_dir): # Skip if it's not a directory
|
||||
continue
|
||||
# Walk through the directory and subdirectories
|
||||
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
|
||||
for fname in sorted(fnames):
|
||||
path = os.path.join(root, fname) # Full file path
|
||||
if is_valid_file(path): # Check if it's a valid file
|
||||
instances.append((path, class_index)) # Append file path and class index to the list
|
||||
|
||||
return instances # Return the dataset
|
||||
|
||||
# DatasetFolder class: Generic data loader for samples arranged in subdirectories by class
|
||||
class DatasetFolder(VisionDataset):
|
||||
"""
|
||||
A generic data loader where samples are arranged in subdirectories by class.
|
||||
|
||||
Args:
|
||||
root (str): Root directory path.
|
||||
loader (callable): Function to load a sample from its file path.
|
||||
extensions (tuple[str]): Allowed file extensions.
|
||||
apply_transformation (callable, optional): apply_transformation applied to each sample.
|
||||
target_apply_transformation (callable, optional): apply_transformation applied to each target.
|
||||
is_valid_file (callable, optional): Function to validate files.
|
||||
split_factor (int, optional): Number of times to apply the apply_transformation.
|
||||
|
||||
Attributes:
|
||||
classes (list): Sorted list of class names.
|
||||
class_to_idx (dict): Mapping of class names to class indices.
|
||||
samples (list): List of (sample_path, class_index) tuples.
|
||||
targets (list): List of class indices corresponding to each sample.
|
||||
"""
|
||||
|
||||
def __init__(self, root, loader, extensions=None, apply_transformation=None,
|
||||
target_apply_transformation=None, is_valid_file=None, split_factor=1):
|
||||
super().__init__(root, apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation)
|
||||
self.classes, self.class_to_idx = self._discover_classes(self.root) # Discover classes in the root directory
|
||||
self.samples = generate_dataset(self.root, self.class_to_idx, extensions, is_valid_file) # Create dataset from files
|
||||
|
||||
# Raise an error if no valid files are found
|
||||
if len(self.samples) == 0:
|
||||
raise RuntimeError(f"Found 0 files in subfolders of: {self.root}. "
|
||||
f"Supported extensions are: {','.join(extensions)}")
|
||||
|
||||
self.loader = loader # Function to load a sample
|
||||
self.extensions = extensions # Allowed file extensions
|
||||
self.targets = [s[1] for s in self.samples] # List of target class indices
|
||||
self.split_factor = split_factor # Number of apply_transformationations to apply
|
||||
|
||||
# Function to find class subdirectories in the root directory
|
||||
def _discover_classes(self, dir):
|
||||
"""
|
||||
Discover class subdirectories in the root directory.
|
||||
|
||||
Args:
|
||||
dir (str): Root directory.
|
||||
|
||||
Returns:
|
||||
tuple: (classes, class_to_idx) where classes are subdirectories of 'dir',
|
||||
and class_to_idx is a mapping of class names to indices.
|
||||
"""
|
||||
classes = sorted([d.name for d in os.scandir(dir) if d.is_dir()]) # List of subdirectory names (classes)
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))} # Map class names to indices
|
||||
return classes, class_to_idx
|
||||
|
||||
# Function to get a sample and its target by index
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Retrieve a sample and its target by index.
|
||||
|
||||
Args:
|
||||
index (int): Index of the sample.
|
||||
|
||||
Returns:
|
||||
tuple: (sample, target), where the sample is the apply_transformationed image and
|
||||
the target is the class index.
|
||||
"""
|
||||
path, target = self.samples[index] # Get the file path and target class index
|
||||
sample = self.loader(path) # Load the sample (image)
|
||||
|
||||
# Apply apply_transformationation to the sample 'split_factor' times
|
||||
imgs = [self.apply_transformation(sample) for _ in range(self.split_factor)] if self.apply_transformation else NotImplementedError
|
||||
|
||||
# Apply target apply_transformationation if specified
|
||||
if self.target_apply_transformation:
|
||||
target = self.target_apply_transformation(target)
|
||||
|
||||
return torch.cat(imgs, dim=0), target # Return concatenated apply_transformationed images and the target
|
||||
|
||||
# Return the number of samples in the dataset
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
# List of supported image file extensions
|
||||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
|
||||
|
||||
# Function to load an image using PIL
|
||||
def load_image_pil(path):
|
||||
"""
|
||||
Load an image from the given path using PIL.
|
||||
|
||||
Args:
|
||||
path (str): Path to the image.
|
||||
|
||||
Returns:
|
||||
Image: RGB image.
|
||||
"""
|
||||
with open(path, 'rb') as f:
|
||||
img = Image.open(f) # Open the image file
|
||||
return img.convert('RGB') # Convert the image to RGB format
|
||||
|
||||
# Function to load an image using accimage library with fallback to PIL
|
||||
def load_accimage(path):
|
||||
"""
|
||||
Load an image using the accimage library, falling back to PIL on failure.
|
||||
|
||||
Args:
|
||||
path (str): Path to the image.
|
||||
|
||||
Returns:
|
||||
Image: Image loaded with accimage or PIL.
|
||||
"""
|
||||
import accimage # accimage is a faster image loading library
|
||||
try:
|
||||
return accimage.Image(path) # Try loading with accimage
|
||||
except IOError:
|
||||
return load_image_pil(path) # Fall back to PIL on error
|
||||
|
||||
# Function to load an image using the default backend (accimage or PIL)
|
||||
def basic_loader(path):
|
||||
"""
|
||||
Load an image using the default image backend (accimage or PIL).
|
||||
|
||||
Args:
|
||||
path (str): Path to the image.
|
||||
|
||||
Returns:
|
||||
Image: Loaded image.
|
||||
"""
|
||||
from torchvision import get_image_backend # Get the default image backend
|
||||
return load_accimage(path) if get_image_backend() == 'accimage' else load_image_pil(path) # Load using the appropriate backend
|
||||
|
||||
# ImageFolder class: A dataset loader for images arranged in subdirectories by class
|
||||
class ImageFolder(DatasetFolder):
|
||||
"""
|
||||
A dataset loader for images arranged in subdirectories by class.
|
||||
|
||||
Args:
|
||||
root (str): Root directory path.
|
||||
apply_transformation (callable, optional): apply_transformation applied to each image.
|
||||
target_apply_transformation (callable, optional): apply_transformation applied to each target.
|
||||
loader (callable, optional): Function to load an image from its path.
|
||||
is_valid_file (callable, optional): Function to validate files.
|
||||
|
||||
Attributes:
|
||||
classes (list): Sorted list of class names.
|
||||
class_to_idx (dict): Mapping of class names to class indices.
|
||||
imgs (list): List of (image_path, class_index) tuples.
|
||||
"""
|
||||
|
||||
def __init__(self, root, apply_transformation=None, target_apply_transformation=None, loader=basic_loader, is_valid_file=None, split_factor=1):
|
||||
super().__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
|
||||
apply_transformation=apply_transformation, target_apply_transformation=target_apply_transformation,
|
||||
is_valid_file=is
|
173
EdgeFLite/data_collection/helper_utils.py
Normal file
173
EdgeFLite/data_collection/helper_utils.py
Normal file
@ -0,0 +1,173 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import gzip
|
||||
import tarfile
|
||||
import zipfile
|
||||
import urllib.request
|
||||
from torch.utils.model_zoo import tqdm
|
||||
|
||||
def generate_update_progress_barr():
|
||||
"""Generates a progress bar for tracking download progress."""
|
||||
pbar = tqdm(total=None)
|
||||
|
||||
def update_progress_bar(count, block_size, total_size):
|
||||
"""Updates the progress bar based on the downloaded data size."""
|
||||
if pbar.total is None and total_size:
|
||||
pbar.total = total_size
|
||||
progress_bytes = count * block_size
|
||||
pbar.update(progress_bytes - pbar.n)
|
||||
|
||||
return update_progress_bar
|
||||
|
||||
def compute_md5_checksum(fpath, chunk_size=1024 * 1024):
|
||||
"""Calculates the MD5 checksum for a given file."""
|
||||
md5 = hashlib.md5()
|
||||
with open(fpath, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b''):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()
|
||||
|
||||
def verify_md5_checksum(fpath, md5):
|
||||
"""Checks if the MD5 of a file matches the given checksum."""
|
||||
return md5 == compute_md5_checksum(fpath)
|
||||
|
||||
def validate_integrity(fpath, md5=None):
|
||||
"""Checks the integrity of a file by verifying its existence and MD5 checksum."""
|
||||
if not os.path.isfile(fpath):
|
||||
return False
|
||||
return md5 is None or verify_md5_checksum(fpath, md5)
|
||||
|
||||
def download_url(url, root, filename=None, md5=None):
|
||||
"""Download a file from a URL and save it in the specified directory."""
|
||||
root = os.path.expanduser(root)
|
||||
filename = filename or os.path.basename(url)
|
||||
fpath = os.path.join(root, filename)
|
||||
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
if validate_integrity(fpath, md5):
|
||||
print('Using downloaded and verified file: ' + fpath)
|
||||
return
|
||||
|
||||
try:
|
||||
print('Downloading ' + url + ' to ' + fpath)
|
||||
urllib.request.urlretrieve(url, fpath, reporthook=generate_update_progress_barr())
|
||||
except (urllib.error.URLError, IOError) as e:
|
||||
if url.startswith('https'):
|
||||
url = url.replace('https:', 'http:')
|
||||
print('Failed download. Retrying with http.')
|
||||
urllib.request.urlretrieve(url, fpath, reporthook=generate_update_progress_barr())
|
||||
else:
|
||||
raise e
|
||||
|
||||
if not validate_integrity(fpath, md5):
|
||||
raise RuntimeError("File not found or corrupted.")
|
||||
|
||||
def list_dir(root, prefix=False):
|
||||
"""List all directories at the specified root."""
|
||||
root = os.path.expanduser(root)
|
||||
directories = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
|
||||
|
||||
return [os.path.join(root, d) for d in directories] if prefix else directories
|
||||
|
||||
def list_files(root, suffix, prefix=False):
|
||||
"""List all files with a specific suffix in the specified root."""
|
||||
root = os.path.expanduser(root)
|
||||
files = [f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f)) and f.endswith(suffix)]
|
||||
|
||||
return [os.path.join(root, f) for f in files] if prefix else files
|
||||
|
||||
def fetch_file_google_drive(file_id, root, filename=None, md5=None):
|
||||
"""Download a file from Google Drive and save it in the specified directory."""
|
||||
url = "https://docs.google.com/uc?export=download"
|
||||
root = os.path.expanduser(root)
|
||||
filename = filename or file_id
|
||||
fpath = os.path.join(root, filename)
|
||||
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
if os.path.isfile(fpath) and validate_integrity(fpath, md5):
|
||||
print('Using downloaded and verified file: ' + fpath)
|
||||
return
|
||||
|
||||
session = requests.Session()
|
||||
response = session.get(url, params={'id': file_id}, stream=True)
|
||||
token = _get_confirm_token(response)
|
||||
|
||||
if token:
|
||||
params = {'id': file_id, 'confirm': token}
|
||||
response = session.get(url, params=params, stream=True)
|
||||
|
||||
_store_response_content(response, fpath)
|
||||
|
||||
def _get_confirm_token(response):
|
||||
"""Extract the download token from Google Drive cookies."""
|
||||
return next((value for key, value in response.cookies.items() if key.startswith('download_warning')), None)
|
||||
|
||||
def _store_response_content(response, destination, chunk_size=32768):
|
||||
"""Save the response content to a file in chunks."""
|
||||
with open(destination, "wb") as f:
|
||||
pbar = tqdm(total=None)
|
||||
progress = 0
|
||||
for chunk in response.iter_content(chunk_size):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
progress += len(chunk)
|
||||
pbar.update(progress - pbar.n)
|
||||
pbar.close()
|
||||
|
||||
def extract_archive(from_path, to_path=None, remove_finished=False):
|
||||
"""Extract an archive file (tar, zip, gz) to the specified path."""
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
if from_path.endswith((".tar", ".tar.gz", ".tgz", ".tar.xz")):
|
||||
mode = 'r' + ('.gz' if from_path.endswith(('.tar.gz', '.tgz')) else
|
||||
'.xz' if from_path.endswith('.tar.xz') else '')
|
||||
with tarfile.open(from_path, mode) as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif from_path.endswith(".gz"):
|
||||
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
|
||||
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
|
||||
out_f.write(zip_f.read())
|
||||
elif from_path.endswith(".zip"):
|
||||
with zipfile.ZipFile(from_path, 'r') as z:
|
||||
z.extractall(to_path)
|
||||
else:
|
||||
raise ValueError("Extraction of {} not supported".format(from_path))
|
||||
|
||||
if remove_finished:
|
||||
os.remove(from_path)
|
||||
|
||||
def fetch_and_extract_archive(url, download_root, extract_root=None, filename=None, md5=None, remove_finished=False):
|
||||
"""Download and extract an archive file from a URL."""
|
||||
download_root = os.path.expanduser(download_root)
|
||||
extract_root = extract_root or download_root
|
||||
filename = filename or os.path.basename(url)
|
||||
|
||||
download_url(url, download_root, filename, md5)
|
||||
archive = os.path.join(download_root, filename)
|
||||
print("Extracting {} to {}".format(archive, extract_root))
|
||||
extract_archive(archive, extract_root, remove_finished)
|
||||
|
||||
def iterable_to_str(iterable):
|
||||
"""Convert an iterable to a string representation."""
|
||||
return "'" + "', '".join(map(str, iterable)) + "'"
|
||||
|
||||
def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
|
||||
"""Verify that a string argument is valid and raise an error if not."""
|
||||
if not isinstance(value, str):
|
||||
msg = f"Expected type str" + (f" for argument {arg}" if arg else "") + f", but got type {type(value)}."
|
||||
raise ValueError(msg)
|
||||
|
||||
if valid_values is None:
|
||||
return value
|
||||
|
||||
if value not in valid_values:
|
||||
msg = custom_msg or f"Unknown value '{value}' for argument {arg}. Valid values are {{{iterable_to_str(valid_values)}}}."
|
||||
raise ValueError(msg)
|
||||
|
||||
return value
|
84
EdgeFLite/data_collection/pill_data_base.py
Normal file
84
EdgeFLite/data_collection/pill_data_base.py
Normal file
@ -0,0 +1,84 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
import torch
|
||||
import os
|
||||
|
||||
# Importing the HOME configuration
|
||||
from config import HOME
|
||||
|
||||
class PillDataBase(Dataset):
|
||||
def __init__(self, data_dir=HOME + '/dataset_hub/pill_base', train=True, apply_transformation=None, split_factor=1):
|
||||
"""
|
||||
Initialize the dataset.
|
||||
|
||||
Args:
|
||||
data_dir (str): Directory where the dataset is stored.
|
||||
train (bool): Flag to indicate if it's a training or testing dataset.
|
||||
apply_transformation (callable): Optional apply_transformationation applied to images (e.g., resizing, normalization).
|
||||
split_factor (int): Number of times each image is split into parts for augmentation purposes.
|
||||
"""
|
||||
self.train = train
|
||||
self.apply_transformation = apply_transformation
|
||||
self.split_factor = split_factor
|
||||
self.data_dir = data_dir + '/pill_base'
|
||||
self.dataset = self._load_data()
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of samples in the dataset."""
|
||||
return len(self.dataset)
|
||||
|
||||
def _load_data(self):
|
||||
"""
|
||||
Load the dataset by reading the corresponding text file (train.txt or test.txt).
|
||||
|
||||
The dataset text file contains the image file paths and corresponding labels.
|
||||
|
||||
Returns:
|
||||
dataset (list): List of image file paths and their respective labels.
|
||||
"""
|
||||
dataset = []
|
||||
txt_path = os.path.join(self.data_dir, 'train.txt' if self.train else 'test.txt')
|
||||
|
||||
with open(txt_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
# Each line contains an image path and a label separated by space
|
||||
filename, label = line.strip().split(' ')
|
||||
# Adjust the image path to the correct directory structure
|
||||
filename = filename.replace('/home/tung/Tung/research/Open-Pill/FACIL/data/Pill_Base_X', self.data_dir)
|
||||
# Append the image file path and label as an integer
|
||||
dataset.append([filename, int(label)])
|
||||
|
||||
return dataset
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Retrieve a specific sample from the dataset at the given index.
|
||||
|
||||
Args:
|
||||
index (int): Index of the image and label to retrieve.
|
||||
|
||||
Returns:
|
||||
tuple: A tensor of concatenated apply_transformationed images and the corresponding label.
|
||||
"""
|
||||
images = []
|
||||
image_path = self.dataset[index][0]
|
||||
label = torch.tensor(int(self.dataset[index][1]))
|
||||
|
||||
# Open the image file
|
||||
image = Image.open(image_path)
|
||||
|
||||
# Apply apply_transformationations to the image if provided and split into parts as specified by split_factor
|
||||
if self.apply_transformation:
|
||||
for _ in range(self.split_factor):
|
||||
images.append(self.apply_transformation(image))
|
||||
|
||||
# Concatenate all apply_transformationed image splits into a single tensor
|
||||
return torch.cat(images, dim=0), label
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example of how to instantiate and use the dataset
|
||||
dataset = PillDataBase()
|
83
EdgeFLite/data_collection/pill_data_large.py
Normal file
83
EdgeFLite/data_collection/pill_data_large.py
Normal file
@ -0,0 +1,83 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import os
|
||||
import glob
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
# Define the folder paths for training and testing datasets
|
||||
FOLDER_PATHS = [
|
||||
'/media/skydata/alpha0012/workspace/EdgeFLite/coremodel/dataset_hub/medical_images/train_images',
|
||||
'/media/skydata/alpha0012/workspace/EdgeFLite/coremodel/dataset_hub/medical_images/test_images'
|
||||
]
|
||||
|
||||
# Custom dataset class inheriting from PyTorch's Dataset class
|
||||
class PillDataLarge(Dataset):
|
||||
def __init__(self, train=True, apply_transformation=None, split_factor=1):
|
||||
"""
|
||||
Initializes the dataset object.
|
||||
|
||||
Args:
|
||||
- train (bool): If True, load the training dataset, otherwise load the test dataset.
|
||||
- apply_transformation (callable, optional): Optional apply_transformationations to be applied on an image sample.
|
||||
- split_factor (int): Number of times to apply the apply_transformationations to the image.
|
||||
"""
|
||||
self.train = train # Flag to determine if the dataset is for training or testing
|
||||
self.apply_transformation = apply_transformation # apply_transformationation to apply to the images
|
||||
self.split_factor = split_factor # Number of times to apply the apply_transformationation
|
||||
self.dataset = self._load_data() # Load the dataset
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns the total number of samples in the dataset.
|
||||
"""
|
||||
return len(self.dataset)
|
||||
|
||||
def _load_data(self):
|
||||
"""
|
||||
Loads the data from the dataset folders.
|
||||
|
||||
Returns:
|
||||
- dataset (list): A list containing image file paths and their corresponding class IDs.
|
||||
"""
|
||||
folder_path = FOLDER_PATHS[0] if self.train else FOLDER_PATHS[1] # Use train or test folder path
|
||||
class_names = sorted(os.listdir(folder_path)) # Get class names from folder
|
||||
class_map = {name: idx for idx, name in enumerate(class_names)} # Map class names to IDs
|
||||
|
||||
dataset = []
|
||||
for class_name, class_id in class_map.items():
|
||||
folder_class = os.path.join(folder_path, class_name) # Path to class folder
|
||||
files_jpg = glob.glob(os.path.join(folder_class, '**', '*.jpg'), recursive=True) # Get all jpg files
|
||||
for file_path in files_jpg:
|
||||
dataset.append([file_path, class_id]) # Append file path and class ID to the dataset
|
||||
|
||||
return dataset
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Returns a sample and its corresponding label from the dataset.
|
||||
|
||||
Args:
|
||||
- index (int): Index of the sample.
|
||||
|
||||
Returns:
|
||||
- tuple: A tuple of the image tensor and the label tensor.
|
||||
"""
|
||||
Xs = [] # List to store apply_transformationed images
|
||||
image_path = self.dataset[index][0] # Get image path from dataset
|
||||
label = torch.tensor(int(self.dataset[index][1])) # Get class label as tensor
|
||||
|
||||
X = Image.open(image_path) # Open the image using PIL
|
||||
|
||||
if self.apply_transformation:
|
||||
for _ in range(self.split_factor):
|
||||
Xs.append(self.apply_transformation(X)) # Apply apply_transformationation multiple times
|
||||
|
||||
return torch.cat(Xs, dim=0), label # Concatenate all apply_transformationed images and return with the label
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = PillDataLarge() # Create an instance of the dataset
|
||||
print(len(dataset)) # Print the size of the dataset
|
||||
print(dataset[0]) # Print the first sample of the dataset
|
46
EdgeFLite/data_collection/skin_dataset.py
Normal file
46
EdgeFLite/data_collection/skin_dataset.py
Normal file
@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Import necessary libraries for image processing and handling datasets.
|
||||
from PIL import Image # Used for opening and manipulating images.
|
||||
from cv2 import split # A function from OpenCV, though it's not used here. It may have been intended for something else.
|
||||
from torch.utils.data import DataLoader, Dataset # These are PyTorch utilities for managing datasets and data loading.
|
||||
import torch # PyTorch library for tensor operations and deep learning.
|
||||
|
||||
# Define a custom dataset class named 'SkinData' which inherits from PyTorch's Dataset class.
|
||||
class SkinData(Dataset):
|
||||
# Initialize the dataset with a DataFrame (df), an optional apply_transformationation (apply_transformation), and a split factor (split_factor).
|
||||
def __init__(self, df, apply_transformation=None, split_factor=1):
|
||||
self.df = df # Store the DataFrame containing image paths and target labels.
|
||||
self.apply_transformation = apply_transformation # Optional image apply_transformationations to apply (e.g., resizing, normalizing).
|
||||
self.split_factor = split_factor # A factor determining how many times to split or augment the image.
|
||||
self.test_same_view = False # A flag indicating whether to return multiple augmentations of the same image.
|
||||
|
||||
# Return the number of samples in the dataset, which corresponds to the number of rows in the DataFrame.
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
# Retrieve the image and corresponding label at a specific index.
|
||||
def __getitem__(self, index):
|
||||
Xs = [] # Create an empty list to store apply_transformationed versions of the image.
|
||||
|
||||
# Open the image located at the 'path' specified by the index in the DataFrame, then resize it to 64x64.
|
||||
X = Image.open(self.df['path'][index]).resize((64, 64))
|
||||
|
||||
# Retrieve the target label (as a tensor) from the 'target' column of the DataFrame and convert it to a PyTorch tensor.
|
||||
y = torch.tensor(int(self.df['target'][index]))
|
||||
|
||||
# If 'test_same_view' is set to True, apply the same apply_transformationation multiple times and store the augmented images.
|
||||
if self.test_same_view:
|
||||
if self.apply_transformation:
|
||||
aug = self.apply_transformation(X) # Apply the apply_transformationation once to the image.
|
||||
# Store the same augmented image multiple times in the list 'Xs' (repeated 'split_factor' times).
|
||||
Xs = [aug for _ in range(self.split_factor)]
|
||||
else:
|
||||
# If 'test_same_view' is False, apply the apply_transformationation independently to create different augmentations.
|
||||
if self.apply_transformation:
|
||||
# Store different augmentations of the image in the list 'Xs', each apply_transformationed independently.
|
||||
Xs = [self.apply_transformation(X) for _ in range(self.split_factor)]
|
||||
|
||||
# Concatenate the list of images into a single tensor along the first dimension (batch) and return it along with the label.
|
||||
return torch.cat(Xs, dim=0), y
|
94
EdgeFLite/data_collection/vision_utils.py
Normal file
94
EdgeFLite/data_collection/vision_utils.py
Normal file
@ -0,0 +1,94 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
# VisionDataset is a custom dataset class inheriting from PyTorch's Dataset class.
|
||||
# It handles the initialization and representation of a vision-related dataset,
|
||||
# including optional apply_transformationation of input data and targets.
|
||||
class VisionDataset(data.Dataset):
|
||||
_repr_indent = 4 # Defines the indentation level for dataset representation
|
||||
|
||||
def __init__(self, root, apply_transformations=None, apply_transformation=None, target_apply_transformation=None):
|
||||
# Initializes the dataset by setting root directory and optional apply_transformationations
|
||||
# If root is a string, expand any user directory shortcuts like "~"
|
||||
self.root = os.path.expanduser(root) if isinstance(root, str) else root
|
||||
|
||||
# Check if either 'apply_transformations' or 'apply_transformation/target_apply_transformation' is provided (but not both)
|
||||
has_apply_transformations = apply_transformations is not None
|
||||
has_separate_apply_transformation = apply_transformation is not None or target_apply_transformation is not None
|
||||
|
||||
if has_apply_transformations and has_separate_apply_transformation:
|
||||
raise ValueError("Only one of 'apply_transformations' or 'apply_transformation/target_apply_transformation' can be provided.")
|
||||
|
||||
# Set apply_transformationations
|
||||
self.apply_transformation = apply_transformation
|
||||
self.target_apply_transformation = target_apply_transformation
|
||||
|
||||
# If separate apply_transformations are provided, wrap them in a StandardTransform
|
||||
if has_separate_apply_transformation:
|
||||
apply_transformations = StandardTransform(apply_transformation, target_apply_transformation)
|
||||
self.apply_transformations = apply_transformations
|
||||
|
||||
# Placeholder for the method to retrieve an item by index
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
# Placeholder for the method to return dataset length
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# Representation of the dataset including number of datapoints, root directory, and apply_transformations
|
||||
def __repr__(self):
|
||||
head = f"Dataset {self.__class__.__name__}"
|
||||
body = [f"Number of datapoints: {self.__len__()}"]
|
||||
if self.root is not None:
|
||||
body.append(f"Root location: {self.root}")
|
||||
body += self.extra_repr().splitlines() # Include any additional representation details
|
||||
if hasattr(self, "apply_transformations") and self.apply_transformations is not None:
|
||||
body.append(repr(self.apply_transformations)) # Include apply_transformationation details if applicable
|
||||
lines = [head] + [" " * self._repr_indent + line for line in body]
|
||||
return '\n'.join(lines)
|
||||
|
||||
# Utility to format the representation of the apply_transformation and target_apply_transformation attributes
|
||||
def _format_apply_transformation_repr(self, apply_transformation, head):
|
||||
lines = apply_transformation.__repr__().splitlines()
|
||||
return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]]
|
||||
|
||||
# Hook for adding extra dataset-specific information in the representation
|
||||
def extra_repr(self):
|
||||
return ""
|
||||
|
||||
|
||||
# StandardTransform class handles the application of the apply_transformation and target_apply_transformation
|
||||
# during dataset iteration or data loading.
|
||||
class StandardTransform:
|
||||
def __init__(self, apply_transformation=None, target_apply_transformation=None):
|
||||
# Initialize with optional input and target apply_transformationations
|
||||
self.apply_transformation = apply_transformation
|
||||
self.target_apply_transformation = target_apply_transformation
|
||||
|
||||
# Calls the appropriate apply_transformations on the input and target when invoked
|
||||
def __call__(self, input, target):
|
||||
if self.apply_transformation is not None:
|
||||
input = self.apply_transformation(input)
|
||||
if self.target_apply_transformation is not None:
|
||||
target = self.target_apply_transformation(target)
|
||||
return input, target
|
||||
|
||||
# Utility to format the apply_transformationation representation
|
||||
def _format_apply_transformation_repr(self, apply_transformation, head):
|
||||
lines = apply_transformation.__repr__().splitlines()
|
||||
return [f"{head}{lines[0]}"] + [f"{' ' * len(head)}{line}" for line in lines[1:]]
|
||||
|
||||
# Representation of the StandardTransform including both input and target apply_transformationations
|
||||
def __repr__(self):
|
||||
body = [self.__class__.__name__]
|
||||
if self.apply_transformation is not None:
|
||||
body += self._format_apply_transformation_repr(self.apply_transformation, "apply_transformation: ")
|
||||
if self.target_apply_transformation is not None:
|
||||
body += self._format_apply_transformation_repr(self.target_apply_transformation, "Target apply_transformation: ")
|
||||
|
||||
return '\n'.join(body)
|
47
EdgeFLite/debug_tool.py
Normal file
47
EdgeFLite/debug_tool.py
Normal file
@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Import necessary libraries
|
||||
import torch # PyTorch for tensor computations and neural networks
|
||||
from torch import nn # Neural network module
|
||||
# "decentralized" is not a valid import in PyTorch, possibly a typo. Removed for now.
|
||||
|
||||
# Check for available device (CPU or GPU)
|
||||
# If a GPU is available (CUDA), the code will use it; otherwise, it falls back to CPU.
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Define normalization layer and the number of initial input channels for the convolutional layers
|
||||
batch_norm_layer = nn.BatchNorm2d # 2D Batch Normalization to stabilize training
|
||||
initial_channels = 32 # Number of channels for the first convolutional layer
|
||||
|
||||
# Define the convolutional neural network (CNN) architecture using nn.Sequential
|
||||
network = nn.Sequential(
|
||||
# 1st convolutional layer: takes 3 input channels (RGB image), outputs 'initial_channels' feature maps
|
||||
# Uses kernel size 3, stride 2 for downsampling, and padding 1 to maintain spatial dimensions
|
||||
nn.Conv2d(in_channels=3, out_channels=initial_channels, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
batch_norm_layer(initial_channels), # Apply Batch Normalization to the output
|
||||
nn.ReLU(inplace=True), # ReLU activation function to introduce non-linearity
|
||||
|
||||
# 2nd convolutional layer: takes 'initial_channels' input, outputs the same number of feature maps
|
||||
# No downsampling here (stride 1)
|
||||
nn.Conv2d(in_channels=initial_channels, out_channels=initial_channels, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
batch_norm_layer(initial_channels), # Batch normalization for better convergence
|
||||
nn.ReLU(inplace=True), # ReLU activation
|
||||
|
||||
# 3rd convolutional layer: doubles the number of output channels (for deeper features)
|
||||
# Again, no downsampling (stride 1)
|
||||
nn.Conv2d(in_channels=initial_channels, out_channels=initial_channels * 2, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
batch_norm_layer(initial_channels * 2), # Batch normalization for the increased feature maps
|
||||
nn.ReLU(inplace=True), # ReLU activation
|
||||
|
||||
# Max pooling layer to further downsample the feature maps (reduces spatial dimensions)
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Pooling with kernel size 3 and stride 2
|
||||
)
|
||||
|
||||
# Create a dummy input tensor simulating a batch of 128 images with 3 channels (RGB), each of size 64x64
|
||||
sample_input = torch.randn(128, 3, 64, 64)
|
||||
|
||||
# Print the defined network architecture and the shape of the output after a forward pass
|
||||
print(network)
|
||||
# Perform a forward pass with the sample input and print the resulting output shape
|
||||
print(network(sample_input).shape)
|
BIN
EdgeFLite/fedml_service/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
EdgeFLite/fedml_service/architecture/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/architecture/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
EdgeFLite/fedml_service/architecture/cv/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/architecture/cv/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
EdgeFLite/fedml_service/architecture/cv/models_pretrained/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/architecture/cv/models_pretrained/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR10/.DS_Store
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/architecture/cv/models_pretrained/CIFAR100/.DS_Store
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/architecture/cv/models_pretrained/CINIC10/.DS_Store
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
EdgeFLite/fedml_service/architecture/cv/resnet56_federated/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/architecture/cv/resnet56_federated/.DS_Store
vendored
Normal file
Binary file not shown.
@ -0,0 +1,196 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['ResNet', 'resnet110']
|
||||
|
||||
def apply_3x3_convolution(in_channels, out_channels, stride=1, groups=1, dilation=1):
|
||||
"""3x3 Convolution with padding."""
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
def apply_1x1_convolution(in_channels, out_channels, stride=1):
|
||||
"""1x1 Convolution."""
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
"""Basic Block used in ResNet. Consists of two 3x3 convolutions."""
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
|
||||
self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride)
|
||||
self.bn1 = norm_layer(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = apply_3x3_convolution(out_channels, out_channels)
|
||||
self.bn2 = norm_layer(out_channels)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
"""Defines the forward pass through the block."""
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
"""Bottleneck block used in ResNet. Has three layers: 1x1, 3x3, and 1x1 convolutions."""
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(out_channels * (base_width / 64.)) * groups
|
||||
|
||||
self.conv1 = apply_1x1_convolution(in_channels, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion)
|
||||
self.bn3 = norm_layer(out_channels * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
"""Defines the forward pass through the bottleneck block."""
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, num_classes=10, zero_init_residual=False, groups=1,
|
||||
width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, KD=False):
|
||||
"""Defines the ResNet architecture."""
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 16
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None or a 3-element tuple.")
|
||||
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.layer1 = self._create_model_layer(block, 16, layers[0])
|
||||
self.layer2 = self._create_model_layer(block, 32, layers[1], stride=2)
|
||||
self.layer3 = self._create_model_layer(block, 64, layers[2], stride=2)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(64 * block.expansion, num_classes)
|
||||
self.KD = KD
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
"""Creates a layer in ResNet using the specified block type."""
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
apply_1x1_convolution(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""Defines the forward pass of the ResNet model."""
|
||||
x = self.layer1(x) # Output: B x 16 x 32 x 32
|
||||
x = self.layer2(x) # Output: B x 32 x 16 x 16
|
||||
x = self.layer3(x) # Output: B x 64 x 8 x 8
|
||||
|
||||
x = self.avgpool(x) # Output: B x 64 x 1 x 1
|
||||
x_f = x.view(x.size(0), -1) # Flatten: B x 64
|
||||
x = self.fc(x_f) # Output: B x num_classes
|
||||
return x
|
||||
|
||||
def resnet56_server(num_classes, models_pretrained=False, path=None, **kwargs):
|
||||
"""
|
||||
Constructs a ResNet-110 model.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of output classes.
|
||||
models_pretrained (bool): If True, returns a model pre-trained on ImageNet.
|
||||
path (str): Path to the pre-trained model.
|
||||
"""
|
||||
logging.info("Loading model with path: " + str(path))
|
||||
model = ResNet(Bottleneck, [6, 6, 6], num_classes=num_classes, **kwargs)
|
||||
|
||||
if models_pretrained:
|
||||
checkpoint = torch.load(path)
|
||||
state_dict = checkpoint['state_dict']
|
||||
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
return model
|
@ -0,0 +1,326 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def apply_3x3_convolution(in_channels, out_channels, stride=1, groups=1, dilation=1):
|
||||
"""
|
||||
Creates a 3x3 convolutional layer with padding.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
stride (int, optional): Stride of the convolution. Default is 1.
|
||||
groups (int, optional): Number of blocked connections from input to output. Default is 1.
|
||||
dilation (int, optional): Spacing between kernel elements. Default is 1.
|
||||
|
||||
Returns:
|
||||
nn.Conv2d: A 3x3 convolutional layer.
|
||||
"""
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
def apply_1x1_convolution(in_channels, out_channels, stride=1):
|
||||
"""
|
||||
Creates a 1x1 convolutional layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
stride (int, optional): Stride of the convolution. Default is 1.
|
||||
|
||||
Returns:
|
||||
nn.Conv2d: A 1x1 convolutional layer.
|
||||
"""
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""
|
||||
A basic block for ResNet.
|
||||
|
||||
This block consists of two convolutional layers with batch normalization and ReLU activation.
|
||||
|
||||
Attributes:
|
||||
expansion (int): The expansion factor of the block.
|
||||
conv1 (nn.Conv2d): First convolutional layer.
|
||||
bn1 (nn.BatchNorm2d): First batch normalization layer.
|
||||
conv2 (nn.Conv2d): Second convolutional layer.
|
||||
bn2 (nn.BatchNorm2d): Second batch normalization layer.
|
||||
downsample (nn.Module): Downsample layer if input and output dimensions differ.
|
||||
"""
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm_layer=None):
|
||||
"""
|
||||
Initializes the BasicBlock.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
stride (int, optional): Stride for the convolutional layers. Default is 1.
|
||||
downsample (nn.Module, optional): Downsample layer if input dimensions differ. Default is None.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d.
|
||||
"""
|
||||
super(BasicBlock, self).__init__()
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride)
|
||||
self.bn1 = norm_layer(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = apply_3x3_convolution(out_channels, out_channels)
|
||||
self.bn2 = norm_layer(out_channels)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Defines the forward pass for the block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after applying the block.
|
||||
"""
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""
|
||||
A bottleneck block for ResNet.
|
||||
|
||||
This block reduces the number of input channels before performing convolution and then expands it back.
|
||||
|
||||
Attributes:
|
||||
expansion (int): The expansion factor of the block.
|
||||
conv1 (nn.Conv2d): First 1x1 convolutional layer.
|
||||
conv2 (nn.Conv2d): 3x3 convolutional layer.
|
||||
conv3 (nn.Conv2d): Second 1x1 convolutional layer.
|
||||
downsample (nn.Module): Downsample layer if input and output dimensions differ.
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm_layer=None):
|
||||
"""
|
||||
Initializes the Bottleneck block.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
stride (int, optional): Stride for the convolutional layers. Default is 1.
|
||||
downsample (nn.Module, optional): Downsample layer if input dimensions differ. Default is None.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d.
|
||||
"""
|
||||
super(Bottleneck, self).__init__()
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
width = int(out_channels * (64 / 64)) # Base width
|
||||
self.conv1 = apply_1x1_convolution(in_channels, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = apply_3x3_convolution(width, width, stride)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion)
|
||||
self.bn3 = norm_layer(out_channels * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Defines the forward pass for the bottleneck block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after applying the block.
|
||||
"""
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
class ResNet(nn.Module):
|
||||
"""
|
||||
ResNet architecture.
|
||||
|
||||
This class constructs a ResNet model with a specified block type and layer configuration.
|
||||
|
||||
Attributes:
|
||||
conv1 (nn.Conv2d): Initial convolutional layer.
|
||||
bn1 (nn.BatchNorm2d): Initial batch normalization layer.
|
||||
layer1 (nn.Sequential): First residual layer.
|
||||
layer2 (nn.Sequential): Second residual layer.
|
||||
layer3 (nn.Sequential): Third residual layer.
|
||||
fc (nn.Linear): Fully connected output layer.
|
||||
"""
|
||||
def __init__(self, block, layers, num_classes=10, zero_init_residual=False, norm_layer=None):
|
||||
"""
|
||||
Initializes the ResNet architecture.
|
||||
|
||||
Args:
|
||||
block (nn.Module): The block type (BasicBlock or Bottleneck).
|
||||
layers (list of int): Number of blocks per layer.
|
||||
num_classes (int, optional): Number of output classes. Default is 10.
|
||||
zero_init_residual (bool, optional): Whether to zero-initialize residual layers. Default is False.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default is BatchNorm2d.
|
||||
"""
|
||||
super(ResNet, self).__init__()
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
self.in_channels = 16
|
||||
|
||||
self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.in_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.layer1 = self._create_model_layer(block, 16, layers[0])
|
||||
self.layer2 = self._create_model_layer(block, 32, layers[1], stride=2)
|
||||
self.layer3 = self._create_model_layer(block, 64, layers[2], stride=2)
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = nn.Linear(64 * block.expansion, num_classes)
|
||||
|
||||
self._init_model_weights(zero_init_residual)
|
||||
|
||||
def _create_model_layer(self, block, out_channels, blocks, stride=1):
|
||||
"""
|
||||
Creates a residual layer.
|
||||
|
||||
Args:
|
||||
block (nn.Module): The block type.
|
||||
out_channels (int): Number of output channels.
|
||||
blocks (int): Number of blocks in the layer.
|
||||
stride (int, optional): Stride for the first block. Default is 1.
|
||||
|
||||
Returns:
|
||||
nn.Sequential: A sequence of residual blocks.
|
||||
"""
|
||||
downsample = None
|
||||
if stride != 1 or self.in_channels != out_channels * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
apply_1x1_convolution(self.in_channels, out_channels * block.expansion, stride),
|
||||
nn.BatchNorm2d(out_channels * block.expansion),
|
||||
)
|
||||
|
||||
layers = [block(self.in_channels, out_channels, stride, downsample)]
|
||||
self.in_channels = out_channels * block.expansion
|
||||
layers.extend(block(self.in_channels, out_channels) for _ in range(1, blocks))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _init_model_weights(self, zero_init_residual):
|
||||
"""
|
||||
Initializes the weights of the model.
|
||||
|
||||
Args:
|
||||
zero_init_residual (bool): If True, initializes residual layers to zero.
|
||||
"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
if zero_init_residual and isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif zero_init_residual and isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Defines the forward pass of the ResNet.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
tuple: Logits and extracted features.
|
||||
"""
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
extracted_features = x
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x_f = x.view(x.size(0), -1)
|
||||
logits = self.fc(x_f)
|
||||
return logits, extracted_features
|
||||
|
||||
def resnet32_models_pretrained(num_classes, models_pretrained=False, path=None, **kwargs):
|
||||
"""
|
||||
Constructs a ResNet-32 model.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of output classes.
|
||||
models_pretrained (bool, optional): If True, loads pretrained weights. Default is False.
|
||||
path (str, optional): Path to the pretrained weights. Default is None.
|
||||
|
||||
Returns:
|
||||
ResNet: A ResNet-32 model.
|
||||
"""
|
||||
model = ResNet(BasicBlock, [5, 5, 5], num_classes=num_classes, **kwargs)
|
||||
if models_pretrained:
|
||||
model.load_state_dict(_load_models_pretrained_weights(path))
|
||||
return model
|
||||
|
||||
def resnet56_models_pretrained(num_classes, models_pretrained=False, path=None, **kwargs):
|
||||
"""
|
||||
Constructs a ResNet-56 model.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of output classes.
|
||||
models_pretrained (bool, optional): If True, loads pretrained weights. Default is False.
|
||||
path (str, optional): Path to the pretrained weights. Default is None.
|
||||
|
||||
Returns:
|
||||
ResNet: A ResNet-56 model.
|
||||
"""
|
||||
logging.info("Loading pretrained model from: " + str(path))
|
||||
model = ResNet(Bottleneck, [6, 6, 6], num_classes=num_classes, **kwargs)
|
||||
if models_pretrained:
|
||||
model.load_state_dict(_load_models_pretrained_weights(path))
|
||||
return model
|
||||
|
||||
def _load_models_pretrained_weights(path):
|
||||
"""
|
||||
Loads pretrained weights from a checkpoint.
|
||||
|
||||
Args:
|
||||
path (str): Path to the checkpoint file.
|
||||
|
||||
Returns:
|
||||
dict: State dictionary with the loaded weights.
|
||||
"""
|
||||
checkpoint = torch.load(path, map_location=torch.device('cpu'))
|
||||
state_dict = checkpoint['state_dict']
|
||||
from collections import OrderedDict
|
||||
new_state_dict = OrderedDict()
|
||||
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict[k.replace("module.", "")] = v
|
||||
|
||||
return new_state_dict
|
@ -0,0 +1,231 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['ResNet']
|
||||
|
||||
# Function to define a 3x3 convolution layer with padding
|
||||
def apply_3x3_convolution(in_channels, out_channels, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
# Function to define a 1x1 convolution layer
|
||||
def apply_1x1_convolution(in_channels, out_channels, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
# BasicBlock class for ResNet architecture
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1 # Expansion factor
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d # Default normalization layer is BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
|
||||
# First convolution and batch normalization layer
|
||||
self.conv1 = apply_3x3_convolution(in_channels, out_channels, stride)
|
||||
self.bn1 = norm_layer(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True) # ReLU activation
|
||||
# Second convolution and batch normalization layer
|
||||
self.conv2 = apply_3x3_convolution(out_channels, out_channels)
|
||||
self.bn2 = norm_layer(out_channels)
|
||||
self.downsample = downsample # If downsample is provided, use it
|
||||
|
||||
def forward(self, x):
|
||||
identity = x # Keep original input as identity for residual connection
|
||||
|
||||
# Forward pass through first convolution, batch norm, and ReLU
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
# Forward pass through second convolution and batch norm
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
# Downsample the identity if downsample is provided
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
# Add residual connection (identity)
|
||||
out += identity
|
||||
out = self.relu(out) # Apply ReLU activation after addition
|
||||
|
||||
return out
|
||||
|
||||
# Bottleneck class for deeper ResNet architectures
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4 # Expansion factor
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d # Default normalization layer is BatchNorm2d
|
||||
width = int(out_channels * (base_width / 64.)) * groups # Calculate width based on group size
|
||||
|
||||
# First 1x1 convolution
|
||||
self.conv1 = apply_1x1_convolution(in_channels, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
# Second 3x3 convolution
|
||||
self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
# Third 1x1 convolution to match output channels
|
||||
self.conv3 = apply_1x1_convolution(width, out_channels * self.expansion)
|
||||
self.bn3 = norm_layer(out_channels * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True) # ReLU activation
|
||||
self.downsample = downsample # Downsample if provided
|
||||
|
||||
def forward(self, x):
|
||||
identity = x # Keep original input as identity for residual connection
|
||||
|
||||
# First 1x1 convolution and ReLU
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
# Second 3x3 convolution and ReLU
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
# Third 1x1 convolution
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
# Add downsampled identity if necessary
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
# Add residual connection (identity)
|
||||
out += identity
|
||||
out = self.relu(out) # Apply ReLU activation after addition
|
||||
|
||||
return out
|
||||
|
||||
# ResNet class to build the entire ResNet model
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, num_classes=10, zero_init_residual=False, groups=1,
|
||||
width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, KD=False):
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d # Default normalization layer
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 16 # Initial number of channels
|
||||
self.dilation = 1 # Dilation factor
|
||||
if replace_stride_with_dilation is None:
|
||||
replace_stride_with_dilation = [False, False, False] # Default stride behavior
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
|
||||
self.groups = groups # Number of groups for convolutions
|
||||
self.base_width = width_per_group # Base width for groups
|
||||
|
||||
# Initial convolutional layer with 3 input channels (RGB image)
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.inplanes) # Batch normalization
|
||||
self.relu = nn.ReLU(inplace=True) # ReLU activation
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Max pooling layer
|
||||
self.layer1 = self._create_model_layer(block, 16, layers[0]) # First block layer
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Adaptive average pooling
|
||||
self.fc = nn.Linear(16 * block.expansion, num_classes) # Fully connected layer
|
||||
|
||||
self.KD = KD # Knowledge Distillation flag
|
||||
for m in self.modules():
|
||||
# Initialize convolutional weights using He initialization
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
# Initialize batch normalization weights
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last batch norm layer if zero_init_residual is True
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
# Helper function to create layers of blocks
|
||||
def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
apply_1x1_convolution(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
# Forward pass of the ResNet model
|
||||
def forward(self, x):
|
||||
x = self.conv1(x) # Initial convolution
|
||||
x = self.bn1(x) # Batch normalization
|
||||
x = self.relu(x) # ReLU activation
|
||||
extracted_features = x # Feature extraction point
|
||||
x = self.layer1(x) # Pass through the first layer
|
||||
x = self.avgpool(x) # Adaptive average pooling
|
||||
x_f = x.view(x.size(0), -1) # Flatten the features
|
||||
logits = self.fc(x_f) # Fully connected layer for classification
|
||||
return logits, extracted_features # Return logits and extracted features
|
||||
|
||||
# Function to create ResNet-5 model
|
||||
def resnet5_56(num_classes, models_pretrained=False, path=None, **kwargs):
|
||||
"""Constructs a ResNet-5 model."""
|
||||
model = ResNet(BasicBlock, [1, 2, 2], num_classes=num_classes, **kwargs)
|
||||
if models_pretrained:
|
||||
checkpoint = torch.load(path)
|
||||
state_dict = checkpoint['state_dict']
|
||||
|
||||
from collections import OrderedDict
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
name = k.replace("module.", "")
|
||||
new_state_dict[name] = v
|
||||
|
||||
model.load_state_dict(new_state_dict)
|
||||
return model
|
||||
|
||||
# Function to create ResNet-8 model
|
||||
def resnet8_56(num_classes, models_pretrained=False, path=None, **kwargs):
|
||||
"""Constructs a ResNet-8 model."""
|
||||
model = ResNet(Bottleneck, [2, 2, 2], num_classes=num_classes, **kwargs)
|
||||
if models_pretrained:
|
||||
checkpoint = torch.load(path)
|
||||
state_dict = checkpoint['state_dict']
|
||||
|
||||
from collections import OrderedDict
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
name = k.replace("module.", "")
|
||||
new_state_dict[name] = v
|
||||
|
||||
model.load_state_dict(new_state_dict)
|
||||
return model
|
BIN
EdgeFLite/fedml_service/architecture/cv/resnet_federated/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/architecture/cv/resnet_federated/.DS_Store
vendored
Normal file
Binary file not shown.
211
EdgeFLite/fedml_service/architecture/cv/resnet_federated/net.py
Normal file
211
EdgeFLite/fedml_service/architecture/cv/resnet_federated/net.py
Normal file
@ -0,0 +1,211 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Try to import load_state_dict_from_url from torch.hub.
|
||||
# If it fails (due to older versions), fall back to load_url from torch.utils.model_zoo.
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
|
||||
# List of all exportable models
|
||||
__all__ = ['resnet110_sl', 'wide_resnetsl50_2', 'wide_resnetsl16_8']
|
||||
|
||||
|
||||
def apply_3x3_convolution(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding."""
|
||||
return nn.Conv2d(
|
||||
in_planes, # Number of input channels
|
||||
out_planes, # Number of output channels
|
||||
kernel_size=3, # Size of the filter
|
||||
stride=stride, # Stride of the convolution
|
||||
padding=dilation, # Padding for the convolution
|
||||
groups=groups, # Group convolution
|
||||
bias=False, # No bias in convolution
|
||||
dilation=dilation # Dilation rate for dilated convolutions
|
||||
)
|
||||
|
||||
|
||||
def apply_1x1_convolution(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution."""
|
||||
return nn.Conv2d(
|
||||
in_planes, # Number of input channels
|
||||
out_planes, # Number of output channels
|
||||
kernel_size=1, # Filter size is 1x1
|
||||
stride=stride, # Stride of the convolution
|
||||
bias=False # No bias in convolution
|
||||
)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""Basic block for ResNet."""
|
||||
|
||||
expansion = 1 # No expansion in BasicBlock
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
|
||||
self.conv1 = apply_3x3_convolution(inplanes, planes, stride) # First 3x3 convolution
|
||||
self.bn1 = norm_layer(planes) # First batch normalization
|
||||
self.relu = nn.ReLU(inplace=True) # ReLU activation
|
||||
self.conv2 = apply_3x3_convolution(planes, planes) # Second 3x3 convolution
|
||||
self.bn2 = norm_layer(planes) # Second batch normalization
|
||||
self.downsample = downsample # If there's downsampling (e.g., stride mismatch)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x # Preserve the input as identity for skip connection
|
||||
out = self.conv1(x) # Apply the first convolution
|
||||
out = self.bn1(out) # Apply first batch normalization
|
||||
out = self.relu(out) # Apply ReLU activation
|
||||
out = self.conv2(out) # Apply the second convolution
|
||||
out = self.bn2(out) # Apply second batch normalization
|
||||
|
||||
# If downsample exists, apply it to the identity
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity # Add skip connection
|
||||
out = self.relu(out) # Final ReLU activation
|
||||
|
||||
return out # Return the result
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""Bottleneck block for ResNet."""
|
||||
|
||||
expansion = 4 # Bottleneck expands the channels by a factor of 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups # Width of the block
|
||||
|
||||
# 1x1 convolution (bottleneck)
|
||||
self.conv1 = apply_1x1_convolution(inplanes, width)
|
||||
self.bn1 = norm_layer(width) # Batch normalization after 1x1 convolution
|
||||
# 3x3 convolution (main block)
|
||||
self.conv2 = apply_3x3_convolution(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width) # Batch normalization after 3x3 convolution
|
||||
# 1x1 convolution (bottleneck exit)
|
||||
self.conv3 = apply_1x1_convolution(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion) # Batch normalization after 1x1 exit
|
||||
self.relu = nn.ReLU(inplace=True) # ReLU activation
|
||||
self.downsample = downsample # Downsampling for skip connection, if needed
|
||||
|
||||
def forward(self, x):
|
||||
identity = x # Store input as identity for the skip connection
|
||||
out = self.conv1(x) # Apply first 1x1 convolution
|
||||
out = self.bn1(out) # Apply batch normalization
|
||||
out = self.relu(out) # Apply ReLU
|
||||
out = self.conv2(out) # Apply 3x3 convolution
|
||||
out = self.bn2(out) # Apply batch normalization
|
||||
out = self.relu(out) # Apply ReLU
|
||||
out = self.conv3(out) # Apply 1x1 convolution
|
||||
out = self.bn3(out) # Apply batch normalization
|
||||
|
||||
# If downsample exists, apply it to the identity
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity # Add skip connection
|
||||
out = self.relu(out) # Final ReLU activation
|
||||
|
||||
return out # Return the result
|
||||
|
||||
|
||||
class PrimaryResNetClient(nn.Module):
|
||||
"""Main ResNet model for client."""
|
||||
|
||||
def __init__(self, arch, block, layers, num_classes=1000, zero_init_residual=True,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||
norm_layer=None, dataset='cifar10', split_factor=1, output_stride=8, dropout_p=None):
|
||||
super(PrimaryResNetClient, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Global average pooling before fully connected layer
|
||||
|
||||
# Dictionary to store input channel size based on dataset and split factor
|
||||
inplanes_dict = {
|
||||
'cifar10': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4, 32: 3},
|
||||
'cifar100': {1: 16, 2: 12, 4: 8, 8: 6, 16: 4},
|
||||
'skin_dataset': {1: 64, 2: 44, 4: 32, 8: 24},
|
||||
'pill_base': {1: 64, 2: 44, 4: 32, 8: 24},
|
||||
'medical_images': {1: 64, 2: 44, 4: 32, 8: 24},
|
||||
}
|
||||
self.inplanes = inplanes_dict[dataset][split_factor] # Set initial input channels
|
||||
|
||||
self.fc = nn.Linear(self.inplanes * 4 * block.expansion, num_classes) # Fully connected layer for classification
|
||||
|
||||
# Initialize all layers
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, std=1e-3)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Optionally initialize the last batch normalization layer to zero
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0)
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _create_model_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
"""Create a residual layer consisting of several blocks."""
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
apply_1x1_convolution(self.inplanes, planes * block.expansion, stride), # Adjust input size for downsampling
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer)) # Add the first block with downsample
|
||||
self.inplanes = planes * block.expansion # Update inplanes for the next block
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer)) # Add the remaining blocks
|
||||
|
||||
return nn.Sequential(*layers) # Return the stacked blocks
|
||||
|
||||
def _forward_impl(self, x):
|
||||
"""Implementation of the forward pass."""
|
||||
x = self.layer0(x) # Initial layer
|
||||
extracted_features = x # Save features after the initial layer
|
||||
x = self.layer1(x) # First layer
|
||||
x = self.avgpool(x) # Global average pooling
|
||||
x = torch.flatten(x, 1) # Flatten the features into a 1D tensor
|
||||
logits = self.fc(x) # Pass through the fully connected layer
|
||||
return logits, extracted_features # Return logits and extracted features
|
||||
|
||||
def forward(self, x):
|
||||
"""Standard forward method."""
|
||||
return self._forward_impl(x)
|
BIN
EdgeFLite/fedml_service/data_cleaning/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/data_cleaning/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
EdgeFLite/fedml_service/data_cleaning/cifar10/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/data_cleaning/cifar10/.DS_Store
vendored
Normal file
Binary file not shown.
@ -0,0 +1,230 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import torchvision.apply_transformations as apply_transformations
|
||||
|
||||
from .datasets import CIFAR10_truncated
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Function to load non-IID data distribution
|
||||
def load_data_distribution(file_path='./data_cleaning/non-iid-distribution/CIFAR10/data_map.txt'):
|
||||
"""
|
||||
Load data distribution for non-IID data.
|
||||
Reads from a text file that maps data classes to the clients in a decentralized manner.
|
||||
"""
|
||||
distribution = {}
|
||||
with open(file_path, 'r') as file:
|
||||
for line in file.readlines():
|
||||
if '{' != line[0] and '}' != line[0]:
|
||||
key, value = line.split(':')
|
||||
if '{' == value.strip():
|
||||
distribution[int(key)] = {}
|
||||
else:
|
||||
sub_key = int(key)
|
||||
distribution[int(key)][sub_key] = int(value.strip().replace(',', ''))
|
||||
return distribution
|
||||
|
||||
# Function to load network data index map
|
||||
def load_net_dataidx_map(file_path='./data_cleaning/non-iid-distribution/CIFAR10/index_map.txt'):
|
||||
"""
|
||||
Load index mapping between data samples and clients.
|
||||
Reads from a text file that assigns data indices to different clients.
|
||||
"""
|
||||
net_dataidx_map = {}
|
||||
with open(file_path, 'r') as file:
|
||||
for line in file.readlines():
|
||||
if '{' != line[0] and '}' != line[0] and ']' != line[0]:
|
||||
key, value = line.split(':')
|
||||
if '[' == value.strip():
|
||||
net_dataidx_map[int(key)] = []
|
||||
else:
|
||||
indices = [int(i.strip()) for i in line.split(',')]
|
||||
net_dataidx_map[int(key)] = indices
|
||||
return net_dataidx_map
|
||||
|
||||
# Function to record and log data statistics for each client
|
||||
def log_net_data_stats(y_train, net_dataidx_map):
|
||||
"""
|
||||
Log the data statistics for each client by calculating class distribution.
|
||||
"""
|
||||
net_cls_counts = {}
|
||||
for net_id, dataidx in net_dataidx_map.items():
|
||||
unique, counts = np.unique(y_train[dataidx], return_counts=True)
|
||||
net_cls_counts[net_id] = dict(zip(unique, counts))
|
||||
logging.debug('Data statistics: %s', net_cls_counts)
|
||||
return net_cls_counts
|
||||
|
||||
# Cutout augmentation class for image data
|
||||
class Cutout:
|
||||
"""
|
||||
Apply the Cutout augmentation technique to images.
|
||||
Randomly masks out a square region in the image.
|
||||
"""
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y, x = np.random.randint(h), np.random.randint(w)
|
||||
|
||||
y1, y2 = np.clip([y - self.length // 2, y + self.length // 2], 0, h)
|
||||
x1, x2 = np.clip([x - self.length // 2, x + self.length // 2], 0, w)
|
||||
|
||||
mask[y1:y2, x1:x2] = 0.
|
||||
mask = torch.from_numpy(mask).expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
# Function to define CIFAR-10 data apply_transformationations
|
||||
def cifar10_data_apply_transformations():
|
||||
"""
|
||||
Define data apply_transformationations for CIFAR-10 dataset.
|
||||
Includes random cropping, horizontal flipping, normalization, and Cutout for training.
|
||||
"""
|
||||
CIFAR_MEAN = [0.4914, 0.4822, 0.4465]
|
||||
CIFAR_STD = [0.2470, 0.2435, 0.2616]
|
||||
|
||||
train_apply_transformation = apply_transformations.Compose([
|
||||
apply_transformations.ToPILImage(),
|
||||
apply_transformations.RandomCrop(32, padding=4),
|
||||
apply_transformations.RandomHorizontalFlip(),
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||
Cutout(16)
|
||||
])
|
||||
|
||||
valid_apply_transformation = apply_transformations.Compose([
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||
])
|
||||
|
||||
return train_apply_transformation, valid_apply_transformation
|
||||
|
||||
# Function to load CIFAR-10 data
|
||||
def load_cifar10(datadir):
|
||||
"""
|
||||
Load the CIFAR-10 dataset with apply_transformationations for training and testing.
|
||||
"""
|
||||
train_apply_transformation, test_apply_transformation = cifar10_data_apply_transformations()
|
||||
|
||||
cifar10_train = CIFAR10_truncated(datadir, train=True, download=True, apply_transformation=train_apply_transformation)
|
||||
cifar10_test = CIFAR10_truncated(datadir, train=False, download=True, apply_transformation=test_apply_transformation)
|
||||
|
||||
X_train, y_train = cifar10_train.data, cifar10_train.target
|
||||
X_test, y_test = cifar10_test.data, cifar10_test.target
|
||||
|
||||
return X_train, y_train, X_test, y_test
|
||||
|
||||
# Function to partition CIFAR-10 data across clients
|
||||
def partition_cifar10_data(dataset, datadir, partition_type, n_nets, alpha):
|
||||
"""
|
||||
Partition the CIFAR-10 dataset across clients for federated learning.
|
||||
Supports homogeneous and heterogeneous partitions.
|
||||
"""
|
||||
logging.info("Partitioning CIFAR-10 data...")
|
||||
X_train, y_train, X_test, y_test = load_cifar10(datadir)
|
||||
n_train = X_train.shape[0]
|
||||
|
||||
if partition_type == "homo":
|
||||
# Homogeneous partitioning (equal distribution across clients)
|
||||
idxs = np.random.permutation(n_train)
|
||||
net_dataidx_map = {i: batch for i, batch in enumerate(np.array_split(idxs, n_nets))}
|
||||
elif partition_type == "hetero":
|
||||
# Heterogeneous partitioning (non-IID distribution)
|
||||
K, N = 10, y_train.shape[0]
|
||||
net_dataidx_map = {}
|
||||
min_size = 0
|
||||
while min_size < 10:
|
||||
idx_batch = [[] for _ in range(n_nets)]
|
||||
for k in range(K):
|
||||
idx_k = np.where(y_train == k)[0]
|
||||
np.random.shuffle(idx_k)
|
||||
proportions = np.random.dirichlet(np.repeat(alpha, n_nets))
|
||||
proportions = np.array([p * (len(idx_j) < N / n_nets) for p, idx_j in zip(proportions, idx_batch)])
|
||||
proportions = np.cumsum(proportions / proportions.sum()) * len(idx_k)
|
||||
split_idx = proportions.astype(int)[:-1]
|
||||
idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, split_idx))]
|
||||
min_size = min([len(idx_j) for idx_j in idx_batch])
|
||||
net_dataidx_map = {i: np.random.permutation(batch) for i, batch in enumerate(idx_batch)}
|
||||
elif partition_type == "hetero-fix":
|
||||
# Fixed heterogeneous partitioning (predefined distribution)
|
||||
net_dataidx_map = load_net_dataidx_map()
|
||||
|
||||
# Load data distribution for 'hetero-fix' partition, otherwise calculate it
|
||||
if partition_type == "hetero-fix":
|
||||
traindata_cls_counts = load_data_distribution()
|
||||
else:
|
||||
traindata_cls_counts = log_net_data_stats(y_train, net_dataidx_map)
|
||||
|
||||
return X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts
|
||||
|
||||
# Function to create data loaders
|
||||
def get_cifar10_dataloader(datadir, train_bs, test_bs, dataidxs=None):
|
||||
"""
|
||||
Create data loaders for CIFAR-10 with the option to load only specific data indices.
|
||||
"""
|
||||
train_apply_transformation, test_apply_transformation = cifar10_data_apply_transformations()
|
||||
|
||||
train_ds = CIFAR10_truncated(datadir, dataidxs=dataidxs, train=True, apply_transformation=train_apply_transformation, download=True)
|
||||
test_ds = CIFAR10_truncated(datadir, train=False, apply_transformation=test_apply_transformation, download=True)
|
||||
|
||||
train_loader = data.DataLoader(dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=True)
|
||||
test_loader = data.DataLoader(dataset=test_ds, batch_size=test_bs, shuffle=False, drop_last=True)
|
||||
|
||||
return train_loader, test_loader
|
||||
|
||||
# Function to load decentralized CIFAR-10 data for a specific client
|
||||
def load_decentralized_cifar10(process_id, dataset, datadir, partition_method, partition_alpha, client_num, batch_size):
|
||||
"""
|
||||
Load decentralized CIFAR-10 data based on the partitioning method and client number.
|
||||
Returns either global data loaders or local data loaders depending on the process ID.
|
||||
"""
|
||||
X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_cifar10_data(
|
||||
dataset, datadir, partition_method, client_num, partition_alpha)
|
||||
class_num = len(np.unique(y_train))
|
||||
logging.info("Class distribution: %s", traindata_cls_counts)
|
||||
|
||||
if process_id == 0:
|
||||
# Global data loaders
|
||||
train_global, test_global = get_cifar10_dataloader(datadir, batch_size, batch_size)
|
||||
return sum(len(net_dataidx_map[r]) for r in range(client_num)), train_global, test_global, 0, None, None, class_num
|
||||
else:
|
||||
# Local data loaders for the specific client
|
||||
dataidxs = net_dataidx_map[process_id - 1]
|
||||
train_local, test_local = get_cifar10_dataloader(datadir, batch_size, batch_size, dataidxs)
|
||||
return len(dataidxs), None, None, len(dataidxs), train_local, test_local, class_num
|
||||
|
||||
# Function to load and partition CIFAR-10 dataset
|
||||
def load_cifar10_partitioned(dataset, datadir, partition_method, partition_alpha, client_num, batch_size):
|
||||
"""
|
||||
Load and partition the CIFAR-10 dataset and prepare data loaders for all clients.
|
||||
"""
|
||||
X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts = partition_cifar10_data(
|
||||
dataset, datadir, partition_method, client_num, partition_alpha)
|
||||
class_num = len(np.unique(y_train))
|
||||
logging.info("Global data statistics: %s", traindata_cls_counts)
|
||||
|
||||
# Global data loaders
|
||||
train_global, test_global = get_cifar10_dataloader(datadir, batch_size, batch_size)
|
||||
|
||||
# Local data loaders for each client
|
||||
data_local_num_dict, train_local_dict, test_local_dict = {}, {}, {}
|
||||
for client_idx in range(client_num):
|
||||
dataidxs = net_dataidx_map[client_idx]
|
||||
local_data_num = len(dataidxs)
|
||||
data_local_num_dict[client_idx] = local_data_num
|
||||
logging.info("Client %d: Local sample count = %d", client_idx, local_data_num)
|
||||
|
||||
train_local, test_local = get_cifar10_dataloader(datadir, batch_size, batch_size, dataidxs)
|
||||
train_local_dict[client_idx], test_local_dict[client_idx] = train_local, test_local
|
||||
|
||||
return sum(len(net_dataidx_map[r]) for r in range(client_num)), len(test_global), train_global, test_global, data_local_num_dict, train_local_dict, test_local_dict, class_num
|
109
EdgeFLite/fedml_service/data_cleaning/cifar10/dataset_hub.py
Normal file
109
EdgeFLite/fedml_service/data_cleaning/cifar10/dataset_hub.py
Normal file
@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Supported image extensions
|
||||
# These are the file extensions that the loaders will support for image formats
|
||||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
|
||||
|
||||
# Loader using accimage, a faster image loading library than PIL
|
||||
def load_accimage(path):
|
||||
import accimage
|
||||
try:
|
||||
# Try to load the image with accimage
|
||||
return accimage.Image(path)
|
||||
except IOError:
|
||||
# If there's an error, fallback to PIL for image loading
|
||||
return load_image_pil(path)
|
||||
|
||||
# Loader using PIL (Python Imaging Library)
|
||||
def load_image_pil(path):
|
||||
# Open the file in binary mode to avoid resource warnings
|
||||
with open(path, 'rb') as f:
|
||||
img = Image.open(f)
|
||||
# Convert the image to RGB mode (3 channels)
|
||||
return img.convert('RGB')
|
||||
|
||||
# Default image loader that chooses accimage if available, otherwise PIL
|
||||
def basic_loader(path):
|
||||
from torchvision import get_image_backend
|
||||
# Check if the image backend is accimage
|
||||
if get_image_backend() == 'accimage':
|
||||
return load_accimage(path)
|
||||
# Otherwise, fallback to PIL
|
||||
return load_image_pil(path)
|
||||
|
||||
# Custom CIFAR10 dataset with truncation capabilities
|
||||
# This class extends the torch.utils.data.Dataset to support CIFAR10 with truncation of data
|
||||
class CIFAR10Truncated(data.Dataset):
|
||||
|
||||
def __init__(self, root, dataidxs=None, train=True, apply_transformation=None, target_apply_transformation=None, download=False):
|
||||
self.root = root # Root directory for the dataset
|
||||
self.dataidxs = dataidxs # Subset of data indices (optional)
|
||||
self.train = train # Boolean flag indicating if the dataset is for training
|
||||
self.apply_transformation = apply_transformation # apply_transformationations to apply to the images (optional)
|
||||
self.target_apply_transformation = target_apply_transformation # apply_transformationations to apply to the labels (optional)
|
||||
self.download = download # Boolean flag to download the dataset if not available
|
||||
|
||||
# Build the truncated dataset based on the provided indices
|
||||
self.data, self.target = self._build_truncated_dataset()
|
||||
|
||||
def _build_truncated_dataset(self):
|
||||
# Log whether the dataset is being downloaded
|
||||
logger.info(f"Download: {self.download}")
|
||||
|
||||
# Load the CIFAR10 dataset from torchvision
|
||||
cifar_data = CIFAR10(self.root, self.train, apply_transformation=self.apply_transformation,
|
||||
target_apply_transformation=self.target_apply_transformation, download=self.download)
|
||||
|
||||
# Extract data (images) and targets (labels) from the CIFAR10 dataset
|
||||
data = cifar_data.data
|
||||
target = np.array(cifar_data.targets)
|
||||
|
||||
# If data indices are provided, filter the data and targets accordingly
|
||||
if self.dataidxs is not None:
|
||||
data = data[self.dataidxs]
|
||||
target = target[self.dataidxs]
|
||||
|
||||
# Return the truncated data and targets
|
||||
return data, target
|
||||
|
||||
def truncate_channel(self, indices):
|
||||
# Zero out the second and third channels (green and blue) for selected images
|
||||
for idx in indices:
|
||||
self.data[idx, :, :, 1] = 0.0 # Zero out the green channel
|
||||
self.data[idx, :, :, 2] = 0.0 # Zero out the blue channel
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index of the image
|
||||
|
||||
Returns:
|
||||
tuple: (image, target) where target is the class label.
|
||||
"""
|
||||
img, target = self.data[index], self.target[index]
|
||||
|
||||
# Apply image apply_transformationations if any are specified
|
||||
if self.apply_transformation is not None:
|
||||
img = self.apply_transformation(img)
|
||||
# Apply target apply_transformationations if any are specified
|
||||
if self.target_apply_transformation is not None:
|
||||
target = self.target_apply_transformation(target)
|
||||
|
||||
# Return the apply_transformationed image and its corresponding target
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
# Return the total number of images in the dataset
|
||||
return len(self.data)
|
BIN
EdgeFLite/fedml_service/data_cleaning/cifar100/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/data_cleaning/cifar100/.DS_Store
vendored
Normal file
Binary file not shown.
@ -0,0 +1,182 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import torchvision.apply_transformations as apply_transformations
|
||||
|
||||
from .datasets import CIFAR100_truncated
|
||||
|
||||
# Set up logging configuration to log information level events
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Function to read non-IID distribution data from a file
|
||||
def read_data_distribution(filename='./data_cleaning/non-iid-distribution/CIFAR10/data_map.txt'):
|
||||
distribution = {}
|
||||
# Open the file and read the distribution map
|
||||
with open(filename, 'r') as file:
|
||||
for line in file.readlines():
|
||||
# Skip lines that do not contain distribution data
|
||||
if '{' != line[0] and '}' != line[0]:
|
||||
key, value = line.split(':')
|
||||
if '{' == value.strip():
|
||||
distribution[int(key)] = {}
|
||||
current_key = int(key)
|
||||
else:
|
||||
sub_key, sub_value = key, value.strip().replace(',', '')
|
||||
distribution[current_key][int(sub_key)] = int(sub_value)
|
||||
return distribution
|
||||
|
||||
# Function to read net data index map from a file
|
||||
def read_net_dataidx_map(filename='./data_cleaning/non-iid-distribution/CIFAR10/index_map.txt'):
|
||||
net_dataidx_map = {}
|
||||
# Open the file and read the index map for the dataset
|
||||
with open(filename, 'r') as file:
|
||||
for line in file.readlines():
|
||||
# Skip lines that do not contain index map data
|
||||
if '{' != line[0] and '}' != line[0] and ']' != line[0]:
|
||||
key, value = line.split(':')
|
||||
if '[' == value.strip():
|
||||
net_dataidx_map[int(key)] = []
|
||||
else:
|
||||
net_dataidx_map[int(key)] = [int(i.strip()) for i in value.split(',')]
|
||||
return net_dataidx_map
|
||||
|
||||
# Function to calculate and record statistics of the net's data
|
||||
def record_net_data_stats(y_train, net_dataidx_map):
|
||||
net_cls_counts = {}
|
||||
# For each net, count the unique classes and their frequencies in the training data
|
||||
for net_id, dataidx in net_dataidx_map.items():
|
||||
unique, counts = np.unique(y_train[dataidx], return_counts=True)
|
||||
net_cls_counts[net_id] = {unique[i]: counts[i] for i in range(len(unique))}
|
||||
logging.debug(f'Data statistics: {net_cls_counts}')
|
||||
return net_cls_counts
|
||||
|
||||
# Custom Cutout data augmentation class to apply a random mask to an image
|
||||
class Cutout:
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y, x = np.random.randint(h), np.random.randint(w)
|
||||
|
||||
# Define the region to apply the mask
|
||||
y1, y2 = np.clip([y - self.length // 2, y + self.length // 2], 0, h)
|
||||
x1, x2 = np.clip([x - self.length // 2, x + self.length // 2], 0, w)
|
||||
|
||||
# Apply the mask and return the augmented image
|
||||
mask[y1:y2, x1:x2] = 0
|
||||
mask = torch.from_numpy(mask).expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
# Function to define CIFAR-100 data apply_transformationation pipelines for training and validation
|
||||
def _data_apply_transformations_cifar100():
|
||||
# Define normalization constants for CIFAR-100
|
||||
CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
|
||||
CIFAR_STD = [0.2673, 0.2564, 0.2762]
|
||||
|
||||
# Data augmentation and apply_transformationation pipeline for training data
|
||||
train_apply_transformation = apply_transformations.Compose([
|
||||
apply_transformations.ToPILImage(),
|
||||
apply_transformations.RandomCrop(32, padding=4),
|
||||
apply_transformations.RandomHorizontalFlip(),
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize(CIFAR_MEAN, CIFAR_STD),
|
||||
Cutout(16) # Apply the Cutout augmentation
|
||||
])
|
||||
|
||||
# apply_transformationation pipeline for validation data
|
||||
valid_apply_transformation = apply_transformations.Compose([
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize(CIFAR_MEAN, CIFAR_STD)
|
||||
])
|
||||
|
||||
return train_apply_transformation, valid_apply_transformation
|
||||
|
||||
# Function to load CIFAR-100 dataset with the specified apply_transformationations
|
||||
def load_cifar100_data(datadir):
|
||||
train_apply_transformation, test_apply_transformation = _data_apply_transformations_cifar100()
|
||||
|
||||
# Load training and testing datasets
|
||||
cifar_train = CIFAR100_truncated(datadir, train=True, download=True, apply_transformation=train_apply_transformation)
|
||||
cifar_test = CIFAR100_truncated(datadir, train=False, download=True, apply_transformation=test_apply_transformation)
|
||||
|
||||
return cifar_train.data, cifar_train.target, cifar_test.data, cifar_test.target
|
||||
|
||||
# Function to partition data based on IID (Independent and Identically Distributed) or non-IID methods
|
||||
def partition_data(dataset, datadir, partition, n_nets, alpha):
|
||||
logging.info("********* Partitioning Data ***************")
|
||||
X_train, y_train, X_test, y_test = load_cifar100_data(datadir)
|
||||
n_train = X_train.shape[0]
|
||||
|
||||
# IID partitioning: randomly split the data across the clients
|
||||
if partition == "homo":
|
||||
idxs = np.random.permutation(n_train)
|
||||
net_dataidx_map = {i: idxs_split for i, idxs_split in enumerate(np.array_split(idxs, n_nets))}
|
||||
|
||||
# Non-IID partitioning using Dirichlet distribution
|
||||
elif partition == "hetero":
|
||||
min_size, K, N = 0, 100, y_train.shape[0]
|
||||
net_dataidx_map = {}
|
||||
|
||||
# Ensure each client has at least 10 samples
|
||||
while min_size < 10:
|
||||
idx_batch = [[] for _ in range(n_nets)]
|
||||
for k in range(K):
|
||||
idx_k = np.where(y_train == k)[0]
|
||||
np.random.shuffle(idx_k)
|
||||
proportions = np.random.dirichlet(np.repeat(alpha, n_nets))
|
||||
proportions = np.array([p * (len(batch) < N / n_nets) for p, batch in zip(proportions, idx_batch)])
|
||||
proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
|
||||
idx_batch = [batch + idx.tolist() for batch, idx in zip(idx_batch, np.split(idx_k, proportions))]
|
||||
min_size = min([len(batch) for batch in idx_batch])
|
||||
|
||||
# Randomly shuffle the data batches for each client
|
||||
net_dataidx_map = {i: np.random.permutation(batch) for i, batch in enumerate(idx_batch)}
|
||||
|
||||
# Non-IID fixed partition: read the distribution from a predefined file
|
||||
elif partition == "hetero-fix":
|
||||
net_dataidx_map = read_net_dataidx_map('./data_cleaning/non-iid-distribution/CIFAR100/index_map.txt')
|
||||
|
||||
# Record class counts for the partitioned training data
|
||||
traindata_cls_counts = read_data_distribution('./data_cleaning/non-iid-distribution/CIFAR100/data_map.txt') \
|
||||
if partition == "hetero-fix" else record_net_data_stats(y_train, net_dataidx_map)
|
||||
|
||||
return X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts
|
||||
|
||||
# Function to get data loaders for centralized and local training
|
||||
def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None):
|
||||
return get_dataloader_CIFAR100(datadir, train_bs, test_bs, dataidxs)
|
||||
|
||||
# Function to get data loaders for test data during decentralized training
|
||||
def get_dataloader_test(dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test):
|
||||
return get_dataloader_test_CIFAR100(datadir, train_bs, test_bs, dataidxs_train, dataidxs_test)
|
||||
|
||||
# Function to load CIFAR-100 data into PyTorch data loaders for training and testing
|
||||
def get_dataloader_CIFAR100(datadir, train_bs, test_bs, dataidxs=None):
|
||||
apply_transformation_train, apply_transformation_test = _data_apply_transformations_cifar100()
|
||||
|
||||
train_ds = CIFAR100_truncated(datadir, dataidxs=dataidxs, train=True, apply_transformation=apply_transformation_train, download=True)
|
||||
test_ds = CIFAR100_truncated(datadir, train=False, apply_transformation=apply_transformation_test, download=True)
|
||||
|
||||
train_dl = data.DataLoader(train_ds, batch_size=train_bs, shuffle=True, drop_last=True)
|
||||
test_dl = data.DataLoader(test_ds, batch_size=test_bs, shuffle=False, drop_last=True)
|
||||
|
||||
return train_dl, test_dl
|
||||
|
||||
# Function to get data loaders for test data during decentralized training (same as above but with test data indexes)
|
||||
def get_dataloader_test_CIFAR100(datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None):
|
||||
apply_transformation_train, apply_transformation_test = _data_apply_transformations_cifar100()
|
||||
|
||||
train_ds = CIFAR100_truncated(datadir, dataidxs=dataidxs_train, train=True, apply_transformation=apply_transformation_train, download=True)
|
||||
test_ds = CIFAR100_truncated(datadir, dataidxs=dataidxs_test, train=False, apply_transformation=apply_transformation_test, download=True)
|
||||
|
||||
train_dl = data.DataLoader(train_ds, batch_size=train_bs, shuffle=True, drop_last=True)
|
||||
test
|
157
EdgeFLite/fedml_service/data_cleaning/cifar100/dataset_hub.py
Normal file
157
EdgeFLite/fedml_service/data_cleaning/cifar100/dataset_hub.py
Normal file
@ -0,0 +1,157 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
from torchvision.datasets import CIFAR100
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Supported image extensions for loading images
|
||||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
|
||||
|
||||
def load_accimage(path):
|
||||
"""
|
||||
Attempts to load an image using the accimage backend.
|
||||
If accimage fails, it falls back to using the PIL image loader.
|
||||
|
||||
Args:
|
||||
path (str): Path to the image file.
|
||||
|
||||
Returns:
|
||||
accimage.Image: The loaded image if successful, otherwise a PIL image.
|
||||
"""
|
||||
import accimage
|
||||
try:
|
||||
return accimage.Image(path)
|
||||
except IOError:
|
||||
# If accimage fails, use PIL to load the image
|
||||
return load_image_pil(path)
|
||||
|
||||
def load_image_pil(path):
|
||||
"""
|
||||
Loads an image using PIL, ensuring that file handles are properly closed to prevent warnings.
|
||||
|
||||
Args:
|
||||
path (str): Path to the image file.
|
||||
|
||||
Returns:
|
||||
Image: The image loaded using PIL, converted to RGB format.
|
||||
"""
|
||||
with open(path, 'rb') as f:
|
||||
img = Image.open(f)
|
||||
return img.convert('RGB')
|
||||
|
||||
def basic_loader(path):
|
||||
"""
|
||||
Selects the appropriate image loader based on the backend configured by torchvision.
|
||||
If the backend is 'accimage', it uses load_accimage; otherwise, it uses PIL.
|
||||
|
||||
Args:
|
||||
path (str): Path to the image file.
|
||||
|
||||
Returns:
|
||||
Image: The loaded image.
|
||||
"""
|
||||
from torchvision import get_image_backend
|
||||
if get_image_backend() == 'accimage':
|
||||
return load_accimage(path)
|
||||
else:
|
||||
return load_image_pil(path)
|
||||
|
||||
class CIFAR100_truncated(data.Dataset):
|
||||
"""
|
||||
Custom dataset class for CIFAR100 with optional data truncation.
|
||||
It allows selecting a subset of the data by index and also enables modification of image channels.
|
||||
"""
|
||||
|
||||
def __init__(self, root, dataidxs=None, train=True, apply_transformation=None, target_apply_transformation=None, download=False):
|
||||
"""
|
||||
Initializes the CIFAR100_truncated dataset.
|
||||
|
||||
Args:
|
||||
root (str): The root directory where the dataset is stored.
|
||||
dataidxs (list or None): List of indices for truncating the dataset, if applicable.
|
||||
train (bool): Whether to load the training set (True) or the test set (False).
|
||||
apply_transformation (callable, optional): apply_transformationation function applied to images.
|
||||
target_apply_transformation (callable, optional): apply_transformationation function applied to targets (labels).
|
||||
download (bool): Whether to download the dataset if it is not found in the root directory.
|
||||
"""
|
||||
self.root = root # Root directory where dataset is stored
|
||||
self.dataidxs = dataidxs # List of indices for truncating the dataset
|
||||
self.train = train # Specifies whether to load the training set
|
||||
self.apply_transformation = apply_transformation # Optional apply_transformationations on images
|
||||
self.target_apply_transformation = target_apply_transformation # Optional apply_transformationations on labels
|
||||
self.download = download # Specifies whether to download the dataset if missing
|
||||
|
||||
# Build the truncated dataset based on the provided indices
|
||||
self.data, self.target = self.__build_truncated_dataset__()
|
||||
|
||||
def __build_truncated_dataset__(self):
|
||||
"""
|
||||
Constructs the truncated dataset based on the provided data indices.
|
||||
|
||||
Returns:
|
||||
tuple: The truncated data and corresponding target labels.
|
||||
"""
|
||||
cifar_dataobj = CIFAR100(self.root, self.train, self.apply_transformation, self.target_apply_transformation, self.download)
|
||||
|
||||
# Load all data and targets
|
||||
data = cifar_dataobj.data
|
||||
target = np.array(cifar_dataobj.targets)
|
||||
|
||||
# If specific indices are provided, truncate the dataset accordingly
|
||||
if self.dataidxs is not None:
|
||||
data = data[self.dataidxs]
|
||||
target = target[self.dataidxs]
|
||||
|
||||
return data, target
|
||||
|
||||
def truncate_channel(self, index):
|
||||
"""
|
||||
Modifies the selected images by zeroing out the green and blue channels,
|
||||
effectively converting them to grayscale-like images.
|
||||
|
||||
Args:
|
||||
index (np.array): The indices of images to modify.
|
||||
"""
|
||||
for i in range(index.shape[0]):
|
||||
gs_index = index[i]
|
||||
self.data[gs_index, :, :, 1] = 0.0 # Set the green channel to 0
|
||||
self.data[gs_index, :, :, 2] = 0.0 # Set the blue channel to 0
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Retrieves an image and its corresponding target (label) at the given index.
|
||||
|
||||
Args:
|
||||
index (int): Index of the data point to retrieve.
|
||||
|
||||
Returns:
|
||||
tuple: (image, target) where the image is apply_transformationed (if specified), and the target is the label.
|
||||
"""
|
||||
img, target = self.data[index], self.target[index]
|
||||
|
||||
# Apply any specified apply_transformationations to the image
|
||||
if self.apply_transformation is not None:
|
||||
img = self.apply_transformation(img)
|
||||
|
||||
# Apply any specified apply_transformationations to the target label
|
||||
if self.target_apply_transformation is not None:
|
||||
target = self.target_apply_transformation(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Returns the total number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
int: The number of samples in the dataset.
|
||||
"""
|
||||
return len(self.data)
|
BIN
EdgeFLite/fedml_service/data_cleaning/pillbase/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/data_cleaning/pillbase/.DS_Store
vendored
Normal file
Binary file not shown.
@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import logging
|
||||
import random
|
||||
import pickle
|
||||
from PIL import Image
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import torchvision.apply_transformations as apply_transformations
|
||||
|
||||
from dataset.pill_dataset_base import PillDataBase # Custom dataset class for handling Pill data
|
||||
from config import HOME # Configuration file for defining the home directory
|
||||
|
||||
# Configure logging to capture information during the execution of the script
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# Function to load and partition pill base data
|
||||
def load_partition_data_pillbase(dataset, data_dir, partition_method, partition_alpha, client_number, batch_size):
|
||||
# Define the number of samples in the training and testing datasets
|
||||
train_data_num = 8161
|
||||
test_data_num = 1619
|
||||
|
||||
# Normalization parameters (mean and standard deviation) for each channel (RGB) based on the dataset's characteristics
|
||||
mean, std = [0.4550, 0.5239, 0.5653], [0.2460, 0.2446, 0.2252]
|
||||
|
||||
# Define apply_transformationations for training data, including:
|
||||
# 1. Randomly resized crops to augment the data.
|
||||
# 2. Horizontal flipping for data augmentation.
|
||||
# 3. Conversion to tensor and normalization with the provided mean and std.
|
||||
# 4. Random erasing to simulate occlusion as part of augmentation.
|
||||
train_apply_transformation = apply_transformations.Compose([
|
||||
apply_transformations.RandomResizedCrop(224, scale=(0.1, 1.0), interpolation=Image.BILINEAR),
|
||||
apply_transformations.RandomHorizontalFlip(),
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize(mean, std),
|
||||
apply_transformations.RandomErasing(p=0.5, scale=(0.05, 0.12), ratio=(0.5, 1.5), value=0)
|
||||
])
|
||||
|
||||
# Create a training dataset using the PillDataBase class and apply the apply_transformationation
|
||||
train_dataset = PillDataBase(data_dir, train=True, apply_transformation=train_apply_transformation, split_factor=1)
|
||||
|
||||
# Create a DataLoader for the global training dataset, with shuffling enabled and dropping the last incomplete batch
|
||||
train_data_global = data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
||||
|
||||
# Define apply_transformationations for validation data, including:
|
||||
# 1. Resizing to a larger scale for testing.
|
||||
# 2. Center cropping to ensure the image size is 224x224.
|
||||
# 3. Conversion to tensor and normalization with the same mean and std as the training data.
|
||||
val_apply_transformation = apply_transformations.Compose([
|
||||
apply_transformations.Resize(int(224 * 1.15), interpolation=Image.BILINEAR),
|
||||
apply_transformations.CenterCrop(224),
|
||||
apply_transformations.ToTensor(),
|
||||
apply_transformations.Normalize(mean, std)
|
||||
])
|
||||
|
||||
# Create a validation dataset using the PillDataBase class with the validation apply_transformationations
|
||||
val_dataset = PillDataBase(data_dir, train=False, apply_transformation=val_apply_transformation, split_factor=1)
|
||||
|
||||
# Calculate how many images each client will receive for the validation dataset
|
||||
images_per_client = len(val_dataset) // client_number
|
||||
logger.info(f"Images per client: {images_per_client}") # Log the number of images assigned to each client
|
||||
|
||||
# Split the validation data among the clients evenly, ensuring the last client gets any remaining images
|
||||
data_split = [images_per_client] * (client_number - 1) + [len(val_dataset) - images_per_client * (client_number - 1)]
|
||||
|
||||
# Perform the actual data splitting using torch's random_split function and a fixed random seed for reproducibility
|
||||
testdata_split = torch.utils.data.random_split(val_dataset, data_split, generator=torch.Generator().manual_seed(68))
|
||||
|
||||
# Create a DataLoader for each client from their respective validation dataset splits
|
||||
test_data_local_dict = [
|
||||
torch.utils.data.DataLoader(x, batch_size=16, shuffle=True, drop_last=True)
|
||||
for x in testdata_split
|
||||
]
|
||||
|
||||
# Return all necessary data structures, including the number of classes and the training/test data loaders
|
||||
class_num = 98 # Total number of classes in the dataset
|
||||
return train_data_num, test_data_num, train_data_global, None, None, None, test_data_local_dict, class_num
|
BIN
EdgeFLite/fedml_service/data_cleaning/skin_dataset/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/data_cleaning/skin_dataset/.DS_Store
vendored
Normal file
Binary file not shown.
@ -0,0 +1,97 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import torchvision.apply_transformations as apply_transformations
|
||||
import random
|
||||
from dataset.skin_dataset import SkinData
|
||||
from config import HOME
|
||||
import pickle
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig() # Configures the basic logging setup
|
||||
logger = logging.getLogger() # Gets the root logger
|
||||
logger.setLevel(logging.INFO) # Sets the logging level to INFO
|
||||
|
||||
def load_partition_data_skin_dataset(dataset, data_dir, partition_method, partition_alpha, client_number, batch_size):
|
||||
# Predefined dataset sizes for training and testing
|
||||
train_data_num = 8012 # Number of training samples
|
||||
test_data_num = 2003 # Number of testing samples
|
||||
|
||||
# Normalization parameters used for preprocessing
|
||||
mean = [0.485, 0.456, 0.406] # Mean values for normalization (standard for ImageNet)
|
||||
std = [0.229, 0.224, 0.225] # Standard deviation for normalization
|
||||
|
||||
# Load the training data from the pre-saved pickle file
|
||||
with open(HOME + '/dataset_hub/skin_dataset/skin_dataset_train.pickle', 'rb') as train_file:
|
||||
train_data = pickle.load(train_file) # Loading training data from pickle file
|
||||
|
||||
# Data augmentation and preprocessing apply_transformationations for training data
|
||||
train_apply_transformations = apply_transformations.Compose([
|
||||
apply_transformations.RandomHorizontalFlip(), # Randomly flip the image horizontally
|
||||
apply_transformations.RandomVerticalFlip(), # Randomly flip the image vertically
|
||||
apply_transformations.RandomHorizontalFlip(), # Repeated horizontal flip (may be intentional)
|
||||
apply_transformations.RandomAdjustadjust_image_sharpness(random.uniform(0, 4.0)), # Adjust image adjust_image_sharpness
|
||||
apply_transformations.RandomAutocontrast(), # Automatically adjust image contrast
|
||||
apply_transformations.Pad(3), # Pad image by 3 pixels
|
||||
apply_transformations.RandomRotation(10), # Random rotation by 10 degrees
|
||||
apply_transformations.CenterCrop(64), # Crop the center to a size of 64x64
|
||||
apply_transformations.ToTensor(), # Convert the image to a tensor
|
||||
apply_transformations.Normalize(mean=mean, std=std) # Normalize using the predefined mean and std
|
||||
])
|
||||
|
||||
# Create the training dataset with augmentation and apply_transformationation
|
||||
train_dataset = SkinData(train_data, apply_transformation=train_apply_transformations, split_factor=1)
|
||||
# Create a DataLoader for the training dataset
|
||||
train_data_global = data.DataLoader(
|
||||
dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True # Shuffle data and drop incomplete batches
|
||||
)
|
||||
|
||||
# Load the test data from the pre-saved pickle file
|
||||
with open(HOME + "/dataset_hub/skin_dataset/skin_dataset_test.pickle", 'rb') as test_file:
|
||||
test_data = pickle.load(test_file) # Loading test data from pickle file
|
||||
|
||||
# Preprocessing apply_transformationations for validation/testing data (without augmentation)
|
||||
val_apply_transformations = apply_transformations.Compose([
|
||||
apply_transformations.Pad(3), # Pad the image by 3 pixels
|
||||
apply_transformations.CenterCrop(64), # Crop the center to a size of 64x64
|
||||
apply_transformations.ToTensor(), # Convert the image to a tensor
|
||||
apply_transformations.Normalize(mean=mean, std=std) # Normalize using the predefined mean and std
|
||||
])
|
||||
|
||||
# Create the validation/test dataset with the preprocessing apply_transformationations
|
||||
val_dataset = SkinData(test_data, apply_transformation=val_apply_transformations, split_factor=1)
|
||||
|
||||
# Split test data across clients. Each client gets approximately equal data.
|
||||
images_per_client = len(val_dataset) // client_number # Number of images each client will get
|
||||
logger.info(f"Images per client: {images_per_client}") # Log the number of images per client
|
||||
|
||||
# Create a list that determines the size of the data splits for each client
|
||||
data_split = [images_per_client] * (client_number - 1) # Distribute data equally to all but the last client
|
||||
data_split.append(len(val_dataset) - images_per_client * (client_number - 1)) # The last client gets the remaining data
|
||||
logger.info(f"Data split: {data_split}") # Log the data split
|
||||
|
||||
# Randomly split test data for each client using the data_split list
|
||||
testdata_split = torch.utils.data.random_split(
|
||||
val_dataset, data_split, generator=torch.Generator().manual_seed(68) # Set the random seed for reproducibility
|
||||
)
|
||||
|
||||
# Create a DataLoader for each client's test data
|
||||
test_data_local_dict = [
|
||||
torch.utils.data.DataLoader(
|
||||
x, batch_size=32, shuffle=(True if train_sampler is None else False), drop_last=True # Create DataLoader for each client's split
|
||||
) for x in testdata_split
|
||||
]
|
||||
|
||||
# Other variables that are currently unused (placeholders for future implementation)
|
||||
class_num = 7 # Number of classes in the dataset
|
||||
test_data_global = None # Placeholder for global test data (currently not used)
|
||||
data_local_num_dict = None # Placeholder for storing the number of samples per client (currently not used)
|
||||
train_data_local_dict = None # Placeholder for storing local training data for each client (currently not used)
|
||||
|
||||
# Return key values including the number of samples, data loaders, and class number
|
||||
return train_data_num, test_data_num, train_data_global, test_data_global, \
|
||||
data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num
|
BIN
EdgeFLite/fedml_service/decentralized/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/decentralized/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
EdgeFLite/fedml_service/decentralized/federated_gkt/.DS_Store
vendored
Normal file
BIN
EdgeFLite/fedml_service/decentralized/federated_gkt/.DS_Store
vendored
Normal file
Binary file not shown.
@ -0,0 +1,120 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from fedml_service.decentralized.federated_gkt import utils
|
||||
|
||||
# Class for training a GKT client in a federated learning setup
|
||||
class GKTTrainer:
|
||||
def __init__(self, client_index, local_training_data, local_test_data, device, client_model, args):
|
||||
# Initialize the client trainer with various parameters
|
||||
self.client_index = client_index # Index for the current client
|
||||
self.local_training_data = local_training_data[client_index] # Local training dataset specific to the client
|
||||
self.local_test_data = local_test_data[client_index] # Local test dataset specific to the client
|
||||
self.device = device # Device (CPU/GPU) where the computation will take place
|
||||
self.client_model = client_model.to(self.device) # Model assigned to the client
|
||||
self.args = args # Arguments passed for configuring the training process
|
||||
|
||||
logging.info(f"Client device = {self.device}")
|
||||
|
||||
# Model parameters used for optimization
|
||||
self.model_params = self.master_params = self.client_model.parameters()
|
||||
optim_params = self.master_params
|
||||
|
||||
# Configure optimizer based on the provided arguments
|
||||
if self.args.optimizer == "SGD":
|
||||
# Using SGD optimizer with learning rate, momentum, and weight decay
|
||||
self.optimizer = optim.SGD(optim_params, lr=self.args.lr, momentum=0.9, nesterov=True, weight_decay=self.args.wd)
|
||||
elif self.args.optimizer == "Adam":
|
||||
# Using Adam optimizer with learning rate, weight decay, and AMSGrad variant
|
||||
self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True)
|
||||
|
||||
# Define loss functions: CrossEntropy for true label prediction, KL divergence for knowledge distillation
|
||||
self.criterion_CE = nn.CrossEntropyLoss()
|
||||
self.criterion_KL = utils.KL_Loss(self.args.temperature)
|
||||
|
||||
# Dictionary to hold logits received from the server (used for knowledge distillation)
|
||||
self.server_logits_dict = {}
|
||||
|
||||
logging.info(f"Client device = {self.device} - Initialization Complete")
|
||||
|
||||
# Update server logits for knowledge distillation
|
||||
def update_large_model_logits(self, logits):
|
||||
self.server_logits_dict = logits
|
||||
|
||||
# Main training function for the client
|
||||
def train(self):
|
||||
# Dictionaries to store extracted features, logits, and labels during training and testing
|
||||
extracted_feature_dict, logits_dict, labels_dict = {}, {}, {}
|
||||
extracted_feature_dict_test, labels_dict_test = {}, {}
|
||||
|
||||
# Only train if training on client is enabled
|
||||
if self.args.whether_training_on_client:
|
||||
self.client_model.train() # Set model to training mode
|
||||
epoch_loss = [] # Track loss for each epoch
|
||||
|
||||
# Loop over the specified number of federated epochs
|
||||
for epoch in range(self.args.fed_epochs):
|
||||
batch_loss = [] # Track loss for each batch
|
||||
|
||||
# Loop through the local training data in batches
|
||||
for batch_idx, (images, labels) in enumerate(self.local_training_data):
|
||||
# Move images and labels to the specified device
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
|
||||
# Forward pass through the client model
|
||||
log_probs, _ = self.client_model(images)
|
||||
|
||||
# Compute the loss with respect to the true labels
|
||||
loss_true = self.criterion_CE(log_probs, labels)
|
||||
|
||||
# If server logits are available, calculate the distillation loss using KL divergence
|
||||
if self.server_logits_dict:
|
||||
large_model_logits = torch.from_numpy(self.server_logits_dict[batch_idx]).to(self.device)
|
||||
loss_kd = self.criterion_KL(log_probs, large_model_logits)
|
||||
# Combine true label loss and distillation loss
|
||||
loss = loss_true + self.args.alpha * loss_kd
|
||||
else:
|
||||
# Use only the true label loss if no server logits are available
|
||||
loss = loss_true
|
||||
|
||||
# Perform backpropagation and optimization step
|
||||
self.optimizer.zero_grad() # Reset gradients
|
||||
loss.backward() # Backpropagate the loss
|
||||
self.optimizer.step() # Update model parameters
|
||||
|
||||
# Logging progress for each batch
|
||||
logging.info(f'Client {self.client_index} - Update Epoch: {epoch} '
|
||||
f'[{batch_idx * len(images)}/{len(self.local_training_data.dataset)} '
|
||||
f'({100. * batch_idx / len(self.local_training_data):.0f}%)]')
|
||||
batch_loss.append(loss.item()) # Store the loss for the current batch
|
||||
|
||||
# Calculate and store average loss for the epoch
|
||||
epoch_loss.append(sum(batch_loss) / len(batch_loss))
|
||||
|
||||
# Switch to evaluation mode after training
|
||||
self.client_model.eval()
|
||||
|
||||
# Extract features, logits, and labels from the training data for evaluation
|
||||
for batch_idx, (images, labels) in enumerate(self.local_training_data):
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
log_probs, extracted_features = self.client_model(images)
|
||||
|
||||
# Store the extracted features, logits, and labels for this batch
|
||||
extracted_feature_dict[batch_idx] = extracted_features.cpu().detach().numpy()
|
||||
logits_dict[batch_idx] = log_probs.cpu().detach().numpy()
|
||||
labels_dict[batch_idx] = labels.cpu().detach().numpy()
|
||||
|
||||
# Extract features and labels from the test data for evaluation
|
||||
for batch_idx, (images, labels) in enumerate(self.local_test_data):
|
||||
test_images, test_labels = images.to(self.device), labels.to(self.device)
|
||||
_, extracted_features_test = self.client_model(test_images)
|
||||
|
||||
# Store the extracted test features and labels for this batch
|
||||
extracted_feature_dict_test[batch_idx] = extracted_features_test.cpu().detach().numpy()
|
||||
labels_dict_test[batch_idx] = test_labels.cpu().detach().numpy()
|
||||
|
||||
# Return the extracted features, logits, and labels from both training and test datasets
|
||||
return extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test, labels_dict_test
|
@ -0,0 +1,108 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def load_state_dict(file):
|
||||
"""Load a state dict from a file, handling any potential location issues."""
|
||||
try:
|
||||
return torch.load(file)
|
||||
except AssertionError:
|
||||
return torch.load(file, map_location=lambda storage, location: storage)
|
||||
|
||||
|
||||
def flatten_parameters(model):
|
||||
"""Flatten the parameters of the model into a single tensor."""
|
||||
return torch.cat([param.data.view(-1) for param in model.parameters()])
|
||||
|
||||
|
||||
def set_flattened_parameters(model, flat_params):
|
||||
"""Set the model's parameters from a flattened tensor."""
|
||||
prev_ind = 0
|
||||
for param in model.parameters():
|
||||
flat_size = int(np.prod(param.size()))
|
||||
param.data.copy_(flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
|
||||
prev_ind += flat_size
|
||||
|
||||
|
||||
class RollingAverage:
|
||||
"""Class to maintain a running average of a quantity."""
|
||||
def __init__(self):
|
||||
self.steps = 0
|
||||
self.total = 0
|
||||
|
||||
def update(self, val):
|
||||
self.total += val
|
||||
self.steps += 1
|
||||
|
||||
def value(self):
|
||||
return self.total / float(self.steps) if self.steps > 0 else 0
|
||||
|
||||
|
||||
def compute_accuracy(output, target, topk=(1,)):
|
||||
"""Compute the precision@k for the specified values of k."""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, dim=1, largest=True, sorted=True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
return [correct[:k].reshape(-1).float().sum(0).mul_(100.0 / batch_size) for k in topk]
|
||||
|
||||
|
||||
class KLDivergenceLoss(nn.Module):
|
||||
"""Kullback-Leibler Divergence Loss."""
|
||||
def __init__(self, temperature=1):
|
||||
super(KLDivergenceLoss, self).__init__()
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(self, output_batch, teacher_outputs):
|
||||
output_batch = F.log_softmax(output_batch / self.temperature, dim=1)
|
||||
teacher_outputs = F.softmax(teacher_outputs / self.temperature, dim=1) + 1e-7
|
||||
return self.temperature ** 2 * nn.KLDivLoss(reduction='batchmean')(output_batch, teacher_outputs)
|
||||
|
||||
|
||||
class CELoss(nn.Module):
|
||||
"""Cross-Entropy Loss."""
|
||||
def __init__(self, temperature=1):
|
||||
super(CELoss, self).__init__()
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(self, output_batch, teacher_outputs):
|
||||
output_batch = F.log_softmax(output_batch / self.temperature, dim=1)
|
||||
teacher_outputs = F.softmax(teacher_outputs / self.temperature, dim=1)
|
||||
return -self.temperature ** 2 * torch.sum(output_batch * teacher_outputs) / teacher_outputs.size(0)
|
||||
|
||||
|
||||
def save_dict_to_json(data, json_path):
|
||||
"""Save a dictionary of floats to a JSON file."""
|
||||
with open(json_path, 'w') as f:
|
||||
json.dump({k: float(v) for k, v in data.items()}, f, indent=4)
|
||||
|
||||
|
||||
def get_optimized_params(model, model_params, master_params):
|
||||
"""Filter out batch norm parameters from weight decay to improve accuracy."""
|
||||
bn_params, remaining_params = split_bn_params(model, model_params, master_params)
|
||||
return [{'params': bn_params, 'weight_decay': 0}, {'params': remaining_params}]
|
||||
|
||||
|
||||
def split_bn_params(model, model_params, master_params):
|
||||
"""Split parameters into batch norm and non-batch norm."""
|
||||
def get_bn_params(module):
|
||||
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
||||
return set(module.parameters())
|
||||
return {p for child in module.children() for p in get_bn_params(child)}
|
||||
|
||||
mod_bn_params = get_bn_params(model)
|
||||
zipped_params = zip(model_params, master_params)
|
||||
|
||||
mas_bn_params = [p_mast for p_mod, p_mast in zipped_params if p_mod in mod_bn_params]
|
||||
mas_rem_params = [p_mast for p_mod, p_mast in zipped_params if p_mod not in mod_bn_params]
|
||||
|
||||
return mas_bn_params, mas_rem_params
|
@ -0,0 +1,274 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from utils import metric
|
||||
from fedml_service.decentralized.federated_gkt import utils
|
||||
|
||||
# List to store filenames of saved checkpoints
|
||||
saved_ckpt_filenames = []
|
||||
|
||||
class GKTServerTrainer:
|
||||
def __init__(self, client_num, device, server_model, args, writer):
|
||||
# Initialize the trainer with the number of clients, device (CPU/GPU), global server model, training arguments, and a writer for logging
|
||||
self.client_num = client_num
|
||||
self.device = device
|
||||
self.args = args
|
||||
self.writer = writer
|
||||
|
||||
"""
|
||||
Notes: Using data parallelism requires adjusting the batch size accordingly.
|
||||
For example, with a single GPU (batch_size = 64), an epoch takes 1:03;
|
||||
using 4 GPUs (batch_size = 256), it takes 38 seconds, and with 4 GPUs (batch_size = 64), it takes 1:00.
|
||||
If batch size is not adjusted, the communication between CPU and GPU may slow down training.
|
||||
"""
|
||||
|
||||
# Server model setup
|
||||
self.model_global = server_model
|
||||
self.model_global.train() # Set model to training mode
|
||||
self.model_global.to(self.device) # Move model to the specified device (CPU or GPU)
|
||||
|
||||
# Model parameters for optimization
|
||||
self.model_params = self.master_params = self.model_global.parameters()
|
||||
optim_params = self.master_params
|
||||
|
||||
# Choose optimizer based on arguments (SGD or Adam)
|
||||
if self.args.optimizer == "SGD":
|
||||
self.optimizer = optim.SGD(optim_params, lr=self.args.lr, momentum=0.9, nesterov=True, weight_decay=self.args.wd)
|
||||
elif self.args.optimizer == "Adam":
|
||||
self.optimizer = optim.Adam(optim_params, lr=self.args.lr, weight_decay=0.0001, amsgrad=True)
|
||||
|
||||
# Learning rate scheduler to reduce the learning rate when the accuracy plateaus
|
||||
self.scheduler = ReduceLROnPlateau(self.optimizer, 'max')
|
||||
|
||||
# Loss functions: CrossEntropy for classification, KL for knowledge distillation
|
||||
self.criterion_CE = nn.CrossEntropyLoss()
|
||||
self.criterion_KL = utils.KL_Loss(self.args.temperature)
|
||||
|
||||
# Best accuracy tracking
|
||||
self.best_acc = 0.0
|
||||
|
||||
# Client data dictionaries to store features, logits, and labels
|
||||
self.client_extracted_feature_dict = {}
|
||||
self.client_logits_dict = {}
|
||||
self.client_labels_dict = {}
|
||||
self.server_logits_dict = {}
|
||||
|
||||
# Testing data dictionaries
|
||||
self.client_extracted_feature_dict_test = {}
|
||||
self.client_labels_dict_test = {}
|
||||
|
||||
# Miscellaneous dictionaries to store model info, sample numbers, training accuracy, and loss
|
||||
self.model_dict = {}
|
||||
self.sample_num_dict = {}
|
||||
self.train_acc_dict = {}
|
||||
self.train_loss_dict = {}
|
||||
self.test_acc_avg = 0.0
|
||||
self.test_loss_avg = 0.0
|
||||
|
||||
# Dictionary to track if the client model has been uploaded
|
||||
self.flag_client_model_uploaded_dict = {idx: False for idx in range(self.client_num)}
|
||||
|
||||
# Add results from a local client model after training
|
||||
def add_local_trained_result(self, index, extracted_feature_dict, logits_dict, labels_dict,
|
||||
extracted_feature_dict_test, labels_dict_test):
|
||||
logging.info(f"Adding model for client index = {index}")
|
||||
self.client_extracted_feature_dict[index] = extracted_feature_dict
|
||||
self.client_logits_dict[index] = logits_dict
|
||||
self.client_labels_dict[index] = labels_dict
|
||||
self.client_extracted_feature_dict_test[index] = extracted_feature_dict_test
|
||||
self.client_labels_dict_test[index] = labels_dict_test
|
||||
self.flag_client_model_uploaded_dict[index] = True
|
||||
|
||||
# Check if all clients have uploaded their models
|
||||
def check_whether_all_receive(self):
|
||||
if all(self.flag_client_model_uploaded_dict.values()):
|
||||
self.flag_client_model_uploaded_dict = {idx: False for idx in range(self.client_num)}
|
||||
return True
|
||||
return False
|
||||
|
||||
# Get logits from the global model for a specific client
|
||||
def get_global_logits(self, client_index):
|
||||
return self.server_logits_dict.get(client_index)
|
||||
|
||||
# Main training function based on the round index
|
||||
def train(self, round_idx):
|
||||
if self.args.sweep == 1: # Sweep mode
|
||||
self.sweep(round_idx)
|
||||
else: # Normal training process
|
||||
if self.args.whether_training_on_client == 1: # Check if training occurs on client
|
||||
self.train_and_distill_on_client(round_idx)
|
||||
else: # No training on client, just evaluate
|
||||
self.do_not_train_on_client(round_idx)
|
||||
|
||||
# Training and knowledge distillation on client side
|
||||
def train_and_distill_on_client(self, round_idx):
|
||||
# Set the number of server epochs (based on testing mode)
|
||||
epochs_server = 1 if not self.args.test else self.get_server_epoch_strategy_test()[0]
|
||||
self.train_and_eval(round_idx, epochs_server, self.writer, self.args) # Train and evaluate
|
||||
self.scheduler.step(self.best_acc, epoch=round_idx) # Update learning rate scheduler
|
||||
|
||||
# Skip client-side training
|
||||
def do_not_train_on_client(self, round_idx):
|
||||
self.train_and_eval(round_idx, 1)
|
||||
self.scheduler.step(self.best_acc, epoch=round_idx)
|
||||
|
||||
# Training with sweeping strategy
|
||||
def sweep(self, round_idx):
|
||||
self.train_and_eval(round_idx, self.args.epochs_server)
|
||||
self.scheduler.step(self.best_acc, epoch=round_idx)
|
||||
|
||||
# Strategy for determining the number of epochs (used in testing)
|
||||
def get_server_epoch_strategy_test(self):
|
||||
return 1, True
|
||||
|
||||
# Different strategies for determining the number of epochs based on training round
|
||||
def get_server_epoch_strategy_reset56(self, round_idx):
|
||||
epochs = 20 if round_idx < 20 else 15 if round_idx < 30 else 10 if round_idx < 40 else 5 if round_idx < 50 else 3 if round_idx < 150 else 1
|
||||
whether_distill_back = round_idx < 150
|
||||
return epochs, whether_distill_back
|
||||
|
||||
# Another variant of epoch strategy
|
||||
def get_server_epoch_strategy_reset56_2(self, round_idx):
|
||||
return self.args.epochs_server, True
|
||||
|
||||
# Main training and evaluation loop
|
||||
def train_and_eval(self, round_idx, epochs, val_writer, args):
|
||||
for epoch in range(epochs):
|
||||
logging.info(f"Train and evaluate. Round = {round_idx}, Epoch = {epoch}")
|
||||
train_metrics = self.train_large_model_on_the_server() # Training step
|
||||
|
||||
if epoch == epochs - 1:
|
||||
# Log metrics for the final epoch
|
||||
val_writer.add_scalar('average training loss', train_metrics['train_loss'], global_step=round_idx)
|
||||
test_metrics = self.eval_large_model_on_the_server() # Evaluation step
|
||||
test_acc = test_metrics['test_accTop1']
|
||||
|
||||
val_writer.add_scalar('test loss', test_metrics['test_loss'], global_step=round_idx)
|
||||
val_writer.add_scalar('test acc', test_metrics['test_accTop1'], global_step=round_idx)
|
||||
|
||||
# Save best accuracy model
|
||||
if test_acc >= self.best_acc:
|
||||
logging.info("- Found better accuracy")
|
||||
self.best_acc = test_acc
|
||||
|
||||
val_writer.add_scalar('best_acc1', self.best_acc, global_step=round_idx)
|
||||
|
||||
# Save model checkpoints
|
||||
if args.save_weight:
|
||||
filename = f"checkpoint_{round_idx}.pth.tar"
|
||||
saved_ckpt_filenames.append(filename)
|
||||
if len(saved_ckpt_filenames) > args.max_ckpt_nums:
|
||||
os.remove(os.path.join(args.model_dir, saved_ckpt_filenames.pop(0)))
|
||||
|
||||
ckpt_dict = {
|
||||
'round': round_idx + 1,
|
||||
'arch': args.arch,
|
||||
'state_dict': self.model_global.state_dict(),
|
||||
'best_acc1': self.best_acc,
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
}
|
||||
metric.save_checkpoint(ckpt_dict, test_acc >= self.best_acc, args.model_dir, filename=filename)
|
||||
|
||||
# Print metrics for the current round
|
||||
print(f"{round_idx}-th round | Train Loss: {train_metrics['train_loss']:.3g} | Test Loss: {test_metrics['test_loss']:.3g} | Test Acc: {test_metrics['test_accTop1']:.3f}")
|
||||
|
||||
# Function to train the model on the server side
|
||||
def train_large_model_on_the_server(self):
|
||||
# Clear the logits dictionary and set model to training mode
|
||||
self.server_logits_dict.clear()
|
||||
self.model_global.train()
|
||||
|
||||
# Track loss and accuracy
|
||||
loss_avg = utils.RollingAverage()
|
||||
accTop1_avg = utils.RollingAverage()
|
||||
accTop5_avg = utils.RollingAverage()
|
||||
|
||||
# Iterate over clients' extracted features
|
||||
for client_index, extracted_feature_dict in self.client_extracted_feature_dict.items():
|
||||
logits_dict = self.client_logits_dict[client_index]
|
||||
labels_dict = self.client_labels_dict[client_index]
|
||||
|
||||
s_logits_dict = {}
|
||||
self.server_logits_dict[client_index] = s_logits_dict
|
||||
|
||||
# Iterate over batches of features for each client
|
||||
for batch_index, batch_feature_map_x in extracted_feature_dict.items():
|
||||
batch_feature_map_x = torch.from_numpy(batch_feature_map_x).to(self.device)
|
||||
batch_logits = torch.from_numpy(logits_dict[batch_index]).float().to(self.device)
|
||||
batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device)
|
||||
|
||||
# Forward pass
|
||||
output_batch = self.model_global(batch_feature_map_x)
|
||||
|
||||
# Knowledge distillation loss
|
||||
if self.args.whether_distill_on_the_server == 1:
|
||||
loss_kd = self.criterion_KL(output_batch, batch_logits).to(self.device)
|
||||
loss_true = self.criterion_CE(output_batch, batch_labels).to(self.device)
|
||||
loss = loss_kd + self.args.alpha * loss_true
|
||||
else:
|
||||
# Standard cross-entropy loss
|
||||
loss = self.criterion_CE(output_batch, batch_labels).to(self.device)
|
||||
|
||||
# Backward pass and optimization
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# Compute accuracy metrics
|
||||
metrics = utils.accuracy(output_batch, batch_labels, topk=(1, 5))
|
||||
accTop1_avg.update(metrics[0].item())
|
||||
accTop5_avg.update(metrics[1].item())
|
||||
loss_avg.update(loss.item())
|
||||
|
||||
# Store logits for the batch
|
||||
s_logits_dict[batch_index] = output_batch.cpu().detach().numpy()
|
||||
|
||||
# Aggregate and log training metrics
|
||||
train_metrics = {'train_loss': loss_avg.value(),
|
||||
'train_accTop1': accTop1_avg.value(),
|
||||
'train_accTop5': accTop5_avg.value()}
|
||||
logging.info(f"- Train metrics: {' ; '.join(f'{k}: {v:.3f}' for k, v in train_metrics.items())}")
|
||||
return train_metrics
|
||||
|
||||
# Function to evaluate the model on the server side
|
||||
def eval_large_model_on_the_server(self):
|
||||
# Set model to evaluation mode
|
||||
self.model_global.eval()
|
||||
loss_avg = utils.RollingAverage()
|
||||
accTop1_avg = utils.RollingAverage()
|
||||
accTop5_avg = utils.RollingAverage()
|
||||
|
||||
# Disable gradient computation for evaluation
|
||||
with torch.no_grad():
|
||||
# Iterate over clients' extracted features for testing
|
||||
for client_index, extracted_feature_dict in self.client_extracted_feature_dict_test.items():
|
||||
labels_dict = self.client_labels_dict_test[client_index]
|
||||
|
||||
# Iterate over batches for each client
|
||||
for batch_index, batch_feature_map_x in extracted_feature_dict.items():
|
||||
batch_feature_map_x = torch.from_numpy(batch_feature_map_x).to(self.device)
|
||||
batch_labels = torch.from_numpy(labels_dict[batch_index]).long().to(self.device)
|
||||
|
||||
# Forward pass
|
||||
output_batch = self.model_global(batch_feature_map_x)
|
||||
loss = self.criterion_CE(output_batch, batch_labels)
|
||||
|
||||
# Compute accuracy metrics
|
||||
metrics = utils.accuracy(output_batch, batch_labels, topk=(1, 5))
|
||||
accTop1_avg.update(metrics[0].item())
|
||||
accTop5_avg.update(metrics[1].item())
|
||||
loss_avg.update(loss.item())
|
||||
|
||||
# Aggregate and log test metrics
|
||||
test_metrics = {'test_loss': loss_avg.value(),
|
||||
'test_accTop1': accTop1_avg.value(),
|
||||
'test_accTop5': accTop5_avg.value()}
|
||||
logging.info(f"- Test metrics: {' ; '.join(f'{k}: {v:.3f}' for k, v in test_metrics.items())}")
|
||||
return test_metrics
|
BIN
EdgeFLite/helpers/.DS_Store
vendored
Normal file
BIN
EdgeFLite/helpers/.DS_Store
vendored
Normal file
Binary file not shown.
190
EdgeFLite/helpers/evaluation_metrics.py
Normal file
190
EdgeFLite/helpers/evaluation_metrics.py
Normal file
@ -0,0 +1,190 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import torch
|
||||
|
||||
def store_model(state, best_model, directory, filename='checkpoint.pth'):
|
||||
"""
|
||||
Stores the model checkpoint in the specified directory. If it's the best model,
|
||||
it saves another copy named 'best_model.pth'.
|
||||
|
||||
Args:
|
||||
state (dict): Model's state dictionary.
|
||||
best_model (bool): Flag indicating if the current model is the best.
|
||||
directory (str): Directory where the model is saved.
|
||||
filename (str): Name of the file to save the checkpoint (default 'checkpoint.pth').
|
||||
"""
|
||||
save_path = os.path.join(directory, filename)
|
||||
torch.save(state, save_path)
|
||||
if best_model:
|
||||
# If the current model is the best, save another copy as 'best_model.pth'
|
||||
shutil.copy(save_path, os.path.join(directory, 'best_model.pth'))
|
||||
|
||||
def save_main_client_model(state, best_model, directory):
|
||||
"""
|
||||
Saves the model for the main client if it's the best one.
|
||||
|
||||
Args:
|
||||
state (dict): Model's state dictionary.
|
||||
best_model (bool): Flag indicating if the current model is the best.
|
||||
directory (str): Directory where the model is saved.
|
||||
"""
|
||||
if best_model:
|
||||
print("Saving the best main client model")
|
||||
torch.save(state, os.path.join(directory, 'main_client_best.pth'))
|
||||
|
||||
def save_proxy_clients_model(state, best_model, directory):
|
||||
"""
|
||||
Saves the model for proxy clients if it's the best one.
|
||||
|
||||
Args:
|
||||
state (dict): Model's state dictionary.
|
||||
best_model (bool): Flag indicating if the current model is the best.
|
||||
directory (str): Directory where the model is saved.
|
||||
"""
|
||||
if best_model:
|
||||
print("Saving the best proxy client model")
|
||||
torch.save(state, os.path.join(directory, 'proxy_clients_best.pth'))
|
||||
|
||||
def save_individual_client_model(state, best_model, directory):
|
||||
"""
|
||||
Saves the model for individual clients if it's the best one.
|
||||
|
||||
Args:
|
||||
state (dict): Model's state dictionary.
|
||||
best_model (bool): Flag indicating if the current model is the best.
|
||||
directory (str): Directory where the model is saved.
|
||||
"""
|
||||
if best_model:
|
||||
print("Saving the best client model")
|
||||
torch.save(state, os.path.join(directory, 'client_best.pth'))
|
||||
|
||||
def save_server_model(state, best_model, directory):
|
||||
"""
|
||||
Saves the model for the server if it's the best one.
|
||||
|
||||
Args:
|
||||
state (dict): Model's state dictionary.
|
||||
best_model (bool): Flag indicating if the current model is the best.
|
||||
directory (str): Directory where the model is saved.
|
||||
"""
|
||||
if best_model:
|
||||
print("Saving the best server model")
|
||||
torch.save(state, os.path.join(directory, 'server_best.pth'))
|
||||
|
||||
class MetricTracker(object):
|
||||
"""
|
||||
A helper class to track and compute the average of a given metric.
|
||||
|
||||
Args:
|
||||
metric_name (str): Name of the metric to track.
|
||||
fmt (str): Format for printing metric values (default ':f').
|
||||
"""
|
||||
def __init__(self, metric_name, fmt=':f'):
|
||||
self.metric_name = metric_name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Resets all metric counters."""
|
||||
self.current_value = 0
|
||||
self.total_sum = 0
|
||||
self.count = 0
|
||||
self.average = 0
|
||||
|
||||
def update(self, value, n=1):
|
||||
"""
|
||||
Updates the metric value.
|
||||
|
||||
Args:
|
||||
value (float): New value of the metric.
|
||||
n (int): Weight or count for the value (default 1).
|
||||
"""
|
||||
self.current_value = value
|
||||
self.total_sum += value * n
|
||||
self.count += n
|
||||
self.average = self.total_sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
"""Returns the formatted metric string showing current value and average."""
|
||||
return f'{self.metric_name} {self.current_value{self.fmt}} ({self.average{self.fmt}})'
|
||||
|
||||
class ProgressLogger(object):
|
||||
"""
|
||||
A class to log and display the progress of training/testing over multiple batches.
|
||||
|
||||
Args:
|
||||
total_batches (int): Total number of batches.
|
||||
*metrics (MetricTracker): Metrics to log during the process.
|
||||
prefix (str): Prefix for the progress log (default "Progress:").
|
||||
"""
|
||||
def __init__(self, total_batches, *metrics, prefix="Progress:"):
|
||||
self.batch_format = self._get_batch_format(total_batches)
|
||||
self.metrics = metrics
|
||||
self.prefix = prefix
|
||||
|
||||
def log(self, batch_idx):
|
||||
"""
|
||||
Logs the current progress of training/testing.
|
||||
|
||||
Args:
|
||||
batch_idx (int): The current batch index.
|
||||
"""
|
||||
output = [self.prefix + self.batch_format.format(batch_idx)]
|
||||
output += [str(metric) for metric in self.metrics]
|
||||
print(' | '.join(output))
|
||||
|
||||
def _get_batch_format(self, total_batches):
|
||||
"""Creates a format string to display the batch index."""
|
||||
num_digits = len(str(total_batches))
|
||||
return '[{:' + str(num_digits) + 'd}/{}]'.format(total_batches)
|
||||
|
||||
def compute_accuracy(prediction, target, top_k=(1,)):
|
||||
"""
|
||||
Computes the accuracy for the top-k predictions.
|
||||
|
||||
Args:
|
||||
prediction (Tensor): Model predictions.
|
||||
target (Tensor): Ground truth labels.
|
||||
top_k (tuple): Tuple of top-k values to consider for accuracy (default (1,)).
|
||||
|
||||
Returns:
|
||||
List[Tensor]: List of accuracies for each top-k value.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
max_k = max(top_k)
|
||||
batch_size = target.size(0)
|
||||
|
||||
# Get the top-k predictions
|
||||
_, top_predictions = prediction.topk(max_k, 1, largest=True, sorted=True)
|
||||
top_predictions = top_predictions.t()
|
||||
|
||||
# Compare top-k predictions with targets
|
||||
correct_predictions = top_predictions.eq(target.view(1, -1).expand_as(top_predictions))
|
||||
|
||||
accuracy_results = []
|
||||
for k in top_k:
|
||||
# Count the number of correct predictions within the top-k
|
||||
correct_k = correct_predictions[:k].view(-1).float().sum(0, keepdim=True)
|
||||
accuracy_results.append(correct_k.mul_(100.0 / batch_size))
|
||||
return accuracy_results
|
||||
|
||||
def count_model_parameters(model, trainable_only=False):
|
||||
"""
|
||||
Counts the total number of parameters in the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The PyTorch model.
|
||||
trainable_only (bool): Whether to count only trainable parameters (default False).
|
||||
|
||||
Returns:
|
||||
int: Total number of parameters in the model.
|
||||
"""
|
||||
if trainable_only:
|
||||
# Count only the parameters that require gradients (trainable parameters)
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
# Count all parameters (trainable and non-trainable)
|
||||
return sum(p.numel() for p in model.parameters())
|
131
EdgeFLite/helpers/normalization.py
Normal file
131
EdgeFLite/helpers/normalization.py
Normal file
@ -0,0 +1,131 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class PassThrough(nn.Module):
|
||||
"""
|
||||
A placeholder module that simply returns the input tensor unchanged.
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super(PassThrough, self).__init__()
|
||||
|
||||
def forward(self, input_tensor):
|
||||
return input_tensor
|
||||
|
||||
|
||||
class LayerNormalization2D(nn.Module):
|
||||
"""
|
||||
A custom layer normalization module for 2D inputs (typically used for
|
||||
convolutional layers). It optionally applies learned scaling (weight)
|
||||
and shifting (bias) parameters.
|
||||
|
||||
Arguments:
|
||||
epsilon: A small value to avoid division by zero.
|
||||
use_weight: Whether to learn and apply weight parameters.
|
||||
use_bias: Whether to learn and apply bias parameters.
|
||||
"""
|
||||
def __init__(self, epsilon=1e-05, use_weight=True, use_bias=True, **kwargs):
|
||||
super(LayerNormalization2D, self).__init__()
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.use_weight = use_weight
|
||||
self.use_bias = use_bias
|
||||
|
||||
def forward(self, input_tensor):
|
||||
# Initialize weight and bias parameters if they are not nn.Parameter instances
|
||||
if (not isinstance(self.use_weight, nn.parameter.Parameter) and
|
||||
not isinstance(self.use_bias, nn.parameter.Parameter) and
|
||||
(self.use_weight or self.use_bias)):
|
||||
self._initialize_parameters(input_tensor)
|
||||
|
||||
# Apply layer normalization
|
||||
return nn.functional.layer_norm(input_tensor, input_tensor.shape[1:],
|
||||
weight=self.use_weight, bias=self.use_bias,
|
||||
eps=self.epsilon)
|
||||
|
||||
def _initialize_parameters(self, input_tensor):
|
||||
"""
|
||||
Initialize weight and bias parameters for layer normalization.
|
||||
Arguments:
|
||||
input_tensor: The input tensor to the normalization layer.
|
||||
"""
|
||||
channels, height, width = input_tensor.shape[1:]
|
||||
param_shape = [channels, height, width]
|
||||
|
||||
# Initialize weight parameter if applicable
|
||||
if self.use_weight:
|
||||
self.use_weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||
else:
|
||||
self.register_parameter('use_weight', None)
|
||||
|
||||
# Initialize bias parameter if applicable
|
||||
if self.use_bias:
|
||||
self.use_bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
||||
else:
|
||||
self.register_parameter('use_bias', None)
|
||||
|
||||
|
||||
class NormalizationLayer(nn.Module):
|
||||
"""
|
||||
A flexible normalization layer that supports different types of normalization
|
||||
(batch, group, layer, instance, or none). This class is a wrapper that selects
|
||||
the appropriate normalization technique based on the norm_type argument.
|
||||
|
||||
Arguments:
|
||||
norm_type: The type of normalization to apply ('batch', 'group', 'layer', 'instance', or 'none').
|
||||
epsilon: A small value to avoid division by zero (Default: 1e-05).
|
||||
momentum: Momentum for updating running statistics (Default: 0.1, applicable for batch norm).
|
||||
use_weight: Whether to learn weight parameters (Default: True).
|
||||
use_bias: Whether to learn bias parameters (Default: True).
|
||||
track_stats: Whether to track running statistics (Default: True, applicable for batch norm).
|
||||
group_norm_groups: Number of groups to use for group normalization (Default: 32).
|
||||
"""
|
||||
def __init__(self, norm_type='batch', epsilon=1e-05, momentum=0.1,
|
||||
use_weight=True, use_bias=True, track_stats=True, group_norm_groups=32, **kwargs):
|
||||
super(NormalizationLayer, self).__init__()
|
||||
|
||||
if norm_type not in ['batch', 'group', 'layer', 'instance', 'none']:
|
||||
raise ValueError('Unsupported norm_type: {}. Supported options: '
|
||||
'"batch" | "group" | "layer" | "instance" | "none".'.format(norm_type))
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.epsilon = epsilon
|
||||
self.momentum = momentum
|
||||
self.use_weight = use_weight
|
||||
self.use_bias = use_bias
|
||||
self.affine = self.use_weight and self.use_bias # Check if affine apply_transformationation is needed
|
||||
self.track_stats = track_stats
|
||||
self.group_norm_groups = group_norm_groups
|
||||
|
||||
def forward(self, num_features):
|
||||
"""
|
||||
Select and apply the appropriate normalization technique based on the norm_type.
|
||||
|
||||
Arguments:
|
||||
num_features: The number of input channels or features.
|
||||
Returns:
|
||||
A normalization layer corresponding to the norm_type.
|
||||
"""
|
||||
if self.norm_type == 'batch':
|
||||
# Apply Batch Normalization
|
||||
normalizer = nn.BatchNorm2d(num_features=num_features, eps=self.epsilon,
|
||||
momentum=self.momentum, affine=self.affine,
|
||||
track_running_stats=self.track_stats)
|
||||
elif self.norm_type == 'group':
|
||||
# Apply Group Normalization
|
||||
normalizer = nn.GroupNorm(self.group_norm_groups, num_features,
|
||||
eps=self.epsilon, affine=self.affine)
|
||||
elif self.norm_type == 'layer':
|
||||
# Apply Layer Normalization
|
||||
normalizer = LayerNormalization2D(epsilon=self.epsilon, use_weight=self.use_weight, use_bias=self.use_bias)
|
||||
elif self.norm_type == 'instance':
|
||||
# Apply Instance Normalization
|
||||
normalizer = nn.InstanceNorm2d(num_features, eps=self.epsilon, affine=self.affine)
|
||||
else:
|
||||
# No normalization applied, just pass the input through
|
||||
normalizer = PassThrough()
|
||||
|
||||
return normalizer
|
129
EdgeFLite/helpers/optimizer_rmsprop.py
Normal file
129
EdgeFLite/helpers/optimizer_rmsprop.py
Normal file
@ -0,0 +1,129 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class CustomRMSprop(Optimizer):
|
||||
"""
|
||||
Implements a modified version of the RMSprop algorithm with TensorFlow-style epsilon handling.
|
||||
|
||||
Main differences in this implementation:
|
||||
1. Epsilon is incorporated within the square root operation.
|
||||
2. The moving average of squared gradients is initialized to 1.
|
||||
3. The momentum buffer accumulates updates scaled by the learning rate.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=0.01, alpha=0.99, eps=1e-8, momentum=0, weight_decay=0, centered=False, decoupled_decay=False, lr_in_momentum=True):
|
||||
"""
|
||||
Initializes the optimizer with the provided parameters.
|
||||
|
||||
Arguments:
|
||||
- params: iterable of parameters to optimize or dicts defining parameter groups
|
||||
- lr: learning rate (default: 0.01)
|
||||
- alpha: smoothing constant for the moving average (default: 0.99)
|
||||
- eps: small value to prevent division by zero (default: 1e-8)
|
||||
- momentum: momentum factor (default: 0)
|
||||
- weight_decay: weight decay (L2 penalty) (default: 0)
|
||||
- centered: if True, compute centered RMSprop (default: False)
|
||||
- decoupled_decay: if True, decouples weight decay from gradient update (default: False)
|
||||
- lr_in_momentum: if True, applies learning rate within the momentum buffer (default: True)
|
||||
"""
|
||||
if lr < 0.0:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if eps < 0.0:
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if momentum < 0.0:
|
||||
raise ValueError(f"Invalid momentum value: {momentum}")
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError(f"Invalid weight decay: {weight_decay}")
|
||||
if alpha < 0.0:
|
||||
raise ValueError(f"Invalid alpha value: {alpha}")
|
||||
|
||||
# Store the optimizer defaults
|
||||
defaults = {
|
||||
'lr': lr,
|
||||
'alpha': alpha,
|
||||
'eps': eps,
|
||||
'momentum': momentum,
|
||||
'centered': centered,
|
||||
'weight_decay': weight_decay,
|
||||
'decoupled_decay': decoupled_decay,
|
||||
'lr_in_momentum': lr_in_momentum
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
- closure: A closure that reevaluates the model and returns the loss.
|
||||
"""
|
||||
# Get the loss value if a closure is provided
|
||||
loss = closure() if closure is not None else None
|
||||
|
||||
# Iterate over parameter groups
|
||||
for group in self.param_groups:
|
||||
lr = group['lr']
|
||||
momentum = group['momentum']
|
||||
weight_decay = group['weight_decay']
|
||||
alpha = group['alpha']
|
||||
eps = group['eps']
|
||||
|
||||
# Iterate over parameters in the group
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data # Get gradient data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("RMSprop does not support sparse gradients.")
|
||||
|
||||
# Get the state of the parameter
|
||||
state = self.state[p]
|
||||
|
||||
# Initialize state if it doesn't exist
|
||||
if not state:
|
||||
state['step'] = 0
|
||||
state['square_avg'] = torch.ones_like(p.data) # Initialize moving average of squared gradients to 1
|
||||
if momentum > 0:
|
||||
state['momentum_buffer'] = torch.zeros_like(p.data) # Initialize momentum buffer
|
||||
if group['centered']:
|
||||
state['grad_avg'] = torch.zeros_like(p.data) # Initialize moving average of gradients if centered
|
||||
|
||||
square_avg = state['square_avg']
|
||||
one_minus_alpha = 1 - alpha
|
||||
state['step'] += 1 # Update the step count
|
||||
|
||||
# Apply weight decay
|
||||
if weight_decay != 0:
|
||||
if group['decoupled_decay']:
|
||||
p.data.mul_(1 - lr * weight_decay) # Apply decoupled weight decay
|
||||
else:
|
||||
grad.add_(p.data, alpha=weight_decay) # Apply traditional weight decay
|
||||
|
||||
# Update the moving average of squared gradients
|
||||
square_avg.add_((grad ** 2) - square_avg, alpha=one_minus_alpha)
|
||||
|
||||
# Compute the denominator for gradient update
|
||||
if group['centered']:
|
||||
grad_avg = state['grad_avg']
|
||||
grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
|
||||
avg = (square_avg - grad_avg ** 2).add_(eps).sqrt_() # Centered RMSprop
|
||||
else:
|
||||
avg = square_avg.add_(eps).sqrt_() # Standard RMSprop
|
||||
|
||||
# Apply momentum if needed
|
||||
if momentum > 0:
|
||||
buf = state['momentum_buffer']
|
||||
if group['lr_in_momentum']:
|
||||
buf.mul_(momentum).addcdiv_(grad, avg, value=lr) # Apply learning rate inside momentum buffer
|
||||
p.data.add_(-buf)
|
||||
else:
|
||||
buf.mul_(momentum).addcdiv_(grad, avg) # Standard momentum update
|
||||
p.data.add_(buf, alpha=-lr)
|
||||
else:
|
||||
p.data.addcdiv_(grad, avg, value=-lr) # Update parameter without momentum
|
||||
|
||||
return loss # Return the loss if closure was provided
|
146
EdgeFLite/helpers/pace_controller.py
Normal file
146
EdgeFLite/helpers/pace_controller.py
Normal file
@ -0,0 +1,146 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import math
|
||||
|
||||
class CustomScheduler:
|
||||
def __init__(self, mode='cosine',
|
||||
initial_lr=0.1,
|
||||
num_epochs=100,
|
||||
iters_per_epoch=300,
|
||||
lr_milestones=None,
|
||||
lr_step=100,
|
||||
step_multiplier=0.1,
|
||||
slow_start_epochs=0,
|
||||
slow_start_lr=1e-4,
|
||||
min_lr=1e-3,
|
||||
multiplier=1.0,
|
||||
lower_bound=-6.0,
|
||||
upper_bound=3.0,
|
||||
decay_factor=0.97,
|
||||
decay_epochs=0.8,
|
||||
staircase=True):
|
||||
"""
|
||||
Initialize the learning rate scheduler.
|
||||
|
||||
Parameters:
|
||||
mode (str): Mode for learning rate adjustment ('cosine', 'poly', 'HTD', 'step', 'exponential').
|
||||
initial_lr (float): Initial learning rate.
|
||||
num_epochs (int): Total number of epochs.
|
||||
iters_per_epoch (int): Number of iterations per epoch.
|
||||
lr_milestones (list): Epoch milestones for learning rate decay in 'step' mode.
|
||||
lr_step (int): Epoch step size for learning rate reduction in 'step' mode.
|
||||
step_multiplier (float): Multiplication factor for learning rate reduction in 'step' mode.
|
||||
slow_start_epochs (int): Number of slow start epochs for warm-up.
|
||||
slow_start_lr (float): Learning rate during warm-up.
|
||||
min_lr (float): Minimum learning rate limit.
|
||||
multiplier (float): Multiplication factor for applying to different parameter groups.
|
||||
lower_bound (float): Lower bound for the tanh function in 'HTD' mode.
|
||||
upper_bound (float): Upper bound for the tanh function in 'HTD' mode.
|
||||
decay_factor (float): Factor by which learning rate decays in 'exponential' mode.
|
||||
decay_epochs (float): Number of epochs over which learning rate decays in 'exponential' mode.
|
||||
staircase (bool): If True, apply step-wise learning rate decay in 'exponential' mode.
|
||||
"""
|
||||
# Ensure valid mode selection
|
||||
assert mode in ['cosine', 'poly', 'HTD', 'step', 'exponential'], "Invalid mode."
|
||||
|
||||
# Initialize learning rate settings
|
||||
self.initial_lr = initial_lr
|
||||
self.current_lr = initial_lr
|
||||
self.min_lr = min_lr
|
||||
self.mode = mode
|
||||
self.num_epochs = num_epochs
|
||||
self.iters_per_epoch = iters_per_epoch
|
||||
self.total_iterations = (num_epochs - slow_start_epochs) * iters_per_epoch
|
||||
self.slow_start_iters = slow_start_epochs * iters_per_epoch
|
||||
self.slow_start_lr = slow_start_lr
|
||||
self.multiplier = multiplier
|
||||
self.lr_step = lr_step
|
||||
self.lr_milestones = lr_milestones
|
||||
self.step_multiplier = step_multiplier
|
||||
self.lower_bound = lower_bound
|
||||
self.upper_bound = upper_bound
|
||||
self.decay_factor = decay_factor
|
||||
self.decay_steps = decay_epochs * iters_per_epoch
|
||||
self.staircase = staircase
|
||||
|
||||
print(f"INFO: Using {self.mode} learning rate scheduler with {slow_start_epochs} warm-up epochs.")
|
||||
|
||||
def update_lr(self, optimizer, iteration, epoch):
|
||||
"""Update the learning rate based on the current iteration and epoch."""
|
||||
current_iter = epoch * self.iters_per_epoch + iteration
|
||||
|
||||
# During slow start, linearly increase the learning rate
|
||||
if current_iter <= self.slow_start_iters:
|
||||
lr = self.slow_start_lr + (self.initial_lr - self.slow_start_lr) * (current_iter / self.slow_start_iters)
|
||||
else:
|
||||
# After slow start, calculate learning rate based on the selected mode
|
||||
lr = self._calculate_lr(current_iter - self.slow_start_iters)
|
||||
|
||||
# Ensure learning rate does not fall below the minimum limit
|
||||
self.current_lr = max(lr, self.min_lr)
|
||||
self._apply_lr(optimizer, self.current_lr)
|
||||
|
||||
def _calculate_lr(self, adjusted_iter):
|
||||
"""Calculate the learning rate based on the selected scheduling mode."""
|
||||
if self.mode == 'cosine':
|
||||
# Cosine annealing schedule
|
||||
return 0.5 * self.initial_lr * (1 + math.cos(math.pi * adjusted_iter / self.total_iterations))
|
||||
elif self.mode == 'poly':
|
||||
# Polynomial decay schedule
|
||||
return self.initial_lr * (1 - adjusted_iter / self.total_iterations) ** 0.9
|
||||
elif self.mode == 'HTD':
|
||||
# Hyperbolic tangent decay schedule
|
||||
ratio = adjusted_iter / self.total_iterations
|
||||
return 0.5 * self.initial_lr * (1 - math.tanh(self.lower_bound + (self.upper_bound - self.lower_bound) * ratio))
|
||||
elif self.mode == 'step':
|
||||
# Step decay schedule
|
||||
return self._step_lr(adjusted_iter)
|
||||
elif self.mode == 'exponential':
|
||||
# Exponential decay schedule
|
||||
power = math.floor(adjusted_iter / self.decay_steps) if self.staircase else adjusted_iter / self.decay_steps
|
||||
return self.initial_lr * (self.decay_factor ** power)
|
||||
else:
|
||||
raise NotImplementedError("Unknown learning rate mode.")
|
||||
|
||||
def _step_lr(self, adjusted_iter):
|
||||
"""Calculate the learning rate for the 'step' mode."""
|
||||
epoch = adjusted_iter // self.iters_per_epoch
|
||||
# Count how many milestones or steps have passed
|
||||
if self.lr_milestones:
|
||||
num_steps = sum([1 for milestone in self.lr_milestones if epoch >= milestone])
|
||||
else:
|
||||
num_steps = epoch // self.lr_step
|
||||
return self.initial_lr * (self.step_multiplier ** num_steps)
|
||||
|
||||
def _apply_lr(self, optimizer, lr):
|
||||
"""Apply the calculated learning rate to the optimizer."""
|
||||
for i, param_group in enumerate(optimizer.param_groups):
|
||||
# Apply multiplier to parameter groups beyond the first one
|
||||
param_group['lr'] = lr * (self.multiplier if i > 1 else 1.0)
|
||||
|
||||
|
||||
def adjust_hyperparameters(args):
|
||||
"""Adjust the learning rate and momentum based on the batch size."""
|
||||
print(f'Adjusting LR and momentum. Original LR: {args.lr}, Original momentum: {args.momentum}')
|
||||
# Set standard batch size for scaling
|
||||
standard_batch_size = 128 if 'cifar' in args.dataset else NotImplementedError
|
||||
# Scale momentum and learning rate
|
||||
args.momentum = args.momentum ** (args.batch_size / standard_batch_size)
|
||||
args.lr *= (args.batch_size / standard_batch_size)
|
||||
print(f'Adjusted LR: {args.lr}, Adjusted momentum: {args.momentum}')
|
||||
return args
|
||||
|
||||
|
||||
def separate_parameters(model, weight_decay_for_norm=0):
|
||||
"""Separate the model parameters into two groups: regular parameters and norm-based parameters."""
|
||||
regular_params, norm_params = [], []
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
# Parameters related to normalization and biases are treated separately
|
||||
if 'norm' in name or 'bias' in name:
|
||||
norm_params.append(param)
|
||||
else:
|
||||
regular_params.append(param)
|
||||
# Return parameter groups with corresponding weight decay for norm parameters
|
||||
return [{'params': regular_params}, {'params': norm_params, 'weight_decay': weight_decay_for_norm}]
|
39
EdgeFLite/helpers/preloader_module.py
Normal file
39
EdgeFLite/helpers/preloader_module.py
Normal file
@ -0,0 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
|
||||
class DataPrefetcher:
|
||||
def __init__(self, dataloader):
|
||||
# Initialize with the dataloader and create an iterator
|
||||
self.dataloader = iter(dataloader)
|
||||
# Create a CUDA stream for asynchronous data transfer
|
||||
self.cuda_stream = torch.cuda.Stream()
|
||||
# Load the next batch of data
|
||||
self._load_next_batch()
|
||||
|
||||
def _load_next_batch(self):
|
||||
try:
|
||||
# Fetch the next batch from the dataloader iterator
|
||||
self.batch_input, self.batch_target = next(self.dataloader)
|
||||
except StopIteration:
|
||||
# If no more data, set inputs and targets to None
|
||||
self.batch_input, self.batch_target = None, None
|
||||
return
|
||||
|
||||
# Transfer data to GPU asynchronously using the created CUDA stream
|
||||
with torch.cuda.stream(self.cuda_stream):
|
||||
self.batch_input = self.batch_input.cuda(non_blocking=True)
|
||||
self.batch_target = self.batch_target.cuda(non_blocking=True)
|
||||
|
||||
def get_next_batch(self):
|
||||
# Synchronize the current stream with the prefetching stream to ensure data is ready
|
||||
torch.cuda.current_stream().wait_stream(self.cuda_stream)
|
||||
|
||||
# Return the preloaded batch of input and target data
|
||||
current_input, current_target = self.batch_input, self.batch_target
|
||||
|
||||
# Preload the next batch in the background while the current batch is processed
|
||||
self._load_next_batch()
|
||||
|
||||
return current_input, current_target
|
186
EdgeFLite/helpers/report_summary.py
Normal file
186
EdgeFLite/helpers/report_summary.py
Normal file
@ -0,0 +1,186 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
__all__ = ['model_summary']
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
# Format FLOPs value with appropriate unit (T, G, M, K)
|
||||
def format_flops(flops):
|
||||
units = [(1e12, 'T'), (1e9, 'G'), (1e6, 'M'), (1e3, 'K')]
|
||||
for scale, suffix in units:
|
||||
if flops >= scale:
|
||||
return f"{flops / scale:.1f}{suffix}"
|
||||
return f"{flops:.1f}"
|
||||
|
||||
# Calculate the number of trainable or non-trainable parameters
|
||||
def calculate_grad_params(param_count, param):
|
||||
if param.requires_grad:
|
||||
return param_count, 0
|
||||
else:
|
||||
return 0, param_count
|
||||
|
||||
# Compute FLOPs and parameters for a convolutional layer
|
||||
def compute_conv_flops(layer, input, output):
|
||||
oh, ow = output.shape[-2:] # Output height and width
|
||||
kh, kw = layer.kernel_size # Kernel height and width
|
||||
ic, oc = layer.in_channels, layer.out_channels # Input/output channels
|
||||
groups = layer.groups # Number of groups for grouped convolution
|
||||
|
||||
total_trainable = 0
|
||||
total_non_trainable = 0
|
||||
flops = 0
|
||||
|
||||
# Compute parameters and FLOPs for the weight
|
||||
if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'):
|
||||
param_count = np.prod(layer.weight.shape)
|
||||
trainable, non_trainable = calculate_grad_params(param_count, layer.weight)
|
||||
total_trainable += trainable
|
||||
total_non_trainable += non_trainable
|
||||
flops += (2 * ic * kh * kw - 1) * oh * ow * (oc // groups)
|
||||
|
||||
# Compute parameters and FLOPs for the bias
|
||||
if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'):
|
||||
param_count = np.prod(layer.bias.shape)
|
||||
trainable, non_trainable = calculate_grad_params(param_count, layer.bias)
|
||||
total_trainable += trainable
|
||||
total_non_trainable += non_trainable
|
||||
flops += oh * ow * (oc // groups)
|
||||
|
||||
return total_trainable, total_non_trainable, flops
|
||||
|
||||
# Compute FLOPs and parameters for normalization layers (BatchNorm, GroupNorm)
|
||||
def compute_norm_flops(layer, input, output):
|
||||
total_trainable = 0
|
||||
total_non_trainable = 0
|
||||
|
||||
if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'):
|
||||
param_count = np.prod(layer.weight.shape)
|
||||
trainable, non_trainable = calculate_grad_params(param_count, layer.weight)
|
||||
total_trainable += trainable
|
||||
total_non_trainable += non_trainable
|
||||
|
||||
if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'):
|
||||
param_count = np.prod(layer.bias.shape)
|
||||
trainable, non_trainable = calculate_grad_params(param_count, layer.bias)
|
||||
total_trainable += trainable
|
||||
total_non_trainable += non_trainable
|
||||
|
||||
if hasattr(layer, 'running_mean'):
|
||||
total_non_trainable += np.prod(layer.running_mean.shape)
|
||||
|
||||
if hasattr(layer, 'running_var'):
|
||||
total_non_trainable += np.prod(layer.running_var.shape)
|
||||
|
||||
# FLOPs for normalization operations
|
||||
flops = np.prod(input[0].shape)
|
||||
if layer.affine:
|
||||
flops *= 2
|
||||
|
||||
return total_trainable, total_non_trainable, flops
|
||||
|
||||
# Compute FLOPs and parameters for linear (fully connected) layers
|
||||
def compute_linear_flops(layer, input, output):
|
||||
ic, oc = layer.in_features, layer.out_features # Input/output features
|
||||
total_trainable = 0
|
||||
total_non_trainable = 0
|
||||
flops = 0
|
||||
|
||||
# Compute parameters and FLOPs for the weight
|
||||
if hasattr(layer, 'weight') and hasattr(layer.weight, 'shape'):
|
||||
param_count = np.prod(layer.weight.shape)
|
||||
trainable, non_trainable = calculate_grad_params(param_count, layer.weight)
|
||||
total_trainable += trainable
|
||||
total_non_trainable += non_trainable
|
||||
flops += (2 * ic - 1) * oc
|
||||
|
||||
# Compute parameters and FLOPs for the bias
|
||||
if hasattr(layer, 'bias') and hasattr(layer.bias, 'shape'):
|
||||
param_count = np.prod(layer.bias.shape)
|
||||
trainable, non_trainable = calculate_grad_params(param_count, layer.bias)
|
||||
total_trainable += trainable
|
||||
total_non_trainable += non_trainable
|
||||
flops += oc
|
||||
|
||||
return total_trainable, total_non_trainable, flops
|
||||
|
||||
# Model summary function: calculates the total parameters and FLOPs for a model
|
||||
@torch.no_grad()
|
||||
def model_summary(model, input_data, target_data=None, is_coremodel=True, return_data=False):
|
||||
model.eval()
|
||||
|
||||
summary_info = OrderedDict()
|
||||
hooks = []
|
||||
|
||||
# Hook function to register layer and compute its parameters/FLOPs
|
||||
def register_layer_hook(layer):
|
||||
def hook(layer, input, output):
|
||||
layer_name = f"{layer.__class__.__name__}-{len(summary_info) + 1}"
|
||||
summary_info[layer_name] = OrderedDict()
|
||||
summary_info[layer_name]['input_shape'] = list(input[0].shape)
|
||||
summary_info[layer_name]['output_shape'] = list(output.shape) if not isinstance(output, (list, tuple)) else [list(o.shape) for o in output]
|
||||
|
||||
if isinstance(layer, nn.Conv2d):
|
||||
trainable, non_trainable, flops = compute_conv_flops(layer, input, output)
|
||||
elif isinstance(layer, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
trainable, non_trainable, flops = compute_norm_flops(layer, input, output)
|
||||
elif isinstance(layer, nn.Linear):
|
||||
trainable, non_trainable, flops = compute_linear_flops(layer, input, output)
|
||||
else:
|
||||
trainable, non_trainable, flops = 0, 0, 0
|
||||
|
||||
summary_info[layer_name]['trainable_params'] = trainable
|
||||
summary_info[layer_name]['non_trainable_params'] = non_trainable
|
||||
summary_info[layer_name]['total_params'] = trainable + non_trainable
|
||||
summary_info[layer_name]['flops'] = flops
|
||||
|
||||
if not isinstance(layer, (nn.Sequential, nn.ModuleList, nn.Identity)):
|
||||
hooks.append(layer.register_forward_hook(hook))
|
||||
|
||||
model.apply(register_layer_hook)
|
||||
|
||||
if is_coremodel:
|
||||
model(input_data, target=target_data, mode='summary')
|
||||
else:
|
||||
model(input_data)
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
total_params, trainable_params, total_flops = 0, 0, 0
|
||||
for layer_name, layer_info in summary_info.items():
|
||||
total_params += layer_info['total_params']
|
||||
trainable_params += layer_info['trainable_params']
|
||||
total_flops += layer_info['flops']
|
||||
|
||||
param_size_mb = total_params * 4 / (1024 ** 2)
|
||||
print(f"Total parameters: {total_params:,} ({format_flops(total_params)})")
|
||||
print(f"Trainable parameters: {trainable_params:,}")
|
||||
print(f"Non-trainable parameters: {total_params - trainable_params:,}")
|
||||
print(f"Total FLOPs: {total_flops:,} ({format_flops(total_flops)})")
|
||||
print(f"Model size: {param_size_mb:.2f} MB")
|
||||
|
||||
if return_data:
|
||||
return total_params, total_flops
|
||||
|
||||
# Example usage with a convolutional layer
|
||||
if __name__ == '__main__':
|
||||
conv_layer = nn.Conv2d(50, 10, 3, padding=1, groups=5, bias=True)
|
||||
model_summary(conv_layer, torch.rand((1, 50, 10, 10)), target_data=torch.ones(1, dtype=torch.long), is_coremodel=False)
|
||||
|
||||
for name, param in conv_layer.named_parameters():
|
||||
print(f"{name}: {param.size()}")
|
||||
|
||||
# Save the model's summary details as a JSON file
|
||||
def save_model_as_json(args, model_content):
|
||||
"""Save the model's details to a JSON file."""
|
||||
os.makedirs(args.model_dir, exist_ok=True)
|
||||
filename = os.path.join(args.model_dir, f"model_{args.split_factor}.txt")
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
f.write(str(model_content))
|
48
EdgeFLite/helpers/smoothing_labels.py
Normal file
48
EdgeFLite/helpers/smoothing_labels.py
Normal file
@ -0,0 +1,48 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Define the SmoothEntropyLoss class, which inherits from nn.Module
|
||||
class SmoothEntropyLoss(nn.Module):
|
||||
def __init__(self, smoothing=0.1, reduction='mean'):
|
||||
# Initialize the parent class (nn.Module) and set the smoothing factor and reduction method
|
||||
super(SmoothEntropyLoss, self).__init__()
|
||||
self.smoothing = smoothing # Label smoothing factor
|
||||
self.reduction_method = reduction # Reduction method to apply to the loss
|
||||
|
||||
def forward(self, predictions, targets):
|
||||
# Ensure that the batch sizes of predictions and targets match
|
||||
if predictions.shape[0] != targets.shape[0]:
|
||||
raise ValueError(f"Batch size of predictions ({predictions.shape[0]}) does not match targets ({targets.shape[0]}).")
|
||||
|
||||
# Ensure that the predictions tensor has at least 2 dimensions (batch_size x num_classes)
|
||||
if predictions.dim() < 2:
|
||||
raise ValueError(f"Predictions should have at least 2 dimensions, got {predictions.dim()}.")
|
||||
|
||||
# Get the number of classes from the last dimension of predictions (num_classes)
|
||||
num_classes = predictions.size(-1)
|
||||
|
||||
# Convert targets (class indices) to one-hot encoded format
|
||||
target_one_hot = F.one_hot(targets, num_classes=num_classes).type_as(predictions)
|
||||
|
||||
# Apply label smoothing: smooth the one-hot encoded targets by distributing some probability mass across all classes
|
||||
smooth_targets = target_one_hot * (1.0 - self.smoothing) + (self.smoothing / num_classes)
|
||||
|
||||
# Compute the log probabilities of predictions using softmax (log-softmax for numerical stability)
|
||||
log_probabilities = F.log_softmax(predictions, dim=-1)
|
||||
|
||||
# Compute the per-sample loss by multiplying log probabilities with the smoothed targets and summing across classes
|
||||
loss_per_sample = -torch.sum(log_probabilities * smooth_targets, dim=-1)
|
||||
|
||||
# Apply the specified reduction method to the computed loss
|
||||
if self.reduction_method == 'none':
|
||||
return loss_per_sample # Return the unreduced loss for each sample
|
||||
elif self.reduction_method == 'sum':
|
||||
return torch.sum(loss_per_sample) # Return the sum of the losses over all samples
|
||||
elif self.reduction_method == 'mean':
|
||||
return torch.mean(loss_per_sample) # Return the mean loss over all samples
|
||||
else:
|
||||
raise ValueError(f"Invalid reduction option: {self.reduction_method}. Expected 'none', 'sum', or 'mean'.")
|
68
EdgeFLite/info_map.csv
Normal file
68
EdgeFLite/info_map.csv
Normal file
@ -0,0 +1,68 @@
|
||||
import pandas as pd
|
||||
import os
|
||||
from glob import glob
|
||||
from PIL import Image
|
||||
import torch
|
||||
from sklearn.model_selection import train_test_split
|
||||
import pickle
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch import nn
|
||||
from torchvision import apply_transformations
|
||||
|
||||
# Loading the info_mapdata for the skin_dataset dataset
|
||||
info_mapdata = pd.read_csv('dataset_hub/skin_dataset/data/skin_info_map.csv')
|
||||
print(info_mapdata.head())
|
||||
|
||||
# Mapping lesion abbreviations to their full names
|
||||
lesion_labels = {
|
||||
'nv': 'Melanocytic nevi',
|
||||
'mel': 'Melanoma',
|
||||
'bkl': 'Benign keratosis-like lesions',
|
||||
'bcc': 'Basal cell carcinoma',
|
||||
'akiec': 'Actinic keratoses',
|
||||
'vasc': 'Vascular lesions',
|
||||
'df': 'Dermatofibroma'
|
||||
}
|
||||
|
||||
# Combine images from both dataset parts into one dictionary
|
||||
image_paths = {os.path.splitext(os.path.basename(img))[0]: img
|
||||
for img in glob(os.path.join("dataset_hub/skin_dataset/data", '*', '*.jpg'))}
|
||||
|
||||
# Mapping the image paths and cell types to the DataFrame
|
||||
info_mapdata['image_path'] = info_mapdata['image_id'].map(image_paths.get)
|
||||
info_mapdata['cell_type'] = info_mapdata['dx'].map(lesion_labels.get)
|
||||
info_mapdata['label'] = pd.Categorical(info_mapdata['cell_type']).workspaces
|
||||
|
||||
# Display the count of each cell type and their enworkspaced labels
|
||||
print(info_mapdata['cell_type'].value_counts())
|
||||
print(info_mapdata['label'].value_counts())
|
||||
|
||||
# Custom Dataset class for PyTorch
|
||||
class SkinDataset(Dataset):
|
||||
def __init__(self, dataframe, apply_transformation=None):
|
||||
self.dataframe = dataframe
|
||||
self.apply_transformation = apply_transformation
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataframe)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img = Image.open(self.dataframe.loc[idx, 'image_path']).resize((64, 64))
|
||||
label = torch.tensor(self.dataframe.loc[idx, 'label'], dtype=torch.long)
|
||||
|
||||
if self.apply_transformation:
|
||||
img = self.apply_transformation(img)
|
||||
|
||||
return img, label
|
||||
|
||||
# Splitting the data into train and test sets
|
||||
train_data, test_data = train_test_split(info_mapdata, test_size=0.2, random_state=42)
|
||||
train_data = train_data.reset_index(drop=True)
|
||||
test_data = test_data.reset_index(drop=True)
|
||||
|
||||
# Save the train and test data to pickle files
|
||||
with open("skin_dataset_train.pkl", "wb") as train_file:
|
||||
pickle.dump(train_data, train_file)
|
||||
|
||||
with open("skin_dataset_test.pkl", "wb") as test_file:
|
||||
pickle.dump(test_data, test_file)
|
Can't render this file because it has a wrong number of fields in line 8.
|
47
EdgeFLite/process_data.py
Normal file
47
EdgeFLite/process_data.py
Normal file
@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
|
||||
# Define paths to the training and testing datasets
|
||||
train_data_path = '/media/skydata/alpha0012/workspace/EdgeFLite/dataset_hub/pill/train_images'
|
||||
test_data_path = '/media/skydata/alpha0012/workspace/EdgeFLite/dataset_hub/pill/test_images'
|
||||
|
||||
def list_image_files_by_class(directory):
|
||||
"""
|
||||
Returns a list of image file paths and their corresponding class indices.
|
||||
|
||||
Args:
|
||||
directory (str): The path to the directory containing class folders.
|
||||
|
||||
Returns:
|
||||
list: A list of image file paths and their class indices.
|
||||
"""
|
||||
# Get the sorted list of class labels (folder names)
|
||||
class_labels = sorted(os.listdir(directory))
|
||||
# Create a mapping from class names to indices
|
||||
class_to_idx = {class_name: idx for idx, class_name in enumerate(class_labels)}
|
||||
|
||||
image_dataset = [] # Initialize an empty list to store image data
|
||||
|
||||
# Iterate through each class
|
||||
for class_name in class_labels:
|
||||
class_folder = os.path.join(directory, class_name) # Path to the class folder
|
||||
# Find all JPG images in the class folder and its subfolders
|
||||
image_files = glob.glob(os.path.join(class_folder, '**', '*.jpg'), recursive=True)
|
||||
|
||||
# Append image file paths and their class indices to the dataset
|
||||
for image_file in image_files:
|
||||
image_dataset.append([image_file, class_to_idx[class_name]])
|
||||
|
||||
return image_dataset
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Retrieve and print the number of files in the training and testing datasets
|
||||
train_images = list_image_files_by_class(train_data_path)
|
||||
test_images = list_image_files_by_class(test_data_path)
|
||||
|
||||
print(f"Training dataset size: {len(train_images)}") # Output the size of the training dataset
|
||||
print(f"Testing dataset size: {len(test_images)}") # Output the size of the testing dataset
|
161
EdgeFLite/resnet_federated.py
Normal file
161
EdgeFLite/resnet_federated.py
Normal file
@ -0,0 +1,161 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from dataset import factory
|
||||
from params import train_params
|
||||
from fedml_service.data_cleaning.cifar10.data_loader import load_partition_data_cifar10
|
||||
from fedml_service.data_cleaning.cifar100.data_loader import load_partition_data_cifar100
|
||||
from fedml_service.data_cleaning.skin_dataset.data_loader import load_partition_data_skin_dataset
|
||||
from fedml_service.data_cleaning.pillbase.data_loader import load_partition_data_pillbase
|
||||
from fedml_service.model.cv.resnet_gkt.resnet import wide_resnet16_8_gkt, wide_resnet_model_50_2_gkt, resnet110_gkt
|
||||
from fedml_service.decentralized.fedgkt.GKTTrainer import GKTTrainer
|
||||
from fedml_service.decentralized.fedgkt.GKTServerTrainer import GKTServerTrainer
|
||||
from params.train_params import save_hp_to_json
|
||||
from config import HOME
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
# Set CUDA device to be used for training
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Initialize TensorBoard writers for logging
|
||||
def initialize_writers(args):
|
||||
log_dir = os.path.join(args.model_dir, 'val') # Create a log directory inside the model directory
|
||||
return SummaryWriter(log_dir=log_dir) # Initialize SummaryWriter for TensorBoard logging
|
||||
|
||||
# Initialize dataset and data loaders
|
||||
def initialize_dataset(args, data_split_factor):
|
||||
# Fetch training data and sampler based on various input parameters
|
||||
train_data_local_dict, train_sampler = factory.obtain_data_loader(
|
||||
args.data,
|
||||
split_factor=data_split_factor,
|
||||
batch_size=args.batch_size,
|
||||
crop_size=args.crop_size,
|
||||
dataset=args.dataset,
|
||||
split="train", # Split data for training
|
||||
is_decentralized=args.is_decentralized,
|
||||
is_autoaugment=args.is_autoaugment,
|
||||
randaa=args.randaa,
|
||||
is_cutout=args.is_cutout,
|
||||
erase_p=args.erase_p,
|
||||
num_workers=args.workers,
|
||||
is_fed=args.is_fed,
|
||||
num_clusters=args.num_clusters,
|
||||
cifar10_non_iid=args.cifar10_non_iid,
|
||||
cifar100_non_iid=args.cifar100_non_iid
|
||||
)
|
||||
|
||||
# Fetch global test data
|
||||
test_data_global = factory.obtain_data_loader(
|
||||
args.data,
|
||||
batch_size=args.eval_batch_size,
|
||||
crop_size=args.crop_size,
|
||||
dataset=args.dataset,
|
||||
split="val", # Split data for validation
|
||||
num_workers=args.workers,
|
||||
cifar10_non_iid=args.cifar10_non_iid,
|
||||
cifar100_non_iid=args.cifar100_non_iid
|
||||
)
|
||||
return train_data_local_dict, test_data_global # Return both train and test data loaders
|
||||
|
||||
# Setup models based on the dataset
|
||||
def setup_models(args):
|
||||
if args.dataset == "cifar10":
|
||||
return load_partition_data_cifar10, wide_resnet16_8_gkt() # Model for CIFAR-10
|
||||
elif args.dataset == "cifar100":
|
||||
return load_partition_data_cifar100, resnet110_gkt() # Model for CIFAR-100
|
||||
elif args.dataset == "skin_dataset":
|
||||
return load_partition_data_skin_dataset, wide_resnet_model_50_2_gkt() # Model for skin dataset
|
||||
elif args.dataset == "pill_base":
|
||||
return load_partition_data_pillbase, wide_resnet_model_50_2_gkt() # Model for pill base dataset
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset: {args.dataset}") # Raise error for unsupported dataset
|
||||
|
||||
# Initialize trainers for each client in the federated learning setup
|
||||
def initialize_trainers(client_number, device, model_client, args, train_data_local_dict, test_data_local_dict):
|
||||
client_trainers = []
|
||||
# Initialize a trainer for each client
|
||||
for i in range(client_number):
|
||||
trainer = GKTTrainer(
|
||||
client_idx=i,
|
||||
train_data_local_dict=train_data_local_dict,
|
||||
test_data_local_dict=test_data_local_dict,
|
||||
device=device,
|
||||
model_client=model_client,
|
||||
args=args
|
||||
)
|
||||
client_trainers.append(trainer) # Add client trainer to the list
|
||||
return client_trainers
|
||||
|
||||
# Main function to initialize and run the federated learning process
|
||||
def main(args):
|
||||
args.model_dir = os.path.join(str(HOME), "models/coremodel", str(args.spid)) # Set model directory based on home directory and spid
|
||||
|
||||
# Save hyperparameters if not in summary or evaluation mode
|
||||
if not args.is_summary and not args.evaluate:
|
||||
save_hp_to_json(args)
|
||||
|
||||
# Initialize the TensorBoard writer for logging
|
||||
val_writer = initialize_writers(args)
|
||||
data_split_factor = args.loop_factor if args.is_diff_data_train else 1 # Set data split factor based on training mode
|
||||
args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized # Check if decentralized learning is needed
|
||||
|
||||
print(f"INFO: PyTorch: => The number of views of train data is '{data_split_factor}'")
|
||||
|
||||
# Load dataset and initialize data loaders
|
||||
train_data_local_dict, test_data_global = initialize_dataset(args, data_split_factor)
|
||||
|
||||
# Setup models for the clients and server
|
||||
data_loader, (model_client, model_server) = setup_models(args)
|
||||
client_number = args.num_clusters * args.split_factor # Calculate the number of clients
|
||||
|
||||
# Load data for federated learning
|
||||
train_data_num, test_data_num, train_data_global, _, _, _, test_data_local_dict, class_num = data_loader(
|
||||
args.dataset, args.data, 'homo', 0.5, client_number, args.batch_size
|
||||
)
|
||||
|
||||
dataset_info = [train_data_num, test_data_num, train_data_global, test_data_global, train_data_local_dict, test_data_local_dict, class_num]
|
||||
|
||||
print("Server and clients initialized.")
|
||||
round_idx = 0 # Initialize the training round index
|
||||
|
||||
# Initialize client trainers and server trainer
|
||||
client_trainers = initialize_trainers(client_number, device, model_client, args, train_data_local_dict, test_data_local_dict)
|
||||
server_trainer = GKTServerTrainer(client_number, device, model_server, args, val_writer)
|
||||
|
||||
# Start federated training rounds
|
||||
for current_round in range(args.num_rounds):
|
||||
# For each client, perform local training and send results to the server
|
||||
for client_idx in range(client_number):
|
||||
extracted_features, logits, labels, test_features, test_labels = client_trainers[client_idx].train()
|
||||
print(f"Client {client_idx} finished training.")
|
||||
server_trainer.add_local_trained_result(client_idx, extracted_features, logits, labels, test_features, test_labels)
|
||||
|
||||
# Check if the server has received all clients' results
|
||||
if server_trainer.check_whether_all_receive():
|
||||
print("All clients' results received by server.")
|
||||
server_trainer.train(round_idx) # Server performs training using the aggregated results
|
||||
round_idx += 1
|
||||
|
||||
# Send global model updates back to clients
|
||||
for client_idx in range(client_number):
|
||||
global_logits = server_trainer.get_global_logits(client_idx)
|
||||
client_trainers[client_idx].update_large_model_logits(global_logits)
|
||||
print("Server sent updated logits back to clients.")
|
||||
|
||||
# Entry point of the script
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
args = train_params.add_parser_params(parser)
|
||||
|
||||
# Ensure that federated learning mode is enabled
|
||||
assert args.is_fed == 1, "Federated learning requires 'args.is_fed' to be set to 1."
|
||||
|
||||
# Create the model directory if it does not exist
|
||||
os.makedirs(args.model_dir, exist_ok=True)
|
||||
|
||||
print(args) # Print the parsed arguments for verification
|
||||
main(args) # Start the main process
|
158
EdgeFLite/run_federated.py
Normal file
158
EdgeFLite/run_federated.py
Normal file
@ -0,0 +1,158 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.decentralized as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.cuda.amp as amp
|
||||
from torch.backends import cudnn
|
||||
from tensorboardX import SummaryWriter
|
||||
import warnings
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from dataset import factory
|
||||
from model import coremodel
|
||||
from utils import metric, label_smoothing, lr_scheduler, prefetch
|
||||
from params.train_params import save_hp_to_json
|
||||
from params import train_params
|
||||
|
||||
# Global variable to track the best accuracy
|
||||
best_accuracy = 0
|
||||
|
||||
def calculate_average(values):
|
||||
"""Calculate the average of a list of values"""
|
||||
return sum(values) / len(values)
|
||||
|
||||
def initialize_processes(rank, world_size, args):
|
||||
"""
|
||||
Initialize decentralized processes.
|
||||
This function is used to set up distributed training across multiple GPUs.
|
||||
"""
|
||||
ngpus = torch.cuda.device_count()
|
||||
args.ngpus = ngpus
|
||||
args.is_decentralized = world_size > 1
|
||||
|
||||
if args.multiprocessing_decentralized:
|
||||
# If running decentralized with multiple GPUs, spawn processes for each GPU
|
||||
mp.spawn(train_single_worker, nprocs=ngpus, args=(ngpus, args))
|
||||
else:
|
||||
print(f"INFO:PyTorch: Using {ngpus} GPUs")
|
||||
# If single GPU, start the training worker directly
|
||||
train_single_worker(args.gpu, ngpus, args)
|
||||
|
||||
def client_training_step(args, current_round, model, optimizer, scheduler, dataloader, epochs=5, scaler=None):
|
||||
"""
|
||||
Perform training for a single client model in the federated learning setup.
|
||||
This method will train the model for a given number of epochs.
|
||||
"""
|
||||
model.train() # Set model to training mode
|
||||
for epoch in range(epochs):
|
||||
# Prefetch data to improve efficiency
|
||||
prefetcher = prefetch.data_prefetcher(dataloader)
|
||||
images, targets = prefetcher.next()
|
||||
step = 0
|
||||
|
||||
while images is not None:
|
||||
# Update the learning rate using the scheduler
|
||||
scheduler(optimizer, step, current_round)
|
||||
optimizer.zero_grad() # Clear the gradients
|
||||
|
||||
# Enable mixed precision training to optimize memory and computation speed
|
||||
with amp.autocast(enabled=args.is_amp):
|
||||
outputs, ce_loss, cot_loss = model(images, target=targets, mode='train')
|
||||
|
||||
# Combine losses and normalize by accumulation steps
|
||||
loss = (ce_loss + cot_loss) / args.accumulation_steps
|
||||
loss.backward() # Backpropagate the gradients
|
||||
|
||||
# Perform optimization step after enough accumulation
|
||||
if step % args.accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad() # Clear gradients after the step
|
||||
|
||||
images, targets = prefetcher.next() # Get the next batch of images and targets
|
||||
step += 1
|
||||
|
||||
return loss.item() # Return the final loss value
|
||||
|
||||
def combine_model_parameters(global_model, client_models):
|
||||
"""
|
||||
Aggregate the weights of multiple client models to update the global model.
|
||||
This is the core of the Federated Averaging (FedAvg) algorithm.
|
||||
"""
|
||||
global_state = global_model.state_dict()
|
||||
for key in global_state.keys():
|
||||
# Average the weights of the corresponding layers from all client models
|
||||
global_state[key] = torch.stack([client.state_dict()[key].float() for client in client_models], dim=0).mean(dim=0)
|
||||
|
||||
# Load the averaged weights into the global model
|
||||
global_model.load_state_dict(global_state)
|
||||
# Update the client models with the new global model weights
|
||||
for client in client_models:
|
||||
client.load_state_dict(global_model.state_dict())
|
||||
|
||||
def validate_model(validation_loader, model, args):
|
||||
"""
|
||||
Perform model validation on the validation dataset.
|
||||
Calculate and return the average accuracy across the dataset.
|
||||
"""
|
||||
model.eval() # Set the model to evaluation mode
|
||||
accuracy_values = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, targets in validation_loader:
|
||||
if args.gpu is not None:
|
||||
images, targets = images.cuda(args.gpu), targets.cuda(args.gpu)
|
||||
|
||||
# Use mixed precision for inference
|
||||
with amp.autocast(enabled=args.is_amp):
|
||||
ensemble_output, outputs, ce_loss = model(images, target=targets, mode='val')
|
||||
|
||||
# Calculate the top-1 accuracy for the current batch
|
||||
avg_acc1 = metric.accuracy(ensemble_output, targets, topk=(1,))
|
||||
accuracy_values.append(avg_acc1)
|
||||
|
||||
return calculate_average(accuracy_values) # Return the average accuracy
|
||||
|
||||
def train_single_worker(gpu, ngpus, args):
|
||||
"""
|
||||
Training worker function that runs on a single GPU.
|
||||
This function handles the entire federated learning workflow for the assigned GPU.
|
||||
"""
|
||||
global best_accuracy
|
||||
args.gpu = gpu
|
||||
cudnn.performance_test = True # Enable performance optimization for CuDNN
|
||||
|
||||
# Optionally, resume from a checkpoint if provided
|
||||
if args.resume:
|
||||
checkpoint = torch.load(args.resume)
|
||||
args.start_round = checkpoint['round']
|
||||
best_accuracy = checkpoint['best_acc1']
|
||||
|
||||
# Initialize global and client models
|
||||
model = coremodel.coremodel(args).cuda()
|
||||
client_models = [coremodel.coremodel(args).cuda() for _ in range(args.num_clients)]
|
||||
optimizers = [torch.optim.SGD(client.parameters(), lr=args.lr) for client in client_models]
|
||||
|
||||
# Training and validation loop
|
||||
for round_num in range(args.start_round, args.num_rounds):
|
||||
# Perform training for each client model
|
||||
for client_num in range(args.num_clients):
|
||||
client_training_step(args, round_num, client_models[client_num], optimizers[client_num], lr_scheduler, args.train_loader)
|
||||
|
||||
# Aggregate client models to update the global model
|
||||
combine_model_parameters(model, client_models)
|
||||
|
||||
# Validate the updated global model and track the best accuracy
|
||||
validation_accuracy = validate_model(args.val_loader, model, args)
|
||||
best_accuracy = max(best_accuracy, validation_accuracy)
|
||||
print(f"Round {round_num}: Best Accuracy: {best_accuracy:.2f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser(description='FedAvg decentralized Training')
|
||||
args = train_params.add_parser_params(parser)
|
||||
initialize_processes(0, args.world_size, args) # Initialize distributed training
|
279
EdgeFLite/run_local.py
Normal file
279
EdgeFLite/run_local.py
Normal file
@ -0,0 +1,279 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
import warnings
|
||||
import setproctitle
|
||||
from torch import nn, decentralized # Used for decentralized training
|
||||
from torch.backends import cudnn # Optimizes performance for convolutional networks
|
||||
from tensorboardX import SummaryWriter # For logging metrics and results to TensorBoard
|
||||
import torch.cuda.amp as amp # For mixed precision training
|
||||
from config import * # Custom configuration module
|
||||
from params import train_params # Training parameters
|
||||
from utils import label_smoothing, norm, summary, metric, lr_scheduler, prefetch # Utility functions
|
||||
from model import coremodel # Core model implementation
|
||||
from dataset import factory # Dataset and data loader factory
|
||||
from params.train_params import save_hp_to_json # Function to save hyperparameters to JSON
|
||||
|
||||
# Global variable to store the best accuracy obtained during training
|
||||
best_acc1 = 0
|
||||
|
||||
def main(args):
|
||||
# Warn if a specific GPU is chosen, as this will disable data parallelism
|
||||
if args.gpu is not None:
|
||||
warnings.warn("Selecting a specific GPU will disable data parallelism.")
|
||||
|
||||
# Adjust loop factor based on specific training configurations
|
||||
args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor
|
||||
# Check if decentralized training is needed
|
||||
args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized
|
||||
|
||||
# Get the number of available GPUs on the machine
|
||||
num_gpus = torch.cuda.device_count()
|
||||
args.ngpus_per_node = num_gpus
|
||||
print(f"INFO:PyTorch: GPUs available on this node: {num_gpus}")
|
||||
|
||||
# If multiprocessing is needed for decentralized training
|
||||
if args.multiprocessing_decentralized:
|
||||
# Adjust world size to account for multiple GPUs
|
||||
args.world_size *= num_gpus
|
||||
# Spawn multiple processes for each GPU
|
||||
torch.multiprocessing.spawn(execute_worker_process, nprocs=num_gpus, args=(num_gpus, args))
|
||||
else:
|
||||
# If using a single GPU
|
||||
print("INFO:PyTorch: Using GPU 0 for single GPU training")
|
||||
args.gpu = 0
|
||||
# Call main worker for single GPU
|
||||
execute_worker_process(args.gpu, num_gpus, args)
|
||||
|
||||
def execute_worker_process(gpu, num_gpus, args):
|
||||
global best_acc1
|
||||
args.gpu = gpu
|
||||
# Set the directory where models will be saved
|
||||
args.model_dir = os.path.join(HOME, "models", "coremodel", str(args.spid))
|
||||
|
||||
# Initialize the decentralized training process group if needed
|
||||
if args.is_decentralized:
|
||||
print("INFO:PyTorch: Initializing process group for decentralized training.")
|
||||
if args.dist_url == "env://" and args.rank == -1:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
if args.multiprocessing_decentralized:
|
||||
args.rank = args.rank * num_gpus + gpu
|
||||
decentralized.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
|
||||
|
||||
# Set the GPU to be used for training or evaluation
|
||||
if args.gpu is not None:
|
||||
print(f"INFO:PyTorch: GPU {args.gpu} in use for training (Rank: {args.rank})" if not args.evaluate else f"INFO:PyTorch: GPU {args.gpu} in use for evaluation (Rank: {args.rank})")
|
||||
|
||||
# Set process title for better identification in system process monitors
|
||||
setproctitle.setproctitle(f"{args.proc_name}centralized_rank{args.rank}")
|
||||
|
||||
# Initialize a SummaryWriter for TensorBoard logging
|
||||
val_writer = SummaryWriter(log_dir=os.path.join(args.model_dir, 'val'))
|
||||
|
||||
# Use label smoothing if enabled, otherwise use standard cross-entropy loss
|
||||
criterion = label_smoothing.label_smoothing_CE(reduction='mean') if args.is_label_smoothing else nn.CrossEntropyLoss()
|
||||
|
||||
# Instantiate the model
|
||||
model = coremodel.coremodel(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion)
|
||||
print(f"INFO:PyTorch: Model '{args.arch}' has {metric.get_the_number_of_params(model)} parameters")
|
||||
|
||||
# If summary is requested, print model and exit
|
||||
if args.is_summary:
|
||||
print(model)
|
||||
return
|
||||
|
||||
# Save model configuration and hyperparameters
|
||||
summary.save_model_to_json(args, model)
|
||||
|
||||
# Convert BatchNorm layers to synchronized BatchNorm for decentralized training
|
||||
if args.is_decentralized and args.world_size > 1 and args.is_syncbn:
|
||||
print("INFO:PyTorch: Converting BatchNorm to SyncBatchNorm")
|
||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
|
||||
# Set up the model for GPU-based training
|
||||
if args.gpu is not None:
|
||||
torch.cuda.set_device(args.gpu)
|
||||
model.cuda(args.gpu)
|
||||
args.batch_size = int(args.batch_size / num_gpus) # Adjust batch size for multiple GPUs
|
||||
args.workers = int((args.workers + num_gpus - 1) / num_gpus) # Adjust number of workers
|
||||
# Use decentralized data parallel model
|
||||
model = nn.parallel.decentralizedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
||||
else:
|
||||
# Use standard DataParallel for multi-GPU training
|
||||
model = nn.DataParallel(model).cuda()
|
||||
|
||||
# Create the optimizer
|
||||
optimizer = create_optimizer(args, model)
|
||||
# Set up the gradient scaler for mixed precision training, if enabled
|
||||
scaler = amp.GradScaler() if args.is_amp else None
|
||||
|
||||
# If resuming from a checkpoint, load model and optimizer state
|
||||
if args.resume:
|
||||
load_checkpoint(args, model, optimizer, scaler)
|
||||
|
||||
cudnn.performance_test = True # Enable cuDNN performance optimizations
|
||||
|
||||
# Set up data loader parameters
|
||||
data_loader_params = {
|
||||
'split_factor': args.loop_factor if args.is_diff_data_train else 1,
|
||||
'batch_size': args.batch_size,
|
||||
'crop_size': args.crop_size,
|
||||
'dataset': args.dataset,
|
||||
'is_decentralized': args.is_decentralized,
|
||||
'num_workers': args.workers,
|
||||
'randaa': args.randaa,
|
||||
'is_autoaugment': args.is_autoaugment,
|
||||
'is_cutout': args.is_cutout,
|
||||
'erase_p': args.erase_p,
|
||||
}
|
||||
|
||||
# Get the training and validation data loaders
|
||||
train_loader, train_sampler = factory.obtain_data_loader(args.data, split="train", **data_loader_params)
|
||||
val_loader = factory.obtain_data_loader(args.data, split="val", batch_size=args.eval_batch_size, crop_size=args.crop_size, num_workers=args.workers)
|
||||
|
||||
# Set up the learning rate scheduler
|
||||
scheduler = lr_scheduler.create_scheduler(args, len(train_loader))
|
||||
|
||||
# If evaluating, run the validation function and exit
|
||||
if args.evaluate:
|
||||
validate(val_loader, model, args)
|
||||
return
|
||||
|
||||
# Begin training and evaluation
|
||||
train_and_evaluate(train_loader, val_loader, train_sampler, model, optimizer, scheduler, scaler, val_writer, args, num_gpus)
|
||||
|
||||
# Function to create the optimizer
|
||||
def create_optimizer(args, model):
|
||||
param_groups = model.parameters() if args.is_wd_all else lr_scheduler.get_parameter_groups(model)
|
||||
# Select the optimizer based on input arguments
|
||||
if args.optimizer == 'SGD':
|
||||
return torch.optim.SGD(param_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.is_nesterov)
|
||||
elif args.optimizer == 'AdamW':
|
||||
return torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-4, weight_decay=args.weight_decay)
|
||||
elif args.optimizer == 'RMSprop':
|
||||
return torch.optim.RMSprop(param_groups, lr=args.lr, alpha=0.9, momentum=0.9, weight_decay=args.weight_decay)
|
||||
else:
|
||||
# Raise error if unsupported optimizer is selected
|
||||
raise NotImplementedError(f"Optimizer {args.optimizer} not implemented")
|
||||
|
||||
# Function to load a checkpoint and resume training
|
||||
def load_checkpoint(args, model, optimizer, scaler):
|
||||
if os.path.isfile(args.resume):
|
||||
print(f"INFO:PyTorch: Loading checkpoint from '{args.resume}'")
|
||||
loc = f'cuda:{args.gpu}' if args.gpu is not None else None
|
||||
checkpoint = torch.load(args.resume, map_location=loc)
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
global best_acc1
|
||||
best_acc1 = checkpoint['best_acc1']
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
if "scaler" in checkpoint:
|
||||
scaler.load_state_dict(checkpoint['scaler'])
|
||||
print(f"INFO:PyTorch: Checkpoint loaded, epoch {checkpoint['epoch']}")
|
||||
else:
|
||||
print(f"INFO:PyTorch: No checkpoint found at '{args.resume}'")
|
||||
|
||||
# Function to train and evaluate the model over multiple epochs
|
||||
def train_and_evaluate(train_loader, val_loader, train_sampler, model, optimizer, scheduler, scaler, val_writer, args, num_gpus):
|
||||
for epoch in range(args.start_epoch, args.epochs + 1):
|
||||
if args.is_decentralized:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
train_one_epoch(train_loader, model, optimizer, scheduler, epoch, scaler, val_writer, args)
|
||||
|
||||
# Evaluate the model every 'eval_per_epoch' epochs
|
||||
if (epoch + 1) % args.eval_per_epoch == 0:
|
||||
acc_all = validate(val_loader, model, args)
|
||||
global best_acc1
|
||||
is_best = acc_all[0] > best_acc1 # Track the best accuracy
|
||||
best_acc1 = max(acc_all[0], best_acc1)
|
||||
# Save the model checkpoint
|
||||
save_checkpoint(model, optimizer, scaler, epoch, best_acc1, args, is_best)
|
||||
|
||||
# Function to perform one training epoch
|
||||
def train_one_epoch(train_loader, model, optimizer, scheduler, epoch, scaler, val_writer, args):
|
||||
metric_storage = create_metric_storage(args.loop_factor)
|
||||
model.train() # Set the model to training mode
|
||||
data_loader = prefetch.data_prefetcher(train_loader) # Use data prefetching to improve efficiency
|
||||
images, target = data_loader.next()
|
||||
|
||||
optimizer.zero_grad() # Reset gradients
|
||||
while images is not None:
|
||||
# Adjust the learning rate based on the scheduler
|
||||
scheduler(optimizer, epoch)
|
||||
|
||||
# Perform forward pass with mixed precision if enabled
|
||||
if args.is_amp:
|
||||
with amp.autocast():
|
||||
ensemble_output, outputs, ce_loss, cot_loss = model(images, target=target, mode='train', epoch=epoch)
|
||||
else:
|
||||
ensemble_output, outputs, ce_loss, cot_loss = model(images, target=target, mode='train', epoch=epoch)
|
||||
|
||||
# Calculate total loss and normalize
|
||||
total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate
|
||||
val_writer.add_scalar('average_training_loss', total_loss, global_step=epoch)
|
||||
|
||||
# Perform backward pass and update gradients with mixed precision if enabled
|
||||
if args.is_amp:
|
||||
scaler.scale(total_loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
images, target = data_loader.next() # Fetch the next batch of data
|
||||
|
||||
# Function to save the model checkpoint
|
||||
def save_checkpoint(model, optimizer, scaler, epoch, best_acc1, args, is_best):
|
||||
ckpt = {
|
||||
'epoch': epoch + 1,
|
||||
'state_dict': model.state_dict(),
|
||||
'best_acc1': best_acc1,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}
|
||||
if args.is_amp:
|
||||
ckpt['scaler'] = scaler.state_dict()
|
||||
metric.save_checkpoint(ckpt, is_best, args.model_dir, filename=f"checkpoint_{epoch}.pth.tar")
|
||||
|
||||
# Function to validate the model on the validation dataset
|
||||
def validate(val_loader, model, args):
|
||||
metric_storage = create_metric_storage(args.loop_factor)
|
||||
model.eval() # Set the model to evaluation mode
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (images, target) in enumerate(val_loader):
|
||||
if args.gpu is not None:
|
||||
images = images.cuda(args.gpu, non_blocking=True)
|
||||
target = target.cuda(args.gpu, non_blocking=True)
|
||||
|
||||
# Perform forward pass with mixed precision if enabled
|
||||
if args.is_amp:
|
||||
with amp.autocast():
|
||||
ensemble_output, outputs, ce_loss = model(images, target=target, mode='val')
|
||||
else:
|
||||
ensemble_output, outputs, ce_loss = model(images, target=target, mode='val')
|
||||
|
||||
batch_size = images.size(0)
|
||||
acc1, acc5 = metric.accuracy(ensemble_output, target, topk=(1, 5))
|
||||
|
||||
metric_storage.update(acc1, acc5, ce_loss, batch_size)
|
||||
|
||||
return metric_storage.results()
|
||||
|
||||
# Helper function to create a storage for metrics during training and validation
|
||||
def create_metric_storage(loop_factor):
|
||||
# Initialize metrics for accuracy and other performance metrics
|
||||
top1_all = [metric.AverageMeter(f'Acc@1_{i}', ':6.2f') for i in range(loop_factor)]
|
||||
avg_top1 = metric.AverageMeter('Avg_Acc@1', ':6.2f')
|
||||
return metric.ProgressMeter(len(top1_all), top1_all, avg_top1)
|
||||
|
||||
# Main entry point for the script
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Centralized Training')
|
||||
args = train_params.add_parser_params(parser) # Add parameters to the argument parser
|
||||
assert args.is_fed == 0, "Centralized training requires args.is_fed to be False"
|
||||
os.makedirs(args.model_dir, exist_ok=True) # Create model directory if it doesn't exist
|
||||
main(args) # Call the main function
|
223
EdgeFLite/run_prox.py
Normal file
223
EdgeFLite/run_prox.py
Normal file
@ -0,0 +1,223 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import os
|
||||
import warnings
|
||||
import torch
|
||||
import torch.cuda.amp as autocast
|
||||
from torch import nn
|
||||
from torch.backends import cudnn
|
||||
from tensorboardX import SummaryWriter
|
||||
from config import *
|
||||
from params import train_settings
|
||||
from utils import label_smooth, metrics, scheduler, prefetch_loader
|
||||
from model import net_splitter
|
||||
from dataset import data_factory
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from params.train_settings import save_hyperparams_to_json
|
||||
|
||||
# Set the visible GPU to use for training
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
# Variable to store the best accuracy achieved during training
|
||||
best_accuracy = 0
|
||||
|
||||
|
||||
# Helper function to compute the average of a list
|
||||
def compute_average(lst):
|
||||
return sum(lst) / len(lst)
|
||||
|
||||
|
||||
# Main function to initialize the training process
|
||||
def main(args):
|
||||
if args.gpu_index is not None:
|
||||
# Warn if a specific GPU is selected, disabling data parallelism
|
||||
warnings.warn("Specific GPU chosen, disabling data parallelism.")
|
||||
|
||||
# Adjust loop factor based on training setup
|
||||
args.loop_factor = 1 if args.separate_training or args.single_branch else args.split_factor
|
||||
# Determine if decentralized training is required
|
||||
args.decentralized_training = args.world_size > 1 or args.multiprocessing_decentralized
|
||||
num_gpus = torch.cuda.device_count()
|
||||
args.num_gpus = num_gpus
|
||||
|
||||
# If decentralized multiprocessing is enabled, spawn multiple processes
|
||||
if args.multiprocessing_decentralized:
|
||||
args.world_size = num_gpus * args.world_size
|
||||
torch.multiprocessing.spawn(worker_process, nprocs=num_gpus, args=(num_gpus, args))
|
||||
else:
|
||||
# Otherwise, proceed with single-GPU training
|
||||
print(f"INFO:PyTorch: Detected {num_gpus} GPU(s) available.")
|
||||
args.gpu_index = 0
|
||||
worker_process(args.gpu_index, num_gpus, args)
|
||||
|
||||
|
||||
# Client-side training function for federated learning updates
|
||||
def client_train_update(args, round_num, client_model, global_model, sched, opt, train_loader, epochs=5, scaler=None):
|
||||
client_model.train()
|
||||
|
||||
for epoch in range(epochs):
|
||||
# Prefetch data for training
|
||||
loader = prefetch_loader.DataPrefetcher(train_loader)
|
||||
images, targets = loader.next()
|
||||
batch_idx = 0
|
||||
opt.zero_grad()
|
||||
|
||||
while images is not None:
|
||||
# Apply learning rate scheduling
|
||||
sched(opt, batch_idx)
|
||||
|
||||
# Use automatic mixed precision if enabled
|
||||
if args.amp_enabled:
|
||||
with autocast.autocast():
|
||||
ensemble_out, model_outputs, loss_ce, loss_cot = client_model(images, targets=targets, mode='train',
|
||||
epoch=epoch)
|
||||
else:
|
||||
ensemble_out, model_outputs, loss_ce, loss_cot = client_model(images, targets=targets, mode='train',
|
||||
epoch=epoch)
|
||||
|
||||
# Compute accuracy for top-1 predictions
|
||||
batch_size = images.size(0)
|
||||
for j in range(args.loop_factor):
|
||||
top1_acc = metrics.accuracy(model_outputs[j], targets, topk=(1,))
|
||||
|
||||
# Compute the proximal term for FedProx loss
|
||||
prox_term = sum((param - global_param).norm(2) for param, global_param in
|
||||
zip(client_model.parameters(), global_model.parameters()))
|
||||
# Compute the total loss (cross-entropy + contrastive loss + proximal term)
|
||||
total_loss = (loss_ce + loss_cot) / args.accum_steps + (args.mu / 2) * prox_term
|
||||
|
||||
# Backward pass with mixed precision scaling if enabled
|
||||
if args.amp_enabled:
|
||||
scaler.scale(total_loss).backward()
|
||||
if (batch_idx % args.accum_steps == 0) or (batch_idx == len(train_loader)):
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
opt.zero_grad()
|
||||
else:
|
||||
total_loss.backward()
|
||||
if (batch_idx % args.accum_steps == 0) or (batch_idx == len(train_loader)):
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
|
||||
images, targets = loader.next()
|
||||
|
||||
return total_loss.item()
|
||||
|
||||
|
||||
# Function to aggregate model weights from clients on the server
|
||||
def server_compute_average_weights(global_model, client_models):
|
||||
global_state_dict = global_model.state_dict()
|
||||
# Average weights across all clients
|
||||
for key in global_state_dict.keys():
|
||||
global_state_dict[key] = torch.stack(
|
||||
[client_models[i].state_dict()[key].float() for i in range(len(client_models))], 0).mean(0)
|
||||
global_model.load_state_dict(global_state_dict)
|
||||
|
||||
# Update clients with the averaged global model
|
||||
for model in client_models:
|
||||
model.load_state_dict(global_model.state_dict())
|
||||
|
||||
|
||||
# Function to validate the model on the validation set
|
||||
def validate_model(val_loader, model, args):
|
||||
model.eval()
|
||||
acc1_list, acc5_list, loss_ce_list = [], [], []
|
||||
|
||||
# Perform validation without gradient calculation
|
||||
with torch.no_grad():
|
||||
for images, targets in val_loader:
|
||||
if args.gpu_index is not None:
|
||||
images, targets = images.cuda(args.gpu_index, non_blocking=True), targets.cuda(args.gpu_index,
|
||||
non_blocking=True)
|
||||
|
||||
if args.amp_enabled:
|
||||
with autocast.autocast():
|
||||
ensemble_out, model_outputs, loss_ce = model(images, target=targets, mode='val')
|
||||
else:
|
||||
ensemble_out, model_outputs, loss_ce = model(images, target=targets, mode='val')
|
||||
|
||||
for j in range(args.loop_factor):
|
||||
acc1, acc5 = metrics.accuracy(model_outputs[j], targets, topk=(1, 5))
|
||||
|
||||
avg_acc1, avg_acc5 = metrics.accuracy(ensemble_out, targets, topk=(1, 5))
|
||||
acc1_list.append(avg_acc1)
|
||||
acc5_list.append(avg_acc5)
|
||||
loss_ce_list.append(loss_ce)
|
||||
|
||||
return compute_average(loss_ce_list), compute_average(acc1_list)
|
||||
|
||||
|
||||
# Function to handle the worker process for training on a specific GPU
|
||||
def worker_process(gpu_index, num_gpus, args):
|
||||
global best_accuracy
|
||||
args.gpu_index = gpu_index
|
||||
args.model_path = os.path.join(HOME, "models", "coremodel", str(args.model_id))
|
||||
|
||||
# Create summary writer for validation if not using decentralized training
|
||||
if not args.decentralized_training or (args.multiprocessing_decentralized and args.rank % num_gpus == 0):
|
||||
val_summary_writer = SummaryWriter(log_dir=os.path.join(args.model_path, 'validation'))
|
||||
|
||||
# Set the loss function based on the label smoothing option
|
||||
criterion = label_smooth.smooth_ce_loss(reduction='mean') if args.use_label_smooth else nn.CrossEntropyLoss()
|
||||
# Initialize the global model and client models
|
||||
global_model = net_splitter.coremodel(args, normalization=args.norm_mode, loss_function=criterion)
|
||||
client_models = [net_splitter.coremodel(args, normalization=args.norm_mode, loss_function=criterion) for _ in
|
||||
range(args.num_clients)]
|
||||
|
||||
# Save hyperparameters to JSON if required
|
||||
if args.save_summary:
|
||||
save_hyperparams_to_json(args)
|
||||
return
|
||||
|
||||
# Move models to GPU
|
||||
global_model = global_model.cuda()
|
||||
for model in client_models:
|
||||
model.cuda()
|
||||
model.load_state_dict(global_model.state_dict())
|
||||
|
||||
# Create optimizers for each client
|
||||
opt_list = [torch.optim.SGD(client.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
|
||||
nesterov=args.use_nesterov) for client in client_models]
|
||||
|
||||
# Initialize gradient scaler if AMP is enabled
|
||||
scaler = torch.cuda.amp.GradScaler() if args.amp_enabled else None
|
||||
cudnn.performance_test = True
|
||||
|
||||
# Resume training from checkpoint if specified
|
||||
if args.resume_training:
|
||||
if os.path.isfile(args.resume_checkpoint):
|
||||
checkpoint = torch.load(args.resume_checkpoint,
|
||||
map_location=f'cuda:{args.gpu_index}' if args.gpu_index else None)
|
||||
args.start_round = checkpoint['round']
|
||||
best_accuracy = checkpoint['best_acc1']
|
||||
global_model.load_state_dict(checkpoint['state_dict'])
|
||||
for opt in opt_list:
|
||||
opt.load_state_dict(checkpoint['optimizer'])
|
||||
if "scaler" in checkpoint:
|
||||
scaler.load_state_dict(checkpoint['scaler'])
|
||||
for client_model in client_models:
|
||||
client_model.load_state_dict(global_model.state_dict())
|
||||
else:
|
||||
args.start_round = 0
|
||||
else:
|
||||
args.start_round = 0
|
||||
|
||||
# Load training and validation data
|
||||
train_loader, _ = data_factory.load_data(args.data_dir, args.batch_size, args.split_factor,
|
||||
dataset_name=args.dataset_name, split="train",
|
||||
num_workers=args.num_workers, decentralized=args.decentralized_training)
|
||||
val_loader = data_factory.load_data(args.data_dir, args.eval_batch_size, args.split_factor,
|
||||
dataset_name=args.dataset_name, split="val", num_workers=args.num_workers)
|
||||
|
||||
# Federated learning rounds
|
||||
for round_num in range(args.start_round, args.num_rounds + 1):
|
||||
if args.fixed_cluster:
|
||||
# Select clients from fixed clusters for each round
|
||||
selected_clusters = np.random.permutation(args.num_clusters)[:args.num_clients]
|
||||
for i in tqdm(range(args.num_clients)):
|
||||
selected_clients = np.arange(start=selected_clusters[i] * args.split_factor,
|
||||
stop=(selected_clusters[i] + 1) * args.split_factor)
|
||||
for client in selected_clients:
|
||||
loss = client_train
|
210
EdgeFLite/run_splitfed.py
Normal file
210
EdgeFLite/run_splitfed.py
Normal file
@ -0,0 +1,210 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import argparse
|
||||
import warnings
|
||||
from tqdm import tqdm
|
||||
from tensorboardX import SummaryWriter
|
||||
from dataset import factory
|
||||
from config import *
|
||||
from model import coremodelsl
|
||||
from utils import label_smoothing, norm, metric, lr_scheduler, prefetch
|
||||
from params import train_params
|
||||
from params.train_params import save_hp_to_json
|
||||
|
||||
# Set the visible GPU devices for the training
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
# Global best accuracy to track the performance
|
||||
best_acc1 = 0 # Global best accuracy
|
||||
|
||||
def average(values):
|
||||
"""Calculate the average of a list of values."""
|
||||
return sum(values) / len(values)
|
||||
|
||||
def combine_model_weights(global_model_client, global_model_server, client_models, server_models):
|
||||
"""
|
||||
Aggregate weights from client and server models using the mean method.
|
||||
This function updates the global model weights by averaging the weights
|
||||
from all client and server models.
|
||||
"""
|
||||
# Get the state dictionaries (weights) for both client and server models
|
||||
client_state_dict = global_model_client.state_dict()
|
||||
server_state_dict = global_model_server.state_dict()
|
||||
|
||||
# Average the weights across all client models
|
||||
for key in client_state_dict.keys():
|
||||
client_state_dict[key] = torch.stack([model.state_dict()[key].float() for model in client_models], dim=0).mean(0)
|
||||
global_model_client.load_state_dict(client_state_dict)
|
||||
|
||||
# Average the weights across all server models
|
||||
for key in server_state_dict.keys():
|
||||
server_state_dict[key] = torch.stack([model.state_dict()[key].float() for model in server_models], dim=0).mean(0)
|
||||
global_model_server.load_state_dict(server_state_dict)
|
||||
|
||||
# Load the updated global model weights back into the client models
|
||||
for model in client_models:
|
||||
model.load_state_dict(global_model_client.state_dict())
|
||||
|
||||
# Load the updated global model weights back into the server models
|
||||
for model in server_models:
|
||||
model.load_state_dict(global_model_server.state_dict())
|
||||
|
||||
def client_training(args, round_num, client_model, server_model, scheduler_client, scheduler_server, optimizer_client, optimizer_server, data_loader, epochs=5, streams=None):
|
||||
"""
|
||||
Perform client-side model training for the given number of epochs.
|
||||
The client model performs the forward pass and sends intermediate outputs
|
||||
to the server model for further computation.
|
||||
"""
|
||||
client_model.train()
|
||||
server_model.train()
|
||||
|
||||
for epoch in range(epochs):
|
||||
# Prefetch data to improve data loading speed
|
||||
prefetcher = prefetch.data_prefetcher(data_loader)
|
||||
images, target = prefetcher.next()
|
||||
i = 0
|
||||
optimizer_client.zero_grad()
|
||||
optimizer_server.zero_grad()
|
||||
|
||||
while images is not None:
|
||||
# Adjust learning rates using the schedulers
|
||||
scheduler_client(optimizer_client, i, round_num)
|
||||
scheduler_server(optimizer_server, i, round_num)
|
||||
i += 1
|
||||
|
||||
# Forward pass on the client model
|
||||
outputs_client, y_a, y_b, lam = client_model(images, target=target, mode='train', epoch=epoch, streams=streams)
|
||||
client_fx = [outputs.clone().detach().requires_grad_(True) for outputs in outputs_client]
|
||||
|
||||
# Forward pass on the server model and compute losses
|
||||
ensemble_output, outputs_server, ce_loss, cot_loss = server_model(client_fx, y_a, y_b, lam, target=target, mode='train', epoch=epoch, streams=streams)
|
||||
total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate
|
||||
total_loss.backward()
|
||||
|
||||
# Backpropagate the gradients to the client model
|
||||
for fx, grad in zip(outputs_client, client_fx):
|
||||
fx.backward(grad.grad)
|
||||
|
||||
# Perform optimization step when the accumulation condition is met
|
||||
if i % args.iters_to_accumulate == 0 or i == len(data_loader):
|
||||
optimizer_client.step()
|
||||
optimizer_server.step()
|
||||
optimizer_client.zero_grad()
|
||||
optimizer_server.zero_grad()
|
||||
|
||||
# Fetch the next batch of data
|
||||
images, target = prefetcher.next()
|
||||
|
||||
return total_loss.item()
|
||||
|
||||
def validate_model(val_loader, client_model, server_model, args, streams=None):
|
||||
"""
|
||||
Validate the performance of client and server models.
|
||||
This function performs forward passes without updating the model weights
|
||||
and computes validation accuracy and loss.
|
||||
"""
|
||||
client_model.eval()
|
||||
server_model.eval()
|
||||
|
||||
acc1_list, acc5_list, ce_loss_list = [], [], []
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (images, target) in enumerate(val_loader):
|
||||
if args.gpu is not None:
|
||||
images = images.cuda(args.gpu, non_blocking=True)
|
||||
target = target.cuda(args.gpu, non_blocking=True)
|
||||
|
||||
# Forward pass on the client model
|
||||
outputs_client = client_model(images, target=target, mode='val')
|
||||
client_fx = [output.clone().detach().requires_grad_(True) for output in outputs_client]
|
||||
|
||||
# Forward pass on the server model
|
||||
ensemble_output, outputs_server, ce_loss = server_model(client_fx, target=target, mode='val')
|
||||
|
||||
# Calculate accuracy and losses
|
||||
acc1, acc5 = metric.accuracy(ensemble_output, target, topk=(1, 5))
|
||||
acc1_list.append(acc1)
|
||||
acc5_list.append(acc5)
|
||||
ce_loss_list.append(ce_loss)
|
||||
|
||||
# Calculate average accuracy and loss over the validation dataset
|
||||
avg_acc1 = average(acc1_list)
|
||||
avg_acc5 = average(acc5_list)
|
||||
avg_ce_loss = average(ce_loss_list)
|
||||
|
||||
return avg_ce_loss, avg_acc1, avg_acc5
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
The main entry point for the federated learning process.
|
||||
Initializes models, handles multiprocessing setup, and starts training.
|
||||
"""
|
||||
if args.gpu is not None:
|
||||
warnings.warn("A specific GPU has been chosen. Data parallelism is disabled.")
|
||||
|
||||
# Set loop factor based on training configuration
|
||||
args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
args.ngpus_per_node = ngpus_per_node
|
||||
|
||||
if args.multiprocessing_decentralized:
|
||||
# Spawn a process for each GPU in decentralized setup
|
||||
args.world_size = ngpus_per_node * args.world_size
|
||||
torch.multiprocessing.spawn(execute_worker_process, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
||||
else:
|
||||
# Use only a single GPU in non-decentralized setup
|
||||
args.gpu = 0
|
||||
execute_worker_process(args.gpu, ngpus_per_node, args)
|
||||
|
||||
def execute_worker_process(gpu, ngpus_per_node, args):
|
||||
"""
|
||||
Worker function that handles model initialization, training, and validation.
|
||||
"""
|
||||
global best_acc1
|
||||
args.gpu = gpu
|
||||
|
||||
if args.gpu is not None:
|
||||
print(f"Using GPU {args.gpu} for training.")
|
||||
|
||||
# Create tensorboard writer for logging validation metrics
|
||||
if not args.multiprocessing_decentralized or (args.multiprocessing_decentralized and args.rank % ngpus_per_node == 0):
|
||||
val_writer = SummaryWriter(log_dir=os.path.join(args.model_dir, 'val'))
|
||||
|
||||
# Define loss criterion with label smoothing or cross-entropy
|
||||
criterion = label_smoothing.label_smoothing_CE(reduction='mean') if args.is_label_smoothing else nn.CrossEntropyLoss()
|
||||
|
||||
# Initialize global client and server models
|
||||
global_model_client = coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion)
|
||||
global_model_server = coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion)
|
||||
|
||||
# Initialize client and server models for each selected client
|
||||
client_models = [coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) for _ in range(args.num_selected)]
|
||||
server_models = [coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion) for _ in range(args.num_selected)]
|
||||
|
||||
# Save hyperparameters to a JSON file
|
||||
save_hp_to_json(args)
|
||||
|
||||
# Move global models and client/server models to GPU
|
||||
global_model_client = global_model_client.cuda()
|
||||
global_model_server = global_model_server.cuda()
|
||||
for model in client_models + server_models:
|
||||
model.cuda()
|
||||
|
||||
# Load global model weights into each client and server model
|
||||
for model in client_models:
|
||||
model.load_state_dict(global_model_client.state_dict())
|
||||
for model in server_models:
|
||||
model.load_state_dict(global_model_server.state_dict())
|
||||
|
||||
# Initialize learning rate schedulers for clients and servers
|
||||
schedulers_clients = [lr_scheduler.lr_scheduler(args.lr_mode, args.lr, args.num_rounds, len(factory.obtain_data_loader(args.data)), args.lr_milestones, args.lr_multiplier) for _ in range(args.num_selected)]
|
||||
schedulers_servers = [lr_scheduler.lr_scheduler(args.lr_mode, args.lr, args.num_rounds, len(factory.obtain_data_loader(args.data)), args.lr_milestones, args.lr_multiplier) for _ in range(args.num_selected)]
|
||||
|
||||
# Start the training and validation loop for the specified number of rounds
|
||||
for r in range(args.start_round, args.num_rounds + 1):
|
||||
# Randomly select client indices for training in each round
|
||||
client_indices = np.random.permutation(args.num_clusters * args.loop_factor)[:args.num_selected * args.loop
|
BIN
EdgeFLite/scripts/.DS_Store
vendored
Normal file
BIN
EdgeFLite/scripts/.DS_Store
vendored
Normal file
Binary file not shown.
48
EdgeFLite/scripts/EdgeFLite_R110_100c_650r.sh
Normal file
48
EdgeFLite/scripts/EdgeFLite_R110_100c_650r.sh
Normal file
@ -0,0 +1,48 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Load necessary modules
|
||||
# This section loads essential modules required for the execution environment
|
||||
source /etc/profile.d/modules.sh # Load the module environment configuration
|
||||
module load gcc/11.2.0 # Load GCC (GNU Compiler Collection) version 11.2.0
|
||||
module load openmpi/4.1.3 # Load OpenMPI version 4.1.3 for parallel computing
|
||||
module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2 for GPU computing
|
||||
module load cudnn/8.3/8.3.3 # Load cuDNN version 8.3.3 for deep learning libraries
|
||||
module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11 for multi-GPU communication
|
||||
module load python/3.10/3.10.4 # Load Python version 3.10.4 for executing Python scripts
|
||||
|
||||
# Activate virtual environment
|
||||
# This activates the virtual environment that contains the required Python packages
|
||||
source ~/venv/pytorch1.11+horovod/bin/activate
|
||||
|
||||
# Configure log directory
|
||||
# Sets up the directory for storing logs related to the job execution
|
||||
LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}"
|
||||
mkdir -p ${LOG_PATH} # Create the log directory if it doesn't exist
|
||||
|
||||
# Prepare dataset directory
|
||||
# This section prepares the dataset directory by copying data to the local directory for the job
|
||||
TEMP_DATA_PATH="${SGE_LOCALDIR}/${JOB_ID}/" # Define the temporary data path for the current job
|
||||
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${TEMP_DATA_PATH} # Copy the dataset to the temporary path
|
||||
|
||||
# Change to project directory
|
||||
# Navigates to the project directory where the training script is located
|
||||
cd EdgeFLite
|
||||
|
||||
# Execute training script
|
||||
# This runs the training script with the specified configuration
|
||||
python train_EdgeFLite.py \
|
||||
--is_fed=1 \ # Enable federated learning mode
|
||||
--fixed_cluster=0 \ # Do not use a fixed cluster configuration
|
||||
--split_factor=4 \ # Specify the data split factor for federated learning
|
||||
--num_clusters=25 \ # Set the number of clusters to 25
|
||||
--num_selected=25 \ # Select all 25 clusters for training
|
||||
--arch="resnet_model_110sl" \ # Use the 'resnet_model_110sl' architecture for the model
|
||||
--dataset="cifar100" \ # Set the dataset to CIFAR-100
|
||||
--num_classes=100 \ # Specify the number of output classes (100 for CIFAR-100)
|
||||
--is_single_branch=0 \ # Enable multi-branch mode for model training
|
||||
--is_amp=0 \ # Disable automatic mixed precision (AMP) for this run
|
||||
--num_rounds=650 \ # Set the total number of federated rounds to 650
|
||||
--fed_epochs=1 \ # Set the number of local epochs per round to 1
|
||||
--spid="EdgeFLite_R110_100c_650r" \ # Set the session/process ID for the current job
|
||||
--data=${TEMP_DATA_PATH} # Specify the dataset location (temporary directory)
|
60
EdgeFLite/scripts/EdgeFLite_R110_80c_650r.sh
Normal file
60
EdgeFLite/scripts/EdgeFLite_R110_80c_650r.sh
Normal file
@ -0,0 +1,60 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Load necessary environment modules
|
||||
source /etc/profile.d/modules.sh # Source the module environment setup script
|
||||
module load gcc/11.2.0 # Load GCC compiler version 11.2.0
|
||||
module load openmpi/4.1.3 # Load OpenMPI version 4.1.3 for distributed computing
|
||||
module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2 for GPU acceleration
|
||||
module load cudnn/8.3/8.3.3 # Load cuDNN version 8.3.3 for deep learning operations
|
||||
module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11 for multi-GPU communication
|
||||
module load python/3.10/3.10.4 # Load Python version 3.10.4
|
||||
|
||||
# Activate the Python virtual environment with PyTorch and Horovod installed
|
||||
source ~/venv/pytorch1.11+horovod/bin/activate
|
||||
|
||||
# Setup the log directory for the experiment
|
||||
LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" # Define the log path
|
||||
rm -rf ${LOG_PATH} # Remove any existing logs in the directory
|
||||
mkdir -p ${LOG_PATH} # Create the log directory if it doesn't exist
|
||||
|
||||
# Setup the dataset directory, copying data for local use
|
||||
DATA_PATH="${SGE_LOCALDIR}/${JOB_ID}/" # Define the local directory for the dataset
|
||||
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_PATH} # Copy CIFAR-100 dataset to local storage
|
||||
|
||||
# Set experiment parameters for federated learning
|
||||
OUTPUT_DIR="./EdgeFLite/models/coremodel/" # Directory where model checkpoints will be saved
|
||||
FED_MODE=1 # Federated learning mode enabled
|
||||
CLUSTER_FIXED=0 # Cluster dynamic, not fixed
|
||||
SPLIT_RATIO=4 # Split the dataset into 4 parts
|
||||
TOTAL_CLUSTERS=20 # Number of clusters (e.g., number of different clients in federated learning)
|
||||
SELECTED_CLIENTS=20 # Number of clients selected per round
|
||||
MODEL_ARCH="resnet_model_110sl" # Model architecture to be used (ResNet-110 with some custom changes)
|
||||
DATASET_NAME="cifar100" # Dataset being used (CIFAR-100)
|
||||
NUM_CLASS_LABELS=100 # Number of class labels in the dataset (CIFAR-100 has 100 classes)
|
||||
SINGLE_BRANCH=0 # Multi-branch model architecture (not single-branch)
|
||||
AMP_MODE=0 # Disable Automatic Mixed Precision (AMP) for training
|
||||
ROUNDS=650 # Total number of federated learning rounds
|
||||
EPOCHS_PER_ROUND=1 # Number of local epochs per round of federated learning
|
||||
EXP_ID="EdgeFLite_R110_80c_650r" # Experiment ID for tracking
|
||||
|
||||
# Navigate to the project directory
|
||||
cd EdgeFLite # Change to the EdgeFLite project directory
|
||||
|
||||
# Execute the training process for federated learning with the defined parameters
|
||||
python train_EdgeFLite.py \
|
||||
--is_fed=${FED_MODE} # Enable federated learning mode
|
||||
--fixed_cluster=${CLUSTER_FIXED} # Use dynamic clusters
|
||||
--split_factor=${SPLIT_RATIO} # Set the dataset split ratio
|
||||
--num_clusters=${TOTAL_CLUSTERS} # Total number of clusters (clients)
|
||||
--num_selected=${SELECTED_CLIENTS} # Number of clients selected per federated round
|
||||
--arch=${MODEL_ARCH} # Set model architecture (ResNet-110 variant)
|
||||
--dataset=${DATASET_NAME} # Dataset name (CIFAR-100)
|
||||
--num_classes=${NUM_CLASS_LABELS} # Number of classes in the dataset
|
||||
--is_single_branch=${SINGLE_BRANCH} # Use multi-branch model (set to 0)
|
||||
--is_amp=${AMP_MODE} # Disable automatic mixed precision
|
||||
--num_rounds=${ROUNDS} # Total number of rounds for federated learning
|
||||
--fed_epochs=${EPOCHS_PER_ROUND} # Number of local epochs per round
|
||||
--spid=${EXP_ID} # Set experiment ID for tracking
|
||||
--data=${DATA_PATH} # Provide dataset path
|
||||
--model_dir=${OUTPUT_DIR} # Directory where the model will be saved
|
49
EdgeFLite/scripts/EdgeFLite_W168_96c_650r.sh
Normal file
49
EdgeFLite/scripts/EdgeFLite_W168_96c_650r.sh
Normal file
@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Initialize environment and load necessary modules
|
||||
# This sets up the environment for running the necessary libraries like GCC, OpenMPI, CUDA, cuDNN, NCCL, and Python
|
||||
source /etc/profile.d/modules.sh
|
||||
module load gcc/11.2.0 # Load GCC version 11.2.0 for compiling
|
||||
module load openmpi/4.1.3 # Load OpenMPI version 4.1.3 for distributed computing
|
||||
module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2 for GPU computing
|
||||
module load cudnn/8.3/8.3.3 # Load cuDNN version 8.3.3 for deep learning frameworks
|
||||
module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11.4-1 for multi-GPU communication
|
||||
module load python/3.10/3.10.4 # Load Python version 3.10.4
|
||||
|
||||
# Activate the Python virtual environment
|
||||
# This activates the pre-configured virtual environment where necessary Python packages (e.g., PyTorch, Horovod) are installed
|
||||
source ~/venv/pytorch1.11+horovod/bin/activate
|
||||
|
||||
# Prepare the log directory and clean up any old records
|
||||
# Create a log directory for this job run and remove any previous log records
|
||||
LOG_DIRECTORY="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}"
|
||||
rm -rf ${LOG_DIRECTORY} # Remove old logs if they exist
|
||||
mkdir -p ${LOG_DIRECTORY} # Create a new directory for current job logs
|
||||
|
||||
# Set up local data directory and copy dataset
|
||||
# Define local data storage and copy the dataset for training the model
|
||||
DATA_STORAGE="${SGE_LOCALDIR}/${JOB_ID}/"
|
||||
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_STORAGE} # Copy CIFAR-100 dataset to the local directory
|
||||
|
||||
# Change directory to project location
|
||||
# Navigate to the EdgeFLite project directory to execute the training script
|
||||
cd EdgeFLite
|
||||
|
||||
# Execute the training process for the federated learning model
|
||||
# This runs the model training with specific hyperparameters for federated learning, including architecture, dataset, and configuration settings
|
||||
python train_EdgeFLite.py \
|
||||
--is_fed=1 \ # Enable federated learning mode
|
||||
--fixed_cluster=0 \ # Disable fixed clusters, allowing dynamic changes
|
||||
--split_factor=16 \ # Set data split factor to 16
|
||||
--num_clusters=6 \ # Use 6 clusters for the federated learning process
|
||||
--num_selected=6 \ # Select 6 clients for each training round
|
||||
--arch="wide_resnetsl16_8" \ # Use a Wide ResNet architecture with depth 16 and width 8
|
||||
--dataset="cifar100" \ # Specify CIFAR-100 as the dataset
|
||||
--num_classes=100 \ # CIFAR-100 has 100 output classes
|
||||
--is_single_branch=0 \ # Use multi-branch (multi-head) learning
|
||||
--is_amp=0 \ # Disable automatic mixed precision training
|
||||
--num_rounds=650 \ # Train for 650 communication rounds
|
||||
--fed_epochs=1 \ # Each client trains for 1 epoch per round
|
||||
--spid="EdgeFLite_W168_96c_650r" \ # Set the unique identifier for the job
|
||||
--data=${DATA_STORAGE} # Provide the location of the dataset
|
44
EdgeFLite/scripts/EdgeFLite_W168_96c_650r2.sh
Normal file
44
EdgeFLite/scripts/EdgeFLite_W168_96c_650r2.sh
Normal file
@ -0,0 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Load necessary system modules for the environment
|
||||
source /etc/profile.d/modules.sh
|
||||
module load gcc/11.2.0 # Load GCC compiler version 11.2.0
|
||||
module load openmpi/4.1.3 # Load OpenMPI version 4.1.3 for parallel processing
|
||||
module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2 for GPU acceleration
|
||||
module load cudnn/8.3/8.3.3 # Load cuDNN version 8.3.3 for deep learning libraries
|
||||
module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11.4-1 for multi-GPU communication
|
||||
module load python/3.10/3.10.4 # Load Python version 3.10.4
|
||||
|
||||
# Activate the virtual environment for PyTorch and Horovod
|
||||
source ~/venv/pytorch1.11+horovod/bin/activate
|
||||
|
||||
# Set up the log directory and remove any previous log records
|
||||
LOG_OUTPUT="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}"
|
||||
rm -rf ${LOG_OUTPUT} # Clean previous logs
|
||||
mkdir -p ${LOG_OUTPUT} # Create new log directory
|
||||
|
||||
# Prepare local storage for the dataset
|
||||
LOCAL_DATA_DIR="${SGE_LOCALDIR}/${JOB_ID}/" # Set local storage path
|
||||
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${LOCAL_DATA_DIR} # Copy CIFAR-100 data to local storage
|
||||
|
||||
# Move to the project directory
|
||||
cd EdgeFLite
|
||||
|
||||
# Run the federated learning experiment with the specified parameters
|
||||
python run_gkt.py \
|
||||
--is_fed=1 \ # Enable federated learning
|
||||
--fixed_cluster=0 \ # Use dynamic clustering
|
||||
--split_factor=1 \ # Set split factor
|
||||
--num_clusters=20 \ # Number of clusters in the federation
|
||||
--num_selected=20 \ # Number of selected clients per round
|
||||
--arch=resnet_model_110sl \ # Model architecture: ResNet-110 small layer
|
||||
--dataset=cifar100 \ # Dataset: CIFAR-100
|
||||
--num_classes=100 \ # Number of classes in the dataset
|
||||
--is_single_branch=0 \ # Enable multi-branch model
|
||||
--is_amp=0 \ # Disable automatic mixed precision
|
||||
--num_rounds=650 \ # Total number of federated learning rounds
|
||||
--fed_epochs=1 \ # Number of local epochs per round
|
||||
--cifar100_non_iid="quantity_skew" \ # Specify non-IID scenario: quantity skew
|
||||
--spid="FGKT_R110_20c_skew" \ # Experiment identifier
|
||||
--data=${LOCAL_DATA_DIR} # Path to the local dataset
|
51
EdgeFLite/scripts/EdgeFLite_W168_96c_650r32.sh
Normal file
51
EdgeFLite/scripts/EdgeFLite_W168_96c_650r32.sh
Normal file
@ -0,0 +1,51 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Load environment modules required for execution
|
||||
# This block sets up necessary modules, including compilers and deep learning libraries
|
||||
source /etc/profile.d/modules.sh
|
||||
module load gcc/11.2.0 # Load GCC compiler version 11.2.0
|
||||
module load openmpi/4.1.3 # Load OpenMPI for distributed computing
|
||||
module load cuda/11.5/11.5.2 # Load CUDA 11.5 for GPU acceleration
|
||||
module load cudnn/8.3/8.3.3 # Load cuDNN for deep neural network operations
|
||||
module load nccl/2.11/2.11.4-1 # Load NCCL for multi-GPU communication
|
||||
module load python/3.10/3.10.4 # Load Python version 3.10
|
||||
|
||||
# Activate the Python environment
|
||||
# This line activates a Python virtual environment with required packages (e.g., PyTorch and Horovod)
|
||||
source ~/venv/pytorch1.11+horovod/bin/activate
|
||||
|
||||
# Create and clean the log directory for this job
|
||||
# The log directory is where all training logs will be stored for this specific job
|
||||
LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}"
|
||||
rm -rf ${LOG_PATH} # Remove any pre-existing log directory
|
||||
mkdir -p ${LOG_PATH} # Create a new log directory
|
||||
|
||||
# Prepare the local dataset storage
|
||||
# This copies the dataset to a local directory for faster access during training
|
||||
DATA_PATH="${SGE_LOCALDIR}/${JOB_ID}/"
|
||||
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_PATH}
|
||||
|
||||
# Change to the working directory of the federated training scripts
|
||||
# The working directory contains the necessary scripts for running the training process
|
||||
cd EdgeFLite
|
||||
|
||||
# Execute the federated training process with the specified configuration
|
||||
# This command runs the federated learning training script with several parameters
|
||||
python run_gkt.py \
|
||||
--is_fed=1 # Enables federated learning mode
|
||||
--fixed_cluster=0 # Dynamic clusters during training
|
||||
--split_factor=1 # Data split factor for federated learning
|
||||
--num_clusters=20 # Number of clusters for federated training
|
||||
--num_selected=20 # Number of selected devices per round
|
||||
--arch="wide_resnetsl50_2" # Model architecture (Wide ResNet with layers)
|
||||
--dataset="pill_base" # Dataset being used for training
|
||||
--num_classes=98 # Number of classes in the dataset
|
||||
--is_single_branch=0 # Multi-branch model
|
||||
--is_amp=0 # Disable automatic mixed precision
|
||||
--num_rounds=350 # Total number of communication rounds in federated learning
|
||||
--fed_epochs=1 # Number of local epochs per device
|
||||
--batch_size=32 # Batch size for training
|
||||
--crop_size=224 # Crop size for image preprocessing
|
||||
--spid="FGKT_W502_20c_350r" # Unique identifier for the specific training experiment
|
||||
--data=${DATA_PATH} # Path to the dataset being used for training
|
43
EdgeFLite/scripts/EdgeFLite_W168_96c_650r4.sh
Normal file
43
EdgeFLite/scripts/EdgeFLite_W168_96c_650r4.sh
Normal file
@ -0,0 +1,43 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Load necessary modules and dependencies
|
||||
source /etc/profile.d/modules.sh
|
||||
module load gcc/11.2.0
|
||||
module load openmpi/4.1.3
|
||||
module load cuda/11.5/11.5.2
|
||||
module load cudnn/8.3/8.3.3
|
||||
module load nccl/2.11/2.11.4-1
|
||||
module load python/3.10/3.10.4
|
||||
|
||||
# Activate the Python environment
|
||||
source ~/venv/pytorch1.11+horovod/bin/activate
|
||||
|
||||
# Configure log directory and clean up any existing records
|
||||
OUTPUT_LOG_DIR="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}"
|
||||
rm -rf ${OUTPUT_LOG_DIR}
|
||||
mkdir -p ${OUTPUT_LOG_DIR}
|
||||
|
||||
# Copy dataset to local directory for processing
|
||||
LOCAL_DATA_PATH="${SGE_LOCALDIR}/${JOB_ID}/"
|
||||
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${LOCAL_DATA_PATH}
|
||||
|
||||
# Switch to the working directory containing the training scripts
|
||||
cd EdgeFLite
|
||||
|
||||
# Run the training script with specified settings for federated learning
|
||||
python run_gkt.py \
|
||||
--is_fed=1 \ # Enable federated learning mode
|
||||
--fixed_cluster=0 \ # Use dynamic clustering
|
||||
--split_factor=1 \ # Split factor for distributed computation
|
||||
--num_clusters=20 \ # Number of clusters to create
|
||||
--num_selected=20 \ # Number of selected clients per round
|
||||
--arch="wide_resnet16_8" \ # Architecture to use (Wide ResNet-16-8)
|
||||
--dataset="cifar10" \ # Dataset to use (CIFAR-10)
|
||||
--num_classes=10 \ # Number of classes in the dataset
|
||||
--is_single_branch=0 \ # Disable single branch training mode
|
||||
--is_amp=0 \ # Disable automatic mixed precision
|
||||
--num_rounds=300 \ # Number of communication rounds
|
||||
--fed_epochs=1 \ # Number of local epochs for each client per round
|
||||
--spid="fedgkt_wrn168_split1_cifar10_20clients_20choose_300rounds" \ # Unique ID for the experiment
|
||||
--data=${LOCAL_DATA_PATH} # Local path to the dataset
|
44
EdgeFLite/scripts/EdgeFLite_W168_96c_650r8.sh
Normal file
44
EdgeFLite/scripts/EdgeFLite_W168_96c_650r8.sh
Normal file
@ -0,0 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Load environment modules and required dependencies
|
||||
source /etc/profile.d/modules.sh
|
||||
module load gcc/11.2.0 # Load GCC version 11.2.0
|
||||
module load openmpi/4.1.3 # Load OpenMPI version 4.1.3
|
||||
module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2
|
||||
module load cudnn/8.3/8.3.3 # Load cuDNN version 8.3.3
|
||||
module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11.4-1
|
||||
module load python/3.10/3.10.4 # Load Python version 3.10.4
|
||||
|
||||
# Activate the virtual Python environment
|
||||
source ~/venv/pytorch1.11+horovod/bin/activate # Activate a virtual environment for PyTorch and Horovod
|
||||
|
||||
# Define the log directory, clean up old records if any, and recreate the directory
|
||||
LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}"
|
||||
rm -rf ${LOG_PATH} # Remove any existing log directory
|
||||
mkdir -p ${LOG_PATH} # Create a new log directory
|
||||
|
||||
# Set up the local data directory and copy the dataset into it
|
||||
DATA_STORAGE="${SGE_LOCALDIR}/${JOB_ID}/" # Define a local data directory for the job
|
||||
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_STORAGE} # Copy CIFAR-100 dataset to the local directory
|
||||
|
||||
# Navigate to the working directory where training scripts are located
|
||||
cd EdgeFLite # Change directory to the EdgeFLite project
|
||||
|
||||
# Execute the training script with federated learning parameters
|
||||
python run_gkt.py \
|
||||
--is_fed=1 \ # Enable federated learning
|
||||
--fixed_cluster=0 \ # Allow dynamic cluster formation
|
||||
--split_factor=1 \ # Data split factor
|
||||
--num_clusters=20 \ # Number of clusters
|
||||
--num_selected=20 \ # Number of selected clients per round
|
||||
--arch="wide_resnet16_8" \ # Network architecture: Wide ResNet 16-8
|
||||
--dataset="cifar10" \ # Use CIFAR-10 dataset
|
||||
--num_classes=10 \ # Number of classes in CIFAR-10
|
||||
--is_single_branch=0 \ # Multi-branch network
|
||||
--is_amp=0 \ # Disable Automatic Mixed Precision (AMP)
|
||||
--num_rounds=300 \ # Number of federated learning rounds
|
||||
--fed_epochs=1 \ # Number of local training epochs per round
|
||||
--cifar10_non_iid="quantity_skew" \ # Non-IID data distribution: quantity skew
|
||||
--spid="FGKT_W168_20c_skew" \ # Set a specific job identifier
|
||||
--data=${DATA_STORAGE} # Path to the dataset
|
43
EdgeFLite/scripts/FGKT_R110_20c_650r.sh
Normal file
43
EdgeFLite/scripts/FGKT_R110_20c_650r.sh
Normal file
@ -0,0 +1,43 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Load necessary system modules for the job
|
||||
source /etc/profile.d/modules.sh
|
||||
module load gcc/11.2.0 # Load GCC compiler
|
||||
module load openmpi/4.1.3 # Load OpenMPI for distributed computing
|
||||
module load cuda/11.5/11.5.2 # Load CUDA for GPU acceleration
|
||||
module load cudnn/8.3/8.3.3 # Load cuDNN for deep learning frameworks
|
||||
module load nccl/2.11/2.11.4-1 # Load NCCL for multi-GPU communication
|
||||
module load python/3.10/3.10.4 # Load Python 3.10 environment
|
||||
|
||||
# Activate the required Python virtual environment
|
||||
source ~/venv/pytorch1.11+horovod/bin/activate # Activate PyTorch 1.11 + Horovod environment
|
||||
|
||||
# Define log directory and clean up any existing records before starting
|
||||
LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}" # Set log path
|
||||
rm -rf ${LOG_PATH} # Remove any existing log directory
|
||||
mkdir -p ${LOG_PATH} # Create new log directory
|
||||
|
||||
# Copy the dataset to the local temporary directory
|
||||
DATA_DIR="${SGE_LOCALDIR}/${JOB_ID}/" # Set the local directory for dataset
|
||||
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_DIR} # Copy CIFAR-100 dataset to the local directory
|
||||
|
||||
# Move to the directory containing the training scripts
|
||||
cd EdgeFLite # Change to EdgeFLite project directory
|
||||
|
||||
# Start the federated learning training process with the specified parameters
|
||||
python run_gkt.py \
|
||||
--is_fed=1 \ # Enable federated learning
|
||||
--fixed_cluster=0 \ # Use dynamic clustering
|
||||
--split_factor=1 \ # Set data split factor
|
||||
--num_clusters=20 \ # Set the number of clusters
|
||||
--num_selected=20 \ # Number of selected clients per round
|
||||
--arch="resnet_model_110sl" \ # Model architecture (ResNet 110 with single-layer output)
|
||||
--dataset="cifar100" \ # Dataset used (CIFAR-100)
|
||||
--num_classes=100 \ # Number of classes in the dataset
|
||||
--is_single_branch=0 \ # Enable multi-branch model
|
||||
--is_amp=0 \ # Disable automatic mixed precision
|
||||
--num_rounds=650 \ # Number of federated learning rounds
|
||||
--fed_epochs=1 \ # Number of local epochs per federated round
|
||||
--spid="FGKT_R110_20c_650r" \ # Experiment ID for logging and tracking
|
||||
--data=${DATA_DIR} # Specify the path to the dataset
|
56
EdgeFLite/scripts/FGKT_R110_20c_skew.sh
Normal file
56
EdgeFLite/scripts/FGKT_R110_20c_skew.sh
Normal file
@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Load necessary system modules
|
||||
source /etc/profile.d/modules.sh
|
||||
|
||||
# Load the GCC module version 11.2.0
|
||||
module load gcc/11.2.0
|
||||
|
||||
# Load the OpenMPI module version 4.1.3
|
||||
module load openmpi/4.1.3
|
||||
|
||||
# Load the CUDA module version 11.5.2
|
||||
module load cuda/11.5/11.5.2
|
||||
|
||||
# Load the cuDNN module version 8.3.3
|
||||
module load cudnn/8.3/8.3.3
|
||||
|
||||
# Load the NCCL module version 2.11.4-1
|
||||
module load nccl/2.11/2.11.4-1
|
||||
|
||||
# Load the Python module version 3.10.4
|
||||
module load python/3.10/3.10.4
|
||||
|
||||
# Activate the virtual environment for PyTorch and Horovod
|
||||
source ~/venv/pytorch1.11+horovod/bin/activate
|
||||
|
||||
# Set up the log directory and clean previous records if they exist
|
||||
LOG_OUTPUT="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}"
|
||||
rm -rf ${LOG_OUTPUT} # Remove previous log files
|
||||
mkdir -p ${LOG_OUTPUT} # Create a new directory for logs
|
||||
|
||||
# Prepare local storage for the dataset by copying it to a local directory
|
||||
LOCAL_DATA_DIR="${SGE_LOCALDIR}/${JOB_ID}/"
|
||||
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${LOCAL_DATA_DIR}
|
||||
|
||||
# Navigate to the EdgeFLite project directory
|
||||
cd EdgeFLite
|
||||
|
||||
# Run the federated learning experiment with the specified parameters
|
||||
python run_gkt.py \
|
||||
--is_fed=1 \ # Enable federated learning
|
||||
--fixed_cluster=0 \ # Disable fixed cluster settings
|
||||
--split_factor=1 \ # Use split factor of 1
|
||||
--num_clusters=20 \ # Set the number of clusters to 20
|
||||
--num_selected=20 \ # Select 20 clients for each round
|
||||
--arch=resnet_model_110sl \ # Use ResNet110 single branch architecture
|
||||
--dataset=cifar100 \ # Use CIFAR-100 dataset
|
||||
--num_classes=100 \ # Set the number of classes to 100
|
||||
--is_single_branch=0 \ # Use multiple branches in the model
|
||||
--is_amp=0 \ # Disable automatic mixed precision
|
||||
--num_rounds=650 \ # Set the number of communication rounds to 650
|
||||
--fed_epochs=1 \ # Set the number of federated epochs to 1
|
||||
--cifar100_non_iid="quantity_skew" \ # Apply non-IID data partitioning (quantity skew)
|
||||
--spid="FGKT_R110_20c_skew" \ # Set the experiment ID
|
||||
--data=${LOCAL_DATA_DIR} # Set the path to the dataset in local storage
|
61
EdgeFLite/scripts/FGKT_W168_20c_300r.sh
Normal file
61
EdgeFLite/scripts/FGKT_W168_20c_300r.sh
Normal file
@ -0,0 +1,61 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Load necessary modules and dependencies
|
||||
source /etc/profile.d/modules.sh
|
||||
|
||||
# Load GCC version 11.2.0
|
||||
module load gcc/11.2.0
|
||||
|
||||
# Load OpenMPI version 4.1.3 for distributed computing
|
||||
module load openmpi/4.1.3
|
||||
|
||||
# Load CUDA version 11.5 (subversion 11.5.2) for GPU acceleration
|
||||
module load cuda/11.5/11.5.2
|
||||
|
||||
# Load cuDNN version 8.3 (subversion 8.3.3) for deep learning operations
|
||||
module load cudnn/8.3/8.3.3
|
||||
|
||||
# Load NCCL version 2.11 (subversion 2.11.4-1) for multi-GPU communication
|
||||
module load nccl/2.11/2.11.4-1
|
||||
|
||||
# Load Python version 3.10 (subversion 3.10.4)
|
||||
module load python/3.10/3.10.4
|
||||
|
||||
# Activate the Python virtual environment for PyTorch 1.11 + Horovod
|
||||
source ~/venv/pytorch1.11+horovod/bin/activate
|
||||
|
||||
# Configure the output log directory and clean up any existing records
|
||||
OUTPUT_LOG_DIR="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}"
|
||||
|
||||
# Remove any previous log files from the directory
|
||||
rm -rf ${OUTPUT_LOG_DIR}
|
||||
|
||||
# Create a fresh directory for storing logs
|
||||
mkdir -p ${OUTPUT_LOG_DIR}
|
||||
|
||||
# Copy the dataset to a local directory for processing during training
|
||||
LOCAL_DATA_PATH="${SGE_LOCALDIR}/${JOB_ID}/"
|
||||
|
||||
# Copy the dataset files from the performance test directory to the local directory
|
||||
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${LOCAL_DATA_PATH}
|
||||
|
||||
# Switch to the working directory containing the EdgeFLite training scripts
|
||||
cd EdgeFLite
|
||||
|
||||
# Run the federated learning training script with the specified settings
|
||||
python run_gkt.py \
|
||||
--is_fed=1 \ # Enable federated learning
|
||||
--fixed_cluster=0 \ # Disable fixed clusters
|
||||
--split_factor=1 \ # Set data split factor
|
||||
--num_clusters=20 \ # Specify number of clusters
|
||||
--num_selected=20 \ # Specify number of selected clients
|
||||
--arch="wide_resnet16_8" \ # Use Wide ResNet 16-8 architecture
|
||||
--dataset="cifar10" \ # Set dataset to CIFAR-10
|
||||
--num_classes=10 \ # Set number of classes
|
||||
--is_single_branch=0 \ # Use multi-branch training
|
||||
--is_amp=0 \ # Disable automatic mixed precision (AMP)
|
||||
--num_rounds=300 \ # Set number of training rounds
|
||||
--fed_epochs=1 \ # Set number of federated learning epochs per round
|
||||
--spid="fedgkt_wrn168_split1_cifar10_20clients_20choose_300rounds" \ # Set session ID
|
||||
--data=${LOCAL_DATA_PATH} # Set path to the local dataset
|
44
EdgeFLite/scripts/FGKT_W168_20c_skew.sh
Normal file
44
EdgeFLite/scripts/FGKT_W168_20c_skew.sh
Normal file
@ -0,0 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Load environment modules and required dependencies
|
||||
source /etc/profile.d/modules.sh
|
||||
module load gcc/11.2.0 # Load GCC compiler version 11.2.0
|
||||
module load openmpi/4.1.3 # Load OpenMPI version 4.1.3 for distributed computing
|
||||
module load cuda/11.5/11.5.2 # Load CUDA version 11.5.2 for GPU acceleration
|
||||
module load cudnn/8.3/8.3.3 # Load cuDNN version 8.3.3 for deep learning libraries
|
||||
module load nccl/2.11/2.11.4-1 # Load NCCL version 2.11.4 for multi-GPU communication
|
||||
module load python/3.10/3.10.4 # Load Python version 3.10.4
|
||||
|
||||
# Activate the virtual Python environment
|
||||
source ~/venv/pytorch1.11+horovod/bin/activate # Activate the virtual environment with PyTorch 1.11 and Horovod
|
||||
|
||||
# Define the log directory, clean up old records if any, and recreate the directory
|
||||
LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}"
|
||||
rm -rf ${LOG_PATH} # Remove the existing log directory if it exists
|
||||
mkdir -p ${LOG_PATH} # Create the log directory
|
||||
|
||||
# Set up the local data directory and copy the dataset into it
|
||||
DATA_STORAGE="${SGE_LOCALDIR}/${JOB_ID}/"
|
||||
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_STORAGE} # Copy CIFAR-100 dataset into the local storage directory
|
||||
|
||||
# Navigate to the working directory where training scripts are located
|
||||
cd EdgeFLite # Change directory to the project EdgeFLite
|
||||
|
||||
# Execute the training script with federated learning parameters
|
||||
python run_gkt.py \
|
||||
--is_fed=1 # Enable federated learning mode
|
||||
--fixed_cluster=0 # Allow dynamic cluster selection
|
||||
--split_factor=1 # Set the split factor for cluster selection
|
||||
--num_clusters=20 # Specify the number of clusters for federated learning
|
||||
--num_selected=20 # Specify the number of selected clusters for each round
|
||||
--arch="wide_resnet16_8" # Use the Wide ResNet16_8 architecture
|
||||
--dataset="cifar10" # Specify the dataset as CIFAR-10
|
||||
--num_classes=10 # Set the number of classes for classification
|
||||
--is_single_branch=0 # Use multiple branches (not single branch)
|
||||
--is_amp=0 # Disable automatic mixed precision (AMP)
|
||||
--num_rounds=300 # Specify the number of federated learning rounds
|
||||
--fed_epochs=1 # Set the number of epochs per round for federated learning
|
||||
--cifar10_non_iid="quantity_skew" # Use non-iid data distribution with quantity skew for CIFAR-10
|
||||
--spid="FGKT_W168_20c_skew" # Set the specific process ID for tracking
|
||||
--data=${DATA_STORAGE} # Specify the local data storage path
|
60
EdgeFLite/scripts/FGKT_W502_20c_350r.sh
Normal file
60
EdgeFLite/scripts/FGKT_W502_20c_350r.sh
Normal file
@ -0,0 +1,60 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Load environment modules required for execution
|
||||
source /etc/profile.d/modules.sh
|
||||
|
||||
# Load the GCC compiler version 11.2.0
|
||||
module load gcc/11.2.0
|
||||
|
||||
# Load the OpenMPI version 4.1.3 for distributed computing
|
||||
module load openmpi/4.1.3
|
||||
|
||||
# Load CUDA version 11.5 (subversion 11.5.2) for GPU acceleration
|
||||
module load cuda/11.5/11.5.2
|
||||
|
||||
# Load cuDNN version 8.3 (subversion 8.3.3) for deep learning libraries
|
||||
module load cudnn/8.3/8.3.3
|
||||
|
||||
# Load NCCL version 2.11 (subversion 2.11.4-1) for multi-GPU communication
|
||||
module load nccl/2.11/2.11.4-1
|
||||
|
||||
# Load Python version 3.10 (subversion 3.10.4) as the programming language
|
||||
module load python/3.10/3.10.4
|
||||
|
||||
# Activate the Python virtual environment for PyTorch and Horovod
|
||||
source ~/venv/pytorch1.11+horovod/bin/activate
|
||||
|
||||
# Create and clean the log directory for this job
|
||||
LOG_PATH="/home/projadmin/Federated_Learning/project_EdgeFLite/records/${JOB_NAME}_${JOB_ID}"
|
||||
# Remove any existing log directory to avoid conflicts
|
||||
rm -rf ${LOG_PATH}
|
||||
# Create a fresh log directory for the current job
|
||||
mkdir -p ${LOG_PATH}
|
||||
|
||||
# Prepare the local dataset storage
|
||||
DATA_PATH="${SGE_LOCALDIR}/${JOB_ID}/"
|
||||
# Copy the dataset for local processing to improve performance
|
||||
cp -r ../summit2024/simpleFL/performance_test/cifar100/data ${DATA_PATH}
|
||||
|
||||
# Change to the working directory of the federated training scripts
|
||||
cd EdgeFLite
|
||||
|
||||
# Execute the federated training process with the specified configuration
|
||||
python run_gkt.py \
|
||||
--is_fed=1 \ # Enable federated training mode
|
||||
--fixed_cluster=0 \ # Do not fix clusters
|
||||
--split_factor=1 \ # Set the split factor to 1
|
||||
--num_clusters=20 \ # Number of clusters to use in federated training
|
||||
--num_selected=20 \ # Number of selected clusters per round
|
||||
--arch="wide_resnetsl50_2" \ # Use the wide ResNet-50_2 architecture
|
||||
--dataset="pill_base" \ # Specify the dataset to use (Pill Base)
|
||||
--num_classes=98 \ # Number of classes in the dataset
|
||||
--is_single_branch=0 \ # Enable multi-branch training
|
||||
--is_amp=0 \ # Disable automatic mixed precision training
|
||||
--num_rounds=350 \ # Number of federated training rounds
|
||||
--fed_epochs=1 \ # Number of epochs per federated round
|
||||
--batch_size=32 \ # Batch size for training
|
||||
--crop_size=224 \ # Image crop size
|
||||
--spid="FGKT_W502_20c_350r" \ # Specify the unique session ID for logging
|
||||
--data=${DATA_PATH} # Path to the dataset
|
7
EdgeFLite/settings.py
Normal file
7
EdgeFLite/settings.py
Normal file
@ -0,0 +1,7 @@
|
||||
import os # Import the 'os' module, which provides functions for interacting with the operating system
|
||||
|
||||
current_directory = os.getcwd()
|
||||
|
||||
# Define a variable 'data_directory' to store the path to the directory where data will be stored.
|
||||
# In this case, it's being set to the same path as 'current_directory', meaning the data will be stored in the same location as the current working directory
|
||||
data_directory = current_directory
|
BIN
EdgeFLite/thop/.DS_Store
vendored
Normal file
BIN
EdgeFLite/thop/.DS_Store
vendored
Normal file
Binary file not shown.
38
EdgeFLite/thop/helper_utils.py
Normal file
38
EdgeFLite/thop/helper_utils.py
Normal file
@ -0,0 +1,38 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
from collections.abc import Iterable
|
||||
|
||||
# Define a function named 'clever_format' that takes two arguments:
|
||||
# 1. 'nums' - either a single number or a list of numbers to format.
|
||||
# 2. 'fmt' - an optional string argument specifying the format for the numbers (default is "%.2f", meaning two decimal places).
|
||||
def clever_format(nums, fmt="%.2f"):
|
||||
|
||||
# Check if the input 'nums' is not an instance of an iterable (like a list or tuple).
|
||||
# If it is not iterable, convert the single number into a list for uniform processing later.
|
||||
if not isinstance(nums, Iterable):
|
||||
nums = [nums]
|
||||
|
||||
# Create an empty list to store the formatted numbers.
|
||||
formatted_nums = []
|
||||
|
||||
# Loop through each number in the 'nums' list.
|
||||
for num in nums:
|
||||
# Check if the number is greater than 1 trillion (1e12). If so, format it by dividing it by 1 trillion and appending 'T' (for trillions).
|
||||
if num > 1e12:
|
||||
formatted_nums.append(fmt % (num / 1e12) + "T")
|
||||
# If the number is greater than 1 billion (1e9), format it by dividing by 1 billion and appending 'G' (for billions).
|
||||
elif num > 1e9:
|
||||
formatted_nums.append(fmt % (num / 1e9) + "G")
|
||||
# If the number is greater than 1 million (1e6), format it by dividing by 1 million and appending 'M' (for millions).
|
||||
elif num > 1e6:
|
||||
formatted_nums.append(fmt % (num / 1e6) + "M")
|
||||
# If the number is greater than 1 thousand (1e3), format it by dividing by 1 thousand and appending 'K' (for thousands).
|
||||
elif num > 1e3:
|
||||
formatted_nums.append(fmt % (num / 1e3) + "K")
|
||||
# If the number is less than 1 thousand, simply format it using the provided format and append 'B' (for base or basic).
|
||||
else:
|
||||
formatted_nums.append(fmt % num + "B")
|
||||
|
||||
# If only one number was passed, return just the formatted string for that number.
|
||||
# If multiple numbers were passed, return a tuple containing all formatted numbers.
|
||||
return formatted_nums[0] if len(formatted_nums) == 1 else tuple(formatted_nums)
|
91
EdgeFLite/thop/hooks_basic.py
Normal file
91
EdgeFLite/thop/hooks_basic.py
Normal file
@ -0,0 +1,91 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.conv import _ConvNd
|
||||
|
||||
multiply_adds = 1
|
||||
|
||||
def count_parameters(m, x, y):
|
||||
"""Counts the number of parameters in a model."""
|
||||
total_params = sum(p.numel() for p in m.parameters())
|
||||
m.total_params[0] = torch.DoubleTensor([total_params])
|
||||
|
||||
def zero_ops(m, x, y):
|
||||
"""Sets total operations to zero."""
|
||||
m.total_ops += torch.DoubleTensor([0])
|
||||
|
||||
def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor):
|
||||
"""Counts operations for convolutional layers."""
|
||||
x = x[0]
|
||||
kernel_ops = m.weight[0][0].numel() # Kw x Kh
|
||||
bias_ops = 1 if m.bias is not None else 0
|
||||
total_ops = y.nelement() * (m.in_channels // m.groups * kernel_ops + bias_ops)
|
||||
m.total_ops += torch.DoubleTensor([total_ops])
|
||||
|
||||
def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor):
|
||||
"""Alternative method for counting operations for convolutional layers."""
|
||||
x = x[0]
|
||||
output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel()
|
||||
kernel_ops = m.weight.numel() + (m.bias.numel() if m.bias is not None else 0)
|
||||
m.total_ops += torch.DoubleTensor([output_size * kernel_ops])
|
||||
|
||||
def count_bn(m, x, y):
|
||||
"""Counts operations for batch normalization layers."""
|
||||
x = x[0]
|
||||
nelements = x.numel()
|
||||
if not m.training:
|
||||
total_ops = 2 * nelements
|
||||
m.total_ops += torch.DoubleTensor([total_ops])
|
||||
|
||||
def count_relu(m, x, y):
|
||||
"""Counts operations for ReLU activation."""
|
||||
x = x[0]
|
||||
nelements = x.numel()
|
||||
m.total_ops += torch.DoubleTensor([nelements])
|
||||
|
||||
def count_softmax(m, x, y):
|
||||
"""Counts operations for softmax."""
|
||||
x = x[0]
|
||||
batch_size, nfeatures = x.size()
|
||||
total_ops = batch_size * (2 * nfeatures - 1)
|
||||
m.total_ops += torch.DoubleTensor([total_ops])
|
||||
|
||||
def count_avgpool(m, x, y):
|
||||
"""Counts operations for average pooling layers."""
|
||||
num_elements = y.numel()
|
||||
m.total_ops += torch.DoubleTensor([num_elements])
|
||||
|
||||
def count_adap_avgpool(m, x, y):
|
||||
"""Counts operations for adaptive average pooling layers."""
|
||||
kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor(list((m.output_size,))).squeeze()
|
||||
kernel_ops = torch.prod(kernel) + 1
|
||||
num_elements = y.numel()
|
||||
m.total_ops += torch.DoubleTensor([kernel_ops * num_elements])
|
||||
|
||||
def count_upsample(m, x, y):
|
||||
"""Counts operations for upsample layers."""
|
||||
if m.mode not in ("nearest", "linear", "bilinear", "bicubic"):
|
||||
logging.warning(f"Mode {m.mode} is not implemented yet, assuming zero ops")
|
||||
return zero_ops(m, x, y)
|
||||
|
||||
if m.mode == "nearest":
|
||||
return zero_ops(m, x, y)
|
||||
|
||||
total_ops = {
|
||||
"linear": 5,
|
||||
"bilinear": 11,
|
||||
"bicubic": 259, # 224 muls + 35 adds
|
||||
"trilinear": 31 # 2 * bilinear + 1 * linear
|
||||
}.get(m.mode, 0) * y.nelement()
|
||||
|
||||
m.total_ops += torch.DoubleTensor([total_ops])
|
||||
|
||||
def count_linear(m, x, y):
|
||||
"""Counts operations for linear layers."""
|
||||
total_ops = m.in_features * y.numel()
|
||||
m.total_ops += torch.DoubleTensor([total_ops])
|
195
EdgeFLite/thop/hooks_rnn.py
Normal file
195
EdgeFLite/thop/hooks_rnn.py
Normal file
@ -0,0 +1,195 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def _count_rnn_cell(input_size, hidden_size, bias=True):
|
||||
"""Calculate the total operations for a single RNN cell.
|
||||
|
||||
Args:
|
||||
input_size (int): Size of the input.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
bias (bool, optional): Whether the RNN cell uses bias. Defaults to True.
|
||||
|
||||
Returns:
|
||||
int: Total number of operations for the RNN cell.
|
||||
"""
|
||||
ops = hidden_size * (input_size + hidden_size) + hidden_size
|
||||
if bias:
|
||||
ops += hidden_size * 2
|
||||
return ops
|
||||
|
||||
def count_rnn_cell(cell: nn.RNNCell, x: torch.Tensor):
|
||||
"""Count operations for the RNNCell over a batch of input.
|
||||
|
||||
Args:
|
||||
cell (nn.RNNCell): The RNNCell to count operations for.
|
||||
x (torch.Tensor): Input tensor.
|
||||
"""
|
||||
ops = _count_rnn_cell(cell.input_size, cell.hidden_size, cell.bias)
|
||||
batch_size = x[0].size(0)
|
||||
total_ops = ops * batch_size
|
||||
cell.total_ops += torch.DoubleTensor([int(total_ops)])
|
||||
|
||||
def _count_gru_cell(input_size, hidden_size, bias=True):
|
||||
"""Calculate the total operations for a single GRU cell.
|
||||
|
||||
Args:
|
||||
input_size (int): Size of the input.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
bias (bool, optional): Whether the GRU cell uses bias. Defaults to True.
|
||||
|
||||
Returns:
|
||||
int: Total number of operations for the GRU cell.
|
||||
"""
|
||||
ops = (hidden_size + input_size) * hidden_size + hidden_size
|
||||
if bias:
|
||||
ops += hidden_size * 2
|
||||
ops *= 2 # For reset and update gates
|
||||
|
||||
ops += (hidden_size + input_size) * hidden_size + hidden_size # Calculate new gate
|
||||
if bias:
|
||||
ops += hidden_size * 2
|
||||
ops += hidden_size # Hadamard product
|
||||
ops += hidden_size * 3 # Final output
|
||||
|
||||
return ops
|
||||
|
||||
def count_gru_cell(cell: nn.GRUCell, x: torch.Tensor):
|
||||
"""Count operations for the GRUCell over a batch of input.
|
||||
|
||||
Args:
|
||||
cell (nn.GRUCell): The GRUCell to count operations for.
|
||||
x (torch.Tensor): Input tensor.
|
||||
"""
|
||||
ops = _count_gru_cell(cell.input_size, cell.hidden_size, cell.bias)
|
||||
batch_size = x[0].size(0)
|
||||
total_ops = ops * batch_size
|
||||
cell.total_ops += torch.DoubleTensor([int(total_ops)])
|
||||
|
||||
def _count_lstm_cell(input_size, hidden_size, bias=True):
|
||||
"""Calculate the total operations for a single LSTM cell.
|
||||
|
||||
Args:
|
||||
input_size (int): Size of the input.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
bias (bool, optional): Whether the LSTM cell uses bias. Defaults to True.
|
||||
|
||||
Returns:
|
||||
int: Total number of operations for the LSTM cell.
|
||||
"""
|
||||
ops = (input_size + hidden_size) * hidden_size + hidden_size
|
||||
if bias:
|
||||
ops += hidden_size * 2
|
||||
ops *= 4 # For input, forget, output, and cell gates
|
||||
|
||||
ops += hidden_size * 3 # Cell state update
|
||||
ops += hidden_size # Final output
|
||||
|
||||
return ops
|
||||
|
||||
def count_lstm_cell(cell: nn.LSTMCell, x: torch.Tensor):
|
||||
"""Count operations for the LSTMCell over a batch of input.
|
||||
|
||||
Args:
|
||||
cell (nn.LSTMCell): The LSTMCell to count operations for.
|
||||
x (torch.Tensor): Input tensor.
|
||||
"""
|
||||
ops = _count_lstm_cell(cell.input_size, cell.hidden_size, cell.bias)
|
||||
batch_size = x[0].size(0)
|
||||
total_ops = ops * batch_size
|
||||
cell.total_ops += torch.DoubleTensor([int(total_ops)])
|
||||
|
||||
def _count_rnn_layers(model: nn.RNN, num_layers, input_size, hidden_size):
|
||||
"""Calculate the total operations for RNN layers.
|
||||
|
||||
Args:
|
||||
model (nn.RNN): The RNN model.
|
||||
num_layers (int): Number of layers in the RNN.
|
||||
input_size (int): Size of the input.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
|
||||
Returns:
|
||||
int: Total number of operations for the RNN layers.
|
||||
"""
|
||||
ops = _count_rnn_cell(input_size, hidden_size, model.bias)
|
||||
for _ in range(num_layers - 1):
|
||||
ops += _count_rnn_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias)
|
||||
return ops
|
||||
|
||||
def count_rnn(model: nn.RNN, x: torch.Tensor):
|
||||
"""Count operations for the entire RNN over a batch of input.
|
||||
|
||||
Args:
|
||||
model (nn.RNN): The RNN model.
|
||||
x (torch.Tensor): Input tensor.
|
||||
"""
|
||||
batch_size = x[0].size(0) if model.batch_first else x[0].size(1)
|
||||
num_steps = x[0].size(1) if model.batch_first else x[0].size(0)
|
||||
|
||||
ops = _count_rnn_layers(model, model.num_layers, model.input_size, model.hidden_size)
|
||||
total_ops = ops * num_steps * batch_size
|
||||
model.total_ops += torch.DoubleTensor([int(total_ops)])
|
||||
|
||||
def _count_gru_layers(model: nn.GRU, num_layers, input_size, hidden_size):
|
||||
"""Calculate the total operations for GRU layers.
|
||||
|
||||
Args:
|
||||
model (nn.GRU): The GRU model.
|
||||
num_layers (int): Number of layers in the GRU.
|
||||
input_size (int): Size of the input.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
|
||||
Returns:
|
||||
int: Total number of operations for the GRU layers.
|
||||
"""
|
||||
ops = _count_gru_cell(input_size, hidden_size, model.bias)
|
||||
for _ in range(num_layers - 1):
|
||||
ops += _count_gru_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias)
|
||||
return ops
|
||||
|
||||
def count_gru(model: nn.GRU, x: torch.Tensor):
|
||||
"""Count operations for the entire GRU over a batch of input.
|
||||
|
||||
Args:
|
||||
model (nn.GRU): The GRU model.
|
||||
x (torch.Tensor): Input tensor.
|
||||
"""
|
||||
batch_size = x[0].size(0) if model.batch_first else x[0].size(1)
|
||||
num_steps = x[0].size(1) if model.batch_first else x[0].size(0)
|
||||
|
||||
ops = _count_gru_layers(model, model.num_layers, model.input_size, model.hidden_size)
|
||||
total_ops = ops * num_steps * batch_size
|
||||
model.total_ops += torch.DoubleTensor([int(total_ops)])
|
||||
|
||||
def _count_lstm_layers(model: nn.LSTM, num_layers, input_size, hidden_size):
|
||||
"""Calculate the total operations for LSTM layers.
|
||||
|
||||
Args:
|
||||
model (nn.LSTM): The LSTM model.
|
||||
num_layers (int): Number of layers in the LSTM.
|
||||
input_size (int): Size of the input.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
|
||||
Returns:
|
||||
int: Total number of operations for the LSTM layers.
|
||||
"""
|
||||
ops = _count_lstm_cell(input_size, hidden_size, model.bias)
|
||||
for _ in range(num_layers - 1):
|
||||
ops += _count_lstm_cell(hidden_size * (2 if model.bidirectional else 1), hidden_size, model.bias)
|
||||
return ops
|
||||
|
||||
def count_lstm(model: nn.LSTM, x: torch.Tensor):
|
||||
"""Count operations for the entire LSTM over a batch of input.
|
||||
|
||||
Args:
|
||||
model (nn.LSTM): The LSTM model.
|
||||
x (torch.Tensor): Input tensor.
|
||||
"""
|
||||
batch_size = x[0].size(0) if model.batch_first else x[0].size(1)
|
||||
num_steps = x[0].size(1) if model.batch_first else x[0].size(0)
|
||||
|
||||
ops = _count_lstm_layers(model, model.num_layers, model.input_size, model.hidden_size)
|
||||
total_ops = ops * num_steps * batch_size
|
||||
model.total_ops += torch.DoubleTensor([int(total_ops)])
|
168
EdgeFLite/thop/profiling.py
Normal file
168
EdgeFLite/thop/profiling.py
Normal file
@ -0,0 +1,168 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
# Importing necessary modules
|
||||
from distutils.version import LooseVersion # Used for version comparisons
|
||||
from .basic_hooks import * # Importing basic hooks (functions for profiling operations)
|
||||
from .rnn_hooks import * # Importing hooks specific to RNN operations
|
||||
|
||||
# Uncomment the following for logging purposes
|
||||
# import logging
|
||||
# logger = logging.getLogger(__name__) # Creating a logger instance
|
||||
# logger.setLevel(logging.INFO) # Setting the log level to INFO
|
||||
|
||||
# Functions to print text in different colors
|
||||
# Useful for visually differentiating output in terminal
|
||||
def prRed(skk):
|
||||
print("\033[91m{}\033[00m".format(skk)) # Print red text
|
||||
def prGreen(skk):
|
||||
print("\033[92m{}\033[00m".format(skk)) # Print green text
|
||||
def prYellow(skk):
|
||||
print("\033[93m{}\033[00m".format(skk)) # Print yellow text
|
||||
|
||||
# Checking if the installed version of PyTorch is outdated
|
||||
if LooseVersion(torch.__version__) < LooseVersion("1.0.0"):
|
||||
# If the version is below 1.0.0, print a warning
|
||||
logging.warning(
|
||||
f"You are using an old version of PyTorch {torch.__version__}, which THOP may not support in the future."
|
||||
)
|
||||
|
||||
# Setting the default data type for tensors
|
||||
default_dtype = torch.float64 # Using 64-bit float as the default precision
|
||||
|
||||
# Register hooks for different layers in PyTorch
|
||||
# Each layer type is mapped to its respective counting function
|
||||
register_hooks = {
|
||||
nn.ZeroPad2d: zero_ops,
|
||||
nn.Conv1d: count_convNd, nn.Conv2d: count_convNd, nn.Conv3d: count_convNd,
|
||||
nn.ConvTranspose1d: count_convNd, nn.ConvTranspose2d: count_convNd, nn.ConvTranspose3d: count_convNd,
|
||||
nn.BatchNorm1d: count_bn, nn.BatchNorm2d: count_bn, nn.BatchNorm3d: count_bn, nn.SyncBatchNorm: count_bn,
|
||||
nn.ReLU: zero_ops, nn.ReLU6: zero_ops, nn.LeakyReLU: count_relu,
|
||||
nn.MaxPool1d: zero_ops, nn.MaxPool2d: zero_ops, nn.MaxPool3d: zero_ops,
|
||||
nn.AdaptiveMaxPool1d: zero_ops, nn.AdaptiveMaxPool2d: zero_ops, nn.AdaptiveMaxPool3d: zero_ops,
|
||||
nn.AvgPool1d: count_avgpool, nn.AvgPool2d: count_avgpool, nn.AvgPool3d: count_avgpool,
|
||||
nn.AdaptiveAvgPool1d: count_adap_avgpool, nn.AdaptiveAvgPool2d: count_adap_avgpool, nn.AdaptiveAvgPool3d: count_adap_avgpool,
|
||||
nn.Linear: count_linear, nn.Dropout: zero_ops,
|
||||
nn.Upsample: count_upsample, nn.UpsamplingBilinear2d: count_upsample, nn.UpsamplingNearest2d: count_upsample,
|
||||
nn.RNNCell: count_rnn_cell, nn.GRUCell: count_gru_cell, nn.LSTMCell: count_lstm_cell,
|
||||
nn.RNN: count_rnn, nn.GRU: count_gru, nn.LSTM: count_lstm,
|
||||
}
|
||||
|
||||
# Function for profiling model operations and parameters
|
||||
# This tracks how many operations (ops) and parameters (params) a model uses
|
||||
def profile_origin(model, inputs, custom_ops=None, verbose=True):
|
||||
handler_collection = [] # Collection of hooks
|
||||
types_collection = set() # Keep track of registered layer types
|
||||
custom_ops = custom_ops or {} # Custom operation handling
|
||||
|
||||
def add_hooks(m):
|
||||
# Ignore compound modules (those that contain other modules)
|
||||
if len(list(m.children())) > 0:
|
||||
return
|
||||
|
||||
# Check if the module already has the required attributes
|
||||
if hasattr(m, "total_ops") or hasattr(m, "total_params"):
|
||||
logging.warning(f"Either .total_ops or .total_params is already defined in {str(m)}. Be cautious.")
|
||||
|
||||
# Add buffers to store the total number of operations and parameters
|
||||
m.register_buffer('total_ops', torch.zeros(1, dtype=default_dtype))
|
||||
m.register_buffer('total_params', torch.zeros(1, dtype=default_dtype))
|
||||
|
||||
# Count the number of parameters for this module
|
||||
for p in m.parameters():
|
||||
m.total_params += torch.DoubleTensor([p.numel()])
|
||||
|
||||
# Determine which function to use for counting operations
|
||||
m_type = type(m)
|
||||
fn = custom_ops.get(m_type, register_hooks.get(m_type, None))
|
||||
|
||||
if fn:
|
||||
# If the function exists, register the forward hook
|
||||
if m_type not in types_collection and verbose:
|
||||
print(f"[INFO] {'Customize rule' if m_type in custom_ops else 'Register'} {fn.__qualname__} for {m_type}.")
|
||||
handler = m.register_forward_hook(fn)
|
||||
handler_collection.append(handler)
|
||||
else:
|
||||
# Warn if no counting rule is found
|
||||
if m_type not in types_collection and verbose:
|
||||
prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero MACs and zero Params.")
|
||||
|
||||
types_collection.add(m_type)
|
||||
|
||||
# Set the model to evaluation mode (no gradients)
|
||||
model.eval()
|
||||
model.apply(add_hooks)
|
||||
|
||||
# Run a forward pass with no gradients
|
||||
with torch.no_grad():
|
||||
model(*inputs)
|
||||
|
||||
# Sum up the total operations and parameters across all layers
|
||||
total_ops = sum(m.total_ops.item() for m in model.modules() if hasattr(m, 'total_ops'))
|
||||
total_params = sum(m.total_params.item() for m in model.modules() if hasattr(m, 'total_params'))
|
||||
|
||||
# Restore the model to training mode and remove hooks
|
||||
model.train()
|
||||
for handler in handler_collection:
|
||||
handler.remove()
|
||||
for m in model.modules():
|
||||
if hasattr(m, "total_ops"): del m._buffers['total_ops']
|
||||
if hasattr(m, "total_params"): del m._buffers['total_params']
|
||||
|
||||
return total_ops, total_params # Return the total number of ops and params
|
||||
|
||||
# Updated profiling function with a different approach for hierarchical modules
|
||||
def profile(model: nn.Module, inputs, custom_ops=None, verbose=True):
|
||||
handler_collection = {} # Dictionary to store handlers
|
||||
types_collection = set() # Store layer types that have been processed
|
||||
custom_ops = custom_ops or {} # Custom operation handling
|
||||
|
||||
def add_hooks(m: nn.Module):
|
||||
# Add buffers for storing total ops and params
|
||||
m.register_buffer('total_ops', torch.zeros(1, dtype=default_dtype))
|
||||
m.register_buffer('total_params', torch.zeros(1, dtype=default_dtype))
|
||||
|
||||
# Find the appropriate counting function for this layer
|
||||
fn = custom_ops.get(type(m), register_hooks.get(type(m), None))
|
||||
if fn:
|
||||
# Register hooks for both operations and parameters
|
||||
handler_collection[m] = (m.register_forward_hook(fn), m.register_forward_hook(count_parameters))
|
||||
if type(m) not in types_collection and verbose:
|
||||
print(f"[INFO] {'Customize rule' if type(m) in custom_ops else 'Register'} {fn.__qualname__} for {type(m)}.")
|
||||
else:
|
||||
# Warn if no rule is found for this layer
|
||||
if type(m) not in types_collection and verbose:
|
||||
prRed(f"[WARN] Cannot find rule for {type(m)}. Treat it as zero MACs and zero Params.")
|
||||
types_collection.add(type(m))
|
||||
|
||||
# Set the model to evaluation mode
|
||||
model.eval()
|
||||
model.apply(add_hooks)
|
||||
|
||||
# Run a forward pass with no gradients
|
||||
with torch.no_grad():
|
||||
model(*inputs)
|
||||
|
||||
# Recursive function to count ops and params for hierarchical models
|
||||
def dfs_count(module: nn.Module) -> (int, int):
|
||||
total_ops, total_params = 0, 0
|
||||
for m in module.children():
|
||||
if m in handler_collection:
|
||||
m_ops, m_params = m.total_ops.item(), m.total_params.item()
|
||||
else:
|
||||
m_ops, m_params = dfs_count(m)
|
||||
total_ops += m_ops
|
||||
total_params += m_params
|
||||
return total_ops, total_params
|
||||
|
||||
total_ops, total_params = dfs_count(model) # Perform the depth-first count
|
||||
|
||||
# Restore the model to training mode and remove hooks
|
||||
model.train()
|
||||
for m, (op_handler, params_handler) in handler_collection.items():
|
||||
op_handler.remove()
|
||||
params_handler.remove()
|
||||
del m._buffers['total_ops']
|
||||
del m._buffers['total_params']
|
||||
|
||||
return total_ops, total_params # Return the total ops and params
|
205
EdgeFLite/train_EdgeFLite.py
Normal file
205
EdgeFLite/train_EdgeFLite.py
Normal file
@ -0,0 +1,205 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Weisen Pan
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
import torch.nn as nn
|
||||
from config import * # Import configuration
|
||||
from params import train_params # Import training parameters
|
||||
from model import coremodel, coremodelsl # Import models
|
||||
from utils import ( # Import utility functions
|
||||
label_smoothing, norm, metric, lr_scheduler, prefetch,
|
||||
save_hp_to_json, profile, clever_format
|
||||
)
|
||||
from dataset import factory # Import dataset factory
|
||||
|
||||
# Specify the GPU to be used
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
# Global variable for tracking the best accuracy
|
||||
best_acc1 = 0
|
||||
|
||||
# Function to calculate the average of a list of values
|
||||
def average(values):
|
||||
"""Calculate average of a list."""
|
||||
return sum(values) / len(values)
|
||||
|
||||
# Function to aggregate the models from multiple clients into a global model
|
||||
def merge_models(global_model_main, global_model_proxy, client_main_models, client_proxy_models):
|
||||
"""Aggregates weights of the models using simple mean."""
|
||||
# Get the state dictionaries for the global models
|
||||
global_main_state = global_model_main.state_dict()
|
||||
global_proxy_state = global_model_proxy.state_dict()
|
||||
|
||||
# Aggregate the main client models by averaging the weights
|
||||
for key in global_main_state.keys():
|
||||
global_main_state[key] = torch.stack([client.state_dict()[key].float() for client in client_main_models], 0).mean(0)
|
||||
global_model_main.load_state_dict(global_main_state)
|
||||
|
||||
# Aggregate the proxy client models similarly
|
||||
for key in global_proxy_state.keys():
|
||||
global_proxy_state[key] = torch.stack([client.state_dict()[key].float() for client in client_proxy_models], 0).mean(0)
|
||||
global_model_proxy.load_state_dict(global_proxy_state)
|
||||
|
||||
# Synchronize the client models with the updated global model
|
||||
for client in client_main_models:
|
||||
client.load_state_dict(global_model_main.state_dict())
|
||||
for client in client_proxy_models:
|
||||
client.load_state_dict(global_model_proxy.state_dict())
|
||||
|
||||
# Function to perform client-side training updates
|
||||
def client_update(args, round_idx, main_model, proxy_models, schedulers_main, schedulers_proxy, optimizers_main, optimizers_proxy, train_loader, epochs=5, streams=None):
|
||||
"""Client-side training update."""
|
||||
main_model.train()
|
||||
proxy_models.train()
|
||||
|
||||
# Train for a given number of epochs
|
||||
for epoch in range(epochs):
|
||||
# Prefetch data for faster loading
|
||||
prefetcher = prefetch.data_prefetcher(train_loader)
|
||||
images, targets = prefetcher.next()
|
||||
batch_idx = 0
|
||||
|
||||
# Zero the gradients
|
||||
optimizers_main.zero_grad()
|
||||
optimizers_proxy.zero_grad()
|
||||
|
||||
# Process each batch of data
|
||||
while images is not None:
|
||||
# Adjust learning rates using the scheduler
|
||||
schedulers_main(optimizers_main, batch_idx, round_idx)
|
||||
schedulers_proxy(optimizers_proxy, batch_idx, round_idx)
|
||||
|
||||
# Forward pass for the main model
|
||||
outputs, y_a, y_b, lam = main_model(images, target=targets, mode='train', epoch=epoch, streams=streams)
|
||||
main_fx = [output.clone().detach().requires_grad_(True) for output in outputs]
|
||||
|
||||
# Forward pass for the proxy model with outputs from the main model
|
||||
ensemble_output, proxy_outputs, ce_loss, cot_loss = proxy_models(main_fx, y_a, y_b, lam, target=targets, mode='train', epoch=epoch, streams=streams)
|
||||
|
||||
# Calculate total loss and perform backpropagation
|
||||
total_loss = (ce_loss + cot_loss) / args.iters_to_accumulate
|
||||
total_loss.backward()
|
||||
|
||||
# Backpropagate gradients for the main model
|
||||
for j in range(len(main_fx)):
|
||||
outputs[j].backward(main_fx[j].grad)
|
||||
|
||||
# Update the model weights periodically
|
||||
if batch_idx % args.iters_to_accumulate == 0 or batch_idx == len(train_loader):
|
||||
optimizers_main.step()
|
||||
optimizers_main.zero_grad()
|
||||
optimizers_proxy.step()
|
||||
optimizers_proxy.zero_grad()
|
||||
|
||||
# Fetch the next batch of images
|
||||
images, targets = prefetcher.next()
|
||||
batch_idx += 1
|
||||
|
||||
return total_loss.item()
|
||||
|
||||
# Function to validate the models on a validation set
|
||||
def validate(val_loader, main_model, proxy_models, args, streams=None):
|
||||
"""Validation function to evaluate models."""
|
||||
main_model.eval()
|
||||
proxy_models.eval()
|
||||
|
||||
# Initialize metrics for accuracy tracking
|
||||
top1_metrics = [metric.AverageMeter(f"Acc@1_{i}", ":6.2f") for i in range(args.loop_factor)]
|
||||
acc1_list, acc5_list, ce_loss_list = [], [], []
|
||||
|
||||
# Disable gradient computation for validation
|
||||
with torch.no_grad():
|
||||
for images, targets in val_loader:
|
||||
images, targets = images.cuda(), targets.cuda()
|
||||
|
||||
# Forward pass for main model
|
||||
outputs = main_model(images, target=targets, mode='val')
|
||||
main_fx = [output.clone().detach().requires_grad_(True) for output in outputs]
|
||||
|
||||
# Forward pass for proxy model
|
||||
ensemble_output, proxy_outputs, ce_loss = proxy_models(main_fx, target=targets, mode='val')
|
||||
|
||||
# Calculate accuracy
|
||||
acc1, acc5 = metric.accuracy(ensemble_output, targets, topk=(1, 5))
|
||||
acc1_list.append(acc1)
|
||||
acc5_list.append(acc5)
|
||||
ce_loss_list.append(ce_loss)
|
||||
|
||||
# Calculate average metrics over the validation set
|
||||
avg_acc1 = average(acc1_list)
|
||||
avg_acc5 = average(acc5_list)
|
||||
avg_ce_loss = average(ce_loss_list)
|
||||
|
||||
return avg_ce_loss, avg_acc1, top1_metrics
|
||||
|
||||
# Main function to set up and start decentralized training
|
||||
def main(args):
|
||||
"""Main function to set up decentralized training."""
|
||||
# Set loop factor based on training configuration
|
||||
args.loop_factor = 1 if args.is_train_sep or args.is_single_branch else args.split_factor
|
||||
# Determine if decentralized training is needed
|
||||
args.is_decentralized = args.world_size > 1 or args.multiprocessing_decentralized
|
||||
|
||||
# Get the number of GPUs available
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
args.ngpus_per_node = ngpus_per_node
|
||||
|
||||
# If using decentralized training with multiprocessing
|
||||
if args.multiprocessing_decentralized:
|
||||
args.world_size *= ngpus_per_node
|
||||
torch.multiprocessing.spawn(execute_worker_process, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
||||
else:
|
||||
# If not using multiprocessing, proceed with a single GPU
|
||||
args.gpu = 0
|
||||
execute_worker_process(args.gpu, ngpus_per_node, args)
|
||||
|
||||
# Main worker function to handle training with multiple GPUs or single GPU
|
||||
def execute_worker_process(gpu, ngpus_per_node, args):
|
||||
"""Main worker function for multi-GPU or single-GPU training."""
|
||||
global best_acc1
|
||||
args.gpu = gpu
|
||||
|
||||
# Set process title
|
||||
setproctitle.setproctitle(f"{args.proc_name}_EdgeFLite_rank{args.rank}")
|
||||
|
||||
# Set the criterion for loss calculation
|
||||
if args.is_label_smoothing:
|
||||
criterion = label_smoothing.label_smoothing_CE(reduction='mean')
|
||||
else:
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# Create the main and proxy models for training
|
||||
main_model = coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda()
|
||||
proxy_model = coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda()
|
||||
|
||||
# Initialize client models for federated learning
|
||||
client_main_models = [coremodelsl.CoreModelClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() for _ in range(args.num_selected)]
|
||||
client_proxy_models = [coremodelsl.coremodelProxyClient(args, norm_layer=norm.norm(args.norm_mode), criterion=criterion).cuda() for _ in range(args.num_selected)]
|
||||
|
||||
# Synchronize client models with the global models
|
||||
for client in client_main_models:
|
||||
client.load_state_dict(main_model.state_dict())
|
||||
for client in client_proxy_models:
|
||||
client.load_state_dict(proxy_model.state_dict())
|
||||
|
||||
# Load training and validation data
|
||||
train_loader = factory.obtain_data_loader(args.data, batch_size=args.batch_size, dataset=args.dataset, split="train", num_workers=args.workers)
|
||||
val_loader = factory.obtain_data_loader(args.data, batch_size=args.eval_batch_size, dataset=args.dataset, split="val", num_workers=args.workers)
|
||||
|
||||
# Loop over training rounds
|
||||
for r in range(args.start_round, args.num_rounds + 1):
|
||||
# Update client models with new training data
|
||||
client_update(args, r, client_main_models, client_proxy_models, lr_scheduler.lr_scheduler, lr_scheduler.lr_scheduler, torch.optim.SGD, torch.optim.SGD, train_loader)
|
||||
|
||||
# Validate the models
|
||||
test_loss, acc, top1 = validate(val_loader, main_model, proxy_model, args)
|
||||
|
||||
# Track the best accuracy achieved
|
||||
best_acc1 = max(acc, best_acc1)
|
||||
|
||||
# Entry point for the script
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Training EdgeFLite")
|
||||
args = train_params.add_parser_params(parser)
|
||||
main(args)
|
Loading…
Reference in New Issue
Block a user