Arm backend: Lower MXFP Linear to TOSA#20065
Conversation
Add fake TOSA dialect support and serializer lowering for CAST_TO_BLOCK_SCALED. Co-authored-by: Sebastian Larsson <sebastian.larsson@arm.com> Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: Ic7cdab5134f0fb9502f5985563f0662286ef5fb7
Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Co-authored-by: Sebastian Larsson <sebastian.larsson@arm.com> Change-Id: Iab2e1cf2ed21047bbc2a7a51604b9230fe2f2819
This commit is exact suggestion that kirklandsign posted in pytorch#19969 Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Suggested-by: @kirklandsign Change-Id: Iacce4dd11e61e4d79296e37e59ddb214072dd2ef
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20065
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c87b5df with merge base 9400da1 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label ciflow/trunk |
|
@pytorchbot label "partner: arm" |
|
@pytorchbot label "release notes: arm" |
There was a problem hiding this comment.
Pull request overview
This PR adds end-to-end support in the Arm backend to lower the tosa_mxfp.linear custom op into explicit TOSA 1.1 MXFP block-scaled operators (cast + matmul), including dtype plumbing, serializer visitors, a rewrite pass, and updated tests/build targets.
Changes:
- Extend Arm TOSA dtype/spec support to cover MXFP-related FP8 dtypes and allow FP8E4M3/E5M2 when the
mxfpextension is enabled. - Introduce new TOSA dialect fake ops (
CAST_TO_BLOCK_SCALED,MATMUL_T_BLOCK_SCALED) and serializer visitors to emit corresponding TOSA operators. - Add an Arm pass to rewrite
tosa_mxfp.linearinto the explicit block-scaled TOSA operator sequence, and update/expand tests accordingly.
Reviewed changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/arm/tosa/mapping.py | Adds FP8 dtype mapping and updates spec validation logic to treat FP8 as allowed under mxfp. |
| backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py | New fake TOSA op implementation/validation for MXFP block-scaled matmul. |
| backends/arm/tosa/dialect/ops/cast_to_block_scaled.py | New fake TOSA op implementation/validation for MXFP block-scaled cast. |
| backends/arm/tosa/dialect/init.py | Registers the new MXFP dialect op modules. |
| backends/arm/test/targets.bzl | Moves MXFP linear op test path, adds new dialect tests, and updates test deps; also drops several VGF/custom-shader tests from the suite. |
| backends/arm/test/passes/test_rewrite_mxfp_linear_pass.py | New unit tests validating the rewrite pass replaces the custom op with TOSA MXFP ops. |
| backends/arm/test/ops/mxfp/test_mxfp_linear.py | Refactors MXFP linear tests into shared pipelines; adds dim-order coverage and VGF xfail scaffolding. |
| backends/arm/test/ops/mxfp/common.py | New shared pipeline helpers/stage for converting models to MXFP in tests. |
| backends/arm/test/ops/mxfp/init.py | New package marker for MXFP op tests. |
| backends/arm/test/misc/tosa_dialect/test_tosa_dialect_mxfp_linear.py | New tests for the MATMUL_T_BLOCK_SCALED fake dialect op behavior/validation. |
| backends/arm/test/misc/tosa_dialect/test_tosa_dialect_cast_to_block_scaled.py | New tests for the CAST_TO_BLOCK_SCALED fake dialect op behavior/validation. |
| backends/arm/test/BUCK | Removes VGF/custom shader helper targets and adds a shared MXFP test helper target. |
| backends/arm/process_node.py | Adds serialization support for torch.float8_e8m0fnu constants via ml_dtypes. |
| backends/arm/operators/op_tosa_matmul_t_block_scaled.py | New serializer visitor to emit MATMUL_T_BLOCK_SCALED into TOSA graphs (requires mxfp). |
| backends/arm/operators/op_tosa_cast_to_block_scaled.py | New serializer visitor to emit CAST_TO_BLOCK_SCALED into TOSA graphs (requires mxfp) with multi-output workaround. |
| backends/arm/operators/init.py | Ensures new MXFP visitors are imported/registered. |
| backends/arm/operator_support/tosa_supported_operators.py | Allows MXFP custom op under mxfp spec and relaxes FP8 dtype gating when mxfp is enabled. |
| backends/arm/_passes/rewrite_mxfp_linear.py | New Arm pass rewriting tosa_mxfp.linear into explicit cast+matmul block-scaled ops. |
| backends/arm/_passes/arm_pass_manager.py | Adds the new rewrite pass into the TOSA lowering pipeline. |
| backends/arm/_passes/init.py | Exposes the new rewrite pass in the passes package. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| case ts.DType.FP8E4M3: | ||
| if not tosa_spec.support_extension("fp8e4m3"): | ||
| if not ( | ||
| tosa_spec.support_extension("fp8e4m3") | ||
| or tosa_spec.support_extension("mxfp") | ||
| ): | ||
| return False |
| output_data_tensor, output_scale_tensor = node.meta["val"] | ||
|
|
||
| # TODO(MLETORCH-2018): This is a local workaround for multi-output TOSA ops. | ||
| # Remove it once twe can handle multiple outputs generally. |
| "ops/test_log10.py", | ||
| "ops/test_max_pool1d.py", | ||
| "ops/test_mul.py", | ||
| "ops/test_mxfp_linear.py", | ||
| "ops/mxfp/test_mxfp_linear.py", | ||
| "ops/test_permute.py", |
| input_tensor = inputs[0] | ||
| block_size = inputs[1].number | ||
| output_data_tensor, output_scale_tensor = node.meta["val"] | ||
|
|
| torch.bfloat16: ts.DType.BF16, | ||
| torch.float8_e4m3fn: ts.DType.FP8E4M3, | ||
| torch.float8_e5m2: ts.DType.FP8E5M2, | ||
| torch.float8_e8m0fnu: ts.DType.FP8UE8M0, |
| disallowed_dtypes = [torch.float64] | ||
| if not tosa_spec.support_extension("bf16"): | ||
| disallowed_dtypes.append(torch.bfloat16) | ||
| if not tosa_spec.support_extension("fp8e4m3"): | ||
| if not ( | ||
| tosa_spec.support_extension("fp8e4m3") or tosa_spec.support_extension("mxfp") | ||
| ): | ||
| disallowed_dtypes.append(torch.float8_e4m3fn) | ||
| if not tosa_spec.support_extension("fp8e5m2"): | ||
| if not ( | ||
| tosa_spec.support_extension("fp8e5m2") or tosa_spec.support_extension("mxfp") | ||
| ): | ||
| disallowed_dtypes.append(torch.float8_e5m2) |
Resubmit the MXFP Linear lowering after the original change was reverted. This version includes the Buck file updates suggested by @kirklandsign in #19969.
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani