Skip to content

Commit 7ecbbfd

Browse files
committed
Update base for Update on "[ET-VK] Adding UniformData struct in vTensor class to store uniform data, which will be stored using shared ptr and can be shared with push constants."
This diff adds a new struct called `UniformData` in the `vTensor` class to store uniform data, which can be shared with push constants. The `UniformData` struct contains the sizes, strides, and logical limits of the tensor, as well as the number of elements in the tensor. Diff adds `Attribute` enum to Tensor class to enumerate attributes supplied to dispatch and `UniformData` class to store tensor data supplied as uniforms to op shaders. The diff also adds write_attribute function to UniformData class to write attribute data to a given memory. Differential Revision: [D66733611](https://our.internmc.facebook.com/intern/diff/D66733611/) [ghstack-poisoned]
2 parents 83a2671 + ec56da8 commit 7ecbbfd

File tree

30 files changed

+985
-191
lines changed

30 files changed

+985
-191
lines changed

backends/arm/test/conftest.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import shutil
1212
import subprocess
1313
import sys
14-
from enum import auto, Enum
1514
from typing import Any
1615

1716
import pytest
@@ -22,30 +21,24 @@
2221
"""
2322

2423

25-
class arm_test_options(Enum):
26-
quantize_io = auto()
27-
corstone_fvp = auto()
28-
fast_fvp = auto()
29-
30-
31-
_test_options: dict[arm_test_options, Any] = {}
32-
3324
# ==== Pytest hooks ====
3425

3526

3627
def pytest_configure(config):
28+
pytest._test_options = {}
29+
3730
if config.option.arm_quantize_io:
3831
_load_libquantized_ops_aot_lib()
39-
_test_options[arm_test_options.quantize_io] = True
32+
pytest._test_options["quantize_io"] = True
4033
if config.option.arm_run_corstoneFVP:
4134
corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55")
4235
corstone320_exists = shutil.which("FVP_Corstone_SSE-320")
4336
if not (corstone300_exists and corstone320_exists):
4437
raise RuntimeError(
4538
"Tests are run with --arm_run_corstoneFVP but corstone FVP is not installed."
4639
)
47-
_test_options[arm_test_options.corstone_fvp] = True
48-
_test_options[arm_test_options.fast_fvp] = config.option.fast_fvp
40+
pytest._test_options["corstone_fvp"] = True
41+
pytest._test_options["fast_fvp"] = config.option.fast_fvp
4942
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
5043

5144

@@ -131,9 +124,7 @@ def expectedFailureOnFVP(test_item):
131124
# ==== End of Custom Pytest decorators =====
132125

133126

134-
def is_option_enabled(
135-
option: str | arm_test_options, fail_if_not_enabled: bool = False
136-
) -> bool:
127+
def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool:
137128
"""
138129
Returns whether an option is successfully enabled, i.e. if the flag was
139130
given to pytest and the necessary requirements are available.
@@ -144,10 +135,8 @@ def is_option_enabled(
144135
The optional parameter 'fail_if_not_enabled' makes the function raise
145136
a RuntimeError instead of returning False.
146137
"""
147-
if isinstance(option, str):
148-
option = arm_test_options[option.lower()]
149138

150-
if option in _test_options and _test_options[option]:
139+
if option in pytest._test_options and pytest._test_options[option]:
151140
return True
152141
else:
153142
if fail_if_not_enabled:
@@ -156,15 +145,15 @@ def is_option_enabled(
156145
return False
157146

158147

159-
def get_option(option: arm_test_options) -> Any | None:
148+
def get_option(option: str) -> Any | None:
160149
"""
161150
Returns the value of an pytest option if it is set, otherwise None.
162151
163152
Args:
164-
option (arm_test_options): The option to check for.
153+
option (str): The option to check for.
165154
"""
166-
if option in _test_options:
167-
return _test_options[option]
155+
if option in pytest._test_options:
156+
return pytest._test_options[option]
168157
return None
169158

170159

backends/arm/test/ops/test_depthwise_conv.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,11 @@
156156
("two_dw_conv2d", two_dw_conv2d),
157157
]
158158

159-
testsuite_conv2d_u85 = [
159+
testsuite_conv2d_u85_xfails = [
160160
("2x2_1x6x4x4_gp6_st1", dw_conv2d_2x2_1x6x4x4_gp6_st1),
161161
("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1),
162162
("3x3_1x4x256x256_gp4_st1", dw_conv2d_3x3_1x4x256x256_gp4_st1),
163163
("3x3_1x4x256x256_gp4_nobias", dw_conv2d_3x3_1x4x256x256_gp4_nobias),
164-
]
165-
166-
testsuite_conv2d_u85_xfails = [
167164
("3x3_2x8x198x198_gp8_st3", dw_conv2d_3x3_2x8x198x198_gp8_st3),
168165
("two_dw_conv2d", two_dw_conv2d),
169166
]
@@ -287,7 +284,7 @@ def test_dw_conv1d_u55_BI(
287284
model.get_inputs(),
288285
)
289286

290-
@parameterized.expand(testsuite_conv1d + testsuite_conv2d_u85)
287+
@parameterized.expand(testsuite_conv1d[2:])
291288
def test_dw_conv_u85_BI(
292289
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False
293290
):
@@ -299,8 +296,12 @@ def test_dw_conv_u85_BI(
299296
model.get_inputs(),
300297
)
301298

299+
testsuite_conv2d_u85_xfails.remove(
300+
("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1)
301+
) # Works
302+
302303
# All test cases except 3x3_1x3x256x256_gp3_st1 have numerical issues on FVP. MLETORCH-520
303-
@parameterized.expand(testsuite_conv2d_u85_xfails)
304+
@parameterized.expand(testsuite_conv2d_u85_xfails + testsuite_conv1d[:2])
304305
@conftest.expectedFailureOnFVP
305306
def test_dw_conv_u85_BI_xfails(
306307
self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False

backends/arm/test/ops/test_div.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -183,21 +183,8 @@ def test_div_tosa_BI(
183183
test_data = (input_, other_)
184184
self._test_div_tosa_BI_pipeline(self.Div(), test_data)
185185

186-
@parameterized.expand(test_data_suite[:2])
187-
def test_div_u55_BI(
188-
self,
189-
test_name: str,
190-
input_: Union[torch.Tensor, torch.types.Number],
191-
other_: Union[torch.Tensor, torch.types.Number],
192-
rounding_mode: Optional[str] = None,
193-
):
194-
test_data = (input_, other_)
195-
self._test_div_ethos_BI_pipeline(
196-
self.Div(), common.get_u55_compile_spec(), test_data
197-
)
198-
199186
# Numerical issues on FVP likely due to mul op, MLETORCH-521
200-
@parameterized.expand(test_data_suite[2:])
187+
@parameterized.expand(test_data_suite)
201188
@conftest.expectedFailureOnFVP
202189
def test_div_u55_BI_xfails(
203190
self,
@@ -211,21 +198,8 @@ def test_div_u55_BI_xfails(
211198
self.Div(), common.get_u55_compile_spec(), test_data
212199
)
213200

214-
@parameterized.expand(test_data_suite[:2])
215-
def test_div_u85_BI(
216-
self,
217-
test_name: str,
218-
input_: Union[torch.Tensor, torch.types.Number],
219-
other_: Union[torch.Tensor, torch.types.Number],
220-
rounding_mode: Optional[str] = None,
221-
):
222-
test_data = (input_, other_)
223-
self._test_div_ethos_BI_pipeline(
224-
self.Div(), common.get_u85_compile_spec(), test_data
225-
)
226-
227201
# Numerical issues on FVP likely due to mul op, MLETORCH-521
228-
@parameterized.expand(test_data_suite[2:])
202+
@parameterized.expand(test_data_suite)
229203
@conftest.expectedFailureOnFVP
230204
def test_div_u85_BI_xfails(
231205
self,

backends/arm/test/ops/test_mul.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ def test_mul_tosa_BI(
152152
test_data = (input_, other_)
153153
self._test_mul_tosa_BI_pipeline(self.Mul(), test_data)
154154

155+
# Numerical issues on FVP, MLETORCH-521
155156
@parameterized.expand(test_data_sute)
157+
@conftest.expectedFailureOnFVP
156158
def test_mul_u55_BI(
157159
self,
158160
test_name: str,
@@ -164,7 +166,10 @@ def test_mul_u55_BI(
164166
common.get_u55_compile_spec(), self.Mul(), test_data
165167
)
166168

167-
@parameterized.expand(test_data_sute)
169+
# Numerical issues on FVP, MLETORCH-521
170+
# test_data_sute[0] works on U85
171+
@parameterized.expand(test_data_sute[1:])
172+
@conftest.expectedFailureOnFVP
168173
def test_mul_u85_BI(
169174
self,
170175
test_name: str,

backends/arm/test/runner_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818
import torch
1919

20-
from executorch.backends.arm.test.conftest import arm_test_options, is_option_enabled
20+
from executorch.backends.arm.test.conftest import is_option_enabled
2121

2222
from torch.export import ExportedProgram
2323
from torch.fx.node import Node
@@ -251,7 +251,7 @@ def run_corstone(
251251
cmd_line += f" -i {input_path}"
252252

253253
ethos_u_extra_args = ""
254-
if is_option_enabled(arm_test_options.fast_fvp):
254+
if is_option_enabled("fast_fvp"):
255255
ethos_u_extra_args = ethos_u_extra_args + "--fast"
256256

257257
command_args = {

backends/cadence/aot/functions_fusion_g3.yaml

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
- op: _softmax.out
2121
kernels:
2222
- arg_meta: null
23-
kernel_name: cadence::impl::G3::softmax_out
23+
kernel_name: cadence::impl::G3::_softmax_out
2424

2525
- op: add.out
2626
kernels:
@@ -71,7 +71,7 @@
7171
kernels:
7272
- arg_meta: null
7373
kernel_name: cadence::impl::G3::mul_out
74-
74+
7575
- op: mul.Scalar_out
7676
kernels:
7777
- arg_meta: null
@@ -111,8 +111,21 @@
111111
kernels:
112112
- arg_meta: null
113113
kernel_name: torch::executor::where_out
114-
114+
115115
- op: native_layer_norm.out
116116
kernels:
117117
- arg_meta: null
118-
kernel_name: cadence::impl::G3::native_layer_norm_out
118+
kernel_name: cadence::impl::G3::native_layer_norm_out
119+
120+
# custom ops
121+
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
122+
variants: function
123+
kernels:
124+
- arg_meta: null
125+
kernel_name: cadence::impl::G3::native::quantize_per_tensor_out
126+
127+
- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
128+
variants: function
129+
kernels:
130+
- arg_meta: null
131+
kernel_name: cadence::impl::G3::native::dequantize_per_tensor_out
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
load("targets.bzl", "define_common_targets")
2+
3+
oncall("odai_jarvis")
4+
5+
define_common_targets()

backends/cadence/fusion_g3/operators/op_dequantize.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void check_dequantize_per_tensor_args(
5252
ET_CHECK_MSG(
5353
input.scalar_type() == ScalarType::Byte ||
5454
input.scalar_type() == ScalarType::Char ||
55-
input.scalar_type() == ScalarType::Bits16 ||
55+
input.scalar_type() == ScalarType::UInt16 ||
5656
input.scalar_type() == ScalarType::Short ||
5757
input.scalar_type() == (ScalarType)Ushort ||
5858
input.scalar_type() == (ScalarType)Bits4 ||
@@ -83,7 +83,7 @@ void check_dequantize_per_tensor_args(
8383
} // namespace
8484

8585
/* Local function which calls the kernels based on the input datatype */
86-
void Dequantize_impl(
86+
void dequantize_impl(
8787
Tensor& out,
8888
const Tensor& input,
8989
float* scale_data,
@@ -211,7 +211,7 @@ void Dequantize_impl(
211211
break;
212212
switch (input.scalar_type()) {
213213
ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_TENSOR);
214-
ASYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16);
214+
ASYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, UInt16);
215215
default:
216216
ET_CHECK_MSG(
217217
false,
@@ -302,7 +302,7 @@ void Dequantize_impl(
302302
break;
303303
switch (input.scalar_type()) {
304304
ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_CHANNEL);
305-
ASYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16);
305+
ASYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, UInt16);
306306
default:
307307
ET_CHECK_MSG(
308308
false,
@@ -368,7 +368,7 @@ void Dequantize_impl(
368368
break;
369369
switch (input.scalar_type()) {
370370
ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_TENSOR);
371-
SYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16);
371+
SYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, UInt16);
372372
default:
373373
ET_CHECK_MSG(
374374
false,
@@ -459,7 +459,7 @@ void Dequantize_impl(
459459
break;
460460
switch (input.scalar_type()) {
461461
ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_CHANNEL);
462-
SYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16);
462+
SYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, UInt16);
463463
default:
464464
ET_CHECK_MSG(
465465
false,
@@ -502,7 +502,7 @@ Tensor& dequantize_per_tensor_out(
502502
float scale_data = (float)scale;
503503
int zero_point_data = (int)zero_point;
504504

505-
Dequantize_impl(out, input, &scale_data, &zero_point_data, NULL, out_dtype);
505+
dequantize_impl(out, input, &scale_data, &zero_point_data, NULL, out_dtype);
506506

507507
return out;
508508
}
@@ -620,7 +620,7 @@ Tensor& dequantize_per_channel_out(
620620
for (int i = 0; i < scale.numel(); i++) {
621621
scale_data[i] = (float)scale_dt[i];
622622
}
623-
Dequantize_impl(out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
623+
dequantize_impl(out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
624624

625625
return out;
626626
}
@@ -661,13 +661,19 @@ Tensor& dequantize_per_tensor_out(
661661
int64_t quant_min,
662662
int64_t quant_max,
663663
ScalarType dtype,
664-
exec_aten::optional<ScalarType> out_dtype,
665664
Tensor& out) {
666665
// TODO(larryliu): Add a context arg to the real op function and remove this
667666
// wrapper
668667
(void)context;
669668
return dequantize_per_tensor_out(
670-
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
669+
input,
670+
scale,
671+
zero_point,
672+
quant_min,
673+
quant_max,
674+
dtype,
675+
out.scalar_type(),
676+
out);
671677
}
672678

673679
Tensor& dequantize_per_tensor_tensor_args_out(
@@ -764,4 +770,4 @@ Tensor& dequantize_per_token_out(
764770
} // namespace native
765771
} // namespace G3
766772
} // namespace impl
767-
} // namespace cadence
773+
} // namespace cadence

0 commit comments

Comments
 (0)