From bd0265871ce87e35bf2839c8520774f9a12e22e4 Mon Sep 17 00:00:00 2001 From: Michelle Casbon Date: Wed, 13 Dec 2023 09:55:34 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 590631662 --- seqio/dataset_providers_test.py | 8 ++++---- seqio/experimental.py | 2 ++ seqio/experimental_test.py | 3 +++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/seqio/dataset_providers_test.py b/seqio/dataset_providers_test.py index 21405e98..27bf3b7f 100644 --- a/seqio/dataset_providers_test.py +++ b/seqio/dataset_providers_test.py @@ -1178,7 +1178,7 @@ def test_tasks(self): self.add_task("task2", self.function_source) MixtureRegistry.add("test_mix1", [("task1", 1), ("task2", 1)]) mix = MixtureRegistry.get("test_mix1") - self.assertEqual(len(mix.tasks), 2) + self.assertLen(mix.tasks, 2) for task in mix.tasks: self.verify_task_matches_fake_datasets(task.name, use_cached=False) @@ -1200,7 +1200,7 @@ def test_task_objs(self): MixtureRegistry.add("test_mix1", [(task1, 1), (task2, 1)]) mix = MixtureRegistry.get("test_mix1") - self.assertEqual(len(mix.tasks), 2) + self.assertLen(mix.tasks, 2) for task in mix.tasks: self.verify_task_matches_fake_datasets(task=task, use_cached=False) @@ -1221,7 +1221,7 @@ def test_task_objs_default_rate(self): ) MixtureRegistry.add("test_mix1", [task1, task2], default_rate=1.0) mix = MixtureRegistry.get("test_mix1") - self.assertEqual(len(mix.tasks), 2) + self.assertLen(mix.tasks, 2) for task in mix.tasks: self.verify_task_matches_fake_datasets(task=task, use_cached=False) @@ -1250,7 +1250,7 @@ def test_tasks_with_tunable_rates(self): ) mix = MixtureRegistry.get("test_mix2") - self.assertEqual(len(mix.tasks), 3) + self.assertLen(mix.tasks, 3) automl_context = pg.hyper.DynamicEvaluationContext(require_hyper_name=True) with automl_context.collect(): diff --git a/seqio/experimental.py b/seqio/experimental.py index 1963e1c1..77f3907e 100644 --- a/seqio/experimental.py +++ b/seqio/experimental.py @@ -13,6 +13,7 @@ # limitations under the License. """Experimental utilities for SeqIO.""" + import functools import inspect from typing import Callable, Iterable, Mapping, Optional, Sequence @@ -86,6 +87,7 @@ def _no_op_mixture_registry_get(*args, **kwargs): def disable_registry(): + """Disables the seqio TaskRegistry and MixtureRegistry.""" _enfore_empty_registries() dataset_providers.TaskRegistry.add = _no_op_task_registry_add dataset_providers.TaskRegistry.add_provider = _no_op_task_registry_add diff --git a/seqio/experimental_test.py b/seqio/experimental_test.py index a17e029c..2f08521e 100644 --- a/seqio/experimental_test.py +++ b/seqio/experimental_test.py @@ -15,6 +15,8 @@ """Tests for seqio.preprocessors.""" import contextlib +from unittest import mock + from absl.testing import absltest from seqio import dataset_providers from seqio import experimental @@ -984,5 +986,6 @@ def test_mixture_registry_get_error(self): MixtureRegistry.get('dummy_mixture') + if __name__ == '__main__': absltest.main()