Skip to content

Commit 29b7176

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
buckify g3 targets, fix issues in quant, dequant, softmax, Replace Bits16 with UInt16 (#7061)
Summary: Pull Request resolved: #7061 update targets in fallback, fixed inherent issues from G3 PR. G3 op status page: https://docs.google.com/document/d/1ZRW6Uoq_NhpVCSH4y-t3Bl2pQZiKXMzSNT5XgrbE0fM/edit?tab=t.0 included D66834249, D66681284 Reviewed By: hsharma35 Differential Revision: D66398494
1 parent f6bfa21 commit 29b7176

File tree

6 files changed

+85
-23
lines changed

6 files changed

+85
-23
lines changed

backends/cadence/aot/functions_fusion_g3.yaml

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
- op: _softmax.out
2121
kernels:
2222
- arg_meta: null
23-
kernel_name: cadence::impl::G3::softmax_out
23+
kernel_name: cadence::impl::G3::_softmax_out
2424

2525
- op: add.out
2626
kernels:
@@ -71,7 +71,7 @@
7171
kernels:
7272
- arg_meta: null
7373
kernel_name: cadence::impl::G3::mul_out
74-
74+
7575
- op: mul.Scalar_out
7676
kernels:
7777
- arg_meta: null
@@ -111,8 +111,21 @@
111111
kernels:
112112
- arg_meta: null
113113
kernel_name: torch::executor::where_out
114-
114+
115115
- op: native_layer_norm.out
116116
kernels:
117117
- arg_meta: null
118-
kernel_name: cadence::impl::G3::native_layer_norm_out
118+
kernel_name: cadence::impl::G3::native_layer_norm_out
119+
120+
# custom ops
121+
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
122+
variants: function
123+
kernels:
124+
- arg_meta: null
125+
kernel_name: cadence::impl::G3::native::quantize_per_tensor_out
126+
127+
- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
128+
variants: function
129+
kernels:
130+
- arg_meta: null
131+
kernel_name: cadence::impl::G3::native::dequantize_per_tensor_out
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
load("targets.bzl", "define_common_targets")
2+
3+
oncall("odai_jarvis")
4+
5+
define_common_targets()

backends/cadence/fusion_g3/operators/op_dequantize.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void check_dequantize_per_tensor_args(
5252
ET_CHECK_MSG(
5353
input.scalar_type() == ScalarType::Byte ||
5454
input.scalar_type() == ScalarType::Char ||
55-
input.scalar_type() == ScalarType::Bits16 ||
55+
input.scalar_type() == ScalarType::UInt16 ||
5656
input.scalar_type() == ScalarType::Short ||
5757
input.scalar_type() == (ScalarType)Ushort ||
5858
input.scalar_type() == (ScalarType)Bits4 ||
@@ -83,7 +83,7 @@ void check_dequantize_per_tensor_args(
8383
} // namespace
8484

8585
/* Local function which calls the kernels based on the input datatype */
86-
void Dequantize_impl(
86+
void dequantize_impl(
8787
Tensor& out,
8888
const Tensor& input,
8989
float* scale_data,
@@ -211,7 +211,7 @@ void Dequantize_impl(
211211
break;
212212
switch (input.scalar_type()) {
213213
ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_TENSOR);
214-
ASYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16);
214+
ASYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, UInt16);
215215
default:
216216
ET_CHECK_MSG(
217217
false,
@@ -302,7 +302,7 @@ void Dequantize_impl(
302302
break;
303303
switch (input.scalar_type()) {
304304
ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_CHANNEL);
305-
ASYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16);
305+
ASYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, UInt16);
306306
default:
307307
ET_CHECK_MSG(
308308
false,
@@ -368,7 +368,7 @@ void Dequantize_impl(
368368
break;
369369
switch (input.scalar_type()) {
370370
ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_TENSOR);
371-
SYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16);
371+
SYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, UInt16);
372372
default:
373373
ET_CHECK_MSG(
374374
false,
@@ -459,7 +459,7 @@ void Dequantize_impl(
459459
break;
460460
switch (input.scalar_type()) {
461461
ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_CHANNEL);
462-
SYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16);
462+
SYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, UInt16);
463463
default:
464464
ET_CHECK_MSG(
465465
false,
@@ -502,7 +502,7 @@ Tensor& dequantize_per_tensor_out(
502502
float scale_data = (float)scale;
503503
int zero_point_data = (int)zero_point;
504504

505-
Dequantize_impl(out, input, &scale_data, &zero_point_data, NULL, out_dtype);
505+
dequantize_impl(out, input, &scale_data, &zero_point_data, NULL, out_dtype);
506506

507507
return out;
508508
}
@@ -620,7 +620,7 @@ Tensor& dequantize_per_channel_out(
620620
for (int i = 0; i < scale.numel(); i++) {
621621
scale_data[i] = (float)scale_dt[i];
622622
}
623-
Dequantize_impl(out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
623+
dequantize_impl(out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
624624

625625
return out;
626626
}
@@ -661,13 +661,19 @@ Tensor& dequantize_per_tensor_out(
661661
int64_t quant_min,
662662
int64_t quant_max,
663663
ScalarType dtype,
664-
exec_aten::optional<ScalarType> out_dtype,
665664
Tensor& out) {
666665
// TODO(larryliu): Add a context arg to the real op function and remove this
667666
// wrapper
668667
(void)context;
669668
return dequantize_per_tensor_out(
670-
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
669+
input,
670+
scale,
671+
zero_point,
672+
quant_min,
673+
quant_max,
674+
dtype,
675+
out.scalar_type(),
676+
out);
671677
}
672678

673679
Tensor& dequantize_per_tensor_tensor_args_out(
@@ -764,4 +770,4 @@ Tensor& dequantize_per_token_out(
764770
} // namespace native
765771
} // namespace G3
766772
} // namespace impl
767-
} // namespace cadence
773+
} // namespace cadence

backends/cadence/fusion_g3/operators/op_quantize.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ void check_quantize_per_tensor_args(
6969
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
7070
quant_max_upper_bound =
7171
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
72-
} else if (dtype == ScalarType::Bits16) {
72+
} else if (dtype == ScalarType::UInt16) {
7373
quant_min_lower_bound = std::numeric_limits<uint16_t>::min();
7474
quant_max_upper_bound = std::numeric_limits<uint16_t>::max();
7575
} else if (dtype == ScalarType::Short) {
@@ -271,7 +271,7 @@ void quantize_impl(
271271
case ScalarType::in_dtype: \
272272
switch (out.scalar_type()) { \
273273
ET_FORALL_INT_TYPES_WITH(IN_CTYPE, ASYM_QUANTIZE_IMPL_TENSOR); \
274-
ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \
274+
ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, UInt16) \
275275
default: \
276276
ET_CHECK_MSG( \
277277
false, \
@@ -343,7 +343,7 @@ void quantize_impl(
343343
case ScalarType::in_dtype: \
344344
switch (out.scalar_type()) { \
345345
ET_FORALL_INT_TYPES_WITH(CTYPE_IN, ASYM_QUANTIZE_IMPL_CHANNEL); \
346-
ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \
346+
ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, UInt16) \
347347
default: \
348348
ET_CHECK_MSG( \
349349
false, \
@@ -458,7 +458,7 @@ void quantize_impl(
458458
case ScalarType::in_dtype: \
459459
switch (out.scalar_type()) { \
460460
ET_FORALL_INT_TYPES_WITH(IN_CTYPE, SYM_QUANTIZE_IMPL_TENSOR); \
461-
SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \
461+
SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, UInt16) \
462462
default: \
463463
ET_CHECK_MSG( \
464464
false, \
@@ -529,7 +529,7 @@ void quantize_impl(
529529
case ScalarType::in_dtype: \
530530
switch (out.scalar_type()) { \
531531
ET_FORALL_INT_TYPES_WITH(CTYPE_IN, SYM_QUANTIZE_IMPL_CHANNEL); \
532-
SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \
532+
SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, UInt16) \
533533
default: \
534534
ET_CHECK_MSG( \
535535
false, \
@@ -803,4 +803,4 @@ Tensor& quantize_per_token_out(
803803
} // namespace native
804804
} // namespace G3
805805
} // namespace impl
806-
} // namespace cadence
806+
} // namespace cadence

backends/cadence/fusion_g3/operators/op_softmax.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace impl {
2424
namespace G3 {
2525
namespace native {
2626

27-
Tensor& softmax_out(
27+
Tensor& _softmax_out(
2828
KernelRuntimeContext& ctx,
2929
const Tensor& in,
3030
int64_t dim,
@@ -112,4 +112,4 @@ Tensor& softmax_out(
112112
} // namespace native
113113
} // namespace G3
114114
} // namespace impl
115-
} // namespace cadence
115+
} // namespace cadence
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX")
2+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
3+
4+
def define_common_targets():
5+
"""Defines targets that should be shared between fbcode and xplat.
6+
7+
The directory containing this targets.bzl file should also contain both
8+
TARGETS and BUCK files that call this function.
9+
"""
10+
11+
# Define build targets for all operators registered in the tables above.
12+
13+
runtime.cxx_library(
14+
name = "cadence_g3_ops",
15+
srcs = glob([
16+
"*.cpp",
17+
]),
18+
exported_headers = glob([
19+
"*.h",
20+
]),
21+
platforms = CXX,
22+
deps = [
23+
"//executorch/kernels/portable/cpu/util:all_deps",
24+
"//executorch/kernels/portable/cpu/pattern:all_deps",
25+
"//executorch/runtime/kernel:kernel_includes",
26+
"//executorch/kernels/portable/cpu:scalar_utils",
27+
"fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib_common",
28+
"fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib",
29+
],
30+
visibility = [
31+
"//executorch/backends/cadence/...",
32+
"@EXECUTORCH_CLIENTS",
33+
],
34+
exported_deps = [
35+
"fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib_common",
36+
"fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib",
37+
],
38+
)

0 commit comments

Comments
 (0)