Skip to content

Commit e5fe8ce

Browse files
committed
Add an easy way to run a script for a few steps only
I've wanted this tool for a while, figured I should just propose it. Often I need to test out a script or colab I did not write, and just want to run a few train steps without for every fit call without finding every call to fit in the script. This adds a debugging tool to do just that. ``` KERAS_MAX_EPOCHS=1 KERAS_MAX_STEPS=5 python train.py ```
1 parent 81c5097 commit e5fe8ce

File tree

8 files changed

+130
-2
lines changed

8 files changed

+130
-2
lines changed

keras/api/_tf_keras/keras/config/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
from keras.src.backend.config import (
1818
is_flash_attention_enabled as is_flash_attention_enabled,
1919
)
20+
from keras.src.backend.config import max_epochs as max_epochs
21+
from keras.src.backend.config import max_steps as max_steps
2022
from keras.src.backend.config import set_epsilon as set_epsilon
2123
from keras.src.backend.config import set_floatx as set_floatx
2224
from keras.src.backend.config import (
2325
set_image_data_format as set_image_data_format,
2426
)
27+
from keras.src.backend.config import set_max_epochs as set_max_epochs
28+
from keras.src.backend.config import set_max_steps as set_max_steps
2529
from keras.src.dtype_policies.dtype_policy import dtype_policy as dtype_policy
2630
from keras.src.dtype_policies.dtype_policy import (
2731
set_dtype_policy as set_dtype_policy,

keras/api/config/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
from keras.src.backend.config import (
1818
is_flash_attention_enabled as is_flash_attention_enabled,
1919
)
20+
from keras.src.backend.config import max_epochs as max_epochs
21+
from keras.src.backend.config import max_steps as max_steps
2022
from keras.src.backend.config import set_epsilon as set_epsilon
2123
from keras.src.backend.config import set_floatx as set_floatx
2224
from keras.src.backend.config import (
2325
set_image_data_format as set_image_data_format,
2426
)
27+
from keras.src.backend.config import set_max_epochs as set_max_epochs
28+
from keras.src.backend.config import set_max_steps as set_max_steps
2529
from keras.src.dtype_policies.dtype_policy import dtype_policy as dtype_policy
2630
from keras.src.dtype_policies.dtype_policy import (
2731
set_dtype_policy as set_dtype_policy,

keras/src/backend/config.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
# Default backend: TensorFlow.
1616
_BACKEND = "tensorflow"
1717

18+
# Cap run duration for debugging.
19+
_MAX_EPOCHS = None
20+
_MAX_STEPS = None
21+
1822

1923
@keras_export(["keras.config.floatx", "keras.backend.floatx"])
2024
def floatx():
@@ -304,7 +308,10 @@ def keras_home():
304308
_backend = os.environ["KERAS_BACKEND"]
305309
if _backend:
306310
_BACKEND = _backend
307-
311+
if "KERAS_MAX_EPOCHS" in os.environ:
312+
_MAX_EPOCHS = int(os.environ["KERAS_MAX_EPOCHS"])
313+
if "KERAS_MAX_STEPS" in os.environ:
314+
_MAX_STEPS = int(os.environ["KERAS_MAX_STEPS"])
308315

309316
if _BACKEND != "tensorflow":
310317
# If we are not running on the tensorflow backend, we should stop tensorflow
@@ -333,3 +340,66 @@ def backend():
333340
334341
"""
335342
return _BACKEND
343+
344+
345+
@keras_export(["keras.config.set_max_epochs"])
346+
def set_max_epochs(max_epochs):
347+
"""Limit the maximum number of epochs for any call to fit.
348+
349+
This will cap the number of epochs for any training run using `model.fit()`.
350+
This is purely for debugging, and can also be set via the `KERAS_MAX_EPOCHS`
351+
environment variable to quickly run a script without modifying its source.
352+
353+
Args:
354+
max_epochs: The integer limit on the number of epochs or `None`. If
355+
`None`, no limit is applied.
356+
"""
357+
global _MAX_EPOCHS
358+
_MAX_EPOCHS = max_epochs
359+
360+
361+
@keras_export(["keras.config.set_max_steps"])
362+
def set_max_steps(max_steps):
363+
"""Limit the maximum number of steps for any call to fit/evaluate/predict.
364+
365+
This will cap the number of steps for single epoch of a call to `fit()`,
366+
`evaluate()`, or `predict()`. This is purely for debugging, and can also be
367+
set via the `KERAS_MAX_STEPS` environment variable to quickly run a script
368+
without modifying its source.
369+
370+
Args:
371+
max_epochs: The integer limit on the number of epochs or `None`. If
372+
`None`, no limit is applied.
373+
"""
374+
global _MAX_STEPS
375+
_MAX_STEPS = max_steps
376+
377+
378+
@keras_export(["keras.config.max_epochs"])
379+
def max_epochs():
380+
"""Get the maximum number of epochs for any call to fit.
381+
382+
Retrieves the limit on the number of epochs set by
383+
`keras.config.set_max_epochs` or the `KERAS_MAX_EPOCHS` environment
384+
variable.
385+
386+
Returns:
387+
The integer limit on the number of epochs or `None`, if no limit has
388+
been set.
389+
"""
390+
return _MAX_EPOCHS
391+
392+
393+
@keras_export(["keras.config.max_steps"])
394+
def max_steps():
395+
"""Get the maximum number of steps for any call to fit/evaluate/predict.
396+
397+
Retrieves the limit on the number of epochs set by
398+
`keras.config.set_max_steps` or the `KERAS_MAX_STEPS` environment
399+
variable.
400+
401+
Args:
402+
max_epochs: The integer limit on the number of epochs or `None`. If
403+
`None`, no limit is applied.
404+
"""
405+
return _MAX_STEPS

keras/src/backend/jax/trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from keras.src import callbacks as callbacks_module
1010
from keras.src import optimizers as optimizers_module
1111
from keras.src import tree
12+
from keras.src.backend import config
1213
from keras.src.backend import distribution_lib as jax_distribution_lib
1314
from keras.src.distribution import distribution_lib
1415
from keras.src.trainers import trainer as base_trainer
@@ -341,6 +342,7 @@ def fit(
341342
validation_freq=1,
342343
):
343344
self._assert_compile_called("fit")
345+
epochs = config.max_epochs() or epochs
344346
# TODO: respect compiled trainable state
345347
self._eval_epoch_iterator = None
346348
if validation_split and validation_data is None:

keras/src/backend/tensorflow/trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from keras.src import metrics as metrics_module
1010
from keras.src import optimizers as optimizers_module
1111
from keras.src import tree
12+
from keras.src.backend import config
1213
from keras.src.losses import loss as loss_module
1314
from keras.src.trainers import trainer as base_trainer
1415
from keras.src.trainers.data_adapters import array_slicing
@@ -309,6 +310,7 @@ def fit(
309310
validation_freq=1,
310311
):
311312
self._assert_compile_called("fit")
313+
epochs = config.max_epochs() or epochs
312314
# TODO: respect compiled trainable state
313315
self._eval_epoch_iterator = None
314316
if validation_split and validation_data is None:

keras/src/backend/torch/trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src import callbacks as callbacks_module
99
from keras.src import optimizers as optimizers_module
1010
from keras.src import tree
11+
from keras.src.backend import config
1112
from keras.src.trainers import trainer as base_trainer
1213
from keras.src.trainers.data_adapters import array_slicing
1314
from keras.src.trainers.data_adapters import data_adapter_utils
@@ -187,6 +188,7 @@ def fit(
187188
raise ValueError(
188189
"You must call `compile()` before calling `fit()`."
189190
)
191+
epochs = config.max_epochs() or epochs
190192

191193
# TODO: respect compiled trainable state
192194
self._eval_epoch_iterator = None

keras/src/trainers/epoch_iterator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import contextlib
4343
import warnings
4444

45+
from keras.src.backend import config
4546
from keras.src.trainers import data_adapters
4647

4748

@@ -57,7 +58,7 @@ def __init__(
5758
class_weight=None,
5859
steps_per_execution=1,
5960
):
60-
self.steps_per_epoch = steps_per_epoch
61+
self.steps_per_epoch = config.max_steps() or steps_per_epoch
6162
self.steps_per_execution = steps_per_execution
6263
self._current_iterator = None
6364
self._epoch_iterator = None

keras/src/trainers/trainer_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from keras.src import ops
1515
from keras.src import optimizers
1616
from keras.src import testing
17+
from keras.src.backend import config
1718
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
1819
from keras.src.callbacks.callback import Callback
1920
from keras.src.optimizers.rmsprop import RMSprop
@@ -1506,6 +1507,48 @@ def test_steps_per_epoch(self, steps_per_epoch_test, mode):
15061507
)
15071508
self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y))
15081509

1510+
@pytest.mark.requires_trainable_backend
1511+
def test_max_epochs_and_steps(self):
1512+
batch_size = 8
1513+
epochs = 4
1514+
num_batches = 10
1515+
data_size = num_batches * batch_size
1516+
x, y = np.ones((data_size, 4)), np.ones((data_size, 1))
1517+
model = ExampleModel(units=1)
1518+
model.compile(
1519+
loss="mse",
1520+
optimizer="sgd",
1521+
metrics=[EpochAgnosticMeanSquaredError()],
1522+
)
1523+
step_observer = StepObserver()
1524+
model.fit(
1525+
x=x,
1526+
y=y,
1527+
batch_size=batch_size,
1528+
epochs=epochs,
1529+
callbacks=[step_observer],
1530+
verbose=0,
1531+
)
1532+
self.assertEqual(step_observer.epoch_begin_count, epochs)
1533+
self.assertEqual(step_observer.begin_count, num_batches * epochs)
1534+
try:
1535+
config.set_max_epochs(2)
1536+
config.set_max_steps(3)
1537+
step_observer = StepObserver()
1538+
model.fit(
1539+
x=x,
1540+
y=y,
1541+
batch_size=batch_size,
1542+
epochs=epochs,
1543+
callbacks=[step_observer],
1544+
verbose=0,
1545+
)
1546+
self.assertEqual(step_observer.epoch_begin_count, 2)
1547+
self.assertEqual(step_observer.begin_count, 6)
1548+
finally:
1549+
config.set_max_epochs(None)
1550+
config.set_max_steps(None)
1551+
15091552
@parameterized.named_parameters(
15101553
named_product(
15111554
steps_per_epoch_test=[

0 commit comments

Comments
 (0)