diff --git a/CHANGELOG.md b/CHANGELOG.md index 12fb2a0..7c093a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -1.0.6.1 +1.0.6.2 ## 1.0.6 ## 2020-07 @@ -7,6 +7,7 @@ - Refactor torch backend models - Add `--caching_dataset` to cache transformed data into memory (ignored when `memory_limit` set). - Fix FastMetrics multi-threads issue +- Fix loading issue when inferring with VSR models ## 1.0.5 ## 2020-05 diff --git a/Train/eval.py b/Train/eval.py index 07955cb..0c03362 100644 --- a/Train/eval.py +++ b/Train/eval.py @@ -56,7 +56,7 @@ def overwrite_from_env(flags): def main(): flags, args = parser.parse_known_args() - opt = Config(depth=-1) + opt = Config(depth=1) for pair in flags._get_kwargs(): opt.setdefault(*pair) overwrite_from_env(opt) diff --git a/VSR/Backend/TF/Framework/Trainer.py b/VSR/Backend/TF/Framework/Trainer.py index a9f7aa7..0d1c91e 100644 --- a/VSR/Backend/TF/Framework/Trainer.py +++ b/VSR/Backend/TF/Framework/Trainer.py @@ -365,7 +365,7 @@ def infer(self, loader, config, **kwargs): """ v = self.query_config(config, **kwargs) self._restore() - it = loader.make_one_shot_iterator([1, -1, -1, -1], -1) + it = loader.make_one_shot_iterator(v.batch_shape, -1) if hasattr(it, '__len__'): if len(it): LOG.info('Inferring {} at epoch {}'.format( diff --git a/VSR/Backend/Torch/Framework/Trainer.py b/VSR/Backend/Torch/Framework/Trainer.py index 14b8286..edf93bb 100644 --- a/VSR/Backend/Torch/Framework/Trainer.py +++ b/VSR/Backend/Torch/Framework/Trainer.py @@ -168,7 +168,7 @@ def infer(self, loader, config, **kwargs): """ v = self.query_config(config, **kwargs) self._restore(config.epoch, v.map_location) - it = loader.make_one_shot_iterator([1, -1, -1, -1], -1) + it = loader.make_one_shot_iterator(v.batch_shape, -1) if hasattr(it, '__len__'): if len(it) == 0: return