Skip to content

Commit cfae04b

Browse files
exclamafortefacebook-github-bot
authored andcommitted
Fix pt2 multiprocess test (#3151)
Summary: Pull Request resolved: #3151 `common_utils.py` disables global flags, which causes this test to error. https://www.internalfb.com/code/fbsource/[1caeee2345d0]/fbcode/caffe2/torch/testing/_internal/common_utils.py?lines=215 Test error: ``` Process ForkServerProcess-17: Traceback (most recent call last): File "/usr/local/fbcode/platform010/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap self.run() File "/usr/local/fbcode/platform010/lib/python3.10/multiprocessing/process.py", line 108, in run self._target(*self._args, **self._kwargs) File "/re_cwd/buck-out/v2/gen/fbcode/1121f3f340860407/torchrec/distributed/tests/__test_pt2_multiprocess__/test_pt2_multiprocess#link-tree/torchrec/distributed/tests/test_pt2_multiprocess.py", line 244, in _test_compile_rank_fn with MultiProcessContext(rank, world_size, backend, local_size) as ctx: File "/re_cwd/buck-out/v2/gen/fbcode/1121f3f340860407/torchrec/distributed/tests/__test_pt2_multiprocess__/test_pt2_multiprocess#link-tree/torchrec/distributed/test_utils/multi_process.py", line 93, in __exit__ torch.backends.cudnn.allow_tf32 = True File "/re_cwd/buck-out/v2/gen/fbcode/1121f3f340860407/torchrec/distributed/tests/__test_pt2_multiprocess__/test_pt2_multiprocess#link-tree/torch/backends/__init__.py", line 48, in __set__ raise RuntimeError( RuntimeError: not allowed to set torch.backends.cudnn flags after disable_global_flags; please use flags() context manager instead ``` Reviewed By: huydhn Differential Revision: D77614983 fbshipit-source-id: d1eab581fc8fcb961ca375c3017498cd954f0edd
1 parent 8b9c461 commit cfae04b

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

torchrec/distributed/tests/test_pt2_multiprocess.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,11 @@ def test_compile_multiprocess(
749749
model_type, sharding_type, input_type, tovb, compile_backend, config = (
750750
given_config_tuple
751751
)
752+
# torch/testing/_internal/common_utils.py calls `disable_global_flags()`
753+
# workaround RuntimeError: not allowed to set ... after disable_global_flags
754+
setattr( # noqa: B010
755+
torch.backends, "__allow_nonbracketed_mutation_flag", True
756+
)
752757
self._run_multi_process_test(
753758
callable=_test_compile_rank_fn,
754759
test_model_type=model_type,
@@ -761,6 +766,9 @@ def test_compile_multiprocess(
761766
config=config,
762767
torch_compile_backend=compile_backend,
763768
)
769+
setattr( # noqa: B010
770+
torch.backends, "__allow_nonbracketed_mutation_flag", False
771+
)
764772

765773
# pyre-ignore
766774
@unittest.skipIf(

0 commit comments

Comments
 (0)