From 0eac979826d2aabef4610b5cab0c5ad98592ea16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Thu, 21 May 2026 11:58:39 +0200 Subject: [PATCH 1/3] Arm backend: Add TOSA block-scaled cast MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add fake TOSA dialect support and serializer lowering for CAST_TO_BLOCK_SCALED. Co-authored-by: Sebastian Larsson Signed-off-by: Martin Lindström Change-Id: Ic7cdab5134f0fb9502f5985563f0662286ef5fb7 --- .../tosa_supported_operators.py | 8 +- backends/arm/operators/__init__.py | 1 + .../operators/op_tosa_cast_to_block_scaled.py | 78 +++++++++++++++++++ backends/arm/process_node.py | 9 ++- .../test_tosa_dialect_cast_to_block_scaled.py | 63 +++++++++++++++ backends/arm/test/targets.bzl | 1 + backends/arm/tosa/dialect/__init__.py | 1 + .../tosa/dialect/ops/cast_to_block_scaled.py | 73 +++++++++++++++++ backends/arm/tosa/mapping.py | 13 +++- 9 files changed, 241 insertions(+), 6 deletions(-) create mode 100644 backends/arm/operators/op_tosa_cast_to_block_scaled.py create mode 100644 backends/arm/test/misc/tosa_dialect/test_tosa_dialect_cast_to_block_scaled.py create mode 100644 backends/arm/tosa/dialect/ops/cast_to_block_scaled.py diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 2d064ed298c..ed0ddc1cfa9 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -296,9 +296,13 @@ def tosa_support_factory( disallowed_dtypes = [torch.float64] if not tosa_spec.support_extension("bf16"): disallowed_dtypes.append(torch.bfloat16) - if not tosa_spec.support_extension("fp8e4m3"): + if not ( + tosa_spec.support_extension("fp8e4m3") or tosa_spec.support_extension("mxfp") + ): disallowed_dtypes.append(torch.float8_e4m3fn) - if not tosa_spec.support_extension("fp8e5m2"): + if not ( + tosa_spec.support_extension("fp8e5m2") or tosa_spec.support_extension("mxfp") + ): disallowed_dtypes.append(torch.float8_e5m2) if tosa_spec.is_U55_subset: disallowed_dtypes.append(torch.bool) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 32809eed847..d4100695b29 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -47,6 +47,7 @@ op_tanh, op_to_dim_order_copy, op_tosa_avg_pool2d, + op_tosa_cast_to_block_scaled, op_tosa_conv2d, op_tosa_conv3d, op_tosa_custom, diff --git a/backends/arm/operators/op_tosa_cast_to_block_scaled.py b/backends/arm/operators/op_tosa_cast_to_block_scaled.py new file mode 100644 index 00000000000..454c28ddfe2 --- /dev/null +++ b/backends/arm/operators/op_tosa_cast_to_block_scaled.py @@ -0,0 +1,78 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Provide a visitor for lowering block-scaled casts to TOSA.""" + +import operator +from typing import Any, cast, List + +import torch +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, +) +from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification + + +def _ordered_getitem_output_names(node: torch.fx.Node) -> list[str]: + getitem_users = [ + user + for user in node.users + if user.op == "call_function" and user.target == operator.getitem + ] + + ordered_users = sorted(getitem_users, key=lambda user: cast(int, user.args[1])) + if len(ordered_users) != 2: + raise ValueError( + f"{CastToBlockScaledVisitor.target}: Expected exactly two getitem outputs, got {len(ordered_users)}" + ) + + return [user.name for user in ordered_users] + + +@register_node_visitor +class CastToBlockScaledVisitor(NodeVisitor): + """Serialize TOSA ``CAST_TO_BLOCK_SCALED``.""" + + target = "tosa.CAST_TO_BLOCK_SCALED.default" + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.1+FP")] + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, 2) + # The tosa_specs attribute cannot express extension requirements. + # Therefore, check for the extension explicitly here. + if not self.tosa_spec.support_extension("mxfp"): + raise ValueError(f"{self.target} requires the TOSA mxfp extension") + + input_tensor = inputs[0] + block_size = inputs[1].number + output_data_tensor, output_scale_tensor = node.meta["val"] + + # TODO(MLETORCH-2018): This is a local workaround for multi-output TOSA ops. + # Remove it once twe can handle multiple outputs generally. + output_names = _ordered_getitem_output_names(node) + + attr = ts.TosaSerializerAttribute() + attr.CastToBlockScaledAttribute(block_size) + + self._serialize_operator( + node, + tosa_graph, + ts.Op.CAST_TO_BLOCK_SCALED, + [input_tensor.name], + output_names, + attr, + ) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index f86df9627ff..5f9c3e3938c 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -30,7 +30,12 @@ def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: tensor = tensor.detach().cpu().contiguous() - if tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2): + if tensor.dtype in ( + torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e8m0fnu, + ): try: import ml_dtypes # type: ignore[import-not-found] except ImportError as e: @@ -38,11 +43,11 @@ def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: f"ml_dtypes is required to serialize {tensor.dtype} tensors for TOSA. " "Have you run setup.sh?" ) from e - ml_dtype_map = { torch.bfloat16: (torch.uint16, ml_dtypes.bfloat16), torch.float8_e4m3fn: (torch.uint8, ml_dtypes.float8_e4m3fn), torch.float8_e5m2: (torch.uint8, ml_dtypes.float8_e5m2), + torch.float8_e8m0fnu: (torch.uint8, ml_dtypes.float8_e8m0fnu), } storage_dtype, ml_dtype = ml_dtype_map[tensor.dtype] return tensor.view(storage_dtype).numpy().view(ml_dtype) diff --git a/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_cast_to_block_scaled.py b/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_cast_to_block_scaled.py new file mode 100644 index 00000000000..940023fa624 --- /dev/null +++ b/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_cast_to_block_scaled.py @@ -0,0 +1,63 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops import cast_to_block_scaled # noqa: F401 +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def test_cast_to_block_scaled_requires_mxfp_extension() -> None: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP") + sample_input = torch.randn((2, 32), dtype=torch.float32) + + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + with pytest.raises( + TosaValueError, + match="doesn't support MXFP block-scaled casts", + ): + exir_ops.backend.tosa.CAST_TO_BLOCK_SCALED.default( + mode.from_tensor(sample_input), + 32, + output_dtype=torch.float8_e4m3fn, + ) + + +def test_cast_to_block_scaled_tosa_fp_mxfp() -> None: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+mxfp") + sample_input = torch.randn((2, 32), dtype=torch.float32) + + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + output_data, output_scale = exir_ops.backend.tosa.CAST_TO_BLOCK_SCALED.default( + mode.from_tensor(sample_input), + 32, + output_dtype=torch.float8_e4m3fn, + ) + + assert output_data.dtype == torch.float8_e4m3fn + assert tuple(output_data.shape) == (2, 32) + assert output_scale.dtype == torch.float8_e8m0fnu + assert tuple(output_scale.shape) == (2, 1) + + +def test_cast_to_block_scaled_invalid_shape() -> None: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+mxfp") + + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + with pytest.raises( + TosaValueError, + match="Last dim 30 must be divisible by block_size 32", + ): + exir_ops.backend.tosa.CAST_TO_BLOCK_SCALED.default( + mode.from_tensor(torch.randn((2, 30), dtype=torch.float32)), + 32, + output_dtype=torch.float8_e4m3fn, + ) diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 5704f229726..e00725f7ace 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -57,6 +57,7 @@ def define_arm_tests(): "misc/test_compile_spec.py", # "misc/test_evaluate_model.py", "misc/test_pass_pipeline_config.py", + "misc/tosa_dialect/test_tosa_dialect_cast_to_block_scaled.py", "misc/tosa_dialect/test_tosa_resize.py", "misc/test_tosa_spec.py", "misc/test_bn_relu_folding_qat.py", diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index 4678da4d118..daa971215dc 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -7,6 +7,7 @@ activation, avg_pool2d, avg_pool2d_adaptive, + cast_to_block_scaled, conv2d, conv3d, custom, diff --git a/backends/arm/tosa/dialect/ops/cast_to_block_scaled.py b/backends/arm/tosa/dialect/ops/cast_to_block_scaled.py new file mode 100644 index 00000000000..ed109be6124 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/cast_to_block_scaled.py @@ -0,0 +1,73 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import torch + +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) + + +@register_fake_tosa_op( + "CAST_TO_BLOCK_SCALED(Tensor input, SymInt block_size, ScalarType output_dtype) -> (Tensor, Tensor)", + [TosaSpecification.create_from_string("TOSA-1.1+FP")], +) +def CAST_TO_BLOCK_SCALED( + input: torch.Tensor, + block_size: int, + output_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + tosa_spec = get_context_spec() + + if not tosa_spec.support_float() or not tosa_spec.support_extension("mxfp"): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support MXFP block-scaled casts", + op="CAST_TO_BLOCK_SCALED", + ) + + if input.dtype not in (torch.float32, torch.bfloat16): + raise TosaValueError( + f"Unsupported input dtype {input.dtype} for CAST_TO_BLOCK_SCALED", + op="CAST_TO_BLOCK_SCALED", + ) + if input.dtype == torch.bfloat16 and not ( + tosa_spec.support_extension("bf16") or tosa_spec.support_extension("mxfp") + ): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support bf16", + op="CAST_TO_BLOCK_SCALED", + ) + + if input.ndim < 1: + raise TosaValueError( + "CAST_TO_BLOCK_SCALED requires rank >= 1", + op="CAST_TO_BLOCK_SCALED", + ) + if block_size != 32: + raise TosaValueError( + f"Unsupported block_size {block_size} (must be 32)", + op="CAST_TO_BLOCK_SCALED", + ) + if input.shape[-1] % block_size != 0: + raise TosaValueError( + f"Last dim {input.shape[-1]} must be divisible by block_size {block_size}", + op="CAST_TO_BLOCK_SCALED", + ) + + scale_tensor_dtype = torch.float8_e8m0fnu + if output_dtype not in (torch.float8_e4m3fn, torch.float8_e5m2): + raise TosaValueError( + f"Unsupported block-scaled output dtype {output_dtype}", + op="CAST_TO_BLOCK_SCALED", + ) + scale_shape = (*input.shape[:-1], input.shape[-1] // block_size) + output_data = torch.empty_like(input, dtype=output_dtype) + output_scale = input.new_empty(scale_shape, dtype=scale_tensor_dtype) + return output_data, output_scale diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 0e91120c3b8..245a9c00235 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -99,6 +99,9 @@ def map_dtype(data_type: torch.dtype) -> Any: torch.float16: ts.DType.FP16, torch.half: ts.DType.FP16, torch.bfloat16: ts.DType.BF16, + torch.float8_e4m3fn: ts.DType.FP8E4M3, + torch.float8_e5m2: ts.DType.FP8E5M2, + torch.float8_e8m0fnu: ts.DType.FP8UE8M0, torch.int8: ts.DType.INT8, # TOSA uses signless int8; unsigned semantics are expressed via RESCALE. torch.uint8: ts.DType.INT8, @@ -235,10 +238,16 @@ def __validate(self, tosa_spec: TosaSpecification) -> bool: if not tosa_spec.support_extension("bf16"): return False case ts.DType.FP8E4M3: - if not tosa_spec.support_extension("fp8e4m3"): + if not ( + tosa_spec.support_extension("fp8e4m3") + or tosa_spec.support_extension("mxfp") + ): return False case ts.DType.FP8E5M2: - if not tosa_spec.support_extension("fp8e5m2"): + if not ( + tosa_spec.support_extension("fp8e5m2") + or tosa_spec.support_extension("mxfp") + ): return False return True From 41adca8b662ab0fc59f8b716b713afd1c1e9767b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Thu, 21 May 2026 11:59:00 +0200 Subject: [PATCH 2/3] Arm backend: Lower MXFP Linear to TOSA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Martin Lindström Co-authored-by: Sebastian Larsson Change-Id: Iab2e1cf2ed21047bbc2a7a51604b9230fe2f2819 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/_passes/rewrite_mxfp_linear.py | 318 ++++++++++++++++++ .../tosa_supported_operators.py | 16 + backends/arm/operators/__init__.py | 1 + .../op_tosa_matmul_t_block_scaled.py | 94 ++++++ .../test_tosa_dialect_mxfp_linear.py | 56 +++ backends/arm/test/ops/mxfp/__init__.py | 4 + backends/arm/test/ops/mxfp/common.py | 122 +++++++ .../test/ops/{ => mxfp}/test_mxfp_linear.py | 123 +++++-- .../passes/test_rewrite_mxfp_linear_pass.py | 121 +++++++ backends/arm/test/targets.bzl | 11 +- backends/arm/tosa/dialect/__init__.py | 1 + .../tosa/dialect/ops/matmul_t_block_scaled.py | 130 +++++++ 14 files changed, 971 insertions(+), 29 deletions(-) create mode 100644 backends/arm/_passes/rewrite_mxfp_linear.py create mode 100644 backends/arm/operators/op_tosa_matmul_t_block_scaled.py create mode 100644 backends/arm/test/misc/tosa_dialect/test_tosa_dialect_mxfp_linear.py create mode 100644 backends/arm/test/ops/mxfp/__init__.py create mode 100644 backends/arm/test/ops/mxfp/common.py rename backends/arm/test/ops/{ => mxfp}/test_mxfp_linear.py (63%) create mode 100644 backends/arm/test/passes/test_rewrite_mxfp_linear_pass.py create mode 100644 backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 516c486690d..76f93edbab5 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -165,6 +165,7 @@ from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa from .rewrite_matmul import RewriteMatmulPass # noqa from .rewrite_max_pool2d_pass import RewriteMaxPool2dPass # noqa +from .rewrite_mxfp_linear import RewriteMXFPLinearPass # noqa from .rewrite_pad import RewritePadPass # noqa from .rewrite_slice import RewriteSlicePass # noqa from .rewrite_upsample import RewriteUpsamplePass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 521ddfe3ad7..bc20e13d2fc 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -141,6 +141,7 @@ RewriteLeLtToGeGtPass, RewriteMatmulPass, RewriteMaxPool2dPass, + RewriteMXFPLinearPass, RewritePadPass, RewriteSlicePass, RewriteUpsamplePass, @@ -524,6 +525,7 @@ def _tosa_pipeline( RewriteUpsamplePass(), RewriteMaxPool2dPass(), RewriteConvPass(exported_program), + RewriteMXFPLinearPass(exported_program), RewriteMatmulPass(), RewritePadPass(), FuseViewCopyTransformPass(), diff --git a/backends/arm/_passes/rewrite_mxfp_linear.py b/backends/arm/_passes/rewrite_mxfp_linear.py new file mode 100644 index 00000000000..d4ca436dc41 --- /dev/null +++ b/backends/arm/_passes/rewrite_mxfp_linear.py @@ -0,0 +1,318 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +from functools import reduce +from typing import Any, cast, Sequence, Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RewriteMXFPLinearPass(ArmPass): + """Rewrite ``tosa_mxfp.linear`` into explicit TOSA MXFP operators. + + For each MXFP linear custom op, the pass: + 1. Reshapes activations and precomputed weight tensors to the rank expected + by the block-scaled TOSA ops. + 2. Inserts ``tosa.CAST_TO_BLOCK_SCALED`` for the activation input. + 3. Inserts ``tosa.MATMUL_T_BLOCK_SCALED`` using the cast activations and the + MXFP weight data/scale tensors. + 4. Restores the original output shape. + 5. Re-applies bias, reshaping it first to match the output rank when + needed. + + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + def __init__(self, exported_program: torch.export.ExportedProgram, *args, **kwargs): + super().__init__(*args, **kwargs) + self.exported_program = exported_program + + def _get_linear_args( + self, node: torch.fx.Node + ) -> tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node, torch.fx.Node | None, int]: + """Extract the MXFP linear operands from a custom-op node.""" + input_node = cast(torch.fx.Node, node.args[0]) + weight_qdata_node = cast(torch.fx.Node, node.args[1]) + weight_scale_node = cast(torch.fx.Node, node.args[2]) + bias_node = cast( + torch.fx.Node | None, + node.args[3] if len(node.args) > 3 else node.kwargs.get("bias"), + ) + block_size = cast( + int, + node.args[4] if len(node.args) > 4 else node.kwargs.get("block_size", 32), + ) + return input_node, weight_qdata_node, weight_scale_node, bias_node, block_size + + def _reshape_with_view( + self, + graph_module: torch.fx.GraphModule, + input_node: torch.fx.Node, + shape: Sequence[int | torch.SymInt], + from_node: torch.fx.Node, + ) -> torch.fx.Node: + """Insert a ``view_copy`` node and update its fake-tensor metadata.""" + reshaped = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(input_node, shape), + kwargs={}, + from_node=from_node, + ) + reshaped.meta["val"] = exir_ops.edge.aten.view_copy.default( + get_first_fake_tensor(input_node), + shape, + ) + return reshaped + + def _create_block_scaled_inputs( + self, + graph_module: torch.fx.GraphModule, + mxfp_linear_node: torch.fx.Node, + input_node: torch.fx.Node, + weight_qdata_node: torch.fx.Node, + weight_scale_node: torch.fx.Node, + block_size: int, + ) -> tuple[torch.fx.Node, torch.fx.Node]: + """Create rank-3 inputs for the block-scaled cast and matmul ops.""" + graph = graph_module.graph + input_fake = get_first_fake_tensor(input_node) + weight_qdata_fake = get_first_fake_tensor(weight_qdata_node) + weight_scale_fake = get_first_fake_tensor(weight_scale_node) + + batches = reduce(operator.mul, input_fake.shape[:-1], 1) + input_reshape_shape = [1, batches, input_fake.shape[-1]] + + input_reshaped = self._reshape_with_view( + graph_module, + input_node, + input_reshape_shape, + mxfp_linear_node, + ) + if weight_qdata_fake.ndim != 3 or weight_scale_fake.ndim != 3: + raise RuntimeError( + "Expected pre-reshaped rank-3 MXFP weight placeholders in rewrite pass" + ) + + cast_node = create_node( + graph=graph, + op_target=exir_ops.backend.tosa.CAST_TO_BLOCK_SCALED.default, + args=(input_reshaped, block_size), + kwargs={"output_dtype": weight_qdata_fake.dtype}, + from_node=mxfp_linear_node, + ) + cast_node.meta["val"] = exir_ops.backend.tosa.CAST_TO_BLOCK_SCALED.default( + get_first_fake_tensor(input_reshaped), + block_size, + output_dtype=weight_qdata_fake.dtype, + ) + + input_qdata_node = create_node( + graph=graph, + op_target=cast(Any, operator.getitem), + args=(cast_node, 0), + kwargs={}, + from_node=mxfp_linear_node, + ) + input_qdata_node.meta["val"] = cast_node.meta["val"][0] + + input_scale_node = create_node( + graph=graph, + op_target=cast(Any, operator.getitem), + args=(cast_node, 1), + kwargs={}, + from_node=mxfp_linear_node, + ) + input_scale_node.meta["val"] = cast_node.meta["val"][1] + + return ( + input_qdata_node, + input_scale_node, + ) + + def _create_matmul_node( + self, + graph_module: torch.fx.GraphModule, + mxfp_linear_node: torch.fx.Node, + input_qdata_node: torch.fx.Node, + input_scale_node: torch.fx.Node, + weight_qdata_node: torch.fx.Node, + weight_scale_node: torch.fx.Node, + block_size: int, + ) -> torch.fx.Node: + """Insert ``MATMUL_T_BLOCK_SCALED`` with updated fake metadata.""" + matmul_node = create_node( + graph=graph_module.graph, + op_target=exir_ops.backend.tosa.MATMUL_T_BLOCK_SCALED.default, + args=( + input_qdata_node, + input_scale_node, + weight_qdata_node, + weight_scale_node, + block_size, + ), + kwargs={}, + from_node=mxfp_linear_node, + ) + matmul_node.meta["val"] = exir_ops.backend.tosa.MATMUL_T_BLOCK_SCALED.default( + get_first_fake_tensor(input_qdata_node), + get_first_fake_tensor(input_scale_node), + get_first_fake_tensor(weight_qdata_node), + get_first_fake_tensor(weight_scale_node), + block_size, + ) + return matmul_node + + def _create_output_view( + self, + graph_module: torch.fx.GraphModule, + mxfp_linear_node: torch.fx.Node, + matmul_node: torch.fx.Node, + ) -> torch.fx.Node: + """Restore the original linear output shape after block matmul.""" + output_fake = get_first_fake_tensor(mxfp_linear_node) + output_node = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.view_copy.default, + args=(matmul_node, list(output_fake.shape)), + kwargs={}, + from_node=mxfp_linear_node, + ) + output_node.meta["val"] = exir_ops.edge.aten.view_copy.default( + get_first_fake_tensor(matmul_node), + list(output_fake.shape), + ) + return output_node + + def _create_bias_add( + self, + graph_module: torch.fx.GraphModule, + mxfp_linear_node: torch.fx.Node, + output_node: torch.fx.Node, + bias_node: torch.fx.Node, + ) -> torch.fx.Node: + """Reshape bias to match output rank and append the final add node.""" + output_fake = get_first_fake_tensor(mxfp_linear_node) + bias_fake = get_first_fake_tensor(bias_node) + bias_shape = [1] * (output_fake.dim() - 1) + [output_fake.shape[-1]] + bias_arg = bias_node + + if tuple(bias_fake.shape) != tuple(bias_shape): + # Match ranks by prepending singleton dimensions. + with graph_module.graph.inserting_after(output_node): + bias_arg = self._reshape_with_view( + graph_module, + bias_node, + bias_shape, + mxfp_linear_node, + ) + with graph_module.graph.inserting_after(bias_arg): + add_node = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.add.Tensor, + args=(output_node, bias_arg), + kwargs={}, + from_node=mxfp_linear_node, + ) + else: + # Bias already has the right shape, so add it directly. + with graph_module.graph.inserting_after(output_node): + add_node = create_node( + graph=graph_module.graph, + op_target=exir_ops.edge.aten.add.Tensor, + args=(output_node, bias_arg), + kwargs={}, + from_node=mxfp_linear_node, + ) + add_node.meta["val"] = exir_ops.edge.aten.add.Tensor( + get_first_fake_tensor(output_node), + get_first_fake_tensor(bias_arg), + ) + + return add_node + + def _rewrite_mxfp_linear_node( + self, + graph_module: torch.fx.GraphModule, + mxfp_linear_node: torch.fx.Node, + ) -> torch.fx.Node: + """Rewrite one MXFP linear node to explicit TOSA MXFP ops.""" + graph = graph_module.graph + ( + input_node, + weight_qdata_node, + weight_scale_node, + bias_node, + block_size, + ) = self._get_linear_args(mxfp_linear_node) + + with graph.inserting_before(mxfp_linear_node): + ( + input_qdata_node, + input_scale_node, + ) = self._create_block_scaled_inputs( + graph_module, + mxfp_linear_node, + input_node, + weight_qdata_node, + weight_scale_node, + block_size, + ) + matmul_node = self._create_matmul_node( + graph_module, + mxfp_linear_node, + input_qdata_node, + input_scale_node, + weight_qdata_node, + weight_scale_node, + block_size, + ) + + with graph.inserting_after(matmul_node): + output_node = self._create_output_view( + graph_module, mxfp_linear_node, matmul_node + ) + + if bias_node is None: + return output_node + + return self._create_bias_add( + graph_module, + mxfp_linear_node, + output_node, + bias_node, + ) + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + graph = graph_module.graph + + for node in list(graph.nodes): + if node.op != "call_function" or node.target not in ( + torch.ops.tosa_mxfp.linear.default, + exir_ops.edge.tosa_mxfp.linear.default, + ): + continue + + modified = True + replacement = self._rewrite_mxfp_linear_node(graph_module, node) + node.replace_all_uses_with(replacement) + graph.erase_node(node) + + if modified: + graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index ed0ddc1cfa9..2e640b758d2 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -237,6 +237,17 @@ def get_registered_tosa_support_checks( return checks +class MXOpsSupportList(OperatorSupportBase): + """Accept Arm MX custom ops when the active spec enables MX support.""" + + targets = (exir_ops.edge.tosa_mxfp.linear.default,) + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + return node.op == "call_function" and node.target in self.targets + + def tosa_support_factory( tosa_spec: TosaSpecification, exported_program: ExportedProgram, @@ -271,6 +282,8 @@ def tosa_support_factory( positive_checks.append(TOSAProINTSupportList()) elif tosa_spec.support_float(): positive_checks.append(TOSAProFPSupportList()) + if tosa_spec.support_extension("mxfp"): + positive_checks.append(MXOpsSupportList()) # TODO: Refactor to use TOSAProSupportLists + negtive checks positive_checks += [ check(tosa_spec, reporter) @@ -750,6 +763,9 @@ def is_node_supported( ): return True + if node.target in MXOpsSupportList.targets: + return True + floating_dtypes = set() for input_node in ( input_node diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index d4100695b29..ebb2c31c3ed 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -55,6 +55,7 @@ op_tosa_gather, op_tosa_identity, op_tosa_matmul, + op_tosa_matmul_t_block_scaled, op_tosa_max_pool2d, op_tosa_pad, op_tosa_rescale, diff --git a/backends/arm/operators/op_tosa_matmul_t_block_scaled.py b/backends/arm/operators/op_tosa_matmul_t_block_scaled.py new file mode 100644 index 00000000000..2f1bd88c2bb --- /dev/null +++ b/backends/arm/operators/op_tosa_matmul_t_block_scaled.py @@ -0,0 +1,94 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Provide a visitor for lowering block-scaled matmul to TOSA.""" + +from typing import Any, List + +import torch +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification + + +@register_node_visitor +class MatMulTBlockScaledVisitor(NodeVisitor): + """Serialize TOSA ``MATMUL_T_BLOCK_SCALED``.""" + + target = "tosa.MATMUL_T_BLOCK_SCALED.default" + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.1+FP")] + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + # The tosa_specs attribute cannot express extension requirements. + # Therefore, check for the extension explicitly here. + if not self.tosa_spec.support_extension("mxfp"): + raise ValueError(f"{self.target} requires the TOSA mxfp extension") + + validate_num_inputs(self.target, inputs, 5) + + ( + A_data, + A_scale, + B_data, + B_scale, + ) = inputs[:4] + block_size = inputs[4].number + + validate_valid_dtype( + self.target, + [A_data, B_data], + [ts.DType.FP8E4M3, ts.DType.FP8E5M2], + self.tosa_spec, + ) + validate_valid_dtype( + self.target, + [A_scale, B_scale], + ts.DType.FP8UE8M0, + self.tosa_spec, + ) + validate_valid_dtype( + self.target, + output, + ts.DType.FP32, + self.tosa_spec, + ) + if block_size != 32: + raise ValueError(f"Invalid block size {block_size}") + + if A_data.dtype != B_data.dtype: + raise ValueError( + f"{self.target}: payload dtypes must match, got {inputs[0].dtype} and {inputs[2].dtype}" + ) + + attr = ts.TosaSerializerAttribute() + attr.MatMulTBlockScaledAttribute(block_size) + + self._serialize_operator( + node, + tosa_graph, + ts.Op.MATMUL_T_BLOCK_SCALED, + [ + inputs[0].name, + inputs[1].name, + inputs[2].name, + inputs[3].name, + ], + [output.name], + attr, + ) diff --git a/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_mxfp_linear.py b/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_mxfp_linear.py new file mode 100644 index 00000000000..74ce04bf3c1 --- /dev/null +++ b/backends/arm/test/misc/tosa_dialect/test_tosa_dialect_mxfp_linear.py @@ -0,0 +1,56 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops import matmul_t_block_scaled # noqa: F401 +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def test_matmul_t_block_scaled_tosa_fp_mxfp() -> None: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+mxfp") + a_data = torch.randn((1, 4, 32), dtype=torch.float32).to(torch.float8_e4m3fn) + a_scale = torch.empty((1, 4, 1), dtype=torch.float8_e8m0fnu) + b_data = torch.randn((1, 8, 32), dtype=torch.float32).to(torch.float8_e4m3fn) + b_scale = torch.empty((1, 8, 1), dtype=torch.float8_e8m0fnu) + + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + output = exir_ops.backend.tosa.MATMUL_T_BLOCK_SCALED.default( + mode.from_tensor(a_data), + mode.from_tensor(a_scale), + mode.from_tensor(b_data), + mode.from_tensor(b_scale), + 32, + ) + + assert output.dtype == torch.float32 + assert tuple(output.shape) == (1, 4, 8) + + +def test_matmul_t_block_scaled_invalid_scale_shape() -> None: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+mxfp") + a_data = torch.randn((1, 4, 32), dtype=torch.float32).to(torch.float8_e4m3fn) + a_scale = torch.empty((1, 4, 2), dtype=torch.float8_e8m0fnu) + b_data = torch.randn((1, 8, 32), dtype=torch.float32).to(torch.float8_e4m3fn) + b_scale = torch.empty((1, 8, 1), dtype=torch.float8_e8m0fnu) + + with TosaLoweringContext(tosa_spec), FakeTensorMode() as mode: + with pytest.raises( + TosaValueError, + match="A_scale shape \\(1, 4, 2\\) must match \\(1, 4, 1\\)", + ): + exir_ops.backend.tosa.MATMUL_T_BLOCK_SCALED.default( + mode.from_tensor(a_data), + mode.from_tensor(a_scale), + mode.from_tensor(b_data), + mode.from_tensor(b_scale), + 32, + ) diff --git a/backends/arm/test/ops/mxfp/__init__.py b/backends/arm/test/ops/mxfp/__init__.py new file mode 100644 index 00000000000..19ebb35e5f2 --- /dev/null +++ b/backends/arm/test/ops/mxfp/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/arm/test/ops/mxfp/common.py b/backends/arm/test/ops/mxfp/common.py new file mode 100644 index 00000000000..c57c8fbb03e --- /dev/null +++ b/backends/arm/test/ops/mxfp/common.py @@ -0,0 +1,122 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from typing import Any, Callable, Generic, TypeVar + +import torch +from executorch.backends.arm.ao_ext import MXFPOpConfig, to_mxfp +from executorch.backends.arm.test.tester.analyze_output_utils import ( + compare_rel_frobenius_and_cosine_similarity, +) +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + VgfPipeline, +) +from executorch.backends.test.harness.stages import Stage, StageType + +T = TypeVar("T", bound=tuple[Any, ...]) + + +class ConvertToMXFP(Stage): + def __init__( + self, + config: MXFPOpConfig, + filter_fn: Callable[[torch.nn.Module, str], bool], + ) -> None: + self.config = config + self.filter_fn = filter_fn + self.converted_module: torch.nn.Module | None = None + + def stage_type(self) -> StageType: + return StageType.QUANTIZE + + def run(self, artifact: torch.nn.Module, inputs=None) -> None: + self.converted_module = copy.deepcopy(artifact) + to_mxfp(self.converted_module, self.config, filter_fn=self.filter_fn) + + @property + def artifact(self) -> torch.nn.Module: + assert self.converted_module is not None + return self.converted_module + + @property + def graph_module(self) -> torch.nn.Module: + assert self.converted_module is not None + return self.converted_module + + def run_artifact(self, inputs): + assert self.converted_module is not None + return self.converted_module.forward(*inputs) + + +def _configure_mxfp_pipeline( + pipeline: TosaPipelineFP | VgfPipeline, + config: MXFPOpConfig, + filter_fn: Callable[[torch.nn.Module, str], bool], + frobenius_threshold: float | None, + cosine_threshold: float | None, +) -> None: + pipeline.add_stage( + pipeline.tester.quantize, + ConvertToMXFP(config, filter_fn), + pos=0, + ) + if pipeline.has_stage("run_method_and_compare_outputs"): + compare_stage = pipeline._stages[ + pipeline.find_pos("run_method_and_compare_outputs") + ] + compare_stage.kwargs["reference_stage_type"] = StageType.INITIAL_MODEL + compare_stage.kwargs["compare_callback"] = lambda ref, test, qparams: ( + compare_rel_frobenius_and_cosine_similarity( + ref, + test, + qparams, + frobenius_threshold=frobenius_threshold, + cosine_threshold=cosine_threshold, + clean_reference=False, + ) + ) + + +class MXFPTosaPipelineFP(TosaPipelineFP[T], Generic[T]): + def __init__( + self, + *args, + filter_fn: Callable[[torch.nn.Module, str], bool], + frobenius_threshold: float | None, + cosine_threshold: float | None, + mxfp_config: MXFPOpConfig | None = None, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + _configure_mxfp_pipeline( + self, + mxfp_config if mxfp_config is not None else MXFPOpConfig(), + filter_fn, + frobenius_threshold, + cosine_threshold, + ) + + +class MXFPVgfPipeline(VgfPipeline[T], Generic[T]): + def __init__( + self, + *args, + filter_fn: Callable[[torch.nn.Module, str], bool], + frobenius_threshold: float | None, + cosine_threshold: float | None, + mxfp_config: MXFPOpConfig | None = None, + **kwargs, + ) -> None: + kwargs.setdefault("quantize", False) + super().__init__(*args, **kwargs) + _configure_mxfp_pipeline( + self, + mxfp_config if mxfp_config is not None else MXFPOpConfig(), + filter_fn, + frobenius_threshold, + cosine_threshold, + ) diff --git a/backends/arm/test/ops/test_mxfp_linear.py b/backends/arm/test/ops/mxfp/test_mxfp_linear.py similarity index 63% rename from backends/arm/test/ops/test_mxfp_linear.py rename to backends/arm/test/ops/mxfp/test_mxfp_linear.py index da1bbec3b83..5cdd44cf138 100644 --- a/backends/arm/test/ops/test_mxfp_linear.py +++ b/backends/arm/test/ops/mxfp/test_mxfp_linear.py @@ -6,14 +6,26 @@ # LICENSE file in the root directory of this source tree. import copy +from typing import Tuple import torch from executorch.backends.arm.ao_ext import MXFPOpConfig, to_mxfp -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common as arm_common +from executorch.backends.arm.test.ops.mxfp.common import ( + MXFPTosaPipelineFP, + MXFPVgfPipeline, +) from executorch.backends.arm.test.tester.analyze_output_utils import ( compare_rel_frobenius_and_cosine_similarity, ) +aten_op = "torch.ops.tosa_mxfp.linear.default" + +input_t1 = Tuple[torch.Tensor] + +_MXFP_FROBENIUS_THRESHOLD = 0.06 +_MXFP_COSINE_THRESHOLD = 0.995 + def _block_input_rank1() -> torch.Tensor: """Create a rank-1 input with distinct MXFP activation block scales.""" @@ -42,6 +54,12 @@ def _block_input_rank2() -> torch.Tensor: ) +def _channels_last_rank4_input() -> torch.Tensor: + """Create a rank-4 input with channels-last dim order.""" + + return torch.rand(1, 2, 2, 64).to(memory_format=torch.channels_last) + + _test_data_rank1_fp = { "mxfp_linear_rank1_zeros": lambda: ( torch.zeros(32 * 8), @@ -123,13 +141,33 @@ def _block_input_rank2() -> torch.Tensor: ), } +_test_data_dim_order_fp = { + "mxfp_linear_rank4_channels_last": lambda: ( + _channels_last_rank4_input(), + 8, + True, + False, + ), +} + test_data_fp = ( _test_data_rank1_fp | _test_data_rank2_fp | _test_data_rank3_fp | _test_data_rank4_fp | _test_data_block_fp + | _test_data_dim_order_fp +) + +test_data_vgf_fp = test_data_fp + +_vgf_xfail_reason = ( + "MXFP is not yet supported in the VGF toolchain. Enable this test when " + "toolchain support is available." ) +_vgf_xfails: dict[str, str | tuple[str, type[Exception]]] = { + test_case: _vgf_xfail_reason for test_case in test_data_vgf_fp +} class Linear(torch.nn.Module): @@ -177,12 +215,60 @@ def _is_linear(module: torch.nn.Module, _fqn: str) -> bool: return isinstance(module, torch.nn.Linear) -def _test_mxfp_linear_eager_cpu( - test_data: torch.Tensor, - config: MXFPOpConfig, - frobenius_threshold: float, - cosine_threshold: float, -) -> None: +@arm_common.parametrize("test_data", test_data_fp) +def test_mxfp_linear_tosa_FP(test_data) -> None: + test_input, out_features, has_bias, set_block_weights = test_data() + in_features = test_input.shape[-1] + module = Linear( + in_features=in_features, + out_features=out_features, + bias=has_bias, + ).eval() + + if set_block_weights: + module.set_block_test_weights() + + pipeline = MXFPTosaPipelineFP[input_t1]( + module, + (test_input,), + aten_op, + filter_fn=_is_linear, + frobenius_threshold=_MXFP_FROBENIUS_THRESHOLD, + cosine_threshold=_MXFP_COSINE_THRESHOLD, + tosa_version="1.1", + tosa_extensions=["mxfp"], + ) + pipeline.run() + + +@arm_common.parametrize("test_data", test_data_vgf_fp, xfails=_vgf_xfails) +@arm_common.SkipIfNoModelConverter +def test_mxfp_linear_vgf(test_data) -> None: + test_input, out_features, has_bias, set_block_weights = test_data() + in_features = test_input.shape[-1] + module = Linear( + in_features=in_features, + out_features=out_features, + bias=has_bias, + ).eval() + + if set_block_weights: + module.set_block_test_weights() + + pipeline = MXFPVgfPipeline[input_t1]( + module, + (test_input,), + aten_op, + filter_fn=_is_linear, + frobenius_threshold=_MXFP_FROBENIUS_THRESHOLD, + cosine_threshold=_MXFP_COSINE_THRESHOLD, + tosa_spec="TOSA-1.1+FP+mxfp", + ) + pipeline.run() + + +@arm_common.parametrize("test_data", test_data_fp) +def test_mxfp_linear_eager_cpu(test_data) -> None: test_input, out_features, has_bias, set_block_weights = test_data() in_features = test_input.shape[-1] ref_model = Linear( @@ -194,7 +280,7 @@ def _test_mxfp_linear_eager_cpu( ref_model.set_block_test_weights() test_model = copy.deepcopy(ref_model).eval() - to_mxfp(test_model, config, filter_fn=_is_linear) + to_mxfp(test_model, MXFPOpConfig(), filter_fn=_is_linear) test_output = test_model(test_input) ref_output = ref_model(test_input) @@ -203,24 +289,7 @@ def _test_mxfp_linear_eager_cpu( ref_output, test_output, quantization_parameters=None, - frobenius_threshold=frobenius_threshold, - cosine_threshold=cosine_threshold, + frobenius_threshold=_MXFP_FROBENIUS_THRESHOLD, + cosine_threshold=_MXFP_COSINE_THRESHOLD, clean_reference=False, ) - - -@common.parametrize("test_data", test_data_fp) -def test_mxfp_linear_eager_cpu(test_data: torch.Tensor) -> None: - """Check eager MXFP implementation. - - The Arm lowering tests compare lowered output against the eager CPU - implementation, so the eager implementation must be accurate for it to be - used as a reference in other tests. - - """ - _test_mxfp_linear_eager_cpu( - test_data, - MXFPOpConfig(), - frobenius_threshold=0.06, - cosine_threshold=0.995, - ) diff --git a/backends/arm/test/passes/test_rewrite_mxfp_linear_pass.py b/backends/arm/test/passes/test_rewrite_mxfp_linear_pass.py new file mode 100644 index 00000000000..572a2b247e9 --- /dev/null +++ b/backends/arm/test/passes/test_rewrite_mxfp_linear_pass.py @@ -0,0 +1,121 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator + +import executorch.backends.arm.tosa.dialect # noqa: F401 +import torch +from executorch.backends.arm._passes.rewrite_mxfp_linear import RewriteMXFPLinearPass +from executorch.backends.arm.ao_ext import MXFPOpConfig, to_mxfp +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export import export + + +class _LinearModule(torch.nn.Module): + def __init__(self, bias: bool = True) -> None: + super().__init__() + self.linear = torch.nn.Linear(32, 8, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +class _DualLinearModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(32, 8, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + self.linear(x) + + +def _is_linear(module: torch.nn.Module, _fqn: str) -> bool: + return isinstance(module, torch.nn.Linear) + + +def _get_nodes_from_target( + graph_module: torch.fx.GraphModule, target_op +) -> list[torch.fx.Node]: + return [ + node + for node in graph_module.graph.nodes + if node.op == "call_function" and node.target == target_op + ] + + +def test_rewrite_mxfp_linear_replaces_custom_op() -> None: + model = _LinearModule(bias=True).eval() + to_mxfp(model, MXFPOpConfig(), filter_fn=_is_linear) + exported = export(model, (torch.randn(4, 5, 32),), strict=False) + tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+mxfp") + + with TosaLoweringContext(tosa_spec): + graph_module = ( + RewriteMXFPLinearPass(exported).call(exported.graph_module).graph_module + ) + + cast_nodes = _get_nodes_from_target( + graph_module, exir_ops.backend.tosa.CAST_TO_BLOCK_SCALED.default + ) + matmul_nodes = _get_nodes_from_target( + graph_module, exir_ops.backend.tosa.MATMUL_T_BLOCK_SCALED.default + ) + + assert ( + len(_get_nodes_from_target(graph_module, torch.ops.tosa_mxfp.linear.default)) + == 0 + ) + assert len(cast_nodes) == 1 + assert len(matmul_nodes) == 1 + assert len(_get_nodes_from_target(graph_module, exir_ops.edge.aten.add.Tensor)) == 1 + # One getitem for each of the two outputs of CAST_TO_BLOCK_SCALED + assert len(_get_nodes_from_target(graph_module, operator.getitem)) == 2 + + cast_node = cast_nodes[0] + assert tuple(cast_node.meta["val"][0].shape) == (1, 4 * 5, 32) # Output data vector + assert tuple(cast_node.meta["val"][1].shape) == (1, 4 * 5, 1) # Output scale vector + + matmul_node = matmul_nodes[0] + assert tuple(matmul_node.meta["val"].shape) == (1, 4 * 5, 8) + + output_node = graph_module.graph.output_node() + assert tuple(output_node.meta["val"][0].shape) == (4, 5, 8) + + +def test_rewrite_mxfp_dual_linear() -> None: + model = _DualLinearModule().eval() + to_mxfp(model, MXFPOpConfig(), filter_fn=_is_linear) + exported = export(model, (torch.randn(4, 32),), strict=False) + tosa_spec = TosaSpecification.create_from_string("TOSA-1.1+FP+mxfp") + + with TosaLoweringContext(tosa_spec): + graph_module = ( + RewriteMXFPLinearPass(exported).call(exported.graph_module).graph_module + ) + + assert ( + len(_get_nodes_from_target(graph_module, torch.ops.tosa_mxfp.linear.default)) + == 0 + ) + assert ( + len( + _get_nodes_from_target( + graph_module, exir_ops.backend.tosa.CAST_TO_BLOCK_SCALED.default + ) + ) + == 2 + ) + assert ( + len( + _get_nodes_from_target( + graph_module, exir_ops.backend.tosa.MATMUL_T_BLOCK_SCALED.default + ) + ) + == 2 + ) diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index e00725f7ace..0a49046cac9 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -23,7 +23,7 @@ def define_arm_tests(): "ops/test_log10.py", "ops/test_max_pool1d.py", "ops/test_mul.py", - "ops/test_mxfp_linear.py", + "ops/mxfp/test_mxfp_linear.py", "ops/test_permute.py", "ops/test_rsqrt.py", "ops/test_slice.py", @@ -58,6 +58,7 @@ def define_arm_tests(): # "misc/test_evaluate_model.py", "misc/test_pass_pipeline_config.py", "misc/tosa_dialect/test_tosa_dialect_cast_to_block_scaled.py", + "misc/tosa_dialect/test_tosa_dialect_mxfp_linear.py", "misc/tosa_dialect/test_tosa_resize.py", "misc/test_tosa_spec.py", "misc/test_bn_relu_folding_qat.py", @@ -89,10 +90,16 @@ def define_arm_tests(): for test_file in test_files: test_file_name = paths.basename(test_file) test_name = test_file_name.replace("test_", "").replace(".py", "") + test_srcs = [test_file] + if test_file == "ops/mxfp/test_mxfp_linear.py": + test_srcs += [ + "ops/mxfp/__init__.py", + "ops/mxfp/common.py", + ] python_pytest( name = test_name, - srcs = [test_file], + srcs = test_srcs, pytest_config = "pytest.ini", resources = ["conftest.py"], compile = "with-source", diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index daa971215dc..6c51258618d 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -15,6 +15,7 @@ gather, identity, matmul, + matmul_t_block_scaled, max_pool2d, max_pool2d_adaptive, pad, diff --git a/backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py b/backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py new file mode 100644 index 00000000000..b42e2855e4c --- /dev/null +++ b/backends/arm/tosa/dialect/ops/matmul_t_block_scaled.py @@ -0,0 +1,130 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import torch + +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) + + +def _validate_block_size(block_size: int) -> None: + if block_size <= 0: + raise TosaValueError( + f"block_size must be positive, got {block_size}", + op="MATMUL_T_BLOCK_SCALED", + ) + if block_size != 32: + raise TosaValueError( + f"Unsupported block_size {block_size}", + op="MATMUL_T_BLOCK_SCALED", + ) + + +def _validate_dtypes( + A_data: torch.Tensor, + A_scale: torch.Tensor, + B_data: torch.Tensor, + B_scale: torch.Tensor, +) -> None: + if A_data.dtype not in (torch.float8_e4m3fn, torch.float8_e5m2): + raise TosaValueError( + f"Unsupported A_data dtype {A_data.dtype}", + op="MATMUL_T_BLOCK_SCALED", + ) + if B_data.dtype != A_data.dtype: + raise TosaValueError( + f"B_data dtype {B_data.dtype} must match A_data dtype {A_data.dtype}", + op="MATMUL_T_BLOCK_SCALED", + ) + if A_scale.dtype != torch.float8_e8m0fnu or B_scale.dtype != torch.float8_e8m0fnu: + raise TosaValueError( + "Scale tensors must use torch.float8_e8m0fnu", + op="MATMUL_T_BLOCK_SCALED", + ) + + +def _validate_shapes( + A_data: torch.Tensor, + A_scale: torch.Tensor, + B_data: torch.Tensor, + B_scale: torch.Tensor, + block_size: int, +) -> tuple[int, int, int]: + if A_data.ndim != 3 or A_scale.ndim != 3 or B_data.ndim != 3 or B_scale.ndim != 3: + raise TosaValueError( + "MATMUL_T_BLOCK_SCALED expects rank-3 tensors for values and scales", + op="MATMUL_T_BLOCK_SCALED", + ) + + N, H, C = A_data.shape + D, W, Cb = B_data.shape + if C != Cb: + raise TosaValueError( + f"A_data last dim {C} must match B_data last dim {Cb}", + op="MATMUL_T_BLOCK_SCALED", + ) + if C % block_size != 0: + raise TosaValueError( + f"Last dim {C} must be divisible by block_size {block_size}", + op="MATMUL_T_BLOCK_SCALED", + ) + + expected_a_scale_shape = (N, H, C // block_size) + expected_b_scale_shape = (D, W, C // block_size) + if tuple(A_scale.shape) != expected_a_scale_shape: + raise TosaValueError( + f"A_scale shape {tuple(A_scale.shape)} must match {expected_a_scale_shape}", + op="MATMUL_T_BLOCK_SCALED", + ) + if tuple(B_scale.shape) != expected_b_scale_shape: + raise TosaValueError( + f"B_scale shape {tuple(B_scale.shape)} must match {expected_b_scale_shape}", + op="MATMUL_T_BLOCK_SCALED", + ) + + if D not in (1, N): + raise TosaValueError( + f"B_data batch dim {D} must be 1 or match A_data batch dim {N}", + op="MATMUL_T_BLOCK_SCALED", + ) + + return N, H, W + + +@register_fake_tosa_op( + "MATMUL_T_BLOCK_SCALED(Tensor A_data, Tensor A_scale, Tensor B_data, Tensor B_scale, SymInt block_size) -> Tensor", + [TosaSpecification.create_from_string("TOSA-1.1+FP")], +) +def MATMUL_T_BLOCK_SCALED( + A_data: torch.Tensor, + A_scale: torch.Tensor, + B_data: torch.Tensor, + B_scale: torch.Tensor, + block_size: int, +) -> torch.Tensor: + tosa_spec = get_context_spec() + + if not tosa_spec.support_float() or not tosa_spec.support_extension("mxfp"): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support MXFP block-scaled matmul", + op="MATMUL_T_BLOCK_SCALED", + ) + + _validate_block_size(block_size) + _validate_dtypes(A_data, A_scale, B_data, B_scale) + output_shape = _validate_shapes( + A_data, + A_scale, + B_data, + B_scale, + block_size, + ) + return A_data.new_empty(output_shape, dtype=torch.float32) From c87b5dffd8d6812a00bd44ca3d639234ddb98280 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Fri, 5 Jun 2026 11:14:42 +0200 Subject: [PATCH 3/3] Arm backend: Correct buck2 files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit is exact suggestion that kirklandsign posted in https://github.com/pytorch/executorch/pull/19969 Signed-off-by: Martin Lindström Suggested-by: @kirklandsign Change-Id: Iacce4dd11e61e4d79296e37e59ddb214072dd2ef --- backends/arm/test/BUCK | 49 ++++++++++------------------------- backends/arm/test/targets.bzl | 23 ++-------------- 2 files changed, 15 insertions(+), 57 deletions(-) diff --git a/backends/arm/test/BUCK b/backends/arm/test/BUCK index 534d9206cd4..ddcee35f53e 100644 --- a/backends/arm/test/BUCK +++ b/backends/arm/test/BUCK @@ -49,42 +49,6 @@ fbcode_target(_kind = runtime.python_library, ] ) -fbcode_target(_kind = runtime.python_library, - name = "custom_vgf_test_utils", - srcs = ["_custom_vgf_test_utils.py"], - resources = [ - "assets/test_add_buffer.glsl", - "assets/test_grid_read_tensor_debug.glsl", - "assets/test_grid_sample_buffer_nchw_debug.glsl", - "assets/test_grid_sample_sampler.glsl", - "assets/test_grid_sample_sampler_buffer_debug.glsl", - "assets/test_identity_buffer.glsl", - "assets/test_identity_image_packed_buffer.glsl", - "assets/test_threes_buffer.glsl", - "assets/test_threes_image_packed_buffer.glsl", - ], - deps = [ - "//caffe2:torch", - "//executorch/backends/arm:constants", - "//executorch/backends/arm/_passes:passes", - "//executorch/backends/arm/tosa/dialect:lib", - "//executorch/exir:lib", - ], -) - -fbcode_target(_kind = runtime.python_library, - name = "vgf_runtime_test_utils", - srcs = ["runtime/_vgf_runtime_test_utils.py"], - deps = [ - ":custom_vgf_test_utils", - ":runner_utils", - "//executorch/backends/arm:vgf", - "//executorch/backends/arm/_passes:passes", - "//executorch/exir:lib", - "fbsource//third-party/pypi/pytest:pytest", - ], -) - fbcode_target(_kind = runtime.python_library, name = "arm_tester_serialize", srcs = ["tester/serialize.py"], @@ -120,4 +84,17 @@ fbcode_target(_kind = runtime.python_library, ] ) +fbcode_target(_kind = runtime.python_library, + name = "mxfp_test_common", + srcs = [ + "ops/mxfp/__init__.py", + "ops/mxfp/common.py", + ], + deps = [ + ":arm_tester" if runtime.is_oss else "//executorch/backends/arm/test/tester/fb:arm_tester_fb", + "//executorch/backends/arm:ao_ext", + "//executorch/backends/test/harness:tester", + ], +) + fbcode_target(_kind = define_arm_tests,) diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 0a49046cac9..385d40b9d61 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -43,7 +43,6 @@ def define_arm_tests(): "ops/test_gelu.py", "ops/test_bmm.py", "ops/test_split.py", - "ops/test_custom_shader_lowering.py", ] # Quantization @@ -63,23 +62,12 @@ def define_arm_tests(): "misc/test_tosa_spec.py", "misc/test_bn_relu_folding_qat.py", "misc/test_custom_partition.py", - "misc/test_custom_shader_payloads.py", "misc/test_debug_hook.py", "misc/test_mxfp_linear_ao.py", "misc/test_post_quant_device_switch.py", - "misc/test_vgf_check_env.py", - "misc/test_vgf_backend.py", # "misc/test_dim_order.py", (TODO - T238390249) ] - test_files += [ - "runtime/test_vgf_aliasing_runtime.py", - "runtime/test_vgf_combinations_runtime.py", - "runtime/test_vgf_multi_segment_runtime.py", - "runtime/test_vgf_sampler_image_runtime.py", - "runtime/test_vgf_tensor_buffer_runtime.py", - ] - # Deprecation tests test_files += [ "deprecation/test_arm_compile_spec_deprecation.py", @@ -90,16 +78,10 @@ def define_arm_tests(): for test_file in test_files: test_file_name = paths.basename(test_file) test_name = test_file_name.replace("test_", "").replace(".py", "") - test_srcs = [test_file] - if test_file == "ops/mxfp/test_mxfp_linear.py": - test_srcs += [ - "ops/mxfp/__init__.py", - "ops/mxfp/common.py", - ] python_pytest( name = test_name, - srcs = test_srcs, + srcs = [test_file], pytest_config = "pytest.ini", resources = ["conftest.py"], compile = "with-source", @@ -123,9 +105,8 @@ def define_arm_tests(): deps = [ "//executorch/backends/arm/test:arm_tester" if runtime.is_oss else "//executorch/backends/arm/test/tester/fb:arm_tester_fb", "//executorch/backends/arm/test:conftest", + "//executorch/backends/arm/test:mxfp_test_common", "//executorch/backends/arm/test/misc:dw_convs_shared_weights_module", - "//executorch/backends/arm/test:custom_vgf_test_utils", - "//executorch/backends/arm/test:vgf_runtime_test_utils", "//executorch/backends/arm:ao_ext", "//executorch/backends/arm:ethosu", "//executorch/backends/arm/tosa:compile_spec",