Skip to content

Commit 5bf6b54

Browse files
committed
Update on "[11/n][ET-VK] Introduce vTensor creation from external image"
Nearly all metadata is initialized to null/dummy values, except those absolutely needed in the pipeline: (1) image extents, (2) logical limits. Differential Revision: [D63843819](https://our.internmc.facebook.com/intern/diff/D63843819/) [ghstack-poisoned]
2 parents ecb7c58 + 95dfdfd commit 5bf6b54

File tree

98 files changed

+1635
-654
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+1635
-654
lines changed

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ exclude_patterns = [
180180
'**/*.bat',
181181
'**/*.jpg',
182182
'**/*.jar',
183+
'**/*.gif',
183184
# File contains @generated
184185
'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h',
185186
'extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h',

backends/apple/coreml/test/test_coreml_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
)
1616

1717
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
18-
from torch._export import capture_pre_autograd_graph
1918
from torch.ao.quantization.quantize_pt2e import (
2019
convert_pt2e,
2120
prepare_pt2e,
2221
prepare_qat_pt2e,
2322
)
23+
from torch.export import export_for_training
2424

2525

2626
class TestCoreMLQuantizer:
@@ -32,7 +32,7 @@ def quantize_and_compare(
3232
) -> None:
3333
assert quantization_type in {"PTQ", "QAT"}
3434

35-
pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_inputs)
35+
pre_autograd_aten_dialect = export_for_training(model, example_inputs).module()
3636

3737
quantization_config = LinearQuantizerConfig.from_dict(
3838
{

backends/apple/mps/test/test_mps_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,9 @@ def lower_module_and_test_output(
209209

210210
expected_output = model(*sample_inputs)
211211

212-
model = torch._export.capture_pre_autograd_graph(
212+
model = torch.export.export_for_training(
213213
model, sample_inputs, dynamic_shapes=dynamic_shapes
214-
)
214+
).module()
215215

216216
edge_program = export_to_edge(
217217
model,

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5757
exir_ops.edge.aten.sigmoid.default,
5858
exir_ops.edge.aten.mm.default,
5959
exir_ops.edge.aten.repeat.default,
60+
exir_ops.edge.aten.reciprocal.default,
6061
exir_ops.edge.aten.relu.default,
6162
exir_ops.edge.aten.rsqrt.default,
6263
exir_ops.edge.aten._softmax.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
op_cat,
1616
op_conv2d,
1717
op_dequant,
18-
op_div,
1918
op_exp,
2019
op_full,
2120
op_get_item,
@@ -26,6 +25,7 @@
2625
op_mul,
2726
op_permute,
2827
op_quant,
28+
op_reciprocal,
2929
op_relu,
3030
op_repeat,
3131
op_rsqrt,
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import List
6+
7+
import numpy as np
8+
9+
import serializer.tosa_serializer as ts
10+
import torch
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from executorch.backends.arm.tosa_quant_utils import (
17+
dequantize_value,
18+
get_quant_node_args,
19+
QuantArgs,
20+
quantize_value,
21+
)
22+
from serializer.tosa_serializer import TosaOp
23+
24+
25+
@register_node_visitor
26+
class DivVisitor(NodeVisitor):
27+
target = "aten.reciprocal.default"
28+
29+
def __init__(self, *args):
30+
super().__init__(*args)
31+
32+
def define_node(
33+
self,
34+
node: torch.fx.Node,
35+
tosa_graph: ts.TosaSerializer,
36+
inputs: List[TosaArg],
37+
output: TosaArg,
38+
is_quant_node: bool,
39+
) -> None:
40+
# 1/X
41+
42+
if is_quant_node:
43+
input = inputs[0]
44+
input_qargs = get_quant_node_args(node.all_input_nodes[0])
45+
output_qargs = get_quant_node_args(list(node.users)[0])
46+
47+
div_table = div_table_8bit(input_qargs, output_qargs)
48+
49+
table_attr = ts.TosaSerializerAttribute()
50+
table_attr.TableAttribute(div_table)
51+
tosa_graph.addOperator(
52+
TosaOp.Op().TABLE, [input.name], [output.name], table_attr
53+
)
54+
55+
else:
56+
tosa_graph.addOperator(
57+
TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]
58+
)
59+
60+
61+
def div_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
62+
"""
63+
Returns a table mapping 256 entries to div([qmin,qmax])
64+
"""
65+
66+
def div(x):
67+
# Convert quantized input to floating point div input space.
68+
v1 = dequantize_value(x, in_quantargs)
69+
# Compute div.
70+
v2 = 1.0 / v1
71+
# Convert div output back to quantized space.
72+
v3 = quantize_value(v2, out_quantargs)
73+
74+
return v3
75+
76+
return [
77+
div(x)
78+
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
79+
]

backends/arm/passes/arm_pass_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
from executorch.backends.arm.passes.convert_split_to_slice import (
1818
ConvertSplitToSlicePass,
1919
)
20+
from executorch.backends.arm.passes.decompose_div_pass import DecomposeDivPass
2021
from executorch.backends.arm.passes.meandim_to_averagepool_pass import (
2122
ConvertMeanDimToAveragePool,
2223
)
2324
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
25+
from executorch.backends.arm.passes.scalars_to_attribute_pass import (
26+
ScalarsToAttributePass,
27+
)
2428
from executorch.backends.arm.passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
2529
from executorch.exir import ExportedProgram
2630
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -40,6 +44,7 @@ def transform_to_backend_pipeline(
4044
self.add_pass(RemoveClonePass())
4145
self.add_pass(ConvertExpandCopyToRepeatPass())
4246
self.add_pass(ConvertMeanDimToAveragePool())
47+
self.add_pass(DecomposeDivPass())
4348
self.add_pass(ConvertSplitToSlicePass())
4449
for spec in compile_spec:
4550
if spec.key == "permute_memory_format":
@@ -48,3 +53,8 @@ def transform_to_backend_pipeline(
4853
self.add_pass(AnnotateChannelsLastDimOrder())
4954

5055
return self._transform(exported_program.graph_module)
56+
57+
def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
58+
self.add_pass(DecomposeDivPass())
59+
self.add_pass(ScalarsToAttributePass())
60+
return self._transform(graph_module)

backends/arm/passes/arm_pass_utils.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import torch
10+
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from torch._ops import OpOverload
13+
14+
15+
def create_node(
16+
graph: torch.fx.Graph,
17+
op_target: OpOverload,
18+
args: tuple = (),
19+
kwargs: Optional[dict] = None,
20+
quantize: bool = False,
21+
q_params: Optional[tuple] = None,
22+
):
23+
"""
24+
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
25+
If quantize is true and q_params is not None, a q dq pair is inserted after the newly created node.
26+
"""
27+
28+
node = graph.create_node(
29+
"call_function",
30+
op_target,
31+
args=args,
32+
kwargs=kwargs or {},
33+
)
34+
if quantize and q_params:
35+
return insert_q_dq_pair(graph, node, q_params)
36+
return node
37+
38+
39+
def insert_q_dq_pair(
40+
graph: torch.fx.Graph,
41+
anchor: torch.fx.Node,
42+
q_params: tuple,
43+
):
44+
"""
45+
Inserts a q dq node pair after the node 'anchor'.
46+
"""
47+
48+
with graph.inserting_after(anchor):
49+
q = create_node(
50+
graph=graph,
51+
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
52+
args=(), # We add the argument last
53+
)
54+
q.meta = anchor.meta
55+
with graph.inserting_after(q):
56+
dq = create_node(
57+
graph=graph,
58+
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
59+
args=(q,) + q_params,
60+
)
61+
dq.meta = q.meta
62+
anchor.replace_all_uses_with(dq)
63+
# We add this last so the replace all uses above does not replace the quantized
64+
# node's first use
65+
q.args = (anchor,) + q_params
66+
return dq

backends/arm/passes/convert_expand_copy_to_repeat.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88

99
from typing import cast
1010

11-
import torch.fx
1211
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
1312
from executorch.exir.dialects._ops import ops as exir_ops
14-
from executorch.exir.pass_base import ExportPass, PassResult
15-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
13+
from executorch.exir.pass_base import ExportPass
1614

1715

1816
class ConvertExpandCopyToRepeatPass(ExportPass):
@@ -22,42 +20,26 @@ class ConvertExpandCopyToRepeatPass(ExportPass):
2220

2321
expand_copy = exir_ops.edge.aten.expand_copy.default
2422
repeat = exir_ops.edge.aten.repeat.default
25-
patterns = [{expand_copy: 1}]
2623

27-
def call(self, graph_module: torch.fx.GraphModule):
28-
graph = graph_module.graph
29-
partitions = get_source_partitions(
30-
graph, [torch.expand_copy, torch.Tensor.expand, "expand"]
24+
def call_operator(self, op, args, kwargs, meta):
25+
if op != self.expand_copy:
26+
return super().call_operator(op, args, kwargs, meta)
27+
28+
_, shape, _ = extract_tensor_meta(meta.data)
29+
multiples = cast(list[int], args[1])
30+
expanded_rank = len(multiples)
31+
32+
# Expanded shape is 'shape' front-padded with ones.
33+
padding = expanded_rank - len(shape)
34+
extended_shape = [
35+
shape[i] if i >= 0 else 1 for i in range(-padding, len(shape))
36+
]
37+
38+
# To convert expand arg to repeat arg, non-repeated dims should have
39+
# multiples[dim] = 1.
40+
multiples = [
41+
multiples[i] if extended_shape[i] == 1 else 1 for i in range(expanded_rank)
42+
]
43+
return super().call_operator(
44+
op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta
3145
)
32-
for _, src_partitions in partitions.items():
33-
for src_partition in src_partitions:
34-
assert len(src_partition.nodes) == 1
35-
36-
expand_node = src_partition.nodes[0]
37-
_, shape, _ = extract_tensor_meta(expand_node.all_input_nodes[0].meta)
38-
multiples = cast(tuple[int], expand_node.args[1])
39-
expanded_rank = len(multiples)
40-
41-
# Expanded shape is 'shape' front-padded with ones.
42-
padding = expanded_rank - len(shape)
43-
extended_shape = [
44-
shape[i] if i >= 0 else 1 for i in range(-padding, len(shape))
45-
]
46-
47-
# To convert expand arg to repeat arg, non-repeated dims should have
48-
# multiples[dim] = 1.
49-
multiples = [
50-
multiples[i] if extended_shape[i] == 1 else 1
51-
for i in range(expanded_rank)
52-
]
53-
args = (expand_node.args[0], multiples)
54-
55-
with graph_module.graph.inserting_before(expand_node):
56-
repeat_node = graph.create_node("call_function", self.repeat, args)
57-
repeat_node.meta = expand_node.meta
58-
for user in expand_node.users.copy():
59-
user.replace_input_with(expand_node, repeat_node)
60-
61-
graph.eliminate_dead_code()
62-
graph_module.recompile()
63-
return PassResult(graph_module, True)

backends/arm/passes/convert_split_to_slice.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-unsafe
88

99
import torch.fx
10+
from executorch.backends.arm.passes.arm_pass_utils import create_node
1011
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
1112
from executorch.exir.dialects._ops import ops as exir_ops
1213
from executorch.exir.pass_base import ExportPass, PassResult
@@ -55,18 +56,18 @@ def call(self, graph_module: torch.fx.GraphModule):
5556
start = end
5657

5758
# Output nodes are of type getitem
58-
# Create one slice node for each output node with matching argumetns.
59+
# Replace them with one slice node for each output node.
5960
with graph_module.graph.inserting_before(split_node):
6061
for output_node in output_nodes:
6162
index = output_node.args[1]
62-
slice_node = graph.create_node(
63-
"call_function",
63+
slice_node = create_node(
64+
graph,
6465
self.slice,
6566
(input_node, dim, starts[index], ends[index]),
6667
)
6768
slice_node.meta = split_node.meta.copy()
6869
slice_node.meta["val"] = slice_node.meta["val"][index]
69-
output_node.replace_input_with(split_node, slice_node)
70+
output_node.replace_all_uses_with(slice_node)
7071
graph.eliminate_dead_code()
7172
graph_module.recompile()
7273
return PassResult(graph_module, True)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass
10+
11+
12+
def get_div_decomposition(op) -> tuple:
13+
"""
14+
Returns the the (reciprocal_op, mul_op), where the ops depends on if
15+
the div op is in exir_ops torch.ops.aten.
16+
"""
17+
if op == exir_ops.edge.aten.div.Tensor:
18+
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
19+
if op == torch.ops.aten.div.Tensor:
20+
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
21+
raise RuntimeError(f"Can't get div decomposition for op {op}")
22+
23+
24+
class DecomposeDivPass(ExportPass):
25+
"""
26+
This pass decomposes div into a mul and a reciprocal node.
27+
28+
Example:
29+
y = div(a,b)
30+
Becomes:
31+
x = reciprocal(b)
32+
y = mul(a,x)
33+
"""
34+
35+
def call_operator(self, op, args, kwargs, meta):
36+
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor):
37+
return super().call_operator(op, args, kwargs, meta)
38+
39+
reciprocal_op, mul_op = get_div_decomposition(op)
40+
41+
numerator = args[0]
42+
denominator = args[1]
43+
reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta)
44+
45+
return super().call_operator(mul_op, (numerator, reciprocal), {}, meta)

0 commit comments

Comments
 (0)