Skip to content

Commit

Permalink
Improve conda build
Browse files Browse the repository at this point in the history
  • Loading branch information
Adrian Wolny committed Jan 19, 2020
1 parent c462636 commit 9d90e48
Show file tree
Hide file tree
Showing 37 changed files with 70 additions and 260 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,8 @@ __pycache__/
logs/

# PyTorch models
*.pytorch
*.pytorch

# Python packages
pytorch3dunet.egg-info/
dist/
12 changes: 6 additions & 6 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package:
name: pytorch-3dunet
version: 1.0.2
version: {{ GIT_DESCRIBE_TAG }}

source:
git_rev: v1.0.2
git_url: https://github.com/wolny/pytorch-3dunet.git
path: ..

build:
entry_points:
- predict3dunet = main.predict:main
- train3dunet = main.train:main
- predict3dunet = pytorch3dunet.predict:main
- train3dunet = pytorch3dunet.train:main

requirements:
build:
Expand Down Expand Up @@ -37,4 +36,5 @@ test:
- pytest tests/

about:
home: https://github.com/wolny/pytorch-3dunet
home: https://github.com/wolny/pytorch-3dunet
license: MIT
66 changes: 0 additions & 66 deletions predict.py

This file was deleted.

1 change: 1 addition & 0 deletions pytorch3dunet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .__version__ import __version__
1 change: 1 addition & 0 deletions pytorch3dunet/__version__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = 'v1.0.3'
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def weight_transform(self):

@staticmethod
def _transformer_class(class_name):
m = importlib.import_module('augment.transforms')
m = importlib.import_module('pytorch3dunet.augment.transforms')
clazz = getattr(m, class_name)
return clazz

Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions datasets/hdf5.py → pytorch3dunet/datasets/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset

import augment.transforms as transforms
from unet3d.utils import get_logger
import pytorch3dunet.augment.transforms as transforms
from pytorch3dunet.unet3d.utils import get_logger

logger = get_logger('HDF5Dataset')

Expand Down Expand Up @@ -362,7 +362,7 @@ def _check_dimensionality(raws, labels):


def _get_slice_builder_cls(class_name):
m = importlib.import_module('datasets.hdf5')
m = importlib.import_module('pytorch3dunet.datasets.hdf5')
clazz = getattr(m, class_name)
return clazz

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import nn

from unet3d.utils import expand_as_one_hot
from pytorch3dunet.unet3d.utils import expand_as_one_hot


class ContrastiveLoss(nn.Module):
Expand Down
10 changes: 5 additions & 5 deletions main/predict.py → pytorch3dunet/predict.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import importlib
import os

from datasets.hdf5 import get_test_loaders
from unet3d import utils
from unet3d.config import load_config
from unet3d.model import get_model
from pytorch3dunet.datasets.hdf5 import get_test_loaders
from pytorch3dunet.unet3d import utils
from pytorch3dunet.unet3d.config import load_config
from pytorch3dunet.unet3d.model import get_model

logger = utils.get_logger('UNet3DPredictor')

Expand All @@ -31,7 +31,7 @@ def _get_predictor(model, loader, output_file, config):
predictor_config = config.get('predictor', {})
class_name = predictor_config.get('name', 'StandardPredictor')

m = importlib.import_module('unet3d.predictor')
m = importlib.import_module('pytorch3dunet.unet3d.predictor')
predictor_class = getattr(m, class_name)

return predictor_class(model, loader, output_file, config, **predictor_config)
Expand Down
16 changes: 8 additions & 8 deletions main/train.py → pytorch3dunet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from datasets.hdf5 import get_train_loaders
from unet3d.config import load_config
from unet3d.losses import get_loss_criterion
from unet3d.metrics import get_evaluation_metric
from unet3d.model import get_model
from unet3d.trainer import UNet3DTrainer
from unet3d.utils import get_logger, get_tensorboard_formatter
from unet3d.utils import get_number_of_learnable_parameters
from pytorch3dunet.datasets.hdf5 import get_train_loaders
from pytorch3dunet.unet3d.config import load_config
from pytorch3dunet.unet3d.losses import get_loss_criterion
from pytorch3dunet.unet3d.metrics import get_evaluation_metric
from pytorch3dunet.unet3d.model import get_model
from pytorch3dunet.unet3d.trainer import UNet3DTrainer
from pytorch3dunet.unet3d.utils import get_logger, get_tensorboard_formatter
from pytorch3dunet.unet3d.utils import get_number_of_learnable_parameters


def _create_trainer(config, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, logger):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions unet3d/losses.py → pytorch3dunet/unet3d/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from torch.autograd import Variable
from torch.nn import MSELoss, SmoothL1Loss, L1Loss

from embeddings.contrastive_loss import ContrastiveLoss
from unet3d.utils import expand_as_one_hot
from pytorch3dunet.embeddings.contrastive_loss import ContrastiveLoss
from pytorch3dunet.unet3d.utils import expand_as_one_hot


def compute_per_channel_dice(input, target, epsilon=1e-5, ignore_index=None, weight=None):
Expand Down
6 changes: 3 additions & 3 deletions unet3d/metrics.py → pytorch3dunet/unet3d/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from skimage import measure
from sklearn.cluster import MeanShift

from unet3d.losses import compute_per_channel_dice
from unet3d.utils import get_logger, adapted_rand, expand_as_one_hot, plot_segm
from pytorch3dunet.unet3d.losses import compute_per_channel_dice
from pytorch3dunet.unet3d.utils import get_logger, adapted_rand, expand_as_one_hot, plot_segm

LOGGER = get_logger('EvalMetric')

Expand Down Expand Up @@ -688,7 +688,7 @@ def get_evaluation_metric(config):
"""

def _metric_class(class_name):
m = importlib.import_module('unet3d.metrics')
m = importlib.import_module('pytorch3dunet.unet3d.metrics')
clazz = getattr(m, class_name)
return clazz

Expand Down
6 changes: 3 additions & 3 deletions unet3d/model.py → pytorch3dunet/unet3d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch.nn as nn

from unet3d.buildingblocks import Encoder, Decoder, DoubleConv, ExtResNetBlock, SingleConv
from unet3d.utils import create_feature_maps
from pytorch3dunet.unet3d.buildingblocks import Encoder, Decoder, DoubleConv, ExtResNetBlock, SingleConv
from pytorch3dunet.unet3d.utils import create_feature_maps


class UNet3D(nn.Module):
Expand Down Expand Up @@ -299,7 +299,7 @@ def forward(self, x):

def get_model(config):
def _model_class(class_name):
m = importlib.import_module('unet3d.model')
m = importlib.import_module('pytorch3dunet.unet3d.model')
clazz = getattr(m, class_name)
return clazz

Expand Down
6 changes: 3 additions & 3 deletions unet3d/predictor.py → pytorch3dunet/unet3d/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch
from sklearn.cluster import MeanShift

from datasets.hdf5 import SliceBuilder
from unet3d.utils import get_logger
from unet3d.utils import unpad
from pytorch3dunet.datasets.hdf5 import SliceBuilder
from pytorch3dunet.unet3d.utils import get_logger
from pytorch3dunet.unet3d.utils import unpad

logger = get_logger('UNet3DTrainer')

Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion unet3d/utils.py → pytorch3dunet/unet3d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def get_tensorboard_formatter(config):
return DefaultTensorboardFormatter()

class_name = config['name']
m = importlib.import_module('unet3d.utils')
m = importlib.import_module('pytorch3dunet.unet3d.utils')
clazz = getattr(m, class_name)
return clazz(**config)

Expand Down
2 changes: 1 addition & 1 deletion resources/test_config_4d_input.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ loaders:
test:
# paths to the test datasets
file_paths:
- 'resources/random_raw4D.h5'
- '../resources/random_raw4D.h5'

# SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
slice_builder:
Expand Down
2 changes: 1 addition & 1 deletion resources/test_config_ce.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ loaders:
test:
# paths to the test datasets
file_paths:
- 'resources/random_label3D.h5'
- '../resources/random_label3D.h5'

# SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
slice_builder:
Expand Down
2 changes: 1 addition & 1 deletion resources/test_config_dice.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ loaders:
test:
# paths to the test datasets
file_paths:
- 'resources/random_label3D.h5'
- '../resources/random_label3D.h5'

# SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
slice_builder:
Expand Down
4 changes: 2 additions & 2 deletions resources/train_config_4d_input.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ loaders:
train:
# paths to the training datasets
file_paths:
- 'resources/random_raw4D.h5'
- '../resources/random_raw4D.h5'

# SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
slice_builder:
Expand Down Expand Up @@ -116,7 +116,7 @@ loaders:
val:
# paths to the validation datasets
file_paths:
- 'resources/random_raw4D.h5'
- '../resources/random_raw4D.h5'

# SliceBuilder configuration
slice_builder:
Expand Down
4 changes: 2 additions & 2 deletions resources/train_config_ce.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ loaders:
train:
# paths to the training datasets
file_paths:
- 'resources/random_label3D.h5'
- '../resources/random_label3D.h5'

# SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
slice_builder:
Expand Down Expand Up @@ -116,7 +116,7 @@ loaders:
val:
# paths to the validation datasets
file_paths:
- 'resources/random_label3D.h5'
- '../resources/random_label3D.h5'

# SliceBuilder configuration
slice_builder:
Expand Down
4 changes: 2 additions & 2 deletions resources/train_config_dice.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ loaders:
train:
# paths to the training datasets
file_paths:
- 'resources/random_label3D.h5'
- '../resources/random_label3D.h5'

# SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
slice_builder:
Expand Down Expand Up @@ -116,7 +116,7 @@ loaders:
val:
# paths to the validation datasets
file_paths:
- 'resources/random_label3D.h5'
- '../resources/random_label3D.h5'

# SliceBuilder configuration
slice_builder:
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from setuptools import setup, find_packages

_version_ = "1.0.2"

exec(open('pytorch3dunet/__version__.py').read())
setup(
name="pytorch3dunet",
packages=find_packages(exclude=["tests"]),
version=_version_,
version=__version__,
author="Adrian Wolny, Lorenzo Cerrone",
url="https://github.com/wolny/pytorch-3dunet",
license="MIT",
Expand Down
9 changes: 4 additions & 5 deletions tests/test_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import torch.nn as nn
from skimage import measure

from augment.transforms import LabelToAffinities, StandardLabelToBoundary
from unet3d.losses import GeneralizedDiceLoss, WeightedCrossEntropyLoss, BCELossWrapper, \
from pytorch3dunet.augment.transforms import LabelToAffinities, StandardLabelToBoundary
from pytorch3dunet.embeddings.contrastive_loss import ContrastiveLoss
from pytorch3dunet.unet3d.losses import GeneralizedDiceLoss, WeightedCrossEntropyLoss, BCELossWrapper, \
DiceLoss, TagsAngularLoss
from unet3d.metrics import DiceCoefficient, MeanIoU, BoundaryAveragePrecision, AdaptedRandError, \
from pytorch3dunet.unet3d.metrics import DiceCoefficient, MeanIoU, BoundaryAveragePrecision, AdaptedRandError, \
BoundaryAdaptedRandError, EmbeddingsAdaptedRandError

from embeddings.contrastive_loss import ContrastiveLoss


def _compute_criterion(criterion, n_times=100):
shape = [1, 0, 30, 30, 30]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from torch.utils.data import DataLoader

from datasets.hdf5 import HDF5Dataset
from pytorch3dunet.datasets.hdf5 import HDF5Dataset


class TestHDF5Dataset:
Expand Down
Loading

0 comments on commit 9d90e48

Please sign in to comment.