From 629d1f35eea38c2e5acbd0c407f22c94e3f8d7a7 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 24 Jun 2025 10:45:19 +0000 Subject: [PATCH 1/2] reset deterministic in tearDownClass Signed-off-by: jiqing-feng --- tests/quantization/bnb/test_4bit.py | 4 ++++ tests/quantization/bnb/test_mixed_int8.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index bdb8920a399e..a35d5d43fd89 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -100,6 +100,10 @@ class Base4bitTests(unittest.TestCase): def setUpClass(cls): torch.use_deterministic_algorithms(True) + @classmethod + def tearDownClass(cls): + torch.use_deterministic_algorithms(False) + def get_dummy_inputs(self): prompt_embeds = load_pt( "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt", diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index d048b0b7db46..a789a423fd72 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -101,6 +101,10 @@ class Base8bitTests(unittest.TestCase): def setUpClass(cls): torch.use_deterministic_algorithms(True) + @classmethod + def tearDownClass(cls): + torch.use_deterministic_algorithms(False) + def get_dummy_inputs(self): prompt_embeds = load_pt( "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt", From 11d3e103cf54ebca0a35ccd4e654623fd0505ab8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 24 Jun 2025 13:14:24 +0000 Subject: [PATCH 2/2] fix deterministic setting Signed-off-by: jiqing-feng --- tests/quantization/bnb/test_4bit.py | 7 +++++-- tests/quantization/bnb/test_mixed_int8.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index a35d5d43fd89..bf9a8e937c43 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -98,11 +98,14 @@ class Base4bitTests(unittest.TestCase): @classmethod def setUpClass(cls): - torch.use_deterministic_algorithms(True) + cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled() + if not cls.is_deterministic_enabled: + torch.use_deterministic_algorithms(True) @classmethod def tearDownClass(cls): - torch.use_deterministic_algorithms(False) + if not cls.is_deterministic_enabled: + torch.use_deterministic_algorithms(False) def get_dummy_inputs(self): prompt_embeds = load_pt( diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index a789a423fd72..314396d516ef 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -99,11 +99,14 @@ class Base8bitTests(unittest.TestCase): @classmethod def setUpClass(cls): - torch.use_deterministic_algorithms(True) + cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled() + if not cls.is_deterministic_enabled: + torch.use_deterministic_algorithms(True) @classmethod def tearDownClass(cls): - torch.use_deterministic_algorithms(False) + if not cls.is_deterministic_enabled: + torch.use_deterministic_algorithms(False) def get_dummy_inputs(self): prompt_embeds = load_pt(