11
11
import zipfile
12
12
13
13
14
- # class MyDataset(Dataset):
15
- # def __init__(self):
16
- # pass
14
+ # Add your custom dataset class here
15
+ class MyDataset (Dataset ):
16
+ def __init__ (self ):
17
+ pass
17
18
18
19
19
- # def __len__(self):
20
- # pass
20
+ def __len__ (self ):
21
+ pass
21
22
22
- # def __getitem__(self, idx):
23
- # pass
23
+ def __getitem__ (self , idx ):
24
+ pass
25
+
24
26
25
27
class MyCelebA (CelebA ):
26
28
"""
29
+ A work-around to address issues with pytorch's celebA dataset class.
30
+
27
31
Download and Extract
28
32
URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing
29
33
"""
30
34
31
35
def _check_integrity (self ) -> bool :
32
36
return True
33
-
34
- class Food101 (Dataset ):
35
- """
36
- URL : https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/
37
- """
38
- def __init__ (self ):
39
- pass
40
37
41
- def __len__ (self ):
42
- pass
43
-
44
- def __getitem__ (self , idx ):
45
- pass
46
38
39
+
47
40
class OxfordPets (Dataset ):
48
41
"""
49
42
URL = https://www.robots.ox.ac.uk/~vgg/data/pets/
@@ -53,7 +46,7 @@ def __init__(self,
53
46
split : str ,
54
47
transform : Callable ,
55
48
** kwargs ):
56
- self .data_dir = Path (data_path )
49
+ self .data_dir = Path (data_path ) / "OxfordPets"
57
50
self .transforms = transform
58
51
imgs = sorted ([f for f in self .data_dir .iterdir () if f .suffix == '.jpg' ])
59
52
@@ -107,59 +100,58 @@ def __init__(
107
100
def setup (self , stage : Optional [str ] = None ) -> None :
108
101
# ========================= OxfordPets Dataset =========================
109
102
110
- # train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
111
- # transforms.CenterCrop(self.patch_size),
112
- # # transforms.Resize(self.patch_size),
113
- # transforms.ToTensor(),
114
- # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
115
-
116
- # val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
117
- # transforms.CenterCrop(self.patch_size),
118
- # # transforms.Resize(self.patch_size),
119
- # transforms.ToTensor(),
120
- # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
121
-
122
- # self.train_dataset = OxfordPets(
123
- # self.data_dir,
124
- # split='train',
125
- # transform=train_transforms,
126
- # )
127
-
128
- # self.val_dataset = OxfordPets(
129
- # self.data_dir,
130
- # split='val',
131
- # transform=val_transforms,
132
- # )
133
-
134
- # ========================= CelebA Dataset =========================
135
-
136
103
train_transforms = transforms .Compose ([transforms .RandomHorizontalFlip (),
137
- transforms .CenterCrop (148 ),
138
- transforms .Resize (self .patch_size ),
139
- transforms .ToTensor (),])
104
+ transforms .CenterCrop (self .patch_size ),
105
+ # transforms.Resize(self.patch_size),
106
+ transforms .ToTensor (),
107
+ transforms .Normalize ((0.485 , 0.456 , 0.406 ), (0.229 , 0.224 , 0.225 ))])
140
108
141
109
val_transforms = transforms .Compose ([transforms .RandomHorizontalFlip (),
142
- transforms .CenterCrop (148 ),
143
- transforms .Resize (self .patch_size ),
144
- transforms .ToTensor (),])
145
-
146
- self .train_dataset = MyCelebA (
110
+ transforms .CenterCrop (self .patch_size ),
111
+ # transforms.Resize(self.patch_size),
112
+ transforms .ToTensor (),
113
+ transforms .Normalize ((0.485 , 0.456 , 0.406 ), (0.229 , 0.224 , 0.225 ))])
114
+
115
+ self .train_dataset = OxfordPets (
147
116
self .data_dir ,
148
117
split = 'train' ,
149
118
transform = train_transforms ,
150
- download = False ,
151
119
)
152
120
153
- # Replace CelebA with your dataset
154
- self .val_dataset = MyCelebA (
121
+ self .val_dataset = OxfordPets (
155
122
self .data_dir ,
156
- split = 'test ' ,
123
+ split = 'val ' ,
157
124
transform = val_transforms ,
158
- download = False ,
159
125
)
126
+
127
+ # ========================= CelebA Dataset =========================
128
+
129
+ # train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
130
+ # transforms.CenterCrop(148),
131
+ # transforms.Resize(self.patch_size),
132
+ # transforms.ToTensor(),])
133
+
134
+ # val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
135
+ # transforms.CenterCrop(148),
136
+ # transforms.Resize(self.patch_size),
137
+ # transforms.ToTensor(),])
138
+
139
+ # self.train_dataset = MyCelebA(
140
+ # self.data_dir,
141
+ # split='train',
142
+ # transform=train_transforms,
143
+ # download=False,
144
+ # )
145
+
146
+ # # Replace CelebA with your dataset
147
+ # self.val_dataset = MyCelebA(
148
+ # self.data_dir,
149
+ # split='test',
150
+ # transform=val_transforms,
151
+ # download=False,
152
+ # )
160
153
# ===============================================================
161
154
162
-
163
155
def train_dataloader (self ) -> DataLoader :
164
156
return DataLoader (
165
157
self .train_dataset ,
@@ -185,4 +177,5 @@ def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
185
177
num_workers = self .num_workers ,
186
178
shuffle = True ,
187
179
pin_memory = self .pin_memory ,
188
- )
180
+ )
181
+
0 commit comments