From 15413f07a1a8275d4b4493db9edd1e3803fe99e2 Mon Sep 17 00:00:00 2001 From: GiGiKoneti Date: Fri, 26 Jun 2026 14:42:48 +0530 Subject: [PATCH 1/2] Skip layerwise casting tests on devices without float8_e4m3fn support --- tests/lora/utils.py | 9 +++++++++ tests/models/testing_utils/memory.py | 10 ++++++++++ tests/pipelines/test_pipelines_common.py | 5 +++++ 3 files changed, 24 insertions(+) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index d6cb50bc52e4..af145c356aef 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2106,6 +2106,11 @@ def test_correct_lora_configs_with_different_ranks(self): self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) def test_layerwise_casting_inference_denoiser(self): + try: + torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn) + except TypeError: + self.skipTest(f"Device {torch_device} does not support float8 storage dtype.") + from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN @@ -2164,6 +2169,10 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self): See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details. """ + try: + torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn) + except TypeError: + self.skipTest(f"Device {torch_device} does not support float8 storage dtype.") from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from diffusers.hooks.layerwise_casting import ( diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py index 84c3e23133a1..9d1ad0b480c0 100644 --- a/tests/models/testing_utils/memory.py +++ b/tests/models/testing_utils/memory.py @@ -385,6 +385,11 @@ class LayerwiseCastingTesterMixin: @torch.no_grad() def test_layerwise_casting_memory(self): + try: + torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn) + except TypeError: + pytest.skip(f"Device {torch_device} does not support float8 storage dtype.") + MB_TOLERANCE = 0.2 LEAST_COMPUTE_CAPABILITY = 8.0 @@ -437,6 +442,11 @@ def get_memory_usage(storage_dtype, compute_dtype): ), "Peak memory should be lower or within tolerance with fp8 storage" def test_layerwise_casting_training(self): + try: + torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn) + except TypeError: + pytest.skip(f"Device {torch_device} does not support float8 storage dtype.") + def test_fn(storage_dtype, compute_dtype): if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16: pytest.skip("Skipping test because CPU doesn't go well with bfloat16.") diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index fcd8ab24bab8..2a4c8819048f 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2296,6 +2296,11 @@ def test_layerwise_casting_inference(self): if not self.test_layerwise_casting: return + try: + torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn) + except TypeError: + self.skipTest(f"Device {torch_device} does not support float8 storage dtype.") + components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.to(torch_device, dtype=torch.bfloat16) From 30d835667824071d89d88a5d94ea10d557026dd0 Mon Sep 17 00:00:00 2001 From: GiGiKoneti Date: Sun, 28 Jun 2026 12:20:09 +0530 Subject: [PATCH 2/2] Xfail float8 layerwise casting tests on MPS --- tests/lora/utils.py | 19 ++++++++++--------- tests/models/testing_utils/memory.py | 20 ++++++++++---------- tests/pipelines/test_pipelines_common.py | 10 +++++----- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index af145c356aef..a3b0fe0aa2f2 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2105,12 +2105,12 @@ def test_correct_lora_configs_with_different_ranks(self): self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + @pytest.mark.xfail( + condition=torch_device == "mps", + reason="MPS does not support float8 casting.", + strict=True, + ) def test_layerwise_casting_inference_denoiser(self): - try: - torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn) - except TypeError: - self.skipTest(f"Device {torch_device} does not support float8 storage dtype.") - from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN @@ -2154,6 +2154,11 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] + @pytest.mark.xfail( + condition=torch_device == "mps", + reason="MPS does not support float8 casting.", + strict=True, + ) @require_peft_version_greater("0.14.0") def test_layerwise_casting_peft_input_autocast_denoiser(self): r""" @@ -2169,10 +2174,6 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self): See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details. """ - try: - torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn) - except TypeError: - self.skipTest(f"Device {torch_device} does not support float8 storage dtype.") from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from diffusers.hooks.layerwise_casting import ( diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py index 9d1ad0b480c0..9cb28919efb8 100644 --- a/tests/models/testing_utils/memory.py +++ b/tests/models/testing_utils/memory.py @@ -383,13 +383,13 @@ class LayerwiseCastingTesterMixin: - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass """ + @pytest.mark.xfail( + condition=torch_device == "mps", + reason="MPS does not support float8 casting.", + strict=True, + ) @torch.no_grad() def test_layerwise_casting_memory(self): - try: - torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn) - except TypeError: - pytest.skip(f"Device {torch_device} does not support float8 storage dtype.") - MB_TOLERANCE = 0.2 LEAST_COMPUTE_CAPABILITY = 8.0 @@ -441,12 +441,12 @@ def get_memory_usage(storage_dtype, compute_dtype): or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE ), "Peak memory should be lower or within tolerance with fp8 storage" + @pytest.mark.xfail( + condition=torch_device == "mps", + reason="MPS does not support float8 casting.", + strict=True, + ) def test_layerwise_casting_training(self): - try: - torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn) - except TypeError: - pytest.skip(f"Device {torch_device} does not support float8 storage dtype.") - def test_fn(storage_dtype, compute_dtype): if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16: pytest.skip("Skipping test because CPU doesn't go well with bfloat16.") diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 2a4c8819048f..d6a09028157f 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2292,15 +2292,15 @@ def test_save_load_dduf(self, atol=1e-4, rtol=1e-4): elif isinstance(pipeline_out, torch.Tensor) and isinstance(loaded_pipeline_out, torch.Tensor): assert torch.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol) + @pytest.mark.xfail( + condition=torch_device == "mps", + reason="MPS does not support float8 casting.", + strict=True, + ) def test_layerwise_casting_inference(self): if not self.test_layerwise_casting: return - try: - torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn) - except TypeError: - self.skipTest(f"Device {torch_device} does not support float8 storage dtype.") - components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.to(torch_device, dtype=torch.bfloat16)