diff --git a/only_for_me/narval/train.sh b/only_for_me/narval/train.sh index c3bbdd61..0c5a1121 100644 --- a/only_for_me/narval/train.sh +++ b/only_for_me/narval/train.sh @@ -1,7 +1,7 @@ #!/bin/bash #SBATCH --mem=80G #SBATCH --nodes=1 -#SBATCH --time=0:20:0 +#SBATCH --time=0:40:0 #SBATCH --tasks-per-node=2 #SBATCH --cpus-per-task=12 #SBATCH --gres=gpu:a100:2 diff --git a/zoobot/pytorch/training/webdatamodule.py b/zoobot/pytorch/training/webdatamodule.py index dd01bad6..a359c635 100644 --- a/zoobot/pytorch/training/webdatamodule.py +++ b/zoobot/pytorch/training/webdatamodule.py @@ -83,7 +83,8 @@ def make_loader(self, urls, mode="train"): .map_tuple(transform_image, transform_label) # torch collate stacks dicts nicely while webdataset only lists them # so use the torch collate instead - .batched(self.batch_size, torch.utils.data.default_collate, partial=False) + .batched(self.batch_size, torch.utils.data.default_collate, partial=False) + .repeat(2) ) # from itertools import islice @@ -97,8 +98,9 @@ def make_loader(self, urls, mode="train"): loader = wds.WebLoader( dataset, batch_size=None, # already batched - shuffle=False, + shuffle=False, # already shuffled num_workers=self.num_workers, + pin_memory=True ) # print('sampling') @@ -134,8 +136,8 @@ def val_dataloader(self): # parser.add_argument("--valshards", default="imagenet-val-{000000..000006}.tar") # return parser -def nodesplitter_func(urls): - print(urls) +def nodesplitter_func(urls): # SimpleShardList + # print(urls) try: node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size() return list(urls)[node_id::node_count]