|
10 | 10 | from executorch.exir.dialects._ops import ops as exir_ops
|
11 | 11 | from executorch.exir.pass_base import ExportPass, PassResult
|
12 | 12 |
|
| 13 | +from torch._subclasses.fake_tensor import FakeTensor |
| 14 | + |
| 15 | + |
| 16 | +def node_is_local_scalar_dense_chain(node: torch.fx.Node) -> bool: |
| 17 | + """ |
| 18 | + Converting a tensor to a scalar via tensor[0].item() creates a index_select + |
| 19 | + local_scalar_dense pattern in the graph. Check if a node is the start of this pattern. |
| 20 | + """ |
| 21 | + if ( |
| 22 | + node.op == "call_function" |
| 23 | + and node.target == exir_ops.edge.aten.select_copy.int |
| 24 | + and len(node.users) == 1 |
| 25 | + ): |
| 26 | + user = list(node.users.keys())[0] |
| 27 | + return user.target == torch.ops.aten._local_scalar_dense.default |
| 28 | + |
| 29 | + return False |
| 30 | + |
| 31 | + |
| 32 | +def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None: |
| 33 | + """ |
| 34 | + A scalar tensor in the Vulkan backend is a tensor that can be represented as a scalar |
| 35 | + value instead of a Tensor object. The criteria for identifying a tensor as a scalar |
| 36 | + tensor are as follows: |
| 37 | +
|
| 38 | + 1. The tensor has only 1 element |
| 39 | + 2. One of the node's uses is converting it to a scalar via `tensor[0].item()`, which |
| 40 | + creates a index_select + local_scalar_dense pattern in the graph |
| 41 | +
|
| 42 | + If any of these criteria are fulfilled, then tag the node for the tensor to mark it |
| 43 | + so that it is added as a scalar value during serialization. |
| 44 | + """ |
| 45 | + tensor_val = node.meta["val"] |
| 46 | + if not isinstance(tensor_val, FakeTensor): |
| 47 | + return |
| 48 | + |
| 49 | + # Scalar tensors must have only one element |
| 50 | + if tensor_val.numel() != 1: |
| 51 | + return |
| 52 | + |
| 53 | + for user in node.users: |
| 54 | + if node_is_local_scalar_dense_chain(user): |
| 55 | + node.meta["vkdg_is_scalar_tensor"] = True |
| 56 | + |
| 57 | + |
| 58 | +def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None: |
| 59 | + """ |
| 60 | + Remove the index_select + local_scalar_dense pattern in the graph in favor of passing |
| 61 | + the original scalar tensor directly. |
| 62 | + """ |
| 63 | + replace_node = node.args[0] |
| 64 | + assert isinstance(replace_node, torch.fx.Node) |
| 65 | + # If the argument to the local_scalar_dense op is a select op with only |
| 66 | + # one user, and the argument to the select op is a tensor with only one |
| 67 | + # element (i.e. a scalar tensor), then replace the entire pattern with the |
| 68 | + # scalar tensor. |
| 69 | + if ( |
| 70 | + replace_node.op == "call_function" |
| 71 | + and replace_node.target == exir_ops.edge.aten.select_copy.int |
| 72 | + ): |
| 73 | + # pyre-ignore |
| 74 | + if replace_node.args[0].meta["val"].numel() == 1: |
| 75 | + replace_node = replace_node.args[0] |
| 76 | + assert isinstance(replace_node, torch.fx.Node) |
| 77 | + assert replace_node.meta.get("vkdg_is_scalar_tensor", True) |
| 78 | + |
| 79 | + with graph.inserting_after(node): |
| 80 | + node.replace_all_uses_with(replace_node) |
| 81 | + |
13 | 82 |
|
14 | 83 | def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph:
|
15 | 84 | """
|
16 |
| - Remove local_scalar_dense op nodes and replace uses with parent node, or the |
17 |
| - original scalar tensor. |
| 85 | + The purpose of this pass is twofold: |
| 86 | + 1. Tag scalar tensors (see `tag_node_if_scalar_tensor()` for the criteria) |
| 87 | + 2. Remove the index_select + local_scalar_dense pattern in the graph in favor of |
| 88 | + passing the original scalar tensor directly (see `remove_local_scalar_dense_chain()`) |
| 89 | +
|
| 90 | + This makes it easier to deal with scalar tensors in the Vulkan backend. In particular, |
| 91 | + it allows serializing scalar tensors as SymInt objects instead of Tensor objects. |
| 92 | + Because scalar tensors are often used to inform tensor shapes, their values need to |
| 93 | + be easily accessed by the CPU during resizing logic, while also being able to reflect |
| 94 | + updates to their value in any GPU shaders that reference them. |
18 | 95 | """
|
19 | 96 | target_op = torch.ops.aten._local_scalar_dense.default
|
20 | 97 | for node in graph.nodes:
|
| 98 | + tag_node_if_scalar_tensor(node) |
| 99 | + |
21 | 100 | if node.op == "call_function" and node.target == target_op:
|
22 |
| - replace_node = node.args[0] |
23 |
| - # If the argument to the local_scalar_dense op is a select op with only |
24 |
| - # one user, and the argument to the select op is a tensor with only one |
25 |
| - # element (i.e. a scalar tensor), then replace the entire pattern with the |
26 |
| - # scalar tensor. |
27 |
| - if ( |
28 |
| - replace_node.op == "call_function" |
29 |
| - and replace_node.target == exir_ops.edge.aten.select_copy.int |
30 |
| - ): |
31 |
| - if replace_node.args[0].meta["val"].numel() == 1: |
32 |
| - replace_node = replace_node.args[0] |
33 |
| - |
34 |
| - with graph.inserting_after(node): |
35 |
| - node.replace_all_uses_with(replace_node) |
| 101 | + remove_local_scalar_dense_chain(graph, node) |
36 | 102 |
|
37 | 103 | graph.eliminate_dead_code()
|
38 | 104 | return graph
|
|
0 commit comments