Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2106,6 +2106,11 @@ def test_correct_lora_configs_with_different_ranks(self):
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))

def test_layerwise_casting_inference_denoiser(self):
try:
torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn)
except TypeError:
self.skipTest(f"Device {torch_device} does not support float8 storage dtype.")

from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN

Expand Down Expand Up @@ -2164,6 +2169,10 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):

See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
"""
try:
torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn)
except TypeError:
self.skipTest(f"Device {torch_device} does not support float8 storage dtype.")

from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import (
Expand Down
10 changes: 10 additions & 0 deletions tests/models/testing_utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,11 @@ class LayerwiseCastingTesterMixin:

@torch.no_grad()
def test_layerwise_casting_memory(self):
try:
torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn)
except TypeError:
pytest.skip(f"Device {torch_device} does not support float8 storage dtype.")

MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0

Expand Down Expand Up @@ -437,6 +442,11 @@ def get_memory_usage(storage_dtype, compute_dtype):
), "Peak memory should be lower or within tolerance with fp8 storage"

def test_layerwise_casting_training(self):
try:
torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn)
except TypeError:
pytest.skip(f"Device {torch_device} does not support float8 storage dtype.")

def test_fn(storage_dtype, compute_dtype):
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
Expand Down
5 changes: 5 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,6 +2296,11 @@ def test_layerwise_casting_inference(self):
if not self.test_layerwise_casting:
return

try:
torch.zeros(1, device=torch_device).to(dtype=torch.float8_e4m3fn)
except TypeError:
self.skipTest(f"Device {torch_device} does not support float8 storage dtype.")

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device, dtype=torch.bfloat16)
Expand Down
Loading