Incorporate safetensors support to TorchAO#13719
Conversation
wadeKeith
left a comment
There was a problem hiding this comment.
Good integration - safetensors support for TorchAO improves security and loading speed. Clean implementation. LGTM! Reviewed by Hermes Agent.
|
@hlky thanks. I ran some of those tests on an H100 with the following env: - 🤗 Diffusers version: 0.39.0.dev0
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.12.12
- PyTorch version (GPU?): 2.11.0+cu129 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 1.11.0
- Transformers version: 5.6.0.dev0
- Accelerate version: 1.12.0
- PEFT version: 0.18.2.dev0
- Bitsandbytes version: 0.49.0
- Safetensors version: 0.7.0
- xFormers version: not installed
- Accelerator: NVIDIA H100 80GB HBM3, 81559 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>Some passed and some failed not because all of them are functionally incorrect. They failed because of the dependency on the hardware. Stuff like assertion errors on slices, memory ratio (expectation-based) are a bit flaky, which I think are safe to ignore. I have fixed some of them in #13330 |
sayakpaul
left a comment
There was a problem hiding this comment.
Left some comments. LMK what you think.
| all_tensors = [] | ||
| for module in self.modules: | ||
| all_tensors.extend(list(module.parameters())) | ||
| all_tensors.extend(list(module.buffers())) | ||
| all_tensors.extend(self.parameters) | ||
| all_tensors.extend(self.buffers) | ||
| all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates | ||
|
|
||
| self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} | ||
| self._torchao_disk_key_remap: dict[str, str] = {} |
There was a problem hiding this comment.
We don't need to hold it outside of the self.offload_to_disk_path no?
| "Disk offloading is not supported for TorchAO quantized tensors because safetensors " | ||
| "cannot serialize TorchAO subclass tensors. Use memory offloading instead by not " | ||
| "setting `offload_to_disk_path`." | ||
| def _get_disk_state_dict(self): |
There was a problem hiding this comment.
This diff seems complicated to me. I think the code would simplify a bit to try to aim for a pattern like:
if has_torchao:
handle_for_torchao_safetensors_group_offloading_with_disk()
else:
keep_the_existing_codeOnce this pattern is established, we could look into what needs to be factored out in utilities.
If it helps, I wouldn't mind doing it in a separate PR.
| # Save the model | ||
| state_dict = model_to_save.state_dict() | ||
| quantization_metadata = {} | ||
| if safe_serialization and hf_quantizer is not None: |
There was a problem hiding this comment.
Should it also include a check if the quantizer is of TorchAO type just to be explicit?
| else: | ||
| state_dict_and_metadata = model_to_save.state_dict() | ||
| if isinstance(state_dict_and_metadata, tuple): | ||
| state_dict, quantization_metadata = state_dict_and_metadata | ||
| else: | ||
| state_dict = state_dict_and_metadata |
There was a problem hiding this comment.
Where do we get a model_to_save.state_dict() that is of type tuple?
| # At some point we will need to deal better with save_function (used for TPU and other distributed | ||
| # joyfulness), but for now this enough. | ||
| safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) | ||
| safetensors.torch.save_file(shard, filepath, metadata=metadata) |
There was a problem hiding this comment.
If we're saving all metadata, then I'd restrict it to torchao only.
| ) | ||
|
|
||
| if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO: | ||
| is_parallel_loading_enabled = False |
| # limitations under the License. | ||
|
|
||
| import gc | ||
| import os |
There was a problem hiding this comment.
We could move this test to
Also, I don't think we have to be exhaustive across the different quant formats offered by TorchAO. Just 2/3 (int8, fp8, etc.) are enough IMO.
What does this PR do?
Fixes #13713
Notes
Preexisting test failures
Some preexisting tests failed in my environment on
main, the same tests fail on this PR, no new test failures are introduced.Failed tests
Test coverage
More test coverage would be useful.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul @DN6