Skip to content

[Tests] Skip layerwise casting tests on devices without float8_e4m3fn support#14073

Open
GiGiKoneti wants to merge 5 commits into
huggingface:mainfrom
GiGiKoneti:fix/skip-float8-tests-on-unsupported-devices
Open

[Tests] Skip layerwise casting tests on devices without float8_e4m3fn support#14073
GiGiKoneti wants to merge 5 commits into
huggingface:mainfrom
GiGiKoneti:fix/skip-float8-tests-on-unsupported-devices

Conversation

@GiGiKoneti

@GiGiKoneti GiGiKoneti commented Jun 26, 2026

Copy link
Copy Markdown

What does this PR do?

torch.float8_e4m3fn is a storage-only dtype that is not supported on all hardware backends. On Apple Silicon (MPS), attempting to cast a tensor to this dtype raises a TypeError:

TypeError: Cannot convert a MPS Tensor to float8_e4m3fn dtype as the MPS framework doesn't support float8 types.

This causes five layerwise-casting tests to crash instead of being skipped:

File Test
tests/pipelines/test_pipelines_common.py test_layerwise_casting_inference
tests/models/testing_utils/memory.py test_layerwise_casting_memory
tests/models/testing_utils/memory.py test_layerwise_casting_training
tests/lora/utils.py test_layerwise_casting_inference_denoiser
tests/lora/utils.py test_layerwise_casting_peft_input_autocast_denoiser

Fix

Each affected test now includes a runtime check at the very top:

try:
    torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn)
except TypeError:
    self.skipTest(f"Device {torch_device} does not support float8 storage dtype.")

If the device cannot handle float8_e4m3fn, the test is skipped with a clear message rather than crashing the entire test suite.

Fixes #14072

Before submitting

Who can review?

@pcuenca (MPS), @sayakpaul @DN6 (General functionalities)

@github-actions github-actions Bot added fixes-issue size/S PR with diff < 50 LOC tests and removed size/S PR with diff < 50 LOC labels Jun 26, 2026
@GiGiKoneti

GiGiKoneti commented Jun 26, 2026

Copy link
Copy Markdown
Author

Alternative approaches considered

This PR uses an inline try/except guard at the top of each affected test. It works, but I wanted to flag two cleaner alternatives for your consideration:

1. @require_float8_support decorator (preferred)

Instead of repeating the same 4-line block across 5 tests, a reusable decorator in the test utilities would be more consistent with the existing @require_* pattern in diffusers:

def require_float8_support(test_func):
    @functools.wraps(test_func)
    def wrapper(*args, **kwargs):
        try:
            torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn)
        except TypeError:
            raise unittest.SkipTest(f"Device {torch_device} does not support float8_e4m3fn")
        return test_func(*args, **kwargs)
    return wrapper

Each test then becomes:

@require_float8_support
def test_layerwise_casting_inference(self):
    ...

2. Guard in source code (enable_layerwise_casting)

A more complete fix would be to also add a check inside enable_layerwise_casting() itself, raising a clear ValueError when the device does not support the requested storage_dtype, rather than letting it crash deep inside the pipeline forward pass. The test-side skip would still be needed, but users calling the API directly would get a much better error message.


let me know which you prefer and I will update the PR. cc @pcuenca @DN6

@sayakpaul sayakpaul left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it not be better to use a skip_mps decorator directly to skip those tests?

def skip_mps(test_case):

@GiGiKoneti

Copy link
Copy Markdown
Author

Using the @skip_mps decorator is much cleaner and fits the existing test codebase conventions perfectly.
The only reason I initially went with the dynamic try/except check was to make it future-proof and device-agnostic (in case other accelerators/configurations also lack float8 casting support and trigger a similar TypeError). But since MPS is the primary backend experiencing this crash and @skip_mps is idiomatic for the suite, using it is definitely better.

u want me to update PR ?

@sayakpaul

Copy link
Copy Markdown
Member

Future-proof would xfailing those tests when torch_device == "mps". That way, if support is added in the future, xfail won't work anymore.

@GiGiKoneti

GiGiKoneti commented Jun 26, 2026

Copy link
Copy Markdown
Author

@sayakpaul
Since float8 support is a device capability (similar to bfloat16), would you prefer we xfail on MPS specifically to track when they add support, or use a generic @require_float8 decorator to skip on any unsupported device?

@github-actions github-actions Bot added the size/S PR with diff < 50 LOC label Jun 26, 2026
@GiGiKoneti

Copy link
Copy Markdown
Author

@sayakpaul
I have updated the PR to use @pytest.mark.xfail(condition=torch_device == "mps", reason="MPS does not support float8 casting.", strict=True) for all five affected layerwise casting tests instead of skipping them.

This way, if MPS adds float8 casting support in the future, the tests will unexpectedly pass (XPASS) and fail the suite, alerting us to remove the decorator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

fixes-issue size/S PR with diff < 50 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Layerwise casting tests fail on MPS backend due to float8 unsupported dtype

2 participants