Skip to content

Commit c523172

Browse files
zonglinpengDannyYuyang-quic
authored andcommitted
[cadence][g3] link m3 ops and add testcases (pytorch#8824)
Differential Revision: D71155209 Pull Request resolved: pytorch#9289
1 parent 5d86c56 commit c523172

File tree

7 files changed

+43
-25
lines changed

7 files changed

+43
-25
lines changed

backends/cadence/aot/functions_fusion_g3.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
- op: clamp.Tensor_out
5252
kernels:
5353
- arg_meta: null
54-
kernel_name: cadence::impl::G3::clamp_tensor_out
54+
kernel_name: cadence::impl::G3::clamp_Tensor_out
5555

5656
- op: clone.out
5757
kernels:
@@ -81,12 +81,12 @@
8181
- op: lt.Scalar_out
8282
kernels:
8383
- arg_meta: null
84-
kernel_name: cadence::impl::G3::lt_scalar_out
84+
kernel_name: cadence::impl::G3::lt_Scalar_out
8585

8686
- op: lt.Tensor_out
8787
kernels:
8888
- arg_meta: null
89-
kernel_name: cadence::impl::G3::lt_tensor_out
89+
kernel_name: cadence::impl::G3::lt_Tensor_out
9090

9191
- op: mul.out
9292
kernels:
@@ -155,7 +155,7 @@
155155
- op: where.self_out
156156
kernels:
157157
- arg_meta: null
158-
kernel_name: cadence::impl::G3::where_out
158+
kernel_name: cadence::impl::G3::where_self_out
159159

160160
- op: native_layer_norm.out
161161
kernels:

backends/cadence/fusion_g3/operators/op_clamp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ Tensor& clamp_out(
330330
return out;
331331
}
332332

333-
Tensor& clamp_tensor_out(
333+
Tensor& clamp_Tensor_out(
334334
KernelRuntimeContext& ctx,
335335
const Tensor& in,
336336
const optional<Tensor>& min_opt,

backends/cadence/fusion_g3/operators/op_lt.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace impl {
2424
namespace G3 {
2525
namespace native {
2626

27-
Tensor& lt_tensor_out(
27+
Tensor& lt_Tensor_out(
2828
KernelRuntimeContext& ctx,
2929
const Tensor& a,
3030
const Tensor& b,
@@ -141,7 +141,7 @@ Tensor& lt_tensor_out(
141141
return out;
142142
}
143143

144-
Tensor& lt_scalar_out(
144+
Tensor& lt_Scalar_out(
145145
KernelRuntimeContext& ctx,
146146
const Tensor& a,
147147
const Scalar& b,

backends/cadence/fusion_g3/operators/op_permute_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,4 @@ Tensor& permute_copy_out(
157157
} // namespace native
158158
} // namespace G3
159159
} // namespace impl
160-
} // namespace cadence
160+
} // namespace cadence

backends/cadence/fusion_g3/operators/op_where.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace impl {
2424
namespace G3 {
2525
namespace native {
2626

27-
Tensor& where_out(
27+
Tensor& where_self_out(
2828
KernelRuntimeContext& ctx,
2929
const Tensor& cond,
3030
const Tensor& a,

backends/cadence/fusion_g3/operators/targets.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ def define_operator(name: str, deps: list[str] | None = None) -> None:
3535
OPERATORS = [
3636
"add",
3737
"cat",
38+
"clamp",
39+
"lt",
40+
"rsqrt",
41+
"sigmoid",
42+
"sqrt",
43+
"tanh",
44+
"transpose_copy",
45+
"where",
3846
"dequantize",
3947
"mul",
4048
"native_layer_norm",

backends/cadence/utils/facto_util.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,29 @@
2222

2323

2424
def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> None:
25+
additional_tensor_constraints = [
26+
cp.Dtype.In(lambda deps: [torch.int, torch.float]),
27+
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
28+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
29+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
30+
cp.Rank.Ge(lambda deps: 1),
31+
cp.Size.Ge(lambda deps, r, d: 1),
32+
cp.Size.Le(lambda deps, r, d: 2**9),
33+
]
34+
2535
match op_name:
36+
case "where.self":
37+
additional_tensor_constraints = [
38+
cp.Dtype.In(lambda deps: [torch.float, torch.int, torch.bool]),
39+
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
40+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
41+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
42+
cp.Rank.Ge(lambda deps: 1),
43+
cp.Size.Ge(lambda deps, r, d: 1),
44+
cp.Size.Le(lambda deps, r, d: 2**9),
45+
]
2646
case "sigmoid.default" | "rsqrt.default":
27-
tensor_constraints.extend(
47+
additional_tensor_constraints.extend(
2848
[
2949
cp.Dtype.In(lambda deps: [torch.float]),
3050
cp.Rank.Le(lambda deps: 2**2),
@@ -33,45 +53,35 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
3353
]
3454
)
3555
case "mean.dim":
36-
tensor_constraints.extend(
56+
additional_tensor_constraints.extend(
3757
[
3858
cp.Dtype.In(lambda deps: [torch.float]),
3959
cp.Rank.Le(lambda deps: 2**2),
4060
]
4161
)
4262
case "exp.default":
43-
tensor_constraints.extend(
63+
additional_tensor_constraints.extend(
4464
[
4565
cp.Rank.Le(lambda deps: 2**3),
4666
cp.Value.Ge(lambda deps, dtype, struct: -(2**2)),
4767
cp.Value.Le(lambda deps, dtype, struct: 2**2),
4868
]
4969
)
5070
case "slice_copy.Tensor":
51-
tensor_constraints.extend(
71+
additional_tensor_constraints.extend(
5272
[
5373
cp.Rank.Le(lambda deps: 2),
5474
cp.Value.Ge(lambda deps, dtype, struct: 1),
5575
cp.Value.Le(lambda deps, dtype, struct: 2),
5676
]
5777
)
5878
case _:
59-
tensor_constraints.extend(
79+
additional_tensor_constraints.extend(
6080
[
6181
cp.Rank.Le(lambda deps: 2**2),
6282
]
6383
)
64-
tensor_constraints.extend(
65-
[
66-
cp.Dtype.In(lambda deps: [torch.int, torch.float]),
67-
cp.Dtype.NotIn(lambda deps: [torch.int64, torch.float64]),
68-
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
69-
cp.Value.Le(lambda deps, dtype, struct: 2**4),
70-
cp.Rank.Ge(lambda deps: 1),
71-
cp.Size.Ge(lambda deps, r, d: 1),
72-
cp.Size.Le(lambda deps, r, d: 2**9),
73-
]
74-
)
84+
tensor_constraints.extend(additional_tensor_constraints)
7585

7686

7787
def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]:

0 commit comments

Comments
 (0)