@@ -20,6 +20,52 @@ namespace function {
20
20
21
21
namespace {
22
22
23
+ #define __ET_PRIM_OP_ERROR_IMPL (a, b, context ) \
24
+ else { \
25
+ ET_CHECK_MSG (false , " %zu, %zu" , (size_t )a.tag , (size_t )b.tag ); \
26
+ }
27
+
28
+ // TODO Fail using runtime context
29
+ #define __NUMBER_ET_PRIM_OP_IMPL (operator, stack, context ) \
30
+ (void )context; \
31
+ EValue& a = *stack[0 ]; \
32
+ EValue& b = *stack[1 ]; \
33
+ EValue& out = *stack[2 ]; \
34
+ if (a.isInt() && b.isInt()) { \
35
+ out = EValue (a.toInt () operator b.toInt ()); \
36
+ } else if (a.isDouble() && b.isDouble()) { \
37
+ out = EValue (a.toDouble () operator b.toDouble ()); \
38
+ } else if (a.isInt() && b.isDouble()) { \
39
+ out = EValue (a.toInt () operator b.toDouble ()); \
40
+ } else if (a.isDouble() && b.isInt()) { \
41
+ out = EValue (a.toDouble () operator b.toInt ()); \
42
+ }
43
+
44
+ #define ALGEBRA_ET_PRIM_OP (operator, stack, context ) \
45
+ __NUMBER_ET_PRIM_OP_IMPL (operator , stack, context) \
46
+ __ET_PRIM_OP_ERROR_IMPL (a, b, context)
47
+
48
+ #define BOOLEAN_ET_PRIM_OP (operator, stack, context ) \
49
+ __NUMBER_ET_PRIM_OP_IMPL (operator , stack, context) \
50
+ else if (a.isBool() && b.isBool()) { \
51
+ out = EValue (a.toBool () operator b.toBool ()); \
52
+ } \
53
+ __ET_PRIM_OP_ERROR_IMPL (a, b, context)
54
+
55
+ void floor_div_double (double a, double b, EValue& out) {
56
+ if (b == 0 ) {
57
+ out = EValue (std::signbit (a) ? -INFINITY : INFINITY);
58
+ return ;
59
+ }
60
+ const auto mod = std::fmod (a, b);
61
+ auto div = (a - mod) / b;
62
+ if ((mod != 0 ) && std::signbit (b) != std::signbit (mod)) {
63
+ out = EValue (div - 1 );
64
+ return ;
65
+ }
66
+ out = EValue (div);
67
+ }
68
+
23
69
static Kernel prim_ops[] = {
24
70
// aten::sym_size.int(Tensor self, int dim) -> SymInt
25
71
Kernel (
@@ -50,171 +96,118 @@ static Kernel prim_ops[] = {
50
96
" executorch_prim::add.Scalar" ,
51
97
[](RuntimeContext& context, EValue** stack) {
52
98
(void )context;
53
- EValue& a = *stack[0 ];
54
- EValue& b = *stack[1 ];
55
- EValue& out = *stack[2 ];
56
- if (a.isInt () && b.isInt ()) {
57
- out = EValue (a.toInt () + b.toInt ());
58
- } else if (a.isDouble () && b.isDouble ()) {
59
- out = EValue (a.toDouble () + b.toDouble ());
60
- } else {
61
- // TODO Fail using runtime context
62
- ET_CHECK (false );
63
- }
99
+ ALGEBRA_ET_PRIM_OP (+, stack, context);
64
100
}),
65
101
66
102
// executorch_prim::sub.Scalar(Scalar, Scalar) -> Scalar
67
103
Kernel (
68
104
" executorch_prim::sub.Scalar" ,
69
105
[](RuntimeContext& context, EValue** stack) {
70
- (void )context;
71
- EValue& a = *stack[0 ];
72
- EValue& b = *stack[1 ];
73
- EValue& out = *stack[2 ];
74
- if (a.isInt () && b.isInt ()) {
75
- out = EValue (a.toInt () - b.toInt ());
76
- } else if (a.isDouble () && b.isDouble ()) {
77
- out = EValue (a.toDouble () - b.toDouble ());
78
- } else {
79
- // TODO Fail using runtime context
80
- ET_CHECK (false );
81
- }
106
+ ALGEBRA_ET_PRIM_OP (-, stack, context);
82
107
}),
83
108
84
109
// executorch_prim::mul.Scalar(Scalar, Scalar) -> Scalar
85
110
Kernel (
86
111
" executorch_prim::mul.Scalar" ,
112
+ [](RuntimeContext& context, EValue** stack) {
113
+ ALGEBRA_ET_PRIM_OP (*, stack, context);
114
+ }),
115
+
116
+ /* *
117
+ * Python's __floordiv__ operator is more complicated than just floor(a /
118
+ * b). It aims to maintain the property: a == (a // b) * b + remainder(a, b)
119
+ * which can otherwise fail due to rounding errors in the remainder.
120
+ * So, instead it is calculated as: a // b = (a - remainder(a, b)) / b
121
+ * With some additional fix-ups added to the result.
122
+ *
123
+ * executorch_prim::floordiv.Scalar(Scalar, Scalar) -> Scalar
124
+ */
125
+ Kernel (
126
+ " executorch_prim::floordiv.Scalar" ,
87
127
[](RuntimeContext& context, EValue** stack) {
88
128
(void )context;
89
129
EValue& a = *stack[0 ];
90
130
EValue& b = *stack[1 ];
91
131
EValue& out = *stack[2 ];
92
132
if (a.isInt () && b.isInt ()) {
93
- out = EValue (a.toInt () * b.toInt ());
133
+ const int64_t quot = a.toInt () / b.toInt ();
134
+ if (std::signbit (a.toInt ()) == std::signbit (b.toInt ())) {
135
+ out = EValue (quot);
136
+ return ;
137
+ }
138
+ const int64_t rem = a.toInt () % b.toInt ();
139
+ out = EValue (rem ? quot - 1 : quot);
140
+ return ;
94
141
} else if (a.isDouble () && b.isDouble ()) {
95
- out = EValue (a.toDouble () * b.toDouble ());
142
+ floor_div_double (a.toDouble (), b.toDouble (), out);
143
+ } else if (a.isInt () && b.isDouble ()) {
144
+ floor_div_double (static_cast <double >(a.toInt ()), b.toDouble (), out);
145
+ } else if (a.isDouble () && b.isInt ()) {
146
+ floor_div_double (a.toDouble (), static_cast <double >(b.toInt ()), out);
96
147
} else {
97
148
// TODO Fail using runtime context
98
- ET_CHECK (false );
149
+ ET_CHECK_MSG (false , " %zu, %zu " , ( size_t )a. tag , ( size_t )b. tag );
99
150
}
100
151
}),
101
152
102
153
// executorch_prim::floordiv.Scalar(Scalar, Scalar) -> Scalar
103
154
Kernel (
104
- " executorch_prim::floordiv .Scalar" ,
155
+ " executorch_prim::truediv .Scalar" ,
105
156
[](RuntimeContext& context, EValue** stack) {
157
+ // can't use macro because of custom casting behavior
106
158
(void )context;
107
159
EValue& a = *stack[0 ];
108
160
EValue& b = *stack[1 ];
109
161
EValue& out = *stack[2 ];
110
162
if (a.isInt () && b.isInt ()) {
111
- out = EValue (a.toInt () / b.toInt ());
163
+ out = EValue (
164
+ static_cast <double >(a.toInt ()) /
165
+ static_cast <double >(b.toInt ()));
112
166
} else if (a.isDouble () && b.isDouble ()) {
113
167
out = EValue (a.toDouble () / b.toDouble ());
168
+ } else if (a.isInt () && b.isDouble ()) {
169
+ out = EValue (a.toInt () / b.toDouble ());
170
+ } else if (a.isDouble () && b.isInt ()) {
171
+ out = EValue (a.toDouble () / b.toInt ());
114
172
} else {
115
173
// TODO Fail using runtime context
116
- ET_CHECK (false );
174
+ ET_CHECK_MSG (false , " %zu, %zu " , ( size_t )a. tag , ( size_t )b. tag );
117
175
}
118
176
}),
119
177
120
178
// executorch_prim::eq.Scalar(Scalar, Scalar) -> bool
121
179
Kernel (
122
180
" executorch_prim::eq.Scalar" ,
123
181
[](RuntimeContext& context, EValue** stack) {
124
- (void )context;
125
- EValue& a = *stack[0 ];
126
- EValue& b = *stack[1 ];
127
- EValue& out = *stack[2 ];
128
- if (a.isInt () && b.isInt ()) {
129
- out = EValue (a.toInt () == b.toInt ());
130
- } else if (a.isDouble () && b.isDouble ()) {
131
- out = EValue (a.toDouble () == b.toDouble ());
132
- } else if (a.isBool () && b.isBool ()) {
133
- out = EValue (a.toBool () == b.toBool ());
134
- } else {
135
- // TODO Fail using runtime context
136
- ET_CHECK (false );
137
- }
182
+ BOOLEAN_ET_PRIM_OP (==, stack, context);
138
183
}),
139
184
140
185
// executorch_prim::gt.Scalar(Scalar, Scalar) -> bool
141
186
Kernel (
142
187
" executorch_prim::gt.Scalar" ,
143
188
[](RuntimeContext& context, EValue** stack) {
144
- (void )context;
145
- EValue& a = *stack[0 ];
146
- EValue& b = *stack[1 ];
147
- EValue& out = *stack[2 ];
148
- if (a.isInt () && b.isInt ()) {
149
- out = EValue (a.toInt () > b.toInt ());
150
- } else if (a.isDouble () && b.isDouble ()) {
151
- out = EValue (a.toDouble () > b.toDouble ());
152
- } else if (a.isBool () && b.isBool ()) {
153
- out = EValue (a.toBool () > b.toBool ());
154
- } else {
155
- // TODO Fail using runtime context
156
- ET_CHECK (false );
157
- }
189
+ BOOLEAN_ET_PRIM_OP (>, stack, context);
158
190
}),
159
191
160
192
// executorch_prim::lt.Scalar(Scalar, Scalar) -> bool
161
193
Kernel (
162
194
" executorch_prim::lt.Scalar" ,
163
195
[](RuntimeContext& context, EValue** stack) {
164
- (void )context;
165
- EValue& a = *stack[0 ];
166
- EValue& b = *stack[1 ];
167
- EValue& out = *stack[2 ];
168
- if (a.isInt () && b.isInt ()) {
169
- out = EValue (a.toInt () < b.toInt ());
170
- } else if (a.isDouble () && b.isDouble ()) {
171
- out = EValue (a.toDouble () < b.toDouble ());
172
- } else if (a.isBool () && b.isBool ()) {
173
- out = EValue (a.toBool () < b.toBool ());
174
- } else {
175
- // TODO Fail using runtime context
176
- ET_CHECK (false );
177
- }
196
+ BOOLEAN_ET_PRIM_OP (<, stack, context);
178
197
}),
179
198
180
199
// executorch_prim::ge.Scalar(Scalar, Scalar) -> bool
181
200
Kernel (
182
201
" executorch_prim::ge.Scalar" ,
183
202
[](RuntimeContext& context, EValue** stack) {
184
- (void )context;
185
- EValue& a = *stack[0 ];
186
- EValue& b = *stack[1 ];
187
- EValue& out = *stack[2 ];
188
- if (a.isInt () && b.isInt ()) {
189
- out = EValue (a.toInt () >= b.toInt ());
190
- } else if (a.isDouble () && b.isDouble ()) {
191
- out = EValue (a.toDouble () >= b.toDouble ());
192
- } else if (a.isBool () && b.isBool ()) {
193
- out = EValue (a.toBool () >= b.toBool ());
194
- } else {
195
- // TODO Fail using runtime context
196
- ET_CHECK (false );
197
- }
203
+ BOOLEAN_ET_PRIM_OP (>=, stack, context);
198
204
}),
199
205
200
206
// executorch_prim::le.Scalar(Scalar, Scalar) -> bool
201
207
Kernel (
202
208
" executorch_prim::le.Scalar" ,
203
209
[](RuntimeContext& context, EValue** stack) {
204
- (void )context;
205
- EValue& a = *stack[0 ];
206
- EValue& b = *stack[1 ];
207
- EValue& out = *stack[2 ];
208
- if (a.isInt () && b.isInt ()) {
209
- out = EValue (a.toInt () <= b.toInt ());
210
- } else if (a.isDouble () && b.isDouble ()) {
211
- out = EValue (a.toDouble () <= b.toDouble ());
212
- } else if (a.isBool () && b.isBool ()) {
213
- out = EValue (a.toBool () <= b.toBool ());
214
- } else {
215
- // TODO Fail using runtime context
216
- ET_CHECK (false );
217
- }
210
+ BOOLEAN_ET_PRIM_OP (<=, stack, context);
218
211
}),
219
212
220
213
// executorch_prim::floordiv.int(int, int) -> int
0 commit comments