Incorporate safetensors support to TorchAO#13719
Conversation
wadeKeith
left a comment
There was a problem hiding this comment.
Good integration - safetensors support for TorchAO improves security and loading speed. Clean implementation. LGTM! Reviewed by Hermes Agent.
|
@hlky thanks. I ran some of those tests on an H100 with the following env: - 🤗 Diffusers version: 0.39.0.dev0
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.12.12
- PyTorch version (GPU?): 2.11.0+cu129 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 1.11.0
- Transformers version: 5.6.0.dev0
- Accelerate version: 1.12.0
- PEFT version: 0.18.2.dev0
- Bitsandbytes version: 0.49.0
- Safetensors version: 0.7.0
- xFormers version: not installed
- Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>Some passed and some failed not because all of them are functionally incorrect. They failed because of the dependency on the hardware. Stuff like assertion errors on slices, memory ratio (expectation-based) are a bit flaky, which I think are safe to ignore. I have fixed some of them in #13330 |
sayakpaul
left a comment
There was a problem hiding this comment.
Left some comments. LMK what you think.
| # At some point we will need to deal better with save_function (used for TPU and other distributed | ||
| # joyfulness), but for now this enough. | ||
| safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) | ||
| safetensors.torch.save_file(shard, filepath, metadata=metadata) |
There was a problem hiding this comment.
If we're saving all metadata, then I'd restrict it to torchao only.
|
@sayakpaul Thanks for the review. I rebased this onto the latest Changes made:
|
|
@DN6 could you do a review? |
|
@askserge could you do a review as well? |
There was a problem hiding this comment.
🤗 Serge says:
This PR adds safetensors serialization support for TorchAO quantized models, enabling both save_pretrained/from_pretrained with safe_serialization=True and disk-based group offloading for TorchAO tensors. The overall approach is sound — flatten TorchAO tensor subclasses via torchao.prototype.safetensors for saving, reconstruct them on load, and carry incomplete pieces across shards.
Correctness
-
Silent fallback on reconstruction failure in group offloading (
_load_torchao_disk_state_dict): The broadexcept Exceptioncatches any error fromunflatten_tensor_state_dictand falls back to raw (non-reconstructed) tensors with only a warning. Since this is a same-session round-trip (the file was just written by_offload_torchao_to_disk), a failure here indicates a real bug — falling back silently would produce garbage model output. This should re-raise. -
No validation of pending reconstruction state after all shards are loaded:
_pending_flattened_state_dictaccumulates leftover tensor pieces across shards. If pieces remain after the last shard (e.g. due to a corrupted or incomplete checkpoint), they are silently lost.get_weight_reconstruction_pending_keysexists to detect this but is never called — consider adding a check in_process_model_after_weight_loadingor after the shard loading loop.
Dead / unreachable code
get_weight_reconstruction_pending_keysis defined but never called anywhere.- The
if self.metadata is Noneguard inset_metadata(line 274) is unreachable because_metadatais initialized to{}in__init__and the property always returns_metadata.
Tests
- Good coverage for single-file and sharded serialization, plus disk group offloading.
- The serialization tests don't compare outputs against the original (pre-save) model, only check for NaN. A cosine-similarity check against the original model's output would be stronger.
48 LLM turns · 51 tool calls · 236.4s · 1694873 in / 10072 out tokens
|
@sayakpaul Pushed a follow-up commit: Addressed the latest review points:
Validation was run on a Runpod RTX 3090 with Focused PR coverage: pytest -q \
tests/quantization/torchao/test_torchao.py::TorchAoSerializationTest::test_group_offload_to_disk \
tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_serialization[int8wo] \
tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_serialization[int8dq] \
tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_quantization_sharded_serialization[int8dq]Result: Full TorchAO test file: pytest -q tests/quantization/torchao/test_torchao.pyResult: The 6 failures are the same known TorchAO/env failures unrelated to this PR:
|
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks! Left some further comments.
| "not setting `offload_to_disk_path`." | ||
| ) | ||
|
|
||
| def _get_torchao_disk_state_dict(self): |
There was a problem hiding this comment.
It would be nice to have a short explainer for future reference as why we have to perform this.
|
@sayakpaul Thanks for the follow-up review. I addressed the requested cleanup:
Validation:
|
| } | ||
| if is_torchao_quantized: | ||
| preprocess_kwargs["checkpoint_files"] = checkpoint_files | ||
| hf_quantizer.preprocess_model(**preprocess_kwargs) |
There was a problem hiding this comment.
Why do we need to pass checkpoint_files to preprocess_model?
There was a problem hiding this comment.
Good catch. Seems to be bad assumption from agent like "checkpoint_files used in set_metadata only for torchao" -> "preprocess_model has kwargs" -> "pass checkpoint_files to preprocess_model for torchao". Negative result like this is just as good as positive result though, it can be codified into a rule like "You must validate anything passed as kwargs is really consumed by the function".
|
|
||
| @staticmethod | ||
| def _get_quant_config(config_name): | ||
| def _get_quant_config(config_name, **config_kwargs): |
There was a problem hiding this comment.
Is this just for version? If so, we could always use version=2?
There was a problem hiding this comment.
Yes looks like we can always use version=2. For context:
- TorchAO docs show Int8WeightOnlyConfig still defaults to version=1, but its implementation emits a deprecation warning and tells users to use version=2.
- Version 2 moves int8 configs toward the newer Int8Tensor/granularity path; version 1 is the older path using AffineQuantizedTensor-style internals.
- Int4WeightOnlyConfig already defaults to version=2, so setting version=2 uniformly is mostly about avoiding int8 version-1 deprecation/noise and aligning test serialization with the newer tensor subclass path.
| self._test_quantization_lora_inference(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) | ||
|
|
||
| @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) | ||
| @pytest.mark.parametrize("quant_type", ["int8wo", "int8dq"], ids=["int8wo", "int8dq"]) |
There was a problem hiding this comment.
Likely expanded while moving tests into TorchAoTesterMixin after the earlier comment. Sharded serialization already covers int8dq so should be fine to leave this as only int8wo.
|
@hlky It's looking good. My main concern is adding a lot of e.g class DiffusersQuantizer(ABC):
def maybe_update_loaded_keys(self, loaded_keys: list[str], checkpoint_files: list[str]) -> list[str]:
# Instead of: the `is_torchao_quantized` block in from_pretrained that calls
# `set_metadata(checkpoint_files)` + `get_weight_names()` and the
# `has_torchao_safetensors_metadata` bookkeeping — fold all of it into this method.
return loaded_keys
def maybe_update_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]:
# Instead of: `hf_quantizer.get_reconstructed_state_dict(state_dict)` in _load_shard_file,
# including the `hasattr(hf_quantizer, "get_reconstructed_state_dict")` probe.
# Just run it over each shard (or the single state dict) when a quantizer is present.
return state_dict
@property
def supports_parallel_loading(self) -> bool:
return True
def get_state_dict_and_metadata(self, state_dict: dict[str, Any], safe_serialization: bool = False):
return state_dict, {}
@property
def supports_safetensors_serialization(self) -> bool:
return TrueAlso, would it be possible to move the group offloading support into a separate PR. I think the saving/loading component is much closer to merge, and I'd like to spend a bit more time to review the offloading component. |
|
Thanks @DN6, applied the requested scope split and refactor. What changed on this PR:
I split the group-offload work into a separate draft follow-up PR based on Validation:
|
What does this PR do?
Fixes #13713
Notes
Preexisting test failures
Some preexisting tests failed in my environment on
main, the same tests fail on this PR, no new test failures are introduced.Failed tests
Test coverage
More test coverage would be useful.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul @DN6