@@ -18,238 +18,43 @@ namespace torch {
18
18
namespace executor {
19
19
namespace native {
20
20
21
- using exec_aten::Tensor;
22
-
23
- namespace {
24
-
25
- template <class CTYPE >
26
- void check_scalar_value (const Scalar val) {
27
- CTYPE _v = 0 ;
28
- bool ok = utils::extract_scalar (val, &_v);
29
- ET_CHECK_MSG (ok, " Invalid alpha value: wrong type or out of range" );
30
- }
31
-
32
- template <class CTYPE >
33
- double extract_scalar_to_double (const Scalar val) {
34
- CTYPE v = 0 ;
35
- bool ok = utils::extract_scalar (val, &v);
36
- ET_CHECK_MSG (ok, " Invalid end value: wrong type or out of range" );
37
- return static_cast <double >(v);
38
- }
39
-
40
- void check_precondition (
41
- const Scalar start,
42
- const Scalar end,
43
- const Scalar step,
44
- Tensor& out) {
45
- // Check the type consistency between scalar end and tensor out.
46
- // They should be in floating point or integer simultaneously.
47
- #define CHECK_FLOAT_TENSOR (ctype, dtype ) \
48
- case ScalarType::dtype: \
49
- ET_CHECK_MSG ( \
50
- end.isFloatingPoint (), \
51
- " end should have same type as out.dtype, but get \
52
- non-floating point end and a floating point out tensor" ); \
53
- break ;
54
-
55
- #define CHECK_INT_TENSOR (ctype, dtype ) \
56
- case ScalarType::dtype: \
57
- ET_CHECK_MSG ( \
58
- end.isIntegral (true ), \
59
- " end should have same type as out, \
60
- but get non-int end and a int out tensor" ); \
61
- break ;
62
-
63
- switch (out.scalar_type ()) {
64
- ET_FORALL_FLOAT_TYPES (CHECK_FLOAT_TENSOR);
65
- ET_FORALL_INT_TYPES_AND (Bool, CHECK_INT_TENSOR);
66
- default :
67
- ET_CHECK_MSG (
68
- false ,
69
- " out tensor should be in floating point or int dtype, but get %hhd" ,
70
- out.scalar_type ());
71
- }
72
-
73
- #undef CHECK_FLOAT_TENSOR
74
- #undef CHECK_INT_TENSOR
75
-
76
- ET_CHECK_MSG (
77
- out.sizes ().size () == 1 ,
78
- " out should be a 1-d tensor, but got a %zu-d tensor" ,
79
- out.sizes ().size ());
80
-
81
- // Check if out size matches end.
82
-
83
- // Set includeBool = false here because the following extract_scalar for int
84
- // use includeBool = False. Have deal with boolean type separately.
85
- if (start.isIntegral (false )) {
86
- check_scalar_value<int64_t >(start);
87
- } else if (start.isFloatingPoint ()) {
88
- check_scalar_value<double >(start);
89
- } else if (start.isBoolean ()) {
90
- check_scalar_value<bool >(start);
91
- } else {
92
- ET_CHECK_MSG (
93
- false ,
94
- " Unexepcted type of start. Should be floating point or int type" );
95
- }
96
-
97
- if (end.isIntegral (false )) {
98
- check_scalar_value<int64_t >(end);
99
- } else if (end.isFloatingPoint ()) {
100
- check_scalar_value<double >(end);
101
- } else if (end.isBoolean ()) {
102
- check_scalar_value<bool >(end);
103
- } else {
104
- ET_CHECK_MSG (
105
- false , " Unexepcted type of end. Should be floating point or int type" );
106
- }
107
-
108
- if (step.isIntegral (false )) {
109
- check_scalar_value<int64_t >(step);
110
- } else if (step.isFloatingPoint ()) {
111
- check_scalar_value<double >(step);
112
- } else if (step.isBoolean ()) {
113
- check_scalar_value<bool >(step);
114
- } else {
115
- ET_CHECK_MSG (
116
- false , " Unexepcted type of step. Should be floating point or int type" );
117
- }
118
- };
119
-
120
- template <class CTYPE >
121
- void check_end (const Scalar end) {
122
- CTYPE end_v = 0 ;
123
- bool ok = utils::extract_scalar (end, &end_v);
124
- ET_CHECK_MSG (ok, " Invalid alpha value: wrong type or out of range" );
125
- ET_CHECK_MSG (end_v >= 0 , " end shall be larger than or equal to 0\n " );
126
- }
127
-
128
- // end here is non-negative scalar, so we can floor it by casting it to int.
129
- template <class CTYPE >
130
- int64_t floor_scalar_to_nearest_int (const Scalar end) {
131
- CTYPE end_v = 0 ;
132
- bool ok = utils::extract_scalar (end, &end_v);
133
- ET_CHECK_MSG (end_v >= 0 , " Input end should be non-negative." );
134
- ET_CHECK_MSG (ok, " Invalid end value: wrong type or out of range" );
135
- return static_cast <int64_t >(end_v);
136
- }
137
-
138
- void check_precondition (const Scalar end, Tensor& out) {
139
- // Check the type consistency between scalar end and tensor out.
140
- // They should be in floating point or integer simultaneously.
141
- #define CHECK_FLOAT_TENSOR (ctype, dtype ) \
142
- case ScalarType::dtype: \
143
- ET_CHECK_MSG ( \
144
- end.isFloatingPoint (), \
145
- " end should have same type as out.dtype, but get \
146
- non-floating point end and a floating point out tensor" ); \
147
- break ;
148
-
149
- #define CHECK_INT_TENSOR (ctype, dtype ) \
150
- case ScalarType::dtype: \
151
- ET_CHECK_MSG ( \
152
- end.isIntegral (true ), \
153
- " end should have same type as out, \
154
- but get non-int end and a int out tensor" ); \
155
- break ;
156
-
157
- switch (out.scalar_type ()) {
158
- ET_FORALL_FLOAT_TYPES (CHECK_FLOAT_TENSOR);
159
- ET_FORALL_INT_TYPES_AND (Bool, CHECK_INT_TENSOR);
160
- default :
161
- ET_CHECK_MSG (
162
- false ,
163
- " out tensor should be in floating point or int dtype, but get %hhd" ,
164
- out.scalar_type ());
165
- }
166
-
167
- #undef CHECK_FLOAT_TENSOR
168
- #undef CHECK_INT_TENSOR
169
-
170
- ET_CHECK_MSG (
171
- out.sizes ().size () == 1 ,
172
- " out should be a 1-d tensor, but got a %zu-d tensor" ,
173
- out.sizes ().size ());
174
-
175
- // Check if out size matches end.
176
-
177
- // Set includeBool = false here because the following extract_scalar for int
178
- // use includeBool = False. Have deal with boolean type separately.
179
- if (end.isIntegral (false )) {
180
- check_end<int64_t >(end);
181
- } else if (end.isFloatingPoint ()) {
182
- check_end<double >(end);
183
- } else if (end.isBoolean ()) {
184
- check_end<bool >(end);
185
- } else {
186
- ET_CHECK_MSG (
187
- false , " Unexepcted type of end. Should be floating point or int type" );
188
- }
189
- };
190
-
191
- template <class CTYPE >
192
- Tensor& set_arange_value (const size_t out_length, Tensor& out) {
193
- auto out_data = out.mutable_data_ptr <CTYPE>();
194
- for (size_t i = 0 ; i < out_length; i++) {
195
- out_data[i] = static_cast <CTYPE>(i);
196
- }
197
- return out;
198
- }
199
-
200
- template <class CTYPE >
201
- Tensor& set_arange_value (
202
- const double start,
203
- const int64_t out_length,
204
- const double step,
205
- Tensor& out) {
206
- auto out_data = out.mutable_data_ptr <CTYPE>();
207
- for (int64_t i = 0 ; i < out_length; i++) {
208
- out_data[i] = start + i * step;
209
- }
210
- return out;
211
- }
212
-
213
- } // namespace
214
-
215
- /*
216
- * Fill out tensor using arange(0, end)
217
- *
218
- * arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!)
219
- */
220
21
Tensor& arange_out (RuntimeContext& ctx, const Scalar& end, Tensor& out) {
221
- check_precondition (end, out);
222
-
223
- int64_t end_floor = 0 ;
224
- if (end.isIntegral (false )) {
225
- end_floor = floor_scalar_to_nearest_int<int64_t >(end);
226
- } else if (end.isFloatingPoint ()) {
227
- end_floor = floor_scalar_to_nearest_int<double >(end);
228
- } else if (end.isBoolean ()) {
229
- end_floor = floor_scalar_to_nearest_int<bool >(end);
230
- } else {
231
- ET_CHECK_MSG (false , " Unhandled scalar type" );
232
- }
233
-
234
- Tensor::SizesType out_target_length =
235
- static_cast <Tensor::SizesType>(end_floor);
236
- Error status = resize_tensor (out, {&out_target_length, 1 });
237
- ET_CHECK_MSG (status == Error::Ok, " resize_tensor fails\n " );
238
-
239
- #define SET_ARANGE_VALUE_TO_TENSOR (ctype, dtype ) \
240
- case ScalarType::dtype: \
241
- out = set_arange_value<ctype>(end_floor, out); \
242
- break ;
243
-
244
- switch (out.scalar_type ()) {
245
- ET_FORALL_REAL_TYPES_AND (Bool, SET_ARANGE_VALUE_TO_TENSOR)
246
- default :
247
- ET_CHECK_MSG (
248
- false ,
249
- " out tensor should be in floating point or int dtype, but get %hhd" ,
250
- out.scalar_type ());
251
- }
252
- #undef SET_ARANGE_VALUE_TO_TENSOR
22
+ ET_KERNEL_CHECK_MSG (
23
+ ctx,
24
+ out.dim () == 1 ,
25
+ InvalidArgument,
26
+ out,
27
+ " out should be a 1-d tensor, but got a %zu-d tensor" ,
28
+ out.dim ());
29
+
30
+ ScalarType end_type = utils::get_scalar_dtype (end);
31
+
32
+ double end_val = 0 ;
33
+ ET_SWITCH_SCALAR_OBJ_TYPES (end_type, ctx, __func__, CTYPE_END, [&]() {
34
+ CTYPE_END end_v;
35
+ ET_EXTRACT_SCALAR (end, end_v);
36
+ ET_KERNEL_CHECK_MSG (
37
+ ctx,
38
+ end_v >= 0 ,
39
+ InvalidArgument,
40
+ out,
41
+ " Input end should be non-negative." );
42
+ end_val = static_cast <double >(end_v);
43
+ });
44
+
45
+ size_t size = static_cast <size_t >(std::ceil (end_val));
46
+
47
+ Tensor::SizesType out_length = static_cast <Tensor::SizesType>(size);
48
+ Error status = resize_tensor (out, {&out_length, 1 });
49
+ ET_KERNEL_CHECK_MSG (
50
+ ctx, status == Error::Ok, InvalidArgument, out, " resize_tensor fails" );
51
+
52
+ ET_SWITCH_REAL_TYPES (out.scalar_type (), ctx, __func__, CTYPE, [&]() {
53
+ auto out_data = out.mutable_data_ptr <CTYPE>();
54
+ for (size_t i = 0 ; i < size; i++) {
55
+ out_data[i] = static_cast <CTYPE>(i);
56
+ }
57
+ });
253
58
254
59
return out;
255
60
}
@@ -261,66 +66,53 @@ Tensor& arange_start_out(
261
66
const Scalar& step,
262
67
Tensor& out) {
263
68
(void )ctx;
264
- check_precondition (start, end, step, out);
265
69
266
- double d_start;
267
- if (start.isIntegral (false )) {
268
- d_start = extract_scalar_to_double<int64_t >(start);
269
- } else if (start.isFloatingPoint ()) {
270
- d_start = extract_scalar_to_double<double >(start);
271
- } else if (start.isBoolean ()) {
272
- d_start = extract_scalar_to_double<bool >(start);
273
- } else {
274
- ET_CHECK_MSG (false , " Unhandled scalar type" );
275
- }
70
+ ScalarType start_type = utils::get_scalar_dtype (start);
71
+ ScalarType end_type = utils::get_scalar_dtype (end);
72
+ ScalarType step_type = utils::get_scalar_dtype (step);
276
73
277
- double d_end;
278
- if (end.isIntegral (false )) {
279
- d_end = extract_scalar_to_double<int64_t >(end);
280
- } else if (end.isFloatingPoint ()) {
281
- d_end = extract_scalar_to_double<double >(end);
282
- } else if (end.isBoolean ()) {
283
- d_end = extract_scalar_to_double<bool >(end);
284
- } else {
285
- ET_CHECK_MSG (false , " Unhandled scalar type" );
286
- }
74
+ double d_start = 0 ;
75
+ ET_SWITCH_SCALAR_OBJ_TYPES (start_type, ctx, __func__, CTYPE_END, [&]() {
76
+ CTYPE_END start_v;
77
+ ET_EXTRACT_SCALAR (start, start_v);
78
+ d_start = static_cast <double >(start_v);
79
+ });
287
80
288
- double d_step = 0 ;
289
- if (step.isIntegral (false )) {
290
- d_step = extract_scalar_to_double<int64_t >(step);
291
- } else if (step.isFloatingPoint ()) {
292
- d_step = extract_scalar_to_double<double >(step);
293
- } else if (step.isBoolean ()) {
294
- d_step = extract_scalar_to_double<bool >(step);
295
- } else {
296
- ET_CHECK_MSG (false , " Unhandled scalar type" );
297
- }
81
+ double d_end = 0 ;
82
+ ET_SWITCH_SCALAR_OBJ_TYPES (end_type, ctx, __func__, CTYPE_END, [&]() {
83
+ CTYPE_END end_v;
84
+ ET_EXTRACT_SCALAR (end, end_v);
85
+ d_end = static_cast <double >(end_v);
86
+ });
298
87
299
- ET_CHECK_MSG (
88
+ double d_step = 0 ;
89
+ ET_SWITCH_SCALAR_OBJ_TYPES (step_type, ctx, __func__, CTYPE_END, [&]() {
90
+ CTYPE_END step_v;
91
+ ET_EXTRACT_SCALAR (step, step_v);
92
+ d_step = static_cast <double >(step_v);
93
+ });
94
+
95
+ ET_KERNEL_CHECK_MSG (
96
+ ctx,
300
97
(d_step > 0 && (d_end >= d_start)) || (d_step < 0 && (d_end <= d_start)),
98
+ InvalidArgument,
99
+ out,
301
100
" upper bound and larger bound inconsistent with step sign" );
302
101
303
102
double size_d = (d_end - d_start) / d_step;
304
- int64_t size = static_cast <int64_t >(std::ceil (size_d));
305
-
306
- Tensor::SizesType out_target_length = static_cast <Tensor::SizesType>(size);
307
- Error status = resize_tensor (out, {&out_target_length, 1 });
308
- ET_CHECK_MSG (status == Error::Ok, " resize_tensor fails\n " );
309
-
310
- #define SET_START_ARANGE_VALUE_TO_TENSOR (ctype, dtype ) \
311
- case ScalarType::dtype: \
312
- out = set_arange_value<ctype>(d_start, size, d_step, out); \
313
- break ;
314
-
315
- switch (out.scalar_type ()) {
316
- ET_FORALL_REAL_TYPES_AND (Bool, SET_START_ARANGE_VALUE_TO_TENSOR)
317
- default :
318
- ET_CHECK_MSG (
319
- false ,
320
- " out tensor should be in floating point or int dtype, but get %hhd" ,
321
- out.scalar_type ());
322
- }
323
- #undef SET_START_ARANGE_VALUE_TO_TENSOR
103
+ size_t size = static_cast <size_t >(std::ceil (size_d));
104
+
105
+ Tensor::SizesType out_length = static_cast <Tensor::SizesType>(size);
106
+ Error status = resize_tensor (out, {&out_length, 1 });
107
+ ET_KERNEL_CHECK_MSG (
108
+ ctx, status == Error::Ok, InvalidArgument, out, " resize_tensor fails" );
109
+
110
+ ET_SWITCH_REAL_TYPES (out.scalar_type (), ctx, __func__, CTYPE, [&]() {
111
+ auto out_data = out.mutable_data_ptr <CTYPE>();
112
+ for (size_t i = 0 ; i < size; i++) {
113
+ out_data[i] = convert<CTYPE, double >(d_start + i * d_step);
114
+ }
115
+ });
324
116
325
117
return out;
326
118
}
0 commit comments