Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Prioritize user-defined
train()
function over the staged forward()
(
#2174) Summary: It gives more user-friendly error messages upon unimplemented train tests. Fixes #2166 Pull Request resolved: #2174 Test Plan: ``` $ python -u run.py -d cuda -t train --bs 4 --metrics None hf_Whisper /home/runner/miniconda3/envs/torchbench/lib/python3.11/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/runner/miniconda3/envs/torchbench/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/runner/miniconda3/envs/torchbench/lib/python3.11/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( Running train method from hf_Whisper on cuda in eager mode with input batch size 4 and precision fp32. Traceback (most recent call last): File "/workspace/benchmark/run.py", line 623, in <module> main() # pragma: no cover ^^^^^^ File "/workspace/benchmark/run.py", line 593, in main run_one_step( File "/workspace/benchmark/run.py", line 173, in run_one_step func() File "/workspace/benchmark/torchbenchmark/util/model.py", line 315, in invoke self.train() File "/workspace/benchmark/torchbenchmark/models/hf_Whisper/__init__.py", line 20, in train raise NotImplementedError("Training is not implemented.") NotImplementedError: Training is not implemented. ``` Reviewed By: aaronenyeshi Differential Revision: D54012510 Pulled By: xuzhao9 fbshipit-source-id: bb27bd5adb0bcd778c2c58db7ef5a7b8cc9b2c20
- Loading branch information