diff --git a/UniTrain/models/classification.py b/UniTrain/models/classification.py index b3cf3c7..c38e005 100644 --- a/UniTrain/models/classification.py +++ b/UniTrain/models/classification.py @@ -14,11 +14,20 @@ def __init__(self, num_classes): self.softmax = F.softmax self.num_classes = num_classes - - def conv_block(self, xb, inp_filter_size, hidden_filter_size, out_filter_size, pool = False): - layers = nn.Sequential(nn.Conv2d(inp_filter_size, hidden_filter_size, padding=0, kernel_size=1), nn.BatchNorm2d(hidden_filter_size), nn.ReLU(inplace=True), - nn.Conv2d(hidden_filter_size, hidden_filter_size, padding=1, kernel_size=3), nn.BatchNorm2d(hidden_filter_size), nn.ReLU(inplace=True), - nn.Conv2d(hidden_filter_size, out_filter_size, padding=0, kernel_size=1), nn.BatchNorm2d(out_filter_size), nn.ReLU(inplace=True)) + def conv_block( + self, xb, inp_filter_size, hidden_filter_size, out_filter_size, pool=False + ): + layers = nn.Sequential( + nn.Conv2d(inp_filter_size, hidden_filter_size, padding=0, kernel_size=1), + nn.BatchNorm2d(hidden_filter_size), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_filter_size, hidden_filter_size, padding=1, kernel_size=3), + nn.BatchNorm2d(hidden_filter_size), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_filter_size, out_filter_size, padding=0, kernel_size=1), + nn.BatchNorm2d(out_filter_size), + nn.ReLU(inplace=True), + ) layers.to(xb.device) return layers(xb) @@ -38,7 +47,7 @@ def forward(self, xb): y = self.conv_block(y, 512, 256, 1024) for i in range(0, 22): y = self.conv_block(y, 1024, 256, 1024) + y - i+=1 + i += 1 y = self.conv_block(y, 1024, 512, 2048) y = self.conv_block(y, 2048, 512, 2048) + y @@ -52,7 +61,7 @@ def forward(self, xb): y = linear_layer(y) return y - + # Define the ResNet-9 model in a single class class ResNet9(nn.Module): @@ -120,28 +129,32 @@ def forward(self, x): x = x.view(x.size(0), -1) x = self.fc(x) - #ResNet50 functionality addition + +# ResNet50 functionality addition class ResNet9_50(nn.Module): def __init__(self, num_classes): super(ResNet9_50, self).__init__() - + self.resnet9 = ResNet9(num_classes) self.resnet50 = models.resnet50(pretrained=True) def forward(self, x): - x = self.resnet50(x) -#GoogLeNet functionality addition + x = self.resnet50(x) + + +# GoogLeNet functionality addition import torch import torch.nn as nn import torchvision.models as models + class GoogleNetModel(nn.Module): def __init__(self, num_classes): super(GoogleNetModel, self).__init() - + # Load the pre-trained GoogleNet model self.googlenet = models.inception_v3(pretrained=True) - + # Modify the classification head to match the number of classes in your dataset num_ftrs = self.googlenet.fc.in_features self.googlenet.fc = nn.Linear(num_ftrs, num_classes) @@ -156,10 +169,15 @@ def forward(self, x): # Making a custom transfer learning model -def create_transfer_learning_model(num_classes, model = torchvision.models.resnet18, feature_extract=True, use_pretrained=True): +def create_transfer_learning_model( + num_classes, + model=torchvision.models.resnet18, + feature_extract=True, + use_pretrained=True, +): """ Create a transfer learning model with a custom output layer. - + Args: num_classes (int): Number of classes in the custom output layer. model(torchvision.models.): Pre-trained model you want to use. @@ -171,18 +189,19 @@ def create_transfer_learning_model(num_classes, model = torchvision.models.resne """ # Load a pre-trained model, for example, ResNet-18 model = model(pretrained=use_pretrained) - + # Freeze the pre-trained weights if feature_extract is True if feature_extract: for param in model.parameters(): param.requires_grad = False - + # Modify the output layer to match the number of classes num_features = model.fc.in_features model.fc = nn.Linear(num_features, num_classes) - + return model + # Define the ResNet-18 model in a single class class ResNet34(nn.Module): def __init__(self, num_classes): @@ -210,15 +229,31 @@ def make_layer(self, out_channels, num_blocks, stride): layers.append(self.build_residual_block(self.in_channels, out_channels, stride)) self.in_channels = out_channels for _ in range(1, num_blocks): - layers.append(self.build_residual_block(self.in_channels, out_channels, stride=1)) + layers.append( + self.build_residual_block(self.in_channels, out_channels, stride=1) + ) return nn.Sequential(*layers) def build_residual_block(self, in_channels, out_channels, stride): return nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), ) @@ -237,7 +272,6 @@ def forward(self, x): return x - class ResNet50(nn.Module): def __init__(self, num_classes): super(ResNet50, self).__init__() @@ -264,15 +298,31 @@ def make_layer(self, out_channels, num_blocks, stride): layers.append(self.build_residual_block(self.in_channels, out_channels, stride)) self.in_channels = out_channels for _ in range(1, num_blocks): - layers.append(self.build_residual_block(self.in_channels, out_channels, stride=1)) + layers.append( + self.build_residual_block(self.in_channels, out_channels, stride=1) + ) return nn.Sequential(*layers) def build_residual_block(self, in_channels, out_channels, stride): return nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), ) @@ -317,15 +367,31 @@ def make_layer(self, out_channels, num_blocks, stride): layers.append(self.build_residual_block(self.in_channels, out_channels, stride)) self.in_channels = out_channels for _ in range(1, num_blocks): - layers.append(self.build_residual_block(self.in_channels, out_channels, stride=1)) + layers.append( + self.build_residual_block(self.in_channels, out_channels, stride=1) + ) return nn.Sequential(*layers) def build_residual_block(self, in_channels, out_channels, stride): return nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), nn.BatchNorm2d(out_channels), ) @@ -350,7 +416,9 @@ def __init__(self, num_classes, growth_rate=12, num_blocks=3, num_layers=4): self.in_channels = 64 # Initial convolution layer - self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.conv1 = nn.Conv2d( + 3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False + ) self.bn1 = nn.BatchNorm2d(self.in_channels) self.relu = nn.ReLU(inplace=True) @@ -367,11 +435,20 @@ def make_dense_block(self, growth_rate, num_layers): layers = [] in_channels = self.in_channels for _ in range(num_layers): - layers.extend([ - nn.BatchNorm2d(in_channels), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels, growth_rate, kernel_size=3, stride=1, padding=1, bias=False), - ]) + layers.extend( + [ + nn.BatchNorm2d(in_channels), + nn.ReLU(inplace=True), + nn.Conv2d( + in_channels, + growth_rate, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + ] + ) in_channels += growth_rate self.in_channels = in_channels return nn.Sequential(*layers) @@ -387,15 +464,28 @@ def forward(self, x): x = self.fc(x) return x + class LightVisionTransformer(nn.Module): - def __init__(self, image_size=224, patch_size=16, num_classes=1000, dim=256, depth=6, heads=4, mlp_dim=512, dropout=0.1): + def __init__( + self, + image_size=224, + patch_size=16, + num_classes=1000, + dim=256, + depth=6, + heads=4, + mlp_dim=512, + dropout=0.1, + ): super(LightVisionTransformer, self).__init() num_patches = (image_size // patch_size) ** 2 patch_dim = 3 * patch_size * patch_size # 3 for RGB channels # Patch embedding layer - self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size) + self.patch_embedding = nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size + ) self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.transformer = nn.Transformer( @@ -411,10 +501,11 @@ def forward(self, x): B, C, H, W = x.shape x = self.patch_embedding(x) x = x.permute(0, 2, 3, 1).view(B, -1, x.size(1)) # Flatten and transpose - x = torch.cat([self.cls_token.expand(B, -1, -1), x], dim=1) # Prepend the classification token + x = torch.cat( + [self.cls_token.expand(B, -1, -1), x], dim=1 + ) # Prepend the classification token x = x + self.positional_embedding x = self.transformer(x) x = x.mean(dim=1) # Global average pooling x = self.fc(x) return x - diff --git a/UniTrain/models/segmentation.py b/UniTrain/models/segmentation.py index aec8015..778c90d 100644 --- a/UniTrain/models/segmentation.py +++ b/UniTrain/models/segmentation.py @@ -244,10 +244,11 @@ def forward(self, x): return out + class VGGNet(nn.Module): def __init__(self, n_class): super(VGGNet, self).__init() - + # Encoder # VGG-like convolutional blocks self.encoder = nn.Sequential( @@ -256,31 +257,27 @@ def __init__(self, n_class): nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(512, 1024, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(1024, 1024, kernel_size=3, padding=1), - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) - + # Decoder self.decoder = nn.Sequential( nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2), @@ -288,37 +285,204 @@ def __init__(self, n_class): nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2), nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) # Output layer self.output_layer = nn.Conv2d(64, n_class, kernel_size=1) - + def forward(self, x): # Pass input through the encoder enc = self.encoder(x) - + # Pass the encoded features through the decoder dec = self.decoder(enc) - + # Pass the decoder output through the final output layer out = self.output_layer(dec) - + + return out + + +def passthrough(x, **kwargs): + return x + + +def ELUCons(elu, nchan): + if elu: + return nn.ELU(inplace=True) + else: + return nn.PReLU(nchan) + + +# normalization between sub-volumes is necessary +# for good performance +class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm): + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError("expected 5D input (got {}D input)".format(input.dim())) + super(ContBatchNorm3d, self)._check_input_dim(input) + + def forward(self, input): + self._check_input_dim(input) + return F.batch_norm( + input, + self.running_mean, + self.running_var, + self.weight, + self.bias, + True, + self.momentum, + self.eps, + ) + + +class LUConv(nn.Module): + def __init__(self, nchan, elu): + super(LUConv, self).__init__() + self.relu1 = ELUCons(elu, nchan) + self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2) + self.bn1 = ContBatchNorm3d(nchan) + + def forward(self, x): + out = self.relu1(self.bn1(self.conv1(x))) + return out + + +def _make_nConv(nchan, depth, elu): + return nn.Sequential(*[LUConv(nchan, elu) for _ in range(depth)]) + + +class InputTransition(nn.Module): + def __init__(self, outChans, elu): + super(InputTransition, self).__init__() + self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2) + self.bn1 = ContBatchNorm3d(16) + self.relu1 = nn.PReLU(16) + + def forward(self, x): + out = self.bn1(self.conv1(x)) + # split input in to 16 channels + x16 = torch.cat(tuple(x for _ in range(16)), 0) + out = self.relu1(torch.add(out, x16)) + return out + + +class DownTransition(nn.Module): + def __init__(self, inChans, nConvs, elu, dropout=False): + super(DownTransition, self).__init__() + outChans = 2 * inChans + self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2) + self.bn1 = ContBatchNorm3d(outChans) + self.do1 = passthrough + self.relu1 = ELUCons(elu, outChans) + self.relu2 = ELUCons(elu, outChans) + if dropout: + self.do1 = nn.Dropout3d() + self.ops = _make_nConv(outChans, nConvs, elu) + + def forward(self, x): + down = self.relu1(self.bn1(self.down_conv(x))) + out = self.do1(down) + out = self.ops(out) + out = self.relu2(torch.add(out, down)) + return out + + +class UpTransition(nn.Module): + def __init__(self, inChans, outChans, nConvs, elu, dropout=False): + super(UpTransition, self).__init__() + self.up_conv = nn.ConvTranspose3d( + inChans, outChans // 2, kernel_size=2, stride=2 + ) + self.bn1 = ContBatchNorm3d(outChans // 2) + self.do1 = passthrough + self.do2 = nn.Dropout3d() + self.relu1 = ELUCons(elu, outChans // 2) + self.relu2 = ELUCons(elu, outChans) + if dropout: + self.do1 = nn.Dropout3d() + self.ops = _make_nConv(outChans, nConvs, elu) + + def forward(self, x, skipx): + out = self.do1(x) + skipxdo = self.do2(skipx) + out = self.relu1(self.bn1(self.up_conv(out))) + xcat = torch.cat((out, skipxdo), 1) + out = self.ops(xcat) + out = self.relu2(torch.add(out, xcat)) return out + +class OutputTransition(nn.Module): + def __init__(self, inChans, elu, nll): + super(OutputTransition, self).__init__() + self.conv1 = nn.Conv3d(inChans, 2, kernel_size=5, padding=2) + self.bn1 = ContBatchNorm3d(2) + self.conv2 = nn.Conv3d(2, 2, kernel_size=1) + self.relu1 = ELUCons(elu, 2) + if nll: + self.softmax = F.log_softmax + else: + self.softmax = F.softmax + + def forward(self, x): + # convolve 32 down to 2 channels + out = self.relu1(self.bn1(self.conv1(x))) + out = self.conv2(out) + + # make channels the last axis + out = out.permute(0, 2, 3, 4, 1).contiguous() + # flatten + out = out.view(out.numel() // 2, 2) + out = self.softmax(out) + # treat channel 0 as the predicted output + return out + + +class VNet(nn.Module): + def __init__(self, elu=True, nll=False): + super(VNet, self).__init__() + # Input + self.in_tr = InputTransition(16, elu) + + # Encoding + self.down_tr32 = DownTransition(16, 1, elu) + self.down_tr64 = DownTransition(32, 2, elu) + self.down_tr128 = DownTransition(64, 3, elu, dropout=True) + self.down_tr256 = DownTransition(128, 2, elu, dropout=True) + + # Decoding + self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=True) + self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=True) + self.up_tr64 = UpTransition(128, 64, 1, elu) + self.up_tr32 = UpTransition(64, 32, 1, elu) + + # Output + self.out_tr = OutputTransition(32, elu, nll) + + def forward(self, x): + in16 = self.in_tr(x) + in32 = self.down_tr32(in16) + in64 = self.down_tr64(in32) + in128 = self.down_tr128(in64) + in256 = self.down_tr256(in128) + out256 = self.up_tr256(in256, in128) + out128 = self.up_tr128(out256, in64) + out64 = self.up_tr64(out128, in32) + out32 = self.up_tr32(out64, in16) + out16 = self.out_tr(out32) + return out16 diff --git a/UniTrain/utils/DCGAN.py b/UniTrain/utils/DCGAN.py index 162bab0..22c62ed 100644 --- a/UniTrain/utils/DCGAN.py +++ b/UniTrain/utils/DCGAN.py @@ -176,7 +176,6 @@ def train_model( logger=None, iou=False, ): - os.makedirs(checkpoint_dir + "/discriminator_checkpoint", exist_ok=True) os.makedirs(checkpoint_dir + "/generator_checkpoint", exist_ok=True) @@ -213,20 +212,40 @@ def train_model( device="cpu", ) - progress_bar = tqdm(train_data_loader, desc=f'Epoch {epoch + 1}/{epochs}', leave=False, dynamic_ncols=True) + progress_bar = tqdm( + train_data_loader, + desc=f"Epoch {epoch + 1}/{epochs}", + leave=False, + dynamic_ncols=True, + ) for real_images, _ in progress_bar: # Train discriminator i += 1 - loss_d, real_score, fake_score = train_discriminator(discriminator_model, generator_model, real_images, opt_d, 128, 128, device='cpu') + loss_d, real_score, fake_score = train_discriminator( + discriminator_model, + generator_model, + real_images, + opt_d, + 128, + 128, + device="cpu", + ) # Train generator loss_g = train_generator( opt_g, discriminator_model, generator_model, batch_size, device="cpu" ) - progress_bar.set_postfix({'Loss D': loss_d, 'Loss G': loss_g, 'Real Score': real_score, 'Fake Score': fake_score}) - + progress_bar.set_postfix( + { + "Loss D": loss_d, + "Loss G": loss_g, + "Real Score": real_score, + "Fake Score": fake_score, + } + ) + progress_bar.close() # Record losses & scores @@ -239,8 +258,6 @@ def train_model( save_samples(epoch + epoch, generator_model, fixed_latent, show=False) - - def evaluate_model(discriminator_model, dataloader): discriminator_model.eval() # Set the model to evaluation mode correct = 0 diff --git a/UniTrain/utils/StyleTransfer.py b/UniTrain/utils/StyleTransfer.py index add02bf..4af2eab 100644 --- a/UniTrain/utils/StyleTransfer.py +++ b/UniTrain/utils/StyleTransfer.py @@ -8,8 +8,10 @@ class ContentLoss(nn.Module): - - def __init__(self, target,): + def __init__( + self, + target, + ): super(ContentLoss, self).__init__() # we 'detach' the target content from the tree used # to dynamically compute the gradient: this is a stated value, @@ -20,9 +22,9 @@ def __init__(self, target,): def forward(self, input): self.loss = F.mse_loss(input, self.target) return input - -class StyleLoss(nn.Module): + +class StyleLoss(nn.Module): def __init__(self, target_feature): super(StyleLoss, self).__init__() self.target = self.gram_matrix(target_feature).detach() @@ -31,7 +33,7 @@ def forward(self, input): G = self.gram_matrix(input) self.loss = F.mse_loss(G, self.target) return input - + def gram_matrix(self, input): a, b, c, d = input.size() # a=batch size(=1) # b=number of feature maps @@ -45,6 +47,7 @@ def gram_matrix(self, input): # by dividing by the number of element in each feature maps. return G.div(a * b * c * d) + # create a module to normalize input image so we can easily put it in a # ``nn.Sequential`` class Normalization(nn.Module): @@ -56,11 +59,11 @@ def __init__(self, mean, std): self.mean = mean.clone().detach().view(-1, 1, 1) self.std = std.clone().detach().view(-1, 1, 1) - def forward(self, img): # normalize ``img`` return (img - self.mean) / self.std + def parse_folder(dataset_path): print(dataset_path) print(os.getcwd()) @@ -68,8 +71,8 @@ def parse_folder(dataset_path): try: if os.path.exists(dataset_path): # Store paths to train, test, and eval folders if they exist - content_path = os.path.join(dataset_path, 'images', 'content') - style_path = os.path.join(dataset_path, 'images', 'style') + content_path = os.path.join(dataset_path, "images", "content") + style_path = os.path.join(dataset_path, "images", "style") if os.path.exists(content_path) & os.path.exists(style_path): print("Content Data folder path:", content_path) @@ -86,22 +89,24 @@ def parse_folder(dataset_path): else: print("Either content or style directory does not exist") return None - + else: - print(f"The '{dataset_path}' folder does not exist in the current directory.") + print( + f"The '{dataset_path}' folder does not exist in the current directory." + ) return None - + except Exception as e: print("An error occurred:", str(e)) return None + def image_loader(image_name, device): # desired size of the output image imsize = 512 if torch.cuda.is_available() else 128 # use small size if no GPU - loader = transforms.Compose([ - transforms.Resize(imsize), # scale imported image - transforms.ToTensor()] + loader = transforms.Compose( + [transforms.Resize(imsize), transforms.ToTensor()] # scale imported image ) # transform it into a torch tensor image = Image.open(image_name) @@ -111,11 +116,16 @@ def image_loader(image_name, device): return image.to(device, torch.float) -def get_style_model_and_losses(style_img, content_img, cnn, normalization_mean, normalization_std, - content_layers = ['conv_4'], # desired depth layers to compute style/content losses - style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] - ): +def get_style_model_and_losses( + style_img, + content_img, + cnn, + normalization_mean, + normalization_std, + content_layers=["conv_4"], # desired depth layers to compute style/content losses + style_layers=["conv_1", "conv_2", "conv_3", "conv_4", "conv_5"], +): # normalization module normalization = Normalization(normalization_mean, normalization_std) @@ -132,19 +142,19 @@ def get_style_model_and_losses(style_img, content_img, cnn, normalization_mean, for layer in cnn.children(): if isinstance(layer, nn.Conv2d): i += 1 - name = f'conv_{i}' + name = f"conv_{i}" elif isinstance(layer, nn.ReLU): - name = f'relu_{i}' + name = f"relu_{i}" # The in-place version doesn't play very nicely with the ``ContentLoss`` # and ``StyleLoss`` we insert below. So we replace with out-of-place # ones here. layer = nn.ReLU(inplace=False) elif isinstance(layer, nn.MaxPool2d): - name = f'pool_{i}' + name = f"pool_{i}" elif isinstance(layer, nn.BatchNorm2d): - name = f'bn_{i}' + name = f"bn_{i}" else: - raise RuntimeError(f'Unrecognized layer: {layer.__class__.__name__}') + raise RuntimeError(f"Unrecognized layer: {layer.__class__.__name__}") model.add_module(name, layer) @@ -167,18 +177,31 @@ def get_style_model_and_losses(style_img, content_img, cnn, normalization_mean, if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): break - model = model[:(i + 1)] + model = model[: (i + 1)] return model, style_losses, content_losses -def run_style_transfer(cnn, content_img, style_img, input_img, normalization_mean = torch.tensor([0.485, 0.456, 0.406]), # VGG networks are trained on images with each channel normalized by - normalization_std = torch.tensor([0.229, 0.224, 0.225]), # mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] - num_steps=300, - style_weight=1000000, content_weight=1): + +def run_style_transfer( + cnn, + content_img, + style_img, + input_img, + normalization_mean=torch.tensor( + [0.485, 0.456, 0.406] + ), # VGG networks are trained on images with each channel normalized by + normalization_std=torch.tensor( + [0.229, 0.224, 0.225] + ), # mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] + num_steps=300, + style_weight=1000000, + content_weight=1, +): """Run the style transfer.""" - print('Building the style transfer model..') - model, style_losses, content_losses = get_style_model_and_losses(style_img, content_img, cnn, - normalization_mean, normalization_std) + print("Building the style transfer model..") + model, style_losses, content_losses = get_style_model_and_losses( + style_img, content_img, cnn, normalization_mean, normalization_std + ) # We want to optimize the input and not the model parameters so we # update all the requires_grad fields accordingly @@ -190,7 +213,7 @@ def run_style_transfer(cnn, content_img, style_img, input_img, normalization_mea optimizer = optim.LBFGS([input_img]) - print('Optimizing..') + print("Optimizing..") run = [0] while run[0] <= num_steps: @@ -218,8 +241,11 @@ def closure(): run[0] += 1 if run[0] % 50 == 0: print("run {}:".format(run)) - print('Style Loss : {:4f} Content Loss: {:4f}'.format( - style_score.item(), content_score.item())) + print( + "Style Loss : {:4f} Content Loss: {:4f}".format( + style_score.item(), content_score.item() + ) + ) print() return style_score + content_score @@ -232,11 +258,12 @@ def closure(): return input_img -def imshow(tensor, unloader = transforms.ToPILImage(), title=None): + +def imshow(tensor, unloader=transforms.ToPILImage(), title=None): image = tensor.cpu().clone() # we clone the tensor to not do changes on it - image = image.squeeze(0) # remove the fake batch dimension + image = image.squeeze(0) # remove the fake batch dimension image = unloader(image) plt.imshow(image) if title is not None: plt.title(title) - plt.pause(0.001) # pause a bit so that plots are updated + plt.pause(0.001) # pause a bit so that plots are updated diff --git a/UniTrain/utils/StyleTransferVGG.py b/UniTrain/utils/StyleTransferVGG.py index e0bcdb3..0b3907c 100644 --- a/UniTrain/utils/StyleTransferVGG.py +++ b/UniTrain/utils/StyleTransferVGG.py @@ -1,4 +1,4 @@ -#importing the required libraries +# importing the required libraries import torch import torchvision.transforms as transforms from PIL import Image @@ -9,79 +9,95 @@ import matplotlib.pyplot as plt -#Loading the model vgg19 that will serve as the base model +# Loading the model vgg19 that will serve as the base model model = models.vgg19(pretrained=True).features # the vgg19 model has three components : - #features: Containg all the conv, relu and maxpool - #avgpool: Containing the avgpool layer - #classifier: Contains the Dense layer(FC part of the model) +# features: Containg all the conv, relu and maxpool +# avgpool: Containing the avgpool layer +# classifier: Contains the Dense layer(FC part of the model) -#Assigning the GPU to the variable device -device=torch.device( "cuda" if (torch.cuda.is_available()) else 'cpu') +# Assigning the GPU to the variable device +device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu") -#Defining a class that for the model + +# Defining a class that for the model class VGG(nn.Module): def __init__(self): - super(VGG,self).__init__() - #Here we will use the following layers and make an array of their indices + super(VGG, self).__init__() + # Here we will use the following layers and make an array of their indices # 0: block1_conv1 # 5: block2_conv1 # 10: block3_conv1 # 19: block4_conv1 # 28: block5_conv1 - self.req_features= ['0','5','10','19','28'] - #Since we need only the 5 layers in the model so we will be dropping all the rest layers from the features of the model - self.model=models.vgg19(pretrained=True).features[:29] #model will contain the first 29 layers - - - #x holds the input tensor(image) that will be fed to each layer - def forward(self,x): - #initialize an array that wil hold the activations from the chosen layers - features=[] - #Iterate over all the layers of the mode - for layer_num,layer in enumerate(self.model): - #activation of the layer will stored in x - x=layer(x) - #appending the activation of the selected layers and return the feature array - if (str(layer_num) in self.req_features): + self.req_features = ["0", "5", "10", "19", "28"] + # Since we need only the 5 layers in the model so we will be dropping all the rest layers from the features of the model + self.model = models.vgg19(pretrained=True).features[ + :29 + ] # model will contain the first 29 layers + + # x holds the input tensor(image) that will be fed to each layer + def forward(self, x): + # initialize an array that wil hold the activations from the chosen layers + features = [] + # Iterate over all the layers of the mode + for layer_num, layer in enumerate(self.model): + # activation of the layer will stored in x + x = layer(x) + # appending the activation of the selected layers and return the feature array + if str(layer_num) in self.req_features: features.append(x) return features - -#defing a function that will load the image and perform the required preprocessing and put it on the GPU +# defing a function that will load the image and perform the required preprocessing and put it on the GPU def image_loader(path): - image=Image.open(path) - #defining the image transformation steps to be performed before feeding them to the model - loader=transforms.Compose([transforms.Resize((512,512)),transforms.ToTensor()]) - #The preprocessing steps involves resizing the image and then converting it to a tensor + image = Image.open(path) + # defining the image transformation steps to be performed before feeding them to the model + loader = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()]) + # The preprocessing steps involves resizing the image and then converting it to a tensor + + image = loader(image).unsqueeze(0) + return image.to(device, torch.float) - image=loader(image).unsqueeze(0) - return image.to(device,torch.float) -def calc_content_loss(gen_feat,orig_feat): - #calculating the content loss of each layer by calculating the MSE between the content and generated features and adding it to content loss - content_l=torch.mean((gen_feat-orig_feat)**2)#*0.5 +def calc_content_loss(gen_feat, orig_feat): + # calculating the content loss of each layer by calculating the MSE between the content and generated features and adding it to content loss + content_l = torch.mean((gen_feat - orig_feat) ** 2) # *0.5 return content_l -def calc_style_loss(gen,style): - #Calculating the gram matrix for the style and the generated image - batch_size,channel,height,width=gen.shape - G=torch.mm(gen.view(channel,height*width),gen.view(channel,height*width).t()) - A=torch.mm(style.view(channel,height*width),style.view(channel,height*width).t()) +def calc_style_loss(gen, style): + # Calculating the gram matrix for the style and the generated image + batch_size, channel, height, width = gen.shape - #Calcultating the style loss of each layer by calculating the MSE between the gram matrix of the style image and the generated image and adding it to style loss - style_l=torch.mean((G-A)**2)#/(4*channel*(height*width)**2) + G = torch.mm( + gen.view(channel, height * width), gen.view(channel, height * width).t() + ) + A = torch.mm( + style.view(channel, height * width), style.view(channel, height * width).t() + ) + + # Calcultating the style loss of each layer by calculating the MSE between the gram matrix of the style image and the generated image and adding it to style loss + style_l = torch.mean((G - A) ** 2) # /(4*channel*(height*width)**2) return style_l -def train_style_transfer_model(original_image_path, style_image_path, final_output_path, epoch = 1000, optimizer = optim.Adam, lr = 0.004, alpha = 8, beta = 70): - - '''Train a PyTorch model for a style transfer task. + +def train_style_transfer_model( + original_image_path, + style_image_path, + final_output_path, + epoch=1000, + optimizer=optim.Adam, + lr=0.004, + alpha=8, + beta=70, +): + """Train a PyTorch model for a style transfer task. Args: - + original_image_path (path): Path on your device/drive where original image is stored. style_image_path (path): Path on your device/drive where style image is stored. final_output_path (path): Path on your device/drive where output image is to be saved. @@ -93,48 +109,47 @@ def train_style_transfer_model(original_image_path, style_image_path, final_outp Returns: None - ''' - - - #Loading the original and the style image - original_image = image_loader(original_image_path) - style_image = image_loader(style_image_path) - - #Creating the generated image from the original image - generated_image = original_image.clone().requires_grad_(True) - - #Load the model to the GPU - model=VGG().to(device).eval() - - #using adam optimizer and it will update the generated image not the model parameter - optimizer = optimizer([generated_image],lr=lr) - - def calculate_loss(gen_features, orig_feautes, style_featues): - style_loss=content_loss=0 - for gen,cont,style in zip(gen_features,orig_feautes,style_featues): - #extracting the dimensions from the generated image - content_loss+=calc_content_loss(gen,cont) - style_loss+=calc_style_loss(gen,style) - - #calculating the total loss of e th epoch - total_loss=alpha*content_loss + beta*style_loss - return total_loss - - for e in range (epoch): - #extracting the features of generated, content and the original required for calculating the loss - gen_features=model(generated_image) - orig_feautes=model(original_image) - style_featues=model(style_image) - - #iterating over the activation of each layer and calculate the loss and add it to the content and the style loss - total_loss=calculate_loss(gen_features, orig_feautes, style_featues) - #optimize the pixel values of the generated image and backpropagate the loss - optimizer.zero_grad() - total_loss.backward() - optimizer.step() - - #print the image and save it after each 100 epoch - if(not (e%100)): - print(total_loss) - - save_image(generated_image,final_output_path) \ No newline at end of file + """ + + # Loading the original and the style image + original_image = image_loader(original_image_path) + style_image = image_loader(style_image_path) + + # Creating the generated image from the original image + generated_image = original_image.clone().requires_grad_(True) + + # Load the model to the GPU + model = VGG().to(device).eval() + + # using adam optimizer and it will update the generated image not the model parameter + optimizer = optimizer([generated_image], lr=lr) + + def calculate_loss(gen_features, orig_feautes, style_featues): + style_loss = content_loss = 0 + for gen, cont, style in zip(gen_features, orig_feautes, style_featues): + # extracting the dimensions from the generated image + content_loss += calc_content_loss(gen, cont) + style_loss += calc_style_loss(gen, style) + + # calculating the total loss of e th epoch + total_loss = alpha * content_loss + beta * style_loss + return total_loss + + for e in range(epoch): + # extracting the features of generated, content and the original required for calculating the loss + gen_features = model(generated_image) + orig_feautes = model(original_image) + style_featues = model(style_image) + + # iterating over the activation of each layer and calculate the loss and add it to the content and the style loss + total_loss = calculate_loss(gen_features, orig_feautes, style_featues) + # optimize the pixel values of the generated image and backpropagate the loss + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + # print the image and save it after each 100 epoch + if not (e % 100): + print(total_loss) + + save_image(generated_image, final_output_path) diff --git a/UniTrain/utils/__init__.py b/UniTrain/utils/__init__.py index ba441da..84c7d89 100644 --- a/UniTrain/utils/__init__.py +++ b/UniTrain/utils/__init__.py @@ -1,4 +1,9 @@ from .classification import get_data_loader, parse_folder, train_model from .DCGAN import get_data_loader, parse_folder, train_model -from .segmentation import (generate_model_summary, get_data_loader, iou_score, - parse_folder, train_unet) +from .segmentation import ( + generate_model_summary, + get_data_loader, + iou_score, + parse_folder, + train_unet, +) diff --git a/UniTrain/utils/classification.py b/UniTrain/utils/classification.py index 98924cf..5260c49 100644 --- a/UniTrain/utils/classification.py +++ b/UniTrain/utils/classification.py @@ -10,7 +10,8 @@ import tqdm from PIL import Image -def get_data_loader(data_dir, batch_size, shuffle=True, transform = None, split='train'): + +def get_data_loader(data_dir, batch_size, shuffle=True, transform=None, split="train"): """ Create and return a data loader for a custom dataset. @@ -24,33 +25,31 @@ def get_data_loader(data_dir, batch_size, shuffle=True, transform = None, split= """ # Define data transformations (adjust as needed) - if split == 'train': - data_dir = os.path.join(data_dir, 'train') - elif split == 'test': - data_dir = os.path.join(data_dir, 'test') - elif split == 'eval': - data_dir = os.path.join(data_dir, 'eval') + if split == "train": + data_dir = os.path.join(data_dir, "train") + elif split == "test": + data_dir = os.path.join(data_dir, "test") + elif split == "eval": + data_dir = os.path.join(data_dir, "eval") else: raise ValueError(f"Invalid split choice: {split}") - - if transform is None: - transform = transforms.Compose([ - transforms.Resize((224, 224)), # Resize images to a fixed size - transforms.ToTensor(), # Convert images to PyTorch tensors - transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # Normalize with ImageNet stats - ]) + transform = transforms.Compose( + [ + transforms.Resize((224, 224)), # Resize images to a fixed size + transforms.ToTensor(), # Convert images to PyTorch tensors + transforms.Normalize( + (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) + ), # Normalize with ImageNet stats + ] + ) # Create a custom dataset dataset = ClassificationDataset(data_dir, transform=transform) # Create a data loader - data_loader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=shuffle - ) + data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) return data_loader @@ -59,11 +58,15 @@ def parse_folder(dataset_path): try: if os.path.exists(dataset_path): # Store paths to train, test, and eval folders if they exist - train_path = os.path.join(dataset_path, 'train') - test_path = os.path.join(dataset_path, 'test') - eval_path = os.path.join(dataset_path, 'eval') - - if os.path.exists(train_path) and os.path.exists(test_path) and os.path.exists(eval_path): + train_path = os.path.join(dataset_path, "train") + test_path = os.path.join(dataset_path, "test") + eval_path = os.path.join(dataset_path, "eval") + + if ( + os.path.exists(train_path) + and os.path.exists(test_path) + and os.path.exists(eval_path) + ): print("Train folder path:", train_path) print("Test folder path:", test_path) print("Eval folder path:", eval_path) @@ -81,16 +84,29 @@ def parse_folder(dataset_path): print("One or more of the train, test, or eval folders does not exist.") return None else: - print(f"The '{dataset_path}' folder does not exist in the current directory.") + print( + f"The '{dataset_path}' folder does not exist in the current directory." + ) return None except Exception as e: print("An error occurred:", str(e)) return None -def train_model(model, train_data_loader, test_data_loader, num_epochs, learning_rate=0.001, criterion_fn = nn.CrossEntropyLoss, optimizer_fn = optim.Adam, checkpoint_dir='checkpoints', wnb_dir='wnb', logger=None, device=torch.device('cpu')): - - '''Train a PyTorch model for a classification task. +def train_model( + model, + train_data_loader, + test_data_loader, + num_epochs, + learning_rate=0.001, + criterion_fn=nn.CrossEntropyLoss, + optimizer_fn=optim.Adam, + checkpoint_dir="checkpoints", + wnb_dir="wnb", + logger=None, + device=torch.device("cpu"), +): + """Train a PyTorch model for a classification task. Args: model (nn.Module): Torch model to train. train_data_loader (DataLoader): Training data loader. @@ -108,20 +124,21 @@ def train_model(model, train_data_loader, test_data_loader, num_epochs, learning Returns: None - ''' + """ if logger: - logging.basicConfig(level=logging.INFO, format='%(asctime)s - Epoch %(epoch)d - Train Acc: %(train_acc).4f - Val Acc: %(val_acc).4f - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', filename=logger, filemode='w') + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - Epoch %(epoch)d - Train Acc: %(train_acc).4f - Val Acc: %(val_acc).4f - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + filename=logger, + filemode="w", + ) logger = logging.getLogger(__name__) - - # Setting the optimizer and criterion optimizer = optimizer_fn(model.parameters(), lr=learning_rate) criterion = criterion_fn() - - # Initialize optimizer, loss and accuracy optimizer = optimizer(model.parameters(), lr=learning_rate) @@ -150,34 +167,41 @@ def train_model(model, train_data_loader, test_data_loader, num_epochs, learning running_loss += loss.item() loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]") - loop.set_postfix(loss= running_loss / (batch_idx + 1)) + loop.set_postfix(loss=running_loss / (batch_idx + 1)) if batch_idx % 100 == 99: # Print and log every 100 batches avg_loss = running_loss / 100 - + # Save the weights and biases and log the path. - wnb_path = os.path.join(wnb_dir, f'model_epoch_{epoch + 1}_batch{batch_idx + 1}.pth') + wnb_path = os.path.join( + wnb_dir, f"model_epoch_{epoch + 1}_batch{batch_idx + 1}.pth" + ) torch.save(model.state_dict(), wnb_path) if logger: - logger.info(f'Epoch {epoch + 1}, Batch {batch_idx + 1}, Loss: {avg_loss:.4f}, wnbPath: {wnb_path}') - + logger.info( + f"Epoch {epoch + 1}, Batch {batch_idx + 1}, Loss: {avg_loss:.4f}, wnbPath: {wnb_path}" + ) accuracy = evaluate_model(model, test_data_loader) - + # Save the weights and biases and log the path for current epoch. - wnb_path = os.path.join(wnb_dir, f'model_epoch_{epoch + 1}.pth') + wnb_path = os.path.join(wnb_dir, f"model_epoch_{epoch + 1}.pth") torch.save(model.state_dict(), wnb_path) if logger: - logger.info(f'Epoch {epoch + 1}, Validation Accuracy: {accuracy:.2f}%, wnbPath: {wnb_path}') + logger.info( + f"Epoch {epoch + 1}, Validation Accuracy: {accuracy:.2f}%, wnbPath: {wnb_path}" + ) # Save model checkpoint if accuracy improves if accuracy > best_accuracy: best_accuracy = accuracy - checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch + 1}.pth') + checkpoint_path = os.path.join( + checkpoint_dir, f"model_epoch_{epoch + 1}.pth" + ) torch.save(model.state_dict(), checkpoint_path) if logger: - logger.info(f'Saved checkpoint to {checkpoint_path}') + logger.info(f"Saved checkpoint to {checkpoint_path}") - print('Finished Training') + print("Finished Training") def evaluate_model(model, dataloader): @@ -197,12 +221,14 @@ def evaluate_model(model, dataloader): return accuracy -def infer_class(model: nn.Module, image_path: str, device: torch.device, dataloader: DataLoader) -> str: +def infer_class( + model: nn.Module, image_path: str, device: torch.device, dataloader: DataLoader +) -> str: """Perform inference on a single image. Args: model (nn.Module): Model to perform inference with. - image_path (str): Path to image to perform inference on. + image_path (str): Path to image to perform inference on. device (torch.device): Device to run inference on (GPU or CPU). dataloader (DataLoader): Data loader for the dataset. @@ -215,20 +241,21 @@ def infer_class(model: nn.Module, image_path: str, device: torch.device, dataloa # Define transformations for the image if transform is None: - transform = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) - ]) + transform = transforms.Compose( + [ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ] + ) - image = Image.open(image_path).convert('RGB') + image = Image.open(image_path).convert("RGB") image_tensor = transform(image) # Add an extra batch dimension since pytorch treats all images as batches image_tensor = image_tensor.unsqueeze_(0) - + with torch.no_grad(): output = model(image_tensor.to(device)) @@ -238,4 +265,3 @@ def infer_class(model: nn.Module, image_path: str, device: torch.device, dataloa classes = dataloader.dataset.classes return classes[predicted] - diff --git a/UniTrain/utils/segmentation.py b/UniTrain/utils/segmentation.py index 2bbfeb5..fc4da1a 100644 --- a/UniTrain/utils/segmentation.py +++ b/UniTrain/utils/segmentation.py @@ -102,7 +102,6 @@ def parse_folder(dataset_path): return False - def train_unet( model, train_data_loader, @@ -117,23 +116,23 @@ def train_unet( device=torch.device("cpu"), ) -> None: """ - Args: - -def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_rate, checkpoint_dir, optimizer = optim.Adam, loss_criterion = nn.CrossEntropyLoss, logger=None, iou=False, device=torch.device('cpu')) -> None: - ''' - Args: - model (nn.Module): PyTorch model to train. - train_data_loader (DataLoader): Data loader of the training dataset. - test_data_loader (DataLoader): Data loader of the test dataset. - num_epochs (int): Number of epochs to train the model. - learning_rate (float): Learning rate for the optimizer. - checkpoint_dir (str): Directory to save model checkpoints. - logger (Logger): Logger to log training information. - iou (bool): Whether to calculate the IOU score. - device (torch.device): Device to run training on (GPU or CPU). - - Returns: - None + Args: + + def train_unet(model, train_data_loader, test_data_loader, num_epochs, learning_rate, checkpoint_dir, optimizer = optim.Adam, loss_criterion = nn.CrossEntropyLoss, logger=None, iou=False, device=torch.device('cpu')) -> None: + ''' + Args: + model (nn.Module): PyTorch model to train. + train_data_loader (DataLoader): Data loader of the training dataset. + test_data_loader (DataLoader): Data loader of the test dataset. + num_epochs (int): Number of epochs to train the model. + learning_rate (float): Learning rate for the optimizer. + checkpoint_dir (str): Directory to save model checkpoints. + logger (Logger): Logger to log training information. + iou (bool): Whether to calculate the IOU score. + device (torch.device): Device to run training on (GPU or CPU). + + Returns: + None """ if logger: @@ -242,26 +241,28 @@ def iou_score(output, target): iou = (intersection + smooth) / (union + smooth) return iou.mean().item() + def do_inference(image: PIL.Image, device: torch.device(), model) -> PIL.Image: """ Function is used for inference for segmentation of an Image. - Function takes PIL.Image object as input and return a segmented PIL.Image object. - + Function takes PIL.Image object as input and return a segmented PIL.Image object. + Args: image(PIL.Image) : Image to do inference device(torch.device) : Device to run inference model: Model to inference """ - model.eval() # Evaulation mode.. + model.eval() # Evaulation mode.. # Convert Image.PIL into tensor form. - transform = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize((0.485, 0.456, 0.406), - (0.229, 0.224, 0.225)) - ]) + transform = transforms.Compose( + [ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ] + ) image = transform(image) # Move the image to device and adding extra dimension for batch. @@ -273,5 +274,7 @@ def do_inference(image: PIL.Image, device: torch.device(), model) -> PIL.Image: # Post-process the segmentation mask. output = output.squeeze(0).cpu().numpy() - output = nn.functional.softmax(torch.from_numpy(output), dim=0).argmax(0).cpu().numpy() + output = ( + nn.functional.softmax(torch.from_numpy(output), dim=0).argmax(0).cpu().numpy() + ) return transforms.ToPILImage(output)