Skip to content

Commit ceb22c1

Browse files
Zonglin Pengfacebook-github-bot
authored andcommitted
migrate to OSS passes, [cadence][pass] move compiler utils to OSS for passes
Summary: titled titled Differential Revision: D65908549
1 parent ad15852 commit ceb22c1

File tree

3 files changed

+344
-32
lines changed

3 files changed

+344
-32
lines changed

backends/cadence/aot/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ python_library(
7676
"//executorch/exir/dialects:lib",
7777
"//executorch/exir/passes:lib",
7878
"//executorch/exir/passes:spec_prop_pass",
79+
"//executorch/backends/transforms:remove_clone_ops"
7980
],
8081
)
8182

@@ -117,3 +118,15 @@ python_unittest(
117118
"//executorch/exir:pass_base",
118119
],
119120
)
121+
122+
python_library(
123+
name = "compiler_utils",
124+
srcs = [
125+
"compiler_utils.py",
126+
],
127+
typing = True,
128+
deps = [
129+
"//caffe2:torch",
130+
"//executorch/exir/dialects:lib",
131+
],
132+
)
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
6+
# This file contains all the helper utility functions.
7+
8+
from itertools import zip_longest
9+
from math import frexp, isclose, trunc
10+
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union
11+
12+
import torch
13+
import torch.fx
14+
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
from torch.utils._pytree import tree_flatten
17+
18+
19+
# Return the output node of the graph
20+
def get_output_node(graph: torch.fx.Graph) -> torch.fx.Node:
21+
assert graph is not None, "Cannot get output of an empty graph"
22+
output_node = next(iter(reversed(graph.nodes)))
23+
assert (
24+
output_node and output_node.op == "output" and len(output_node.args) == 1
25+
), "Failed to find output node"
26+
return output_node
27+
28+
29+
# Return true if the node is part of the flattened output
30+
def is_node_in_flattened_output(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
31+
output_node = get_output_node(graph)
32+
return node in tree_flatten(output_node.args[0])[0]
33+
34+
35+
# Returns a list with placeholders/inputs
36+
def get_placeholders(graph: torch.fx.Graph) -> List[torch.fx.Node]:
37+
return list(filter(lambda x: x.op == "placeholder", graph.nodes))
38+
39+
40+
# Return the shape of the incoming node.
41+
def get_shape(
42+
graph_module: torch.fx.GraphModule, node: torch.fx.Node
43+
) -> Union[torch.Size, None]:
44+
"""
45+
Return the shape of the tensor correspnding to node. If the node has a
46+
tensor spec, return the shape from the metadata. If the node is a param,
47+
return it shape. Otherwise return None.
48+
"""
49+
try:
50+
# Case 1. node is a scalar (this pass happens before tensorization)
51+
if isinstance(node, (float, int, bool)):
52+
return torch.Size([1])
53+
# Case 2. node has TensorSpec metadata
54+
fake_tensor = node.meta.get("val")
55+
if fake_tensor is not None:
56+
return fake_tensor.shape
57+
# Case 3. node holds a param
58+
if node.op == "get_attr":
59+
attr_node = getattr(graph_module, node.target)
60+
return attr_node.shape
61+
# Default: return None
62+
return None
63+
except RuntimeError:
64+
return None
65+
66+
67+
# Return true if shape_2 can be broadcasted to shape_1
68+
def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool:
69+
"""
70+
Check if 'shape_2' can be broadcasted to 'shape_1'. The broadcast is
71+
feasible if:
72+
(1) shape_2 does not have higher dimensionality than shape_1;
73+
(2) the value at each dimension of shape_2 is either the same as shape_1 or 1;
74+
(3) shape_1 or shape_2 is empty.
75+
"""
76+
return (
77+
not shape_1
78+
or not shape_2
79+
or all(
80+
x == y or y == 1 or y is None
81+
for x, y in zip_longest(shape_1[::-1], shape_2[::-1])
82+
)
83+
)
84+
85+
86+
# Return a chain of nodes with target in op_targets
87+
def get_cascaded_ops(
88+
nodes: List[torch.fx.Node],
89+
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
90+
op_targets: Iterable[Union[Callable[..., Any], str]],
91+
) -> Sequence[torch.fx.Node]:
92+
"""
93+
'nodes' contains a chain of ops with target in 'op_targets'. Extend that chain
94+
by one if nodes[-1] has a single user with its op target in 'op_targets'.
95+
"""
96+
cur = nodes[-1]
97+
users = list(cur.users.keys())
98+
# Assert that (a) there is only one user of cur, and (b) that user is
99+
# one of the op in op_targets.
100+
if len(users) == 1 and users[0].target in op_targets:
101+
nodes.append(users[0])
102+
# Recursively find the chain starting at the user
103+
return get_cascaded_ops(nodes, op_targets)
104+
105+
return nodes
106+
107+
108+
# Capture the effect of transpose op on incoming dimension order
109+
def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
110+
"""
111+
Given a transpose node, and the incoming dimension ordering of the input
112+
tensor to the transpose node, return the net effect of transpose op on the
113+
dimension order.
114+
"""
115+
assert node.target == exir_ops.edge.aten.transpose_copy.int
116+
# Assert that the dims is not empty
117+
assert dims is not None
118+
dim_len = len(dims)
119+
# Get dim0 and dim1 from the transpose op args
120+
transpose_dims0 = node.args[1]
121+
transpose_dims1 = node.args[2]
122+
assert isinstance(transpose_dims0, int)
123+
assert isinstance(transpose_dims1, int)
124+
dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len
125+
dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len
126+
# Perform transpose on dimmension ordering (dims)
127+
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
128+
return dims
129+
130+
131+
# Capture the effect of permute op on incoming dimension order
132+
def get_permuted_dims(node: torch.fx.Node, dims: Optional[List[int]]) -> List[int]:
133+
"""
134+
Given a permute node, and the incoming dimension ordering of the input
135+
tensor to the permute node, return the net effect of permute op on the
136+
dimension order.
137+
"""
138+
assert node.target == exir_ops.edge.aten.permute_copy.default
139+
# Permute each index of the dimension ordering (dims)
140+
permute_dims = node.args[1]
141+
assert isinstance(permute_dims, List)
142+
assert all(isinstance(x, int) for x in permute_dims)
143+
# If the dims is empty, we can simply return the permute order
144+
if not dims:
145+
return permute_dims
146+
dims = [dims[x] for x in permute_dims]
147+
return dims
148+
149+
150+
# Return the tensor of buffer/parameter op
151+
def get_tensor_from_attr(
152+
graph_module: torch.fx.GraphModule, node: Optional[torch.fx.Node]
153+
) -> Optional[torch.Tensor]:
154+
"""
155+
For an input node that is a named buffer or parameter, return
156+
the underlying tensor.
157+
"""
158+
if node is None:
159+
return None
160+
assert node.op == "get_attr"
161+
return getattr(graph_module, node.target)
162+
163+
164+
def is_node_with_op(node: torch.fx.Node, op: str) -> bool:
165+
"""
166+
Return true if the incoming node has the given op type
167+
"""
168+
return node.op == op
169+
170+
171+
def count_users_with_target_op_type(
172+
nodes: Iterable[torch.fx.Node],
173+
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
174+
op_target: Union[Callable[..., Any], str],
175+
) -> int:
176+
"""
177+
Given a set of nodes and a node target type `op_target`, iterate over all
178+
the users of nodes, and return the total number of users with target
179+
op_target.
180+
"""
181+
182+
def contributions_per_node(
183+
node: torch.fx.Node,
184+
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
185+
op_target: Union[Callable[..., Any], str],
186+
) -> int:
187+
return [use.target for use in node.users if use.op == "call_function"].count(
188+
op_target
189+
)
190+
191+
return sum([contributions_per_node(node, op_target) for node in nodes])
192+
193+
194+
def contains_node_with_matching_target(
195+
nodes: Iterable[torch.fx.Node],
196+
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
197+
op_target: Union[Callable[..., Any], str],
198+
) -> bool:
199+
"""
200+
Given a list of nodes, return true if any node in the list has target
201+
'op_target'.
202+
"""
203+
return any(node.target == op_target for node in nodes)
204+
205+
206+
def is_quantized_tensor(x: torch.Tensor) -> bool:
207+
"""
208+
Return true if the tensor x is quantized
209+
"""
210+
return x.is_quantized
211+
212+
213+
def get_scale(x: torch.Tensor) -> torch.Tensor:
214+
"""
215+
Return the scale of a quantized tensor as a float32 tensor.
216+
"""
217+
return (
218+
x.q_per_channel_scales().to(torch.float32)
219+
if x.qscheme() == torch.per_channel_affine
220+
else torch.tensor([x.q_scale()], dtype=torch.float32)
221+
)
222+
223+
224+
def get_zero_point(x: torch.Tensor, reduce: bool = True) -> torch.Tensor:
225+
"""
226+
Return the zero point of a quantized tensor as int32 tensor.
227+
"""
228+
# If x was quantized per-tensor, simply create a tensor out of the scalar
229+
# zero_point, and return it.
230+
if x.qscheme() == torch.per_tensor_affine:
231+
return torch.tensor([x.q_zero_point()], dtype=torch.int32)
232+
# If x was quantized per-channel, check if the zero_point is all zeros. If
233+
# so, then we can compress the zero_point tensor to a scalar.
234+
assert x.qscheme() == torch.per_channel_affine, "Unhandled quantization scheme"
235+
zero_point = x.q_per_channel_zero_points().to(torch.int32)
236+
return (
237+
torch.tensor([zero_point[0]], dtype=torch.int32)
238+
if reduce and all(zero_point == zero_point[0])
239+
else zero_point
240+
)
241+
242+
243+
def quantize_tensor_multiplier(
244+
requantize_scale_tensor: torch.Tensor,
245+
) -> Tuple[torch.Tensor, torch.Tensor]:
246+
"""
247+
Given requantize_scale_tensor with values in the interval (0, 1),
248+
produce a pair of tensors (out_multiplier, right_shift) where out_multiplier
249+
is an int32 tensor representing fixed-point values in the interval [-1, 1),
250+
and right_shift is an amount to shift right by, so that the floating-point
251+
multiplication of some int32 input with each value of requantize_scale_tensor:
252+
result = int32_value * requantize_scale_tensors[i]
253+
is best approximated by the integer-arithmetic-only code:
254+
result = RoundingRightShift(FixedPointMultiplication(int32_value,
255+
out_multiplier[i]), right_shift[i])
256+
"""
257+
258+
# This is identical to C++11 std::round(). The general python round rounds
259+
# down, and C++ rounds away from zero.
260+
# pyre-fixme[2]: Parameter must be annotated.
261+
def round_away_zero(f) -> int:
262+
r = -0.5 if (f < 0) else 0.5
263+
return trunc(f + r)
264+
265+
def quantize_scalar_multiplier(requantize_scale: float) -> Tuple[int, int]:
266+
significand, exponent = frexp(requantize_scale)
267+
significand_q31 = int(round_away_zero(significand * (1 << 31)))
268+
# Handle the special case when the real multiplier was so close to 1
269+
# that its fixed-point approximation was indistinguishable from 1.
270+
# We handle this by dividing it by two, incrementing exponent by 1.
271+
# the right shift amount.
272+
if significand_q31 == (1 << 31):
273+
significand_q31 //= 2
274+
exponent += 1
275+
276+
# Verify that the decomposition of requantize_scale into significand
277+
# and exponent is correct.
278+
reconstructed = significand_q31 / (1 << 31) * pow(2, exponent)
279+
assert isclose(
280+
requantize_scale, reconstructed, rel_tol=1e-4, abs_tol=1e-4
281+
), "computation of significand and exponent from requantize_scale is not accurate"
282+
283+
return (significand_q31, exponent)
284+
285+
# Flatten the input scale tensor so that we can operate on individual values
286+
orig_shape = requantize_scale_tensor.shape
287+
flattened_tensor = requantize_scale_tensor.flatten().to(torch.float32)
288+
out_multiplier = torch.zeros(flattened_tensor.shape, dtype=torch.int32)
289+
right_shift = torch.zeros(flattened_tensor.shape, dtype=torch.int32)
290+
291+
# Iterate over the flattened scale tensor and compute the decomposition of
292+
# each value in scale tensor into significand(out_multiplier) and
293+
# exponent(right_shift)
294+
for idx, scale in enumerate(flattened_tensor):
295+
(si, ex) = quantize_scalar_multiplier(scale)
296+
out_multiplier[idx], right_shift[idx] = si, ex
297+
298+
# Reshape the tensors back to the original shape
299+
out_multiplier = out_multiplier.reshape(orig_shape)
300+
right_shift = right_shift.reshape(orig_shape)
301+
302+
return (out_multiplier, right_shift)

0 commit comments

Comments
 (0)