Skip to content

Commit

Permalink
unpack generators
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 4, 2023
1 parent a621755 commit fdfc84d
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions zoobot/pytorch/training/webdatamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import types

import torch.utils.data
import numpy as np
Expand All @@ -12,6 +13,11 @@
class WebDataModule(pl.LightningDataModule):
def __init__(self, train_urls, val_urls, train_size=None, val_size=None, label_cols=None, batch_size=64, num_workers=4, cache_dir=None):
super().__init__()

if isinstance(train_urls, types.GeneratorType):
train_urls = list(train_urls)
if isinstance(val_urls, types.GeneratorType):
val_urls = list(val_urls)
self.train_urls = train_urls
self.val_urls = val_urls

Expand Down

0 comments on commit fdfc84d

Please sign in to comment.