Skip to content

Commit 6dd31a8

Browse files
committed
Update on "Refactor attention v2"
Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer. The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well. This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py. I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer. It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221 Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/) [ghstack-poisoned]
2 parents 6375fc2 + d4c9f8b commit 6dd31a8

File tree

69 files changed

+3764
-449
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+3764
-449
lines changed

Package.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ let package = Package(
9696
.copy("resources/add.pte")
9797
],
9898
linkerSettings: [
99+
.linkedLibrary("c++"),
99100
.unsafeFlags([
100-
"-Xlinker", "-all_load",
101+
"-Xlinker", "-force_load",
102+
"-Xlinker", "cmake-out/kernels_portable.xcframework/macos-arm64/libkernels_portable_macos.a",
101103
])
102104
]
103105
)

backends/apple/mps/setup.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ cd executorch
7676
## Run the mv3 generated model using the mps_executor_runner
7777

7878
```bash
79-
./cmake-out/examples/apple/mps/mps_executor_runner --model_path mv3_mps_bundled_fp16.pte --bundled_program
79+
./cmake-out/examples/apple/mps/mps_executor_runner --model_path mv3_mps_float16_bundled.pte --bundled_program
8080
```
8181

8282
- You should see the following results. Note that no output file will be generated in this example:
8383
```
84-
I 00:00:00.003290 executorch:mps_executor_runner.mm:286] Model file mv3_mps_bundled_fp16.pte is loaded.
84+
I 00:00:00.003290 executorch:mps_executor_runner.mm:286] Model file mv3_mps_float16_bundled.pte is loaded.
8585
I 00:00:00.003306 executorch:mps_executor_runner.mm:292] Program methods: 1
8686
I 00:00:00.003308 executorch:mps_executor_runner.mm:294] Running method forward
8787
I 00:00:00.003311 executorch:mps_executor_runner.mm:349] Setting up non-const buffer 1, size 606112.
@@ -118,7 +118,7 @@ python3 -m examples.apple.mps.scripts.mps_example --model_name="mv3" --generate_
118118
```
119119
2. Run your Program on the ExecuTorch runtime and generate an [ETDump](../../../docs/source/etdump.md).
120120
```
121-
./cmake-out/examples/apple/mps/mps_executor_runner --model_path mv3_mps_bundled_fp16.pte --bundled_program --dump-outputs
121+
./cmake-out/examples/apple/mps/mps_executor_runner --model_path mv3_mps_float16_bundled.pte --bundled_program --dump-outputs
122122
```
123123
3. Create an instance of the Inspector API by passing in the ETDump you have sourced from the runtime along with the optionally generated ETRecord from step 1.
124124
```bash

backends/arm/operator_support/convolution_support.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
register_tosa_support_check,
1212
SupportedTOSAOperatorCheck,
1313
)
14-
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
14+
from executorch.backends.arm.tosa_specification import (
15+
Tosa_0_80,
16+
Tosa_1_00,
17+
TosaSpecification,
18+
)
1519
from executorch.exir.dialects._ops import ops as exir_ops
1620

1721

@@ -43,6 +47,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
4347

4448
# Hardware specific constraints
4549
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
50+
# TODO remove this once TOSA 1.0 support for u55 is added.
51+
if isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions:
52+
return False
4653
return True
4754
else:
4855
return self._is_node_supported_u55(node)

backends/arm/operators/op_abs.py

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import executorch.backends.arm.tosa_quant_utils as tqutils
1010
import executorch.backends.arm.tosa_utils as tutils
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1312
from executorch.backends.arm.operators.node_visitor import (
1413
NodeVisitor,
1514
register_node_visitor,
@@ -33,10 +32,13 @@ def __init__(self, *args):
3332
def define_node(
3433
self,
3534
node: Node,
36-
tosa_graph: ts.TosaSerializer,
35+
tosa_graph: Any,
3736
inputs: List[TosaArg],
3837
output: TosaArg,
3938
) -> None:
39+
40+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
41+
4042
# Specification (0.80) states that input and output types
4143
# should all be the same
4244
if not (inputs[0].dtype == output.dtype):
@@ -53,7 +55,7 @@ def define_node(
5355
if inputs[0].dtype == ts.DType.INT8:
5456
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5557
tosa_graph, inputs, node
56-
)
58+
) # type: ignore[possibly-undefined]
5759
else:
5860
# input[0].dtype == ts.DType.INT32
5961
# Non quantized input, natively support by TOSA.abs
@@ -96,10 +98,13 @@ def __init__(self, *args):
9698
def define_node(
9799
self,
98100
node: Node,
99-
tosa_graph: ts.TosaSerializer,
101+
tosa_graph: Any,
100102
inputs: List[TosaArg],
101103
output: TosaArg,
102104
) -> None:
105+
106+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107+
103108
# Specification (0.80) states that input and output types
104109
# should all be the same
105110
if not (inputs[0].dtype == output.dtype):
@@ -129,3 +134,122 @@ def define_node(
129134
[output.name],
130135
None,
131136
)
137+
138+
139+
@register_node_visitor
140+
class AbsVisitor_INT(NodeVisitor):
141+
target = "aten.abs.default"
142+
143+
tosa_specs = [
144+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
145+
]
146+
147+
def __init__(self, *args):
148+
super().__init__(*args)
149+
150+
def define_node(
151+
self,
152+
node: Node,
153+
tosa_graph: Any,
154+
inputs: List[TosaArg],
155+
output: TosaArg,
156+
) -> None:
157+
158+
import serializer.tosa_serializer as ts # type: ignore
159+
160+
# Specification (1.0) states that input and output types
161+
# should all be the same
162+
if not (inputs[0].dtype == output.dtype):
163+
raise ValueError(
164+
"All inputs and outputs need same dtype."
165+
f"Got {inputs[0].dtype=}, {output.dtype=}"
166+
)
167+
# Handle int8 (quantized) and int32
168+
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
169+
raise ValueError(
170+
"All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}"
171+
)
172+
173+
scale_back = 1.0
174+
if inputs[0].dtype == ts.DType.INT8:
175+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
176+
tosa_graph, inputs, node, self.tosa_specs
177+
) # type: ignore[possibly-undefined]
178+
else:
179+
# input[0].dtype == ts.DType.INT32
180+
# Non quantized input, natively support by TOSA.abs
181+
rescaled_inputs = inputs
182+
183+
if output.dtype == ts.DType.INT8:
184+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
185+
abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
186+
else:
187+
# output.dtype == ts.DType.INT32
188+
abs_output = output
189+
190+
# Do the INT32 Abs
191+
tosa_graph.addOperator(
192+
ts.TosaOp.Op().ABS,
193+
[
194+
rescaled_inputs[0].name,
195+
],
196+
[abs_output.name],
197+
None,
198+
)
199+
200+
if output.dtype == ts.DType.INT8:
201+
# Scale output back to 8 bit
202+
# pyre-ignore
203+
tqutils.insert_rescale_op_to_int8(
204+
tosa_graph, abs_output, scale_back, node, self.tosa_specs
205+
) # type: ignore[possibly-undefined]
206+
207+
208+
@register_node_visitor
209+
class AbsVisitor_FP(AbsVisitor_INT):
210+
# inheriting 'target' from BI class
211+
212+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
213+
214+
def __init__(self, *args):
215+
super().__init__(*args)
216+
217+
def define_node(
218+
self,
219+
node: Node,
220+
tosa_graph: Any,
221+
inputs: List[TosaArg],
222+
output: TosaArg,
223+
) -> None:
224+
225+
import serializer.tosa_serializer as ts # type: ignore
226+
227+
# Specification (1.0) states that input and output types
228+
# should all be the same
229+
if not (inputs[0].dtype == output.dtype):
230+
raise ValueError(
231+
"All inputs and output need same dtype."
232+
f"Got {inputs[0].dtype=}, {output.dtype=}"
233+
)
234+
235+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
236+
# Call the inherited define_node for handling integers
237+
super().define_node(node, tosa_graph, inputs, output)
238+
else:
239+
# FP32 Abs lowering
240+
241+
if not (inputs[0].dtype == ts.DType.FP32):
242+
raise ValueError(
243+
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
244+
)
245+
246+
if not (output.dtype == ts.DType.FP32):
247+
raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}")
248+
249+
# MI lowering
250+
tosa_graph.addOperator(
251+
ts.TosaOp.Op().ABS,
252+
[inputs[0].name],
253+
[output.name],
254+
None,
255+
)

0 commit comments

Comments
 (0)