Skip to content

Incorporate safetensors support to TorchAO#13719

Open
hlky wants to merge 1 commit into
huggingface:mainfrom
hlky:torchao_safetensors
Open

Incorporate safetensors support to TorchAO#13719
hlky wants to merge 1 commit into
huggingface:mainfrom
hlky:torchao_safetensors

Conversation

@hlky
Copy link
Copy Markdown
Contributor

@hlky hlky commented May 11, 2026

What does this PR do?

Fixes #13713

Notes

Preexisting test failures

Some preexisting tests failed in my environment on main, the same tests fail on this PR, no new test failures are introduced.

Failed tests

FAILED tests/quantization/torchao/test_torchao.py::TorchAoTest::test_device_map - RuntimeError: cutlass cannot initialize
FAILED tests/quantization/torchao/test_torchao.py::TorchAoTest::test_memory_footprint - AssertionError: False is not true
FAILED tests/quantization/torchao/test_torchao.py::TorchAoTest::test_model_memory_usage - assert (34840064 / 34806272) >= 1.02
FAILED tests/quantization/torchao/test_torchao.py::TorchAoTest::test_modules_to_not_convert - AssertionError: False is not true
FAILED tests/quantization/torchao/test_torchao.py::TorchAoTest::test_quantization - TypeError: IntxWeightOnlyConfig.__init__() got an unexpected keyword argument 'dtype'
FAILED tests/quantization/torchao/test_torchao.py::TorchAoTest::test_training - RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Test coverage

More test coverage would be useful.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @DN6

Copy link
Copy Markdown

@wadeKeith wadeKeith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good integration - safetensors support for TorchAO improves security and loading speed. Clean implementation. LGTM! Reviewed by Hermes Agent.

@sayakpaul
Copy link
Copy Markdown
Member

@hlky thanks.

I ran some of those tests on an H100 with the following env:

- 🤗 Diffusers version: 0.39.0.dev0
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.12.12
- PyTorch version (GPU?): 2.11.0+cu129 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 1.11.0
- Transformers version: 5.6.0.dev0
- Accelerate version: 1.12.0
- PEFT version: 0.18.2.dev0
- Bitsandbytes version: 0.49.0
- Safetensors version: 0.7.0
- xFormers version: not installed
- Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Some passed and some failed not because all of them are functionally incorrect. They failed because of the dependency on the hardware. Stuff like assertion errors on slices, memory ratio (expectation-based) are a bit flaky, which I think are safe to ignore. I have fixed some of them in #13330

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments. LMK what you think.

Comment on lines +160 to +169
all_tensors = []
for module in self.modules:
all_tensors.extend(list(module.parameters()))
all_tensors.extend(list(module.buffers()))
all_tensors.extend(self.parameters)
all_tensors.extend(self.buffers)
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates

self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
self._torchao_disk_key_remap: dict[str, str] = {}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to hold it outside of the self.offload_to_disk_path no?

"Disk offloading is not supported for TorchAO quantized tensors because safetensors "
"cannot serialize TorchAO subclass tensors. Use memory offloading instead by not "
"setting `offload_to_disk_path`."
def _get_disk_state_dict(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff seems complicated to me. I think the code would simplify a bit to try to aim for a pattern like:

if has_torchao:
    handle_for_torchao_safetensors_group_offloading_with_disk()
else:
    keep_the_existing_code

Once this pattern is established, we could look into what needs to be factored out in utilities.

If it helps, I wouldn't mind doing it in a separate PR.

# Save the model
state_dict = model_to_save.state_dict()
quantization_metadata = {}
if safe_serialization and hf_quantizer is not None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it also include a check if the quantizer is of TorchAO type just to be explicit?

Comment on lines +767 to +772
else:
state_dict_and_metadata = model_to_save.state_dict()
if isinstance(state_dict_and_metadata, tuple):
state_dict, quantization_metadata = state_dict_and_metadata
else:
state_dict = state_dict_and_metadata
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we get a model_to_save.state_dict() that is of type tuple?

# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
safetensors.torch.save_file(shard, filepath, metadata=metadata)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're saving all metadata, then I'd restrict it to torchao only.

)

if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO:
is_parallel_loading_enabled = False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that?

# limitations under the License.

import gc
import os
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could move this test to

class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):

Also, I don't think we have to be exhaustive across the different quant formats offered by TorchAO. Just 2/3 (int8, fp8, etc.) are enough IMO.

@sayakpaul sayakpaul requested a review from DN6 May 12, 2026 09:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Incorporate safetensors support to TorchAO

3 participants