diff --git a/README.md b/README.md index 11e134171d..350fac9418 100644 --- a/README.md +++ b/README.md @@ -155,16 +155,16 @@ You can specify which model file to use with `--model MODEL.pth`. The training progress can be visualized in real-time using [Weights & Biases](https://wandb.ai/). Loss curves, validation curves, weights and gradient histograms, as well as predicted masks are logged to the platform. When launching a training, a link will be printed in the console. Click on it to go to your dashboard. If you have an existing W&B account, you can link it - by setting the `WANDB_API_KEY` environment variable. + by setting the `WANDB_API_KEY` environment variable. If not, it will create an anonymous run which is automatically deleted after 7 days. ## Pretrained model -A [pretrained model](https://github.com/milesial/Pytorch-UNet/releases/tag/v2.0) is available for the Carvana dataset. It can also be loaded from torch.hub: +A [pretrained model](https://github.com/milesial/Pytorch-UNet/releases/tag/v3.0) is available for the Carvana dataset. It can also be loaded from torch.hub: ```python -net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True) +net = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=True, scale=0.5) ``` -The training was done with a 50% scale and bilinear upsampling. +Available scales are 0.5 and 1.0. ## Data The Carvana data is available on the [Kaggle website](https://www.kaggle.com/c/carvana-image-masking-challenge/data). diff --git a/hubconf.py b/hubconf.py index fa6f7a070c..d9c39d9b8e 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,14 +1,20 @@ import torch from unet import UNet as _UNet -def unet_carvana(pretrained=False): +def unet_carvana(pretrained=False, scale=0.5): """ UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ). Set the scale to 0.5 (50%) when predicting. """ - net = _UNet(n_channels=3, n_classes=2, bilinear=True) + net = _UNet(n_channels=3, n_classes=2, bilinear=False) if pretrained: - checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v2.0/unet_carvana_scale0.5_epoch1.pth' + if scale == 0.5: + checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale0.5_epoch2.pth' + elif scale == 1.0: + checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale1.0_epoch2.pth' + else: + raise RuntimeError('Only 0.5 and 1.0 scales are available') + net.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True)) return net diff --git a/predict.py b/predict.py index fdb95ea1c4..9cfe1874f4 100755 --- a/predict.py +++ b/predict.py @@ -57,6 +57,7 @@ def get_args(): help='Minimum probability value to consider a mask pixel white') parser.add_argument('--scale', '-s', type=float, default=0.5, help='Scale factor for the input images') + parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling') return parser.parse_args() @@ -81,7 +82,7 @@ def mask_to_image(mask: np.ndarray): in_files = args.input out_files = get_output_filenames(args) - net = UNet(n_channels=3, n_classes=2) + net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f'Loading model {args.model}') diff --git a/train.py b/train.py index cae9a8607b..bffe17959b 100644 --- a/train.py +++ b/train.py @@ -25,7 +25,7 @@ def train_net(net, device, epochs: int = 5, batch_size: int = 1, - learning_rate: float = 0.001, + learning_rate: float = 1e-5, val_percent: float = 0.1, save_checkpoint: bool = True, img_scale: float = 0.5, @@ -147,13 +147,14 @@ def get_args(): parser = argparse.ArgumentParser(description='Train the UNet on images and target masks') parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs') parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size') - parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.00001, + parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5, help='Learning rate', dest='lr') parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file') parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images') parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0, help='Percent of the data that is used as validation (0-100)') parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision') + parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling') return parser.parse_args() @@ -168,7 +169,7 @@ def get_args(): # Change here to adapt to your data # n_channels=3 for RGB images # n_classes is the number of probabilities you want to get per pixel - net = UNet(n_channels=3, n_classes=2, bilinear=True) + net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear) logging.info(f'Network:\n' f'\t{net.n_channels} input channels\n' diff --git a/unet/unet_model.py b/unet/unet_model.py index efa7108c84..20c35b52cc 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -4,7 +4,7 @@ class UNet(nn.Module): - def __init__(self, n_channels, n_classes, bilinear=True): + def __init__(self, n_channels, n_classes, bilinear=False): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes