Skip to content

Commit a216a03

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
update prim ops lib (#349)
Summary: Pull Request resolved: #349 saw that floordiv was more complicated then just calling floor added support for truediv added support for mixed dtype operations Reduced code duplication with a macro. from tugsbayasgalan testing seems like operator.truediv always returns symfloat and operator.floordiv respects the input type Reviewed By: larryliu0820 Differential Revision: D48369564 fbshipit-source-id: 3d8c259b76f83c67aa97f4ebc04523a6d41f2e8b
1 parent 31c80cf commit a216a03

File tree

5 files changed

+128
-125
lines changed

5 files changed

+128
-125
lines changed

exir/passes/executorch_prim_ops_registry.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def floordiv(a: _SymScalar, b: _SymScalar) -> _SymScalar:
4242
return a // b # pyre-ignore
4343

4444

45+
@bind_pattern_to_op(
46+
executorch_prims_lib, "truediv.Scalar(Scalar a, Scalar b) -> Scalar"
47+
)
48+
def truediv(a: _SymScalar, b: _SymScalar) -> _SymScalar:
49+
return a / b # pyre-ignore
50+
51+
4552
# TODO: ideally we should return SymBool in the schema, but it seems
4653
# the schema parser does not recognize SymBool yet: P629748075
4754
@bind_pattern_to_op(executorch_prims_lib, "gt.Scalar(Scalar a, Scalar b) -> bool")
@@ -74,6 +81,7 @@ def eq(a: _SymScalar, b: _SymScalar) -> bool:
7481
operator.mul: ops.backend.executorch_prim.mul.Scalar,
7582
operator.add: ops.backend.executorch_prim.add.Scalar,
7683
operator.floordiv: ops.backend.executorch_prim.floordiv.Scalar,
84+
operator.truediv: ops.backend.executorch_prim.truediv.Scalar,
7785
operator.eq: ops.backend.executorch_prim.eq.Scalar,
7886
operator.gt: ops.backend.executorch_prim.gt.Scalar,
7987
operator.lt: ops.backend.executorch_prim.lt.Scalar,

kernels/prim_ops/register_prim_ops.cpp

Lines changed: 92 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,52 @@ namespace function {
2020

2121
namespace {
2222

23+
#define __ET_PRIM_OP_ERROR_IMPL(a, b, context) \
24+
else { \
25+
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag); \
26+
}
27+
28+
// TODO Fail using runtime context
29+
#define __NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
30+
(void)context; \
31+
EValue& a = *stack[0]; \
32+
EValue& b = *stack[1]; \
33+
EValue& out = *stack[2]; \
34+
if (a.isInt() && b.isInt()) { \
35+
out = EValue(a.toInt() operator b.toInt()); \
36+
} else if (a.isDouble() && b.isDouble()) { \
37+
out = EValue(a.toDouble() operator b.toDouble()); \
38+
} else if (a.isInt() && b.isDouble()) { \
39+
out = EValue(a.toInt() operator b.toDouble()); \
40+
} else if (a.isDouble() && b.isInt()) { \
41+
out = EValue(a.toDouble() operator b.toInt()); \
42+
}
43+
44+
#define ALGEBRA_ET_PRIM_OP(operator, stack, context) \
45+
__NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
46+
__ET_PRIM_OP_ERROR_IMPL(a, b, context)
47+
48+
#define BOOLEAN_ET_PRIM_OP(operator, stack, context) \
49+
__NUMBER_ET_PRIM_OP_IMPL(operator, stack, context) \
50+
else if (a.isBool() && b.isBool()) { \
51+
out = EValue(a.toBool() operator b.toBool()); \
52+
} \
53+
__ET_PRIM_OP_ERROR_IMPL(a, b, context)
54+
55+
void floor_div_double(double a, double b, EValue& out) {
56+
if (b == 0) {
57+
out = EValue(std::signbit(a) ? -INFINITY : INFINITY);
58+
return;
59+
}
60+
const auto mod = std::fmod(a, b);
61+
auto div = (a - mod) / b;
62+
if ((mod != 0) && std::signbit(b) != std::signbit(mod)) {
63+
out = EValue(div - 1);
64+
return;
65+
}
66+
out = EValue(div);
67+
}
68+
2369
static Kernel prim_ops[] = {
2470
// aten::sym_size.int(Tensor self, int dim) -> SymInt
2571
Kernel(
@@ -50,171 +96,118 @@ static Kernel prim_ops[] = {
5096
"executorch_prim::add.Scalar",
5197
[](RuntimeContext& context, EValue** stack) {
5298
(void)context;
53-
EValue& a = *stack[0];
54-
EValue& b = *stack[1];
55-
EValue& out = *stack[2];
56-
if (a.isInt() && b.isInt()) {
57-
out = EValue(a.toInt() + b.toInt());
58-
} else if (a.isDouble() && b.isDouble()) {
59-
out = EValue(a.toDouble() + b.toDouble());
60-
} else {
61-
// TODO Fail using runtime context
62-
ET_CHECK(false);
63-
}
99+
ALGEBRA_ET_PRIM_OP(+, stack, context);
64100
}),
65101

66102
// executorch_prim::sub.Scalar(Scalar, Scalar) -> Scalar
67103
Kernel(
68104
"executorch_prim::sub.Scalar",
69105
[](RuntimeContext& context, EValue** stack) {
70-
(void)context;
71-
EValue& a = *stack[0];
72-
EValue& b = *stack[1];
73-
EValue& out = *stack[2];
74-
if (a.isInt() && b.isInt()) {
75-
out = EValue(a.toInt() - b.toInt());
76-
} else if (a.isDouble() && b.isDouble()) {
77-
out = EValue(a.toDouble() - b.toDouble());
78-
} else {
79-
// TODO Fail using runtime context
80-
ET_CHECK(false);
81-
}
106+
ALGEBRA_ET_PRIM_OP(-, stack, context);
82107
}),
83108

84109
// executorch_prim::mul.Scalar(Scalar, Scalar) -> Scalar
85110
Kernel(
86111
"executorch_prim::mul.Scalar",
112+
[](RuntimeContext& context, EValue** stack) {
113+
ALGEBRA_ET_PRIM_OP(*, stack, context);
114+
}),
115+
116+
/**
117+
* Python's __floordiv__ operator is more complicated than just floor(a /
118+
* b). It aims to maintain the property: a == (a // b) * b + remainder(a, b)
119+
* which can otherwise fail due to rounding errors in the remainder.
120+
* So, instead it is calculated as: a // b = (a - remainder(a, b)) / b
121+
* With some additional fix-ups added to the result.
122+
*
123+
* executorch_prim::floordiv.Scalar(Scalar, Scalar) -> Scalar
124+
*/
125+
Kernel(
126+
"executorch_prim::floordiv.Scalar",
87127
[](RuntimeContext& context, EValue** stack) {
88128
(void)context;
89129
EValue& a = *stack[0];
90130
EValue& b = *stack[1];
91131
EValue& out = *stack[2];
92132
if (a.isInt() && b.isInt()) {
93-
out = EValue(a.toInt() * b.toInt());
133+
const int64_t quot = a.toInt() / b.toInt();
134+
if (std::signbit(a.toInt()) == std::signbit(b.toInt())) {
135+
out = EValue(quot);
136+
return;
137+
}
138+
const int64_t rem = a.toInt() % b.toInt();
139+
out = EValue(rem ? quot - 1 : quot);
140+
return;
94141
} else if (a.isDouble() && b.isDouble()) {
95-
out = EValue(a.toDouble() * b.toDouble());
142+
floor_div_double(a.toDouble(), b.toDouble(), out);
143+
} else if (a.isInt() && b.isDouble()) {
144+
floor_div_double(static_cast<double>(a.toInt()), b.toDouble(), out);
145+
} else if (a.isDouble() && b.isInt()) {
146+
floor_div_double(a.toDouble(), static_cast<double>(b.toInt()), out);
96147
} else {
97148
// TODO Fail using runtime context
98-
ET_CHECK(false);
149+
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
99150
}
100151
}),
101152

102153
// executorch_prim::floordiv.Scalar(Scalar, Scalar) -> Scalar
103154
Kernel(
104-
"executorch_prim::floordiv.Scalar",
155+
"executorch_prim::truediv.Scalar",
105156
[](RuntimeContext& context, EValue** stack) {
157+
// can't use macro because of custom casting behavior
106158
(void)context;
107159
EValue& a = *stack[0];
108160
EValue& b = *stack[1];
109161
EValue& out = *stack[2];
110162
if (a.isInt() && b.isInt()) {
111-
out = EValue(a.toInt() / b.toInt());
163+
out = EValue(
164+
static_cast<double>(a.toInt()) /
165+
static_cast<double>(b.toInt()));
112166
} else if (a.isDouble() && b.isDouble()) {
113167
out = EValue(a.toDouble() / b.toDouble());
168+
} else if (a.isInt() && b.isDouble()) {
169+
out = EValue(a.toInt() / b.toDouble());
170+
} else if (a.isDouble() && b.isInt()) {
171+
out = EValue(a.toDouble() / b.toInt());
114172
} else {
115173
// TODO Fail using runtime context
116-
ET_CHECK(false);
174+
ET_CHECK_MSG(false, "%zu, %zu", (size_t)a.tag, (size_t)b.tag);
117175
}
118176
}),
119177

120178
// executorch_prim::eq.Scalar(Scalar, Scalar) -> bool
121179
Kernel(
122180
"executorch_prim::eq.Scalar",
123181
[](RuntimeContext& context, EValue** stack) {
124-
(void)context;
125-
EValue& a = *stack[0];
126-
EValue& b = *stack[1];
127-
EValue& out = *stack[2];
128-
if (a.isInt() && b.isInt()) {
129-
out = EValue(a.toInt() == b.toInt());
130-
} else if (a.isDouble() && b.isDouble()) {
131-
out = EValue(a.toDouble() == b.toDouble());
132-
} else if (a.isBool() && b.isBool()) {
133-
out = EValue(a.toBool() == b.toBool());
134-
} else {
135-
// TODO Fail using runtime context
136-
ET_CHECK(false);
137-
}
182+
BOOLEAN_ET_PRIM_OP(==, stack, context);
138183
}),
139184

140185
// executorch_prim::gt.Scalar(Scalar, Scalar) -> bool
141186
Kernel(
142187
"executorch_prim::gt.Scalar",
143188
[](RuntimeContext& context, EValue** stack) {
144-
(void)context;
145-
EValue& a = *stack[0];
146-
EValue& b = *stack[1];
147-
EValue& out = *stack[2];
148-
if (a.isInt() && b.isInt()) {
149-
out = EValue(a.toInt() > b.toInt());
150-
} else if (a.isDouble() && b.isDouble()) {
151-
out = EValue(a.toDouble() > b.toDouble());
152-
} else if (a.isBool() && b.isBool()) {
153-
out = EValue(a.toBool() > b.toBool());
154-
} else {
155-
// TODO Fail using runtime context
156-
ET_CHECK(false);
157-
}
189+
BOOLEAN_ET_PRIM_OP(>, stack, context);
158190
}),
159191

160192
// executorch_prim::lt.Scalar(Scalar, Scalar) -> bool
161193
Kernel(
162194
"executorch_prim::lt.Scalar",
163195
[](RuntimeContext& context, EValue** stack) {
164-
(void)context;
165-
EValue& a = *stack[0];
166-
EValue& b = *stack[1];
167-
EValue& out = *stack[2];
168-
if (a.isInt() && b.isInt()) {
169-
out = EValue(a.toInt() < b.toInt());
170-
} else if (a.isDouble() && b.isDouble()) {
171-
out = EValue(a.toDouble() < b.toDouble());
172-
} else if (a.isBool() && b.isBool()) {
173-
out = EValue(a.toBool() < b.toBool());
174-
} else {
175-
// TODO Fail using runtime context
176-
ET_CHECK(false);
177-
}
196+
BOOLEAN_ET_PRIM_OP(<, stack, context);
178197
}),
179198

180199
// executorch_prim::ge.Scalar(Scalar, Scalar) -> bool
181200
Kernel(
182201
"executorch_prim::ge.Scalar",
183202
[](RuntimeContext& context, EValue** stack) {
184-
(void)context;
185-
EValue& a = *stack[0];
186-
EValue& b = *stack[1];
187-
EValue& out = *stack[2];
188-
if (a.isInt() && b.isInt()) {
189-
out = EValue(a.toInt() >= b.toInt());
190-
} else if (a.isDouble() && b.isDouble()) {
191-
out = EValue(a.toDouble() >= b.toDouble());
192-
} else if (a.isBool() && b.isBool()) {
193-
out = EValue(a.toBool() >= b.toBool());
194-
} else {
195-
// TODO Fail using runtime context
196-
ET_CHECK(false);
197-
}
203+
BOOLEAN_ET_PRIM_OP(>=, stack, context);
198204
}),
199205

200206
// executorch_prim::le.Scalar(Scalar, Scalar) -> bool
201207
Kernel(
202208
"executorch_prim::le.Scalar",
203209
[](RuntimeContext& context, EValue** stack) {
204-
(void)context;
205-
EValue& a = *stack[0];
206-
EValue& b = *stack[1];
207-
EValue& out = *stack[2];
208-
if (a.isInt() && b.isInt()) {
209-
out = EValue(a.toInt() <= b.toInt());
210-
} else if (a.isDouble() && b.isDouble()) {
211-
out = EValue(a.toDouble() <= b.toDouble());
212-
} else if (a.isBool() && b.isBool()) {
213-
out = EValue(a.toBool() <= b.toBool());
214-
} else {
215-
// TODO Fail using runtime context
216-
ET_CHECK(false);
217-
}
210+
BOOLEAN_ET_PRIM_OP(<=, stack, context);
218211
}),
219212

220213
// executorch_prim::floordiv.int(int, int) -> int

kernels/prim_ops/test/TARGETS

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
77
oncall("executorch")
88

99
python_unittest(
10-
name = "test_prim_ops",
10+
name = "prim_ops_test_py",
1111
srcs = [
12-
"test_prim_ops.py",
12+
"prim_ops_test.py",
1313
],
1414
deps = [
1515
"//caffe2:torch",
@@ -18,9 +18,9 @@ python_unittest(
1818
)
1919

2020
cpp_unittest(
21-
name = "register_prim_ops_test",
21+
name = "prim_ops_test_cpp",
2222
srcs = [
23-
"register_prim_ops_test.cpp",
23+
"prim_ops_test.cpp",
2424
],
2525
supports_static_listing = True,
2626
deps = [

0 commit comments

Comments
 (0)