Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590631662
  • Loading branch information
texasmichelle authored and SeqIO committed Dec 13, 2023
1 parent 8012bf6 commit bd02658
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
8 changes: 4 additions & 4 deletions seqio/dataset_providers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ def test_tasks(self):
self.add_task("task2", self.function_source)
MixtureRegistry.add("test_mix1", [("task1", 1), ("task2", 1)])
mix = MixtureRegistry.get("test_mix1")
self.assertEqual(len(mix.tasks), 2)
self.assertLen(mix.tasks, 2)

for task in mix.tasks:
self.verify_task_matches_fake_datasets(task.name, use_cached=False)
Expand All @@ -1200,7 +1200,7 @@ def test_task_objs(self):

MixtureRegistry.add("test_mix1", [(task1, 1), (task2, 1)])
mix = MixtureRegistry.get("test_mix1")
self.assertEqual(len(mix.tasks), 2)
self.assertLen(mix.tasks, 2)

for task in mix.tasks:
self.verify_task_matches_fake_datasets(task=task, use_cached=False)
Expand All @@ -1221,7 +1221,7 @@ def test_task_objs_default_rate(self):
)
MixtureRegistry.add("test_mix1", [task1, task2], default_rate=1.0)
mix = MixtureRegistry.get("test_mix1")
self.assertEqual(len(mix.tasks), 2)
self.assertLen(mix.tasks, 2)

for task in mix.tasks:
self.verify_task_matches_fake_datasets(task=task, use_cached=False)
Expand Down Expand Up @@ -1250,7 +1250,7 @@ def test_tasks_with_tunable_rates(self):
)

mix = MixtureRegistry.get("test_mix2")
self.assertEqual(len(mix.tasks), 3)
self.assertLen(mix.tasks, 3)

automl_context = pg.hyper.DynamicEvaluationContext(require_hyper_name=True)
with automl_context.collect():
Expand Down
2 changes: 2 additions & 0 deletions seqio/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Experimental utilities for SeqIO."""

import functools
import inspect
from typing import Callable, Iterable, Mapping, Optional, Sequence
Expand Down Expand Up @@ -86,6 +87,7 @@ def _no_op_mixture_registry_get(*args, **kwargs):


def disable_registry():
"""Disables the seqio TaskRegistry and MixtureRegistry."""
_enfore_empty_registries()
dataset_providers.TaskRegistry.add = _no_op_task_registry_add
dataset_providers.TaskRegistry.add_provider = _no_op_task_registry_add
Expand Down
3 changes: 3 additions & 0 deletions seqio/experimental_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""Tests for seqio.preprocessors."""

import contextlib
from unittest import mock

from absl.testing import absltest
from seqio import dataset_providers
from seqio import experimental
Expand Down Expand Up @@ -984,5 +986,6 @@ def test_mixture_registry_get_error(self):
MixtureRegistry.get('dummy_mixture')



if __name__ == '__main__':
absltest.main()

0 comments on commit bd02658

Please sign in to comment.