Skip to content

Commit a5d2d12

Browse files
exclamafortefacebook-github-bot
authored andcommitted
Fix test more generally (#3165)
Summary: Pull Request resolved: #3165 https://www.internalfb.com/diff/D77614983 attempted to fix a test, but I still see it showing up in other tests, so this fixes it in general. Reviewed By: huydhn Differential Revision: D77758554 fbshipit-source-id: bd390081b68fa650f1cfd6d2a93a1fbf206aaff7
1 parent 0919506 commit a5d2d12

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

torchrec/distributed/test_utils/multi_process.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,15 @@ def __exit__(self, exc_type, exc_instance, traceback) -> None:
9090
dist.destroy_process_group(self.pg)
9191
torch.use_deterministic_algorithms(False)
9292
if torch.cuda.is_available() and self.disable_cuda_tf_32:
93+
# torch/testing/_internal/common_utils.py calls `disable_global_flags()`
94+
# workaround RuntimeError: not allowed to set ... after disable_global_flags
95+
setattr( # noqa: B010
96+
torch.backends, "__allow_nonbracketed_mutation_flag", True
97+
)
9398
torch.backends.cudnn.allow_tf32 = True
99+
setattr( # noqa: B010
100+
torch.backends, "__allow_nonbracketed_mutation_flag", False
101+
)
94102

95103

96104
class MultiProcessTestBase(unittest.TestCase):

torchrec/distributed/tests/test_pt2_multiprocess.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -749,11 +749,6 @@ def test_compile_multiprocess(
749749
model_type, sharding_type, input_type, tovb, compile_backend, config = (
750750
given_config_tuple
751751
)
752-
# torch/testing/_internal/common_utils.py calls `disable_global_flags()`
753-
# workaround RuntimeError: not allowed to set ... after disable_global_flags
754-
setattr( # noqa: B010
755-
torch.backends, "__allow_nonbracketed_mutation_flag", True
756-
)
757752
self._run_multi_process_test(
758753
callable=_test_compile_rank_fn,
759754
test_model_type=model_type,
@@ -766,9 +761,6 @@ def test_compile_multiprocess(
766761
config=config,
767762
torch_compile_backend=compile_backend,
768763
)
769-
setattr( # noqa: B010
770-
torch.backends, "__allow_nonbracketed_mutation_flag", False
771-
)
772764

773765
# pyre-ignore
774766
@unittest.skipIf(

0 commit comments

Comments
 (0)