[Tests] Skip layerwise casting tests on devices without float8_e4m3fn support#14073
[Tests] Skip layerwise casting tests on devices without float8_e4m3fn support#14073GiGiKoneti wants to merge 5 commits into
Conversation
Alternative approaches consideredThis PR uses an inline 1.
|
sayakpaul
left a comment
There was a problem hiding this comment.
Would it not be better to use a skip_mps decorator directly to skip those tests?
|
Using the u want me to update PR ? |
|
Future-proof would |
|
@sayakpaul |
|
@sayakpaul This way, if MPS adds float8 casting support in the future, the tests will unexpectedly pass (XPASS) and fail the suite, alerting us to remove the decorator. |
What does this PR do?
torch.float8_e4m3fnis a storage-only dtype that is not supported on all hardware backends. On Apple Silicon (MPS), attempting to cast a tensor to this dtype raises aTypeError:This causes five layerwise-casting tests to crash instead of being skipped:
tests/pipelines/test_pipelines_common.pytest_layerwise_casting_inferencetests/models/testing_utils/memory.pytest_layerwise_casting_memorytests/models/testing_utils/memory.pytest_layerwise_casting_trainingtests/lora/utils.pytest_layerwise_casting_inference_denoisertests/lora/utils.pytest_layerwise_casting_peft_input_autocast_denoiserFix
Each affected test now includes a runtime check at the very top:
If the device cannot handle
float8_e4m3fn, the test is skipped with a clear message rather than crashing the entire test suite.Fixes #14072
Before submitting
.ai/review-rules.md?Who can review?
@pcuenca (MPS), @sayakpaul @DN6 (General functionalities)