Skip to content

Commit b9144d4

Browse files
committedDec 22, 2021
Fixed issues in mssim and requirements
1 parent 27b559f commit b9144d4

File tree

4 files changed

+59
-65
lines changed

4 files changed

+59
-65
lines changed
 

‎dataset.py

+54-61
Original file line numberDiff line numberDiff line change
@@ -11,39 +11,32 @@
1111
import zipfile
1212

1313

14-
# class MyDataset(Dataset):
15-
# def __init__(self):
16-
# pass
14+
# Add your custom dataset class here
15+
class MyDataset(Dataset):
16+
def __init__(self):
17+
pass
1718

1819

19-
# def __len__(self):
20-
# pass
20+
def __len__(self):
21+
pass
2122

22-
# def __getitem__(self, idx):
23-
# pass
23+
def __getitem__(self, idx):
24+
pass
25+
2426

2527
class MyCelebA(CelebA):
2628
"""
29+
A work-around to address issues with pytorch's celebA dataset class.
30+
2731
Download and Extract
2832
URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing
2933
"""
3034

3135
def _check_integrity(self) -> bool:
3236
return True
33-
34-
class Food101(Dataset):
35-
"""
36-
URL : https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/
37-
"""
38-
def __init__(self):
39-
pass
4037

41-
def __len__(self):
42-
pass
43-
44-
def __getitem__(self, idx):
45-
pass
4638

39+
4740
class OxfordPets(Dataset):
4841
"""
4942
URL = https://www.robots.ox.ac.uk/~vgg/data/pets/
@@ -53,7 +46,7 @@ def __init__(self,
5346
split: str,
5447
transform: Callable,
5548
**kwargs):
56-
self.data_dir = Path(data_path)
49+
self.data_dir = Path(data_path) / "OxfordPets"
5750
self.transforms = transform
5851
imgs = sorted([f for f in self.data_dir.iterdir() if f.suffix == '.jpg'])
5952

@@ -107,59 +100,58 @@ def __init__(
107100
def setup(self, stage: Optional[str] = None) -> None:
108101
# ========================= OxfordPets Dataset =========================
109102

110-
# train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
111-
# transforms.CenterCrop(self.patch_size),
112-
# # transforms.Resize(self.patch_size),
113-
# transforms.ToTensor(),
114-
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
115-
116-
# val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
117-
# transforms.CenterCrop(self.patch_size),
118-
# # transforms.Resize(self.patch_size),
119-
# transforms.ToTensor(),
120-
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
121-
122-
# self.train_dataset = OxfordPets(
123-
# self.data_dir,
124-
# split='train',
125-
# transform=train_transforms,
126-
# )
127-
128-
# self.val_dataset = OxfordPets(
129-
# self.data_dir,
130-
# split='val',
131-
# transform=val_transforms,
132-
# )
133-
134-
# ========================= CelebA Dataset =========================
135-
136103
train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
137-
transforms.CenterCrop(148),
138-
transforms.Resize(self.patch_size),
139-
transforms.ToTensor(),])
104+
transforms.CenterCrop(self.patch_size),
105+
# transforms.Resize(self.patch_size),
106+
transforms.ToTensor(),
107+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
140108

141109
val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
142-
transforms.CenterCrop(148),
143-
transforms.Resize(self.patch_size),
144-
transforms.ToTensor(),])
145-
146-
self.train_dataset = MyCelebA(
110+
transforms.CenterCrop(self.patch_size),
111+
# transforms.Resize(self.patch_size),
112+
transforms.ToTensor(),
113+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
114+
115+
self.train_dataset = OxfordPets(
147116
self.data_dir,
148117
split='train',
149118
transform=train_transforms,
150-
download=False,
151119
)
152120

153-
# Replace CelebA with your dataset
154-
self.val_dataset = MyCelebA(
121+
self.val_dataset = OxfordPets(
155122
self.data_dir,
156-
split='test',
123+
split='val',
157124
transform=val_transforms,
158-
download=False,
159125
)
126+
127+
# ========================= CelebA Dataset =========================
128+
129+
# train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
130+
# transforms.CenterCrop(148),
131+
# transforms.Resize(self.patch_size),
132+
# transforms.ToTensor(),])
133+
134+
# val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
135+
# transforms.CenterCrop(148),
136+
# transforms.Resize(self.patch_size),
137+
# transforms.ToTensor(),])
138+
139+
# self.train_dataset = MyCelebA(
140+
# self.data_dir,
141+
# split='train',
142+
# transform=train_transforms,
143+
# download=False,
144+
# )
145+
146+
# # Replace CelebA with your dataset
147+
# self.val_dataset = MyCelebA(
148+
# self.data_dir,
149+
# split='test',
150+
# transform=val_transforms,
151+
# download=False,
152+
# )
160153
# ===============================================================
161154

162-
163155
def train_dataloader(self) -> DataLoader:
164156
return DataLoader(
165157
self.train_dataset,
@@ -185,4 +177,5 @@ def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
185177
num_workers=self.num_workers,
186178
shuffle=True,
187179
pin_memory=self.pin_memory,
188-
)
180+
)
181+

‎models/mssim_vae.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def ssim(self,
231231
sigma2_sq = F.conv2d(img2 * img2, window, padding = window_size//2, groups=in_channel) - mu2_sq
232232
sigma12 = F.conv2d(img1 * img2, window, padding = window_size//2, groups=in_channel) - mu1_mu2
233233

234-
img_range = img1.max() - img1.min()
234+
img_range = 1.0 #img1.max() - img1.min() # Dynamic range
235235
C1 = (0.01 * img_range) ** 2
236236
C2 = (0.03 * img_range) ** 2
237237

‎requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
pytorch-lightning==1.5.6
22
PyYAML==6.0
3-
tensorboard==2.1.0
3+
tensorboard>=2.2.0
44
torch>=1.6.1
55
torchsummary==1.5.1
6-
torchvision>=0.11.2
6+
torchvision>=0.10.1

‎run.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
print(exc)
3030

3131

32-
3332
tb_logger = TensorBoardLogger(save_dir=config['logging_params']['save_dir'],
3433
name=config['model_params']['name'],)
3534

@@ -42,6 +41,7 @@
4241

4342
data = VAEDataset(**config["data_params"], pin_memory=len(config['trainer_params']['gpus']) != 0)
4443

44+
data.setup()
4545
runner = Trainer(logger=tb_logger,
4646
callbacks=[
4747
LearningRateMonitor(),
@@ -57,5 +57,6 @@
5757
Path(f"{tb_logger.log_dir}/Samples").mkdir(exist_ok=True, parents=True)
5858
Path(f"{tb_logger.log_dir}/Reconstructions").mkdir(exist_ok=True, parents=True)
5959

60+
6061
print(f"======= Training {config['model_params']['name']} =======")
6162
runner.fit(experiment, datamodule=data)

0 commit comments

Comments
 (0)
Please sign in to comment.