Skip to content

Commit 3d1f6ce

Browse files
committed
Update on "[ET-VK] Simplifying conv1d op shader by changing it to process one output texel per thread."
This diff changes conv1d shader to process one output texel per thread, increasing GPU occupancy and improve performance. Differential Revision: [D74097560](https://our.internmc.facebook.com/intern/diff/D74097560/) [ghstack-poisoned]
2 parents 7c92551 + 0c03969 commit 3d1f6ce

19 files changed

+1814
-174
lines changed

backends/arm/operator_support/convolution_support.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
register_tosa_support_check,
1212
SupportedTOSAOperatorCheck,
1313
)
14-
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
14+
from executorch.backends.arm.tosa_specification import (
15+
Tosa_0_80,
16+
Tosa_1_00,
17+
TosaSpecification,
18+
)
1519
from executorch.exir.dialects._ops import ops as exir_ops
1620

1721

@@ -43,6 +47,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
4347

4448
# Hardware specific constraints
4549
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
50+
# TODO remove this once TOSA 1.0 support for u55 is added.
51+
if isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions:
52+
return False
4653
return True
4754
else:
4855
return self._is_node_supported_u55(node)

backends/arm/operators/op_abs.py

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import executorch.backends.arm.tosa_quant_utils as tqutils
1010
import executorch.backends.arm.tosa_utils as tutils
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1312
from executorch.backends.arm.operators.node_visitor import (
1413
NodeVisitor,
1514
register_node_visitor,
@@ -33,10 +32,13 @@ def __init__(self, *args):
3332
def define_node(
3433
self,
3534
node: Node,
36-
tosa_graph: ts.TosaSerializer,
35+
tosa_graph: Any,
3736
inputs: List[TosaArg],
3837
output: TosaArg,
3938
) -> None:
39+
40+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
41+
4042
# Specification (0.80) states that input and output types
4143
# should all be the same
4244
if not (inputs[0].dtype == output.dtype):
@@ -53,7 +55,7 @@ def define_node(
5355
if inputs[0].dtype == ts.DType.INT8:
5456
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5557
tosa_graph, inputs, node
56-
)
58+
) # type: ignore[possibly-undefined]
5759
else:
5860
# input[0].dtype == ts.DType.INT32
5961
# Non quantized input, natively support by TOSA.abs
@@ -96,10 +98,13 @@ def __init__(self, *args):
9698
def define_node(
9799
self,
98100
node: Node,
99-
tosa_graph: ts.TosaSerializer,
101+
tosa_graph: Any,
100102
inputs: List[TosaArg],
101103
output: TosaArg,
102104
) -> None:
105+
106+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107+
103108
# Specification (0.80) states that input and output types
104109
# should all be the same
105110
if not (inputs[0].dtype == output.dtype):
@@ -129,3 +134,122 @@ def define_node(
129134
[output.name],
130135
None,
131136
)
137+
138+
139+
@register_node_visitor
140+
class AbsVisitor_INT(NodeVisitor):
141+
target = "aten.abs.default"
142+
143+
tosa_specs = [
144+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
145+
]
146+
147+
def __init__(self, *args):
148+
super().__init__(*args)
149+
150+
def define_node(
151+
self,
152+
node: Node,
153+
tosa_graph: Any,
154+
inputs: List[TosaArg],
155+
output: TosaArg,
156+
) -> None:
157+
158+
import serializer.tosa_serializer as ts # type: ignore
159+
160+
# Specification (1.0) states that input and output types
161+
# should all be the same
162+
if not (inputs[0].dtype == output.dtype):
163+
raise ValueError(
164+
"All inputs and outputs need same dtype."
165+
f"Got {inputs[0].dtype=}, {output.dtype=}"
166+
)
167+
# Handle int8 (quantized) and int32
168+
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
169+
raise ValueError(
170+
"All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}"
171+
)
172+
173+
scale_back = 1.0
174+
if inputs[0].dtype == ts.DType.INT8:
175+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
176+
tosa_graph, inputs, node, self.tosa_specs
177+
) # type: ignore[possibly-undefined]
178+
else:
179+
# input[0].dtype == ts.DType.INT32
180+
# Non quantized input, natively support by TOSA.abs
181+
rescaled_inputs = inputs
182+
183+
if output.dtype == ts.DType.INT8:
184+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
185+
abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
186+
else:
187+
# output.dtype == ts.DType.INT32
188+
abs_output = output
189+
190+
# Do the INT32 Abs
191+
tosa_graph.addOperator(
192+
ts.TosaOp.Op().ABS,
193+
[
194+
rescaled_inputs[0].name,
195+
],
196+
[abs_output.name],
197+
None,
198+
)
199+
200+
if output.dtype == ts.DType.INT8:
201+
# Scale output back to 8 bit
202+
# pyre-ignore
203+
tqutils.insert_rescale_op_to_int8(
204+
tosa_graph, abs_output, scale_back, node, self.tosa_specs
205+
) # type: ignore[possibly-undefined]
206+
207+
208+
@register_node_visitor
209+
class AbsVisitor_FP(AbsVisitor_INT):
210+
# inheriting 'target' from BI class
211+
212+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
213+
214+
def __init__(self, *args):
215+
super().__init__(*args)
216+
217+
def define_node(
218+
self,
219+
node: Node,
220+
tosa_graph: Any,
221+
inputs: List[TosaArg],
222+
output: TosaArg,
223+
) -> None:
224+
225+
import serializer.tosa_serializer as ts # type: ignore
226+
227+
# Specification (1.0) states that input and output types
228+
# should all be the same
229+
if not (inputs[0].dtype == output.dtype):
230+
raise ValueError(
231+
"All inputs and output need same dtype."
232+
f"Got {inputs[0].dtype=}, {output.dtype=}"
233+
)
234+
235+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
236+
# Call the inherited define_node for handling integers
237+
super().define_node(node, tosa_graph, inputs, output)
238+
else:
239+
# FP32 Abs lowering
240+
241+
if not (inputs[0].dtype == ts.DType.FP32):
242+
raise ValueError(
243+
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
244+
)
245+
246+
if not (output.dtype == ts.DType.FP32):
247+
raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}")
248+
249+
# MI lowering
250+
tosa_graph.addOperator(
251+
ts.TosaOp.Op().ABS,
252+
[inputs[0].name],
253+
[output.name],
254+
None,
255+
)

backends/arm/operators/op_add.py

Lines changed: 133 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55

66
# pyre-unsafe
77

8-
from typing import List
8+
from typing import Any, List
99

1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111
import executorch.backends.arm.tosa_utils as tutils
1212

13-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1413
from executorch.backends.arm.operators.node_visitor import (
1514
NodeVisitor,
1615
register_node_visitor,
@@ -34,10 +33,13 @@ def __init__(self, *args):
3433
def define_node(
3534
self,
3635
node: Node,
37-
tosa_graph: ts.TosaSerializer,
36+
tosa_graph: Any,
3837
inputs: List[TosaArg],
3938
output: TosaArg,
4039
) -> None:
40+
41+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
42+
4143
# Specification (0.80) states that input and output types
4244
# should all be the same
4345
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -58,7 +60,7 @@ def define_node(
5860
if len(inputs[0].shape) > len(inputs[1].shape)
5961
else inputs[1].dim_order
6062
)
61-
63+
scale_back = 1.0
6264
if inputs[0].dtype == ts.DType.INT8:
6365
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
6466
tosa_graph, inputs, node
@@ -90,7 +92,9 @@ def define_node(
9092
if output.dtype == ts.DType.INT8:
9193
# Scale output back to 8 bit
9294
# pyre-ignore
93-
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) # type: ignore[possibly-undefined]
95+
tqutils.insert_rescale_op_to_int8(
96+
tosa_graph, add_output, scale_back, node
97+
) # type: ignore[possibly-undefined]
9498

9599

96100
@register_node_visitor
@@ -107,10 +111,13 @@ def __init__(self, *args):
107111
def define_node(
108112
self,
109113
node: Node,
110-
tosa_graph: ts.TosaSerializer,
114+
tosa_graph: Any,
111115
inputs: List[TosaArg],
112116
output: TosaArg,
113117
) -> None:
118+
119+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
120+
114121
# Specification (0.80) states that input and output types
115122
# should all be the same
116123
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -130,7 +137,7 @@ def define_node(
130137
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
131138
)
132139

133-
input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs)
140+
input1, input2 = inputs
134141

135142
# MI lowering
136143
tosa_graph.addOperator(
@@ -139,3 +146,122 @@ def define_node(
139146
[output.name],
140147
None,
141148
)
149+
150+
151+
@register_node_visitor
152+
class AddVisitor_INT(NodeVisitor):
153+
target = "aten.add.Tensor"
154+
155+
tosa_specs = [
156+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
157+
]
158+
159+
def __init__(self, *args):
160+
super().__init__(*args)
161+
162+
def define_node(
163+
self,
164+
node: Node,
165+
tosa_graph: Any,
166+
inputs: List[TosaArg],
167+
output: TosaArg,
168+
) -> None:
169+
170+
import serializer.tosa_serializer as ts # type: ignore
171+
172+
# Specification (1.0) states that input and output types
173+
# should all be the same
174+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
175+
raise TypeError(
176+
f"All IO needs to have the same data type, got input 1: "
177+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
178+
f"{output.dtype}"
179+
)
180+
# Handle int8 (quantized) and int32
181+
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
182+
if inputs[0].dtype not in supported_dtypes:
183+
raise TypeError(
184+
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
185+
)
186+
scale_back = 1.0
187+
if inputs[0].dtype == ts.DType.INT8:
188+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
189+
tosa_graph, inputs, node, self.tosa_specs
190+
)
191+
else:
192+
# input[0].dtype == ts.DType.INT32
193+
# Non quantized input, natively support by TOSA.ADD
194+
rescaled_inputs = inputs
195+
196+
if output.dtype == ts.DType.INT8:
197+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
198+
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
199+
else:
200+
# output.dtype == ts.DType.INT32
201+
add_output = output
202+
203+
input1, input2 = rescaled_inputs
204+
205+
# Do the INT32 Add
206+
tosa_graph.addOperator(
207+
ts.TosaOp.Op().ADD,
208+
[input1.name, input2.name],
209+
[add_output.name],
210+
None,
211+
)
212+
213+
if output.dtype == ts.DType.INT8:
214+
# Scale output back to 8 bit
215+
# pyre-ignore
216+
tqutils.insert_rescale_op_to_int8(
217+
tosa_graph, add_output, scale_back, node, self.tosa_specs
218+
) # type: ignore[possibly-undefined]
219+
220+
221+
@register_node_visitor
222+
class AddVisitor_FP(AddVisitor_INT):
223+
# inheriting 'target' from INT class
224+
225+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
226+
227+
def __init__(self, *args):
228+
super().__init__(*args)
229+
230+
def define_node(
231+
self,
232+
node: Node,
233+
tosa_graph: Any,
234+
inputs: List[TosaArg],
235+
output: TosaArg,
236+
) -> None:
237+
238+
import serializer.tosa_serializer as ts # type: ignore
239+
240+
# Specification (1.0) states that input and output types
241+
# should all be the same
242+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
243+
raise TypeError(
244+
f"All IO needs to have the same data type, got input 1: "
245+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
246+
f"{output.dtype}"
247+
)
248+
249+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
250+
# Call the inherited define_node for handling integers
251+
super().define_node(node, tosa_graph, inputs, output)
252+
else:
253+
# FP32 Add lowering
254+
if inputs[0].dtype != ts.DType.FP32:
255+
raise TypeError(
256+
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
257+
)
258+
259+
input1, input2 = inputs
260+
261+
# FP lowering
262+
tosa_graph.addOperator(
263+
ts.TosaOp.Op().ADD,
264+
[input1.name, input2.name],
265+
[output.name],
266+
None,
267+
)

0 commit comments

Comments
 (0)