From fdfc84dbbb83184c68910c94c65c1ea365f8b7ed Mon Sep 17 00:00:00 2001 From: Mike Walmsley Date: Fri, 3 Nov 2023 20:42:09 -0400 Subject: [PATCH] unpack generators --- zoobot/pytorch/training/webdatamodule.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/zoobot/pytorch/training/webdatamodule.py b/zoobot/pytorch/training/webdatamodule.py index a8cc1c35..ae23d2d6 100644 --- a/zoobot/pytorch/training/webdatamodule.py +++ b/zoobot/pytorch/training/webdatamodule.py @@ -1,4 +1,5 @@ import os +import types import torch.utils.data import numpy as np @@ -12,6 +13,11 @@ class WebDataModule(pl.LightningDataModule): def __init__(self, train_urls, val_urls, train_size=None, val_size=None, label_cols=None, batch_size=64, num_workers=4, cache_dir=None): super().__init__() + + if isinstance(train_urls, types.GeneratorType): + train_urls = list(train_urls) + if isinstance(val_urls, types.GeneratorType): + val_urls = list(val_urls) self.train_urls = train_urls self.val_urls = val_urls