Skip to content

Incorporate safetensors support to TorchAO#13719

Open
hlky wants to merge 7 commits into
huggingface:mainfrom
hlky:torchao_safetensors
Open

Incorporate safetensors support to TorchAO#13719
hlky wants to merge 7 commits into
huggingface:mainfrom
hlky:torchao_safetensors

Conversation

@hlky
Copy link
Copy Markdown
Contributor

@hlky hlky commented May 11, 2026

What does this PR do?

Fixes #13713

Notes

Preexisting test failures

Some preexisting tests failed in my environment on main, the same tests fail on this PR, no new test failures are introduced.

Failed tests

FAILED tests/quantization/torchao/test_torchao.py::TorchAoTest::test_device_map - RuntimeError: cutlass cannot initialize
FAILED tests/quantization/torchao/test_torchao.py::TorchAoTest::test_memory_footprint - AssertionError: False is not true
FAILED tests/quantization/torchao/test_torchao.py::TorchAoTest::test_model_memory_usage - assert (34840064 / 34806272) >= 1.02
FAILED tests/quantization/torchao/test_torchao.py::TorchAoTest::test_modules_to_not_convert - AssertionError: False is not true
FAILED tests/quantization/torchao/test_torchao.py::TorchAoTest::test_quantization - TypeError: IntxWeightOnlyConfig.__init__() got an unexpected keyword argument 'dtype'
FAILED tests/quantization/torchao/test_torchao.py::TorchAoTest::test_training - RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Test coverage

More test coverage would be useful.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @DN6

Copy link
Copy Markdown

@wadeKeith wadeKeith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good integration - safetensors support for TorchAO improves security and loading speed. Clean implementation. LGTM! Reviewed by Hermes Agent.

@sayakpaul
Copy link
Copy Markdown
Member

@hlky thanks.

I ran some of those tests on an H100 with the following env:

- 🤗 Diffusers version: 0.39.0.dev0
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.12.12
- PyTorch version (GPU?): 2.11.0+cu129 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 1.11.0
- Transformers version: 5.6.0.dev0
- Accelerate version: 1.12.0
- PEFT version: 0.18.2.dev0
- Bitsandbytes version: 0.49.0
- Safetensors version: 0.7.0
- xFormers version: not installed
- Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Some passed and some failed not because all of them are functionally incorrect. They failed because of the dependency on the hardware. Stuff like assertion errors on slices, memory ratio (expectation-based) are a bit flaky, which I think are safe to ignore. I have fixed some of them in #13330

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments. LMK what you think.

Comment thread src/diffusers/hooks/group_offloading.py Outdated
Comment thread src/diffusers/hooks/group_offloading.py Outdated
Comment thread src/diffusers/models/modeling_utils.py Outdated
Comment thread src/diffusers/models/modeling_utils.py Outdated
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
safetensors.torch.save_file(shard, filepath, metadata=metadata)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're saving all metadata, then I'd restrict it to torchao only.

Comment thread src/diffusers/models/modeling_utils.py
Comment thread tests/quantization/torchao/test_torchao.py Outdated
@sayakpaul sayakpaul requested a review from DN6 May 12, 2026 09:52
@hlky hlky force-pushed the torchao_safetensors branch from 2124c6b to 9d8e4ac Compare May 28, 2026 08:03
@hlky
Copy link
Copy Markdown
Contributor Author

hlky commented May 28, 2026

@sayakpaul Thanks for the review. I rebased this onto the latest upstream/main and pushed the updated branch.

Changes made:

  • Simplified group-offload disk handling so the existing non-TorchAO path stays separate from the TorchAO safetensors path.
  • Restricted safetensors metadata save/load handling to TorchAO explicitly.
  • Removed the ambiguous state_dict() tuple fallback.
  • Disabled parallel shard loading only when TorchAO safetensors metadata is present, since reconstruction can carry pending tensor pieces across shards.
  • Moved the broader TorchAO safetensors serialization coverage into TorchAoTesterMixin and kept the Flux-specific disk group-offload test in test_torchao.py.
  • Gated TorchAO safetensors support on torchao >= 0.16.0; 0.15.0 exposes the helper APIs but does not support Int8Tensor in the safetensors allowlist.

@sayakpaul
Copy link
Copy Markdown
Member

@DN6 could you do a review?

@sayakpaul
Copy link
Copy Markdown
Member

@askserge could you do a review as well?

Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗 Serge says:

(no overall summary provided)

19 LLM turns · 23 tool calls · 102.7s · 447704 in / 4500 out tokens

Copy link
Copy Markdown

@sergereview sergereview Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗 Serge says:

This PR adds safetensors serialization support for TorchAO quantized models, enabling both save_pretrained/from_pretrained with safe_serialization=True and disk-based group offloading for TorchAO tensors. The overall approach is sound — flatten TorchAO tensor subclasses via torchao.prototype.safetensors for saving, reconstruct them on load, and carry incomplete pieces across shards.

Correctness

  • Silent fallback on reconstruction failure in group offloading (_load_torchao_disk_state_dict): The broad except Exception catches any error from unflatten_tensor_state_dict and falls back to raw (non-reconstructed) tensors with only a warning. Since this is a same-session round-trip (the file was just written by _offload_torchao_to_disk), a failure here indicates a real bug — falling back silently would produce garbage model output. This should re-raise.

  • No validation of pending reconstruction state after all shards are loaded: _pending_flattened_state_dict accumulates leftover tensor pieces across shards. If pieces remain after the last shard (e.g. due to a corrupted or incomplete checkpoint), they are silently lost. get_weight_reconstruction_pending_keys exists to detect this but is never called — consider adding a check in _process_model_after_weight_loading or after the shard loading loop.

Dead / unreachable code

  • get_weight_reconstruction_pending_keys is defined but never called anywhere.
  • The if self.metadata is None guard in set_metadata (line 274) is unreachable because _metadata is initialized to {} in __init__ and the property always returns _metadata.

Tests

  • Good coverage for single-file and sharded serialization, plus disk group offloading.
  • The serialization tests don't compare outputs against the original (pre-save) model, only check for NaN. A cosine-similarity check against the original model's output would be stronger.

48 LLM turns · 51 tool calls · 236.4s · 1694873 in / 10072 out tokens

Comment thread src/diffusers/hooks/group_offloading.py Outdated
Comment thread src/diffusers/quantizers/torchao/torchao_quantizer.py Outdated
Comment thread src/diffusers/quantizers/torchao/torchao_quantizer.py Outdated
@hlky
Copy link
Copy Markdown
Contributor Author

hlky commented May 28, 2026

@sayakpaul Pushed a follow-up commit: 8c1be50fc.

Addressed the latest review points:

  • TorchAO disk-offload reconstruction now raises if safetensors metadata reconstruction fails instead of falling back to raw flattened tensors.
  • The TorchAO disk key remap is now derived during load as well as save, so a fresh ModuleGroup can reuse an existing offload file safely.
  • Cleaned up the TorchAO quantizer metadata state:
    • removed unused reconstruction bookkeeping
    • removed the unused metadata setter
    • get_weight_names() now derives directly from safetensors metadata
  • Strengthened the safetensors serialization tests to compare original vs reloaded model outputs, rather than only checking for non-NaN outputs.
  • Added coverage for reusing an existing disk offload directory with a fresh quantized model.

Validation was run on a Runpod RTX 3090 with torch==2.9.1+cu128, torchao==0.17.0.

Focused PR coverage:

pytest -q \
  tests/quantization/torchao/test_torchao.py::TorchAoSerializationTest::test_group_offload_to_disk \
  tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_serialization[int8wo] \
  tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_serialization[int8dq] \
  tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_sharded_serialization[int8dq]

Result: 4 passed.

Full TorchAO test file:

pytest -q tests/quantization/torchao/test_torchao.py

Result: 13 passed, 10 skipped, 6 failed.

The 6 failures are the same known TorchAO/env failures unrelated to this PR:

  • int4 tests requiring mslk >= 1.0.0
  • IntxWeightOnlyConfig(dtype=...) API mismatch
  • existing modules_to_not_convert assertion drift
  • existing training grad behavior failure

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left some further comments.

Comment thread src/diffusers/hooks/group_offloading.py Outdated
Comment thread src/diffusers/hooks/group_offloading.py Outdated
Comment thread src/diffusers/hooks/group_offloading.py Outdated
"not setting `offload_to_disk_path`."
)

def _get_torchao_disk_state_dict(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have a short explainer for future reference as why we have to perform this.

Comment thread src/diffusers/hooks/group_offloading.py Outdated
Comment thread src/diffusers/hooks/group_offloading.py Outdated
Comment thread src/diffusers/models/modeling_utils.py Outdated
Comment thread src/diffusers/models/modeling_utils.py Outdated
Comment thread tests/quantization/torchao/test_torchao.py Outdated
@hlky
Copy link
Copy Markdown
Contributor Author

hlky commented May 29, 2026

@sayakpaul Thanks for the follow-up review. I addressed the requested cleanup:

  • Simplified the TorchAO metadata/load-key path so it is guarded directly by is_torchao_quantized. Since that already implies hf_quantizer is present and is the TorchAO quantizer, the extra hf_quantizer is not None / hasattr(...) checks were unnecessary.
  • Kept has_torchao_safetensors_metadata tied to actual metadata presence rather than TorchAO quantization in general. TorchAO can still be used for normal on-load quantization from checkpoints without TorchAO safetensors metadata, so we should only override loaded_keys and disable parallel shard loading when metadata was actually found.
  • Merged the duplicated TorchAO/non-TorchAO disk offload/onload paths in group offloading, leaving only the TorchAO-specific save/load/release pieces branched internally.
  • Restricted checkpoint_files plumbing to TorchAO preprocessing only.
  • Moved disk-offload bookkeeping initialization into the disk-offload branch.
  • Added a warning for malformed TorchAO safetensors metadata fallback and a short comment explaining why the TorchAO disk path goes through flatten/unflatten.

Validation:

  • Local ruff, compileall, and git diff --check passed.
  • Focused CUDA tests on an A40 passed: 4 passed.
  • Targeted regular TorchAO load path without TorchAO safetensors metadata passed.
  • Full tests/quantization/torchao/test_torchao.py keeps the same known profile: 13 passed, 10 skipped, 6 failed. The failures match the known TorchAO/upstream/environment issues already discussed, including missing mslk for int4 and the pending TorchAO config/API mismatch.

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay!

Thanks! I left some minor comments and I think we're quite close to merging. @DN6 over to you.

Currently, running the tests (edit: the new tests introduced in this PR are passing).

Comment thread src/diffusers/models/modeling_utils.py Outdated
}
if is_torchao_quantized:
preprocess_kwargs["checkpoint_files"] = checkpoint_files
hf_quantizer.preprocess_model(**preprocess_kwargs)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to pass checkpoint_files to preprocess_model?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Seems to be bad assumption from agent like "checkpoint_files used in set_metadata only for torchao" -> "preprocess_model has kwargs" -> "pass checkpoint_files to preprocess_model for torchao". Negative result like this is just as good as positive result though, it can be codified into a rule like "You must validate anything passed as kwargs is really consumed by the function".


@staticmethod
def _get_quant_config(config_name):
def _get_quant_config(config_name, **config_kwargs):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just for version? If so, we could always use version=2?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes looks like we can always use version=2. For context:

  • TorchAO docs show Int8WeightOnlyConfig still defaults to version=1, but its implementation emits a deprecation warning and tells users to use version=2.
  • Version 2 moves int8 configs toward the newer Int8Tensor/granularity path; version 1 is the older path using AffineQuantizedTensor-style internals.
  • Int4WeightOnlyConfig already defaults to version=2, so setting version=2 uniformly is mostly about avoiding int8 version-1 deprecation/noise and aligning test serialization with the newer tensor subclass path.

self._test_quantization_lora_inference(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type])

@pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"])
@pytest.mark.parametrize("quant_type", ["int8wo", "int8dq"], ids=["int8wo", "int8dq"])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to increase it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely expanded while moving tests into TorchAoTesterMixin after the earlier comment. Sharded serialization already covers int8dq so should be fine to leave this as only int8wo.

@hlky hlky force-pushed the torchao_safetensors branch from b4da670 to ba21459 Compare June 5, 2026 06:43
@DN6
Copy link
Copy Markdown
Collaborator

DN6 commented Jun 5, 2026

@hlky It's looking good. My main concern is adding a lot of torchao specific checks into the core model loading/saving logic. I think it would be better to achieve the same functionality via methods in the hf_quantizer.

e.g

class DiffusersQuantizer(ABC): 

    def maybe_update_loaded_keys(self, loaded_keys: list[str], checkpoint_files: list[str]) -> list[str]:
        # Instead of: the `is_torchao_quantized` block in from_pretrained that calls
        #   `set_metadata(checkpoint_files)` + `get_weight_names()` and the
        #   `has_torchao_safetensors_metadata` bookkeeping — fold all of it into this method.
        return loaded_keys

    def maybe_update_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]:
        # Instead of: `hf_quantizer.get_reconstructed_state_dict(state_dict)` in _load_shard_file,
        #   including the `hasattr(hf_quantizer, "get_reconstructed_state_dict")` probe.
        #   Just run it over each shard (or the single state dict) when a quantizer is present.
        return state_dict

    @property
    def supports_parallel_loading(self) -> bool:
        return True

    def get_state_dict_and_metadata(self, state_dict: dict[str, Any], safe_serialization: bool = False):
        return state_dict, {}

    @property
    def supports_safetensors_serialization(self) -> bool:
        return True

Also, would it be possible to move the group offloading support into a separate PR. I think the saving/loading component is much closer to merge, and I'd like to spend a bit more time to review the offloading component.

@hlky
Copy link
Copy Markdown
Contributor Author

hlky commented Jun 5, 2026

Thanks @DN6, applied the requested scope split and refactor.

What changed on this PR:

  • Removed the TorchAO disk group-offload changes from this branch, including the group-offload-specific test coverage.
  • Kept this PR focused on TorchAO safetensors save/load.
  • Added generic quantizer hooks on DiffusersQuantizer for loaded-key updates, state-dict updates, save metadata, parallel-loading support, and safetensors serialization support.
  • Moved the TorchAO-specific metadata/key/state-dict reconstruction behavior into TorchAoHfQuantizer.
  • Removed the explicit TorchAO branches from core model loading/saving code and replaced them with the generic quantizer hooks.
  • Kept checkpoint_files out of preprocess_model; it is only consumed through the TorchAO metadata hook.

I split the group-offload work into a separate draft follow-up PR based on upstream/main: #13875

Validation:

  • Local ruff, compileall, and git diff --check passed.
  • CUDA A40 focused serialization tests passed: 2 passed.
  • Full tests/quantization/torchao/test_torchao.py result after removing the group-offload test: 12 passed, 10 skipped, 6 known failures. The failures match the previously discussed TorchAO/upstream/environment issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Incorporate safetensors support to TorchAO

4 participants