diff --git a/backends/arm/_passes/fuse_duplicate_users_pass.py b/backends/arm/_passes/fuse_duplicate_users_pass.py index 23e1eb6f6d3..58e6d929181 100644 --- a/backends/arm/_passes/fuse_duplicate_users_pass.py +++ b/backends/arm/_passes/fuse_duplicate_users_pass.py @@ -34,6 +34,7 @@ def call(self, graph_module: GraphModule) -> PassResult: graph = graph_module.graph modified = False + node_order = {node: index for index, node in enumerate(graph.nodes)} producers: Deque[Node] = deque(node for node in graph.nodes) while producers: @@ -48,7 +49,7 @@ def call(self, graph_module: GraphModule) -> PassResult: if len(user_nodes) < 2: continue - candidate_groups = self._get_candidate_groups(user_nodes) + candidate_groups = self._get_candidate_groups(node_order, user_nodes) signature_to_user: Dict[Tuple[Hashable, ...], Node] = {} for group in candidate_groups: @@ -84,7 +85,7 @@ def call(self, graph_module: GraphModule) -> PassResult: return PassResult(graph_module, modified) - def _get_candidate_groups(self, user_nodes): + def _get_candidate_groups(self, node_order, user_nodes): users_by_target: Dict[Tuple[str, Hashable], List[Node]] = {} for user in user_nodes: if user.graph is None: @@ -98,9 +99,12 @@ def _get_candidate_groups(self, user_nodes): target_signature = (user.op, target_key) users_by_target.setdefault(target_signature, []).append(user) - candidate_groups = [ - group for group in users_by_target.values() if len(group) > 1 - ] + candidate_groups = [] + for group in users_by_target.values(): + if len(group) > 1: + candidate_groups.append( + sorted(group, key=lambda node: node_order[node]) + ) return candidate_groups diff --git a/backends/arm/test/passes/test_fuse_duplicate_users_pass.py b/backends/arm/test/passes/test_fuse_duplicate_users_pass.py index d94e01f9847..3227cfa8755 100644 --- a/backends/arm/test/passes/test_fuse_duplicate_users_pass.py +++ b/backends/arm/test/passes/test_fuse_duplicate_users_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -9,6 +9,7 @@ from executorch.backends.arm._passes import FuseDuplicateUsersPass from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from torch.fx import Graph, GraphModule input_t = Tuple[torch.Tensor] # Input x @@ -55,6 +56,42 @@ def forward(self, x): } +def _set_val(node, val): + node.meta["val"] = val + return node + + +def _graph_with_users_not_in_node_order() -> GraphModule: + graph = Graph() + x = _set_val(graph.placeholder("x"), torch.ones(1)) + y = _set_val(graph.placeholder("y"), torch.ones(1)) + + later_duplicate = _set_val( + graph.call_function(torch.ops.aten.add.Tensor, (x, y)), torch.ones(1) + ) + with graph.inserting_before(later_duplicate): + earlier_duplicate = _set_val( + graph.call_function(torch.ops.aten.add.Tensor, (x, y)), torch.ones(1) + ) + consumer = _set_val( + graph.call_function(torch.ops.aten.neg.default, (earlier_duplicate,)), + torch.ones(1), + ) + + output = graph.output(consumer) + output.meta["val"] = torch.ones(1) + graph.lint() + return GraphModule(torch.nn.Module(), graph) + + +def _add_node_names(graph_module): + return [ + node.name + for node in graph_module.graph.nodes + if node.target == torch.ops.aten.add.Tensor + ] + + @common.parametrize("module", modules) def test_fuse_duplicate_users_tosa_FP(module: ModuleWithOps): pipeline = PassPipeline[input_t]( @@ -68,3 +105,14 @@ def test_fuse_duplicate_users_tosa_FP(module: ModuleWithOps): ], ) pipeline.run() + + +def test_fuse_duplicate_users_preserves_graph_order_for_representative(): + graph_module = _graph_with_users_not_in_node_order() + assert _add_node_names(graph_module) == ["add_tensor_1", "add_tensor"] + + result = FuseDuplicateUsersPass()(graph_module) + + result.graph_module.graph.lint() + assert result.modified + assert len(_add_node_names(result.graph_module)) == 1