Skip to content

Commit d4a7d55

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
buckify g3 targets, fix issues in quant, dequant, softmax (#7061)
Summary: update targets in fallback, fixed inherent issues from G3 PR. G3 op status page: https://docs.google.com/document/d/1ZRW6Uoq_NhpVCSH4y-t3Bl2pQZiKXMzSNT5XgrbE0fM/edit?tab=t.0 Reviewed By: hsharma35 Differential Revision: D66398494
1 parent ec68eb3 commit d4a7d55

File tree

6 files changed

+74
-15
lines changed

6 files changed

+74
-15
lines changed

backends/cadence/aot/functions_fusion_g3.yaml

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
- op: _softmax.out
2121
kernels:
2222
- arg_meta: null
23-
kernel_name: cadence::impl::G3::softmax_out
23+
kernel_name: cadence::impl::G3::_softmax_out
2424

2525
- op: add.out
2626
kernels:
@@ -71,7 +71,7 @@
7171
kernels:
7272
- arg_meta: null
7373
kernel_name: cadence::impl::G3::mul_out
74-
74+
7575
- op: mul.Scalar_out
7676
kernels:
7777
- arg_meta: null
@@ -111,8 +111,21 @@
111111
kernels:
112112
- arg_meta: null
113113
kernel_name: torch::executor::where_out
114-
114+
115115
- op: native_layer_norm.out
116116
kernels:
117117
- arg_meta: null
118-
kernel_name: cadence::impl::G3::native_layer_norm_out
118+
kernel_name: cadence::impl::G3::native_layer_norm_out
119+
120+
# custom ops
121+
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
122+
variants: function
123+
kernels:
124+
- arg_meta: null
125+
kernel_name: cadence::impl::G3::native::quantize_per_tensor_out
126+
127+
- func: cadence::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
128+
variants: function
129+
kernels:
130+
- arg_meta: null
131+
kernel_name: cadence::impl::G3::native::dequantize_per_tensor_out
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
load("targets.bzl", "define_common_targets")
2+
3+
oncall("odai_jarvis")
4+
5+
define_common_targets()

backends/cadence/fusion_g3/operators/op_dequantize.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ void check_dequantize_per_tensor_args(
8383
} // namespace
8484

8585
/* Local function which calls the kernels based on the input datatype */
86-
void Dequantize_impl(
86+
void dequantize_impl(
8787
Tensor& out,
8888
const Tensor& input,
8989
float* scale_data,
@@ -502,7 +502,7 @@ Tensor& dequantize_per_tensor_out(
502502
float scale_data = (float)scale;
503503
int zero_point_data = (int)zero_point;
504504

505-
Dequantize_impl(out, input, &scale_data, &zero_point_data, NULL, out_dtype);
505+
dequantize_impl(out, input, &scale_data, &zero_point_data, NULL, out_dtype);
506506

507507
return out;
508508
}
@@ -620,7 +620,7 @@ Tensor& dequantize_per_channel_out(
620620
for (int i = 0; i < scale.numel(); i++) {
621621
scale_data[i] = (float)scale_dt[i];
622622
}
623-
Dequantize_impl(out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
623+
dequantize_impl(out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
624624

625625
return out;
626626
}
@@ -661,13 +661,19 @@ Tensor& dequantize_per_tensor_out(
661661
int64_t quant_min,
662662
int64_t quant_max,
663663
ScalarType dtype,
664-
exec_aten::optional<ScalarType> out_dtype,
665664
Tensor& out) {
666665
// TODO(larryliu): Add a context arg to the real op function and remove this
667666
// wrapper
668667
(void)context;
669668
return dequantize_per_tensor_out(
670-
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
669+
input,
670+
scale,
671+
zero_point,
672+
quant_min,
673+
quant_max,
674+
dtype,
675+
out.scalar_type(),
676+
out);
671677
}
672678

673679
Tensor& dequantize_per_tensor_tensor_args_out(
@@ -764,4 +770,4 @@ Tensor& dequantize_per_token_out(
764770
} // namespace native
765771
} // namespace G3
766772
} // namespace impl
767-
} // namespace cadence
773+
} // namespace cadence

backends/cadence/fusion_g3/operators/op_quantize.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ enum datatype { Ushort = 20, Bits4u = 21, Bits4 = 22 };
3131
*/
3232
namespace cadence {
3333
namespace impl {
34-
namespace FusionG3 {
34+
namespace G3 {
3535
namespace native {
3636

3737
namespace {
@@ -802,6 +802,6 @@ Tensor& quantize_per_token_out(
802802
}
803803

804804
} // namespace native
805-
} // namespace FusionG3
805+
} // namespace G3
806806
} // namespace impl
807-
} // namespace cadence
807+
} // namespace cadence

backends/cadence/fusion_g3/operators/op_softmax.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace impl {
2424
namespace G3 {
2525
namespace native {
2626

27-
Tensor& softmax_out(
27+
Tensor& _softmax_out(
2828
KernelRuntimeContext& ctx,
2929
const Tensor& in,
3030
int64_t dim,
@@ -112,4 +112,4 @@ Tensor& softmax_out(
112112
} // namespace native
113113
} // namespace G3
114114
} // namespace impl
115-
} // namespace cadence
115+
} // namespace cadence
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
load("@fbsource//tools/build_defs:platform_defs.bzl", "CXX")
2+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
3+
4+
def define_common_targets():
5+
"""Defines targets that should be shared between fbcode and xplat.
6+
7+
The directory containing this targets.bzl file should also contain both
8+
TARGETS and BUCK files that call this function.
9+
"""
10+
11+
# Define build targets for all operators registered in the tables above.
12+
13+
runtime.cxx_library(
14+
name = "cadence_g3_ops",
15+
srcs = glob([
16+
"*.cpp",
17+
]),
18+
platforms = CXX,
19+
deps = [
20+
"//executorch/kernels/portable/cpu/util:all_deps",
21+
"//executorch/kernels/portable/cpu/pattern:all_deps",
22+
"//executorch/runtime/kernel:kernel_includes",
23+
"//executorch/kernels/portable/cpu:scalar_utils",
24+
"fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib_common",
25+
"fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib",
26+
],
27+
visibility = [
28+
"//executorch/backends/cadence/...",
29+
"@EXECUTORCH_CLIENTS",
30+
],
31+
exported_deps = [
32+
"fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib_common",
33+
"fbsource//third-party/nnlib-FusionG3/xa_nnlib:libxa_nnlib",
34+
],
35+
)

0 commit comments

Comments
 (0)