From 1c02cdf530bfbf1e843654c08259d9a677569a51 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Tue, 10 Dec 2024 15:59:17 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 704876709 --- .../_src/handlers/composite_checkpoint_handler.py | 7 +------ .../handlers/composite_checkpoint_handler_test.py | 1 - checkpoint/orbax/checkpoint/checkpoint_manager.py | 11 +---------- 3 files changed, 2 insertions(+), 17 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py index e0c07782b..8b70ca0d3 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py @@ -49,7 +49,6 @@ import concurrent.futures import dataclasses from typing import Any, Coroutine, Dict, List, Mapping, MutableSet, Optional, Tuple, Type -import uuid from absl import logging from etils import epath @@ -88,7 +87,6 @@ RESERVED_ITEM_NAMES = [] - # TODO(b/295899152) Clean up when users are all registering `CheckpointArgs`. class _LegacyCheckpointHandlerWrapper(checkpoint_handler.CheckpointHandler): """Wrapper for `CheckpointHandler`s without registered `CheckpointArgs`.""" @@ -661,14 +659,11 @@ def _get_item_temporary_paths( self._get_item_temporary_directory(directory, item_name) for item_name in item_names ] - result = { + return { item_name: item_directory for item_name, item_directory in zip(item_names, item_temporary_paths) } - return result - - async def async_save( self, directory: epath.Path, args: CompositeArgs ) -> Optional[List[Future]]: diff --git a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py index adbc24628..77fa43084 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py @@ -941,7 +941,6 @@ def test_close(self): state_handler.close.assert_called_once() metadata_handler.close.assert_called_once() - def test_items_exist_final(self): handler = CompositeCheckpointHandler('state', 'metadata') state = {'a': 1, 'b': 2} diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index bf87a9565..554dcbb64 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -83,10 +83,9 @@ FileOptions = options_lib.FileOptions DEFAULT_ITEM_NAME = 'default' -DESCRIPTOR_ITEM_NAME = 'descriptor' METRIC_ITEM_NAME = 'metrics' METADATA_ITEM_NAME = 'metadata' -RESERVED_ITEM_NAMES = [DESCRIPTOR_ITEM_NAME, METRIC_ITEM_NAME] +RESERVED_ITEM_NAMES = [METRIC_ITEM_NAME] _INIT_TIME = datetime.datetime.now(tz=datetime.timezone.utc) @@ -99,14 +98,6 @@ def _metrics_file_exists(metrics_item_path: epath.Path) -> bool: ) -def _descriptor_file_exists(descriptor_item_path: epath.Path) -> bool: - """True if item directory AND actual file both exist.""" - return ( - descriptor_item_path.exists() - and (descriptor_item_path / f'{DESCRIPTOR_ITEM_NAME}.pbtxt').exists() - ) - - class StepAlreadyExistsError(ValueError): """Raised when a step is already present for a save request."""