Skip to content

Commit

Permalink
added drop last and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mooniean committed Mar 6, 2024
1 parent 160a0eb commit 5cda120
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/caked/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,20 @@ def process(self, paths: list[str], datatype: str):
shiftmin=shiftmin,
)

def get_loader(self, batch_size: int, split_size: float | None = None):
def get_loader(
self,
batch_size: int,
split_size: float | None = None,
no_val_drop: bool = False,
):
"""
Retrieve the data loader.
Args:
batch_size (int): The batch size for the data loader.
split_size (float | None, optional): The percentage of data to be used for validation set.
If None, the entire dataset will be used for training. Defaults to None.
no_val_drop (bool, optional): If True, the last batch of validation data will not be dropped if it is smaller than batch size. Defaults to False.
Returns:
DataLoader or Tuple[DataLoader, DataLoader]: The data loader(s) for testing or training/validation, according to whether training is True or False.
Expand Down Expand Up @@ -216,12 +222,14 @@ def get_loader(self, batch_size: int, split_size: float | None = None):
batch_size=batch_size,
num_workers=0,
shuffle=True,
drop_last=True,
)
loader_val = DataLoader(
val_data,
batch_size=batch_size,
num_workers=0,
shuffle=True,
drop_last=(not no_val_drop),
)
return loader_train, loader_val

Expand Down
23 changes: 23 additions & 0 deletions tests/test_disk_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,26 @@ def test_processing_after_load():
post_image, post_label = next(iter(post_dataset))
assert pre_label == post_label
assert not torch.equal(pre_image, post_image)


def test_drop_last():
"""
Test the drop_last parameter in the get_loader method of the DiskDataLoader class.
"""
test_loader = DiskDataLoader(
pipeline=DISK_PIPELINE,
classes=DISK_CLASSES_FULL_MRC,
dataset_size=DATASET_SIZE_ALL,
training=True,
)
test_loader.load(datapath=TEST_DATA_MRC, datatype=DATATYPE_MRC)
loader_train_true, loader_val_true = test_loader.get_loader(
split_size=0.7, batch_size=64, no_val_drop=True
)
assert loader_train_true.drop_last
assert not loader_val_true.drop_last
loader_train_false, loader_val_false = test_loader.get_loader(
split_size=0.7, batch_size=64, no_val_drop=False
)
assert loader_train_false.drop_last
assert loader_val_false.drop_last

0 comments on commit 5cda120

Please sign in to comment.