Fix group offloading for quanto-quantized models and the use_stream path for quantized tensor subclasses#14038
Fix group offloading for quanto-quantized models and the use_stream path for quantized tensor subclasses#14038Sunt-ing wants to merge 1 commit into
Conversation
…ath for quantized tensor subclasses
|
Group offloading should have been fixed, though with #13276. Can you check again? |
|
Hi @sayakpaul, thanks. Yes, I rechecked against #13276 before opening this. #13276 makes group offloading work for torchao by swapping the subclass (
Both #12610 and #13281 are still open. I confirmed on current main (with #13276) vs this PROn approach: I deliberately mirrored the existing |
What does this PR do?
Fixes #12610
Fixes #13281
Group offloading moves a group's parameters between CPU and the accelerator by reassigning
param.data:This is correct for plain tensors but wrong for tensor subclasses (quantized weights), whose real payload lives in internal sub-tensors (quanto
WeightQBytesTensor:_data/_scale; torchaoAffineQuantizedTensor:qdata/scale/...). Reassigning.dataonly swaps the outer wrapper and leaves the inner tensors on the source device, so the next matmul fails withmat2 is on cpu, different from cuda:0.#13276 fixed this for torchao by swapping the whole subclass via
torch.utils.swap_tensorsand restoring inner attributes one by one. Two gaps remained:enable_group_offloadhits the wrapper-only.data =path and crashes with a device mismatch on the first forward, for bothleaf_levelandblock_level.use_stream=True,_to_cpu/_pinned_memory_tensorscallpin_memory()/is_pinned(), which neither subclass supports: quanto silently loses the subclass identity, and torchao raisesNotImplementedError: ... aten.is_pinned. So torchao +use_stream=Truecrashes even though its non-stream path was already fixed.Changes (
src/diffusers/hooks/group_offloading.py)_is_quanto_tensorplus quanto helpers, and handle quanto next to the existing torchao branch in_transfer_tensor_to_device(onload),_offload_to_memory(restore / offload), and therecord_streampath. Inner tensor names come from the standard subclass protocol__tensor_flatten__(); quanto onload usestorch.utils.swap_tensorsinstead of.data =._to_cpuand_pinned_memory_tensors, skippin_memory()/is_pinned()for quanto and torchao subclasses.Tests
Added
test_group_offloadingto the quanto and torchao quantization suites. Each loads a quantized tiny Flux transformer, offloads it acrossleaf_level/block_leveland non-stream /use_stream, and asserts the output matches the non-offloaded quantized baseline.tests/quantization/quanto/test_quanto.py(int8 and float8): both fail onmainwith the device mismatch, pass here.tests/quantization/torchao/test_torchao.py::TorchAoTest::test_group_offloading: theuse_stream=Truecases fail onmainwith theaten.is_pinnederror, pass here.Reproduction and before/after
Environment: NVIDIA RTX 4090,
torch==2.8.0+cu128,diffusers@2d0110f,optimum-quanto==0.2.7,torchao==0.17.0.Minimal standalone repro for #12610 (quanto):
Running the new tests (
RUN_NIGHTLY=1 RUN_SLOW=1):Across
leaf_level/block_level× non-stream /use_stream/record_stream, the offloaded output is bit-identical (max abs diff = 0.0) to the fully-on-accelerator quantized baseline. A non-quantized group-offload equivalence sweep stays at0.0(plain-tensor path unchanged).Relationship to other work
mx/nvfp4tensors.Int8WeightOnlyConfigAffineQuantizedTensorstill raisesaten.is_pinnedontorchao==0.17.0, so the streamed path is still broken for the common int8 case. Skipping pinning on the diffusers side fixes it regardless of the torchao version, and is also required for quanto, whose subclass tensors do not implement torch pinning at all._to_cpu,_pinned_memory_tensors,_swap_torchao_tensor) to add disk offload. They are orthogonal in intent (disk vs the memory device-mismatch / stream-pin crash here) but touch the same region, so this PR will need a rebase around whichever lands first.Who can review?
cc @sayakpaul
Before submitting
.ai/review-rules.md?