@@ -15,6 +15,101 @@ using exec_aten::Tensor;
15
15
16
16
namespace {
17
17
18
+ template <class CTYPE >
19
+ void check_scalar_value (const Scalar val) {
20
+ CTYPE _v = 0 ;
21
+ bool ok = utils::extract_scalar (val, &_v);
22
+ ET_CHECK_MSG (ok, " Invalid alpha value: wrong type or out of range" );
23
+ }
24
+
25
+ template <class CTYPE >
26
+ double extract_scalar_to_double (const Scalar val) {
27
+ CTYPE v = 0 ;
28
+ bool ok = utils::extract_scalar (val, &v);
29
+ ET_CHECK_MSG (ok, " Invalid end value: wrong type or out of range" );
30
+ return static_cast <double >(v);
31
+ }
32
+
33
+ void check_precondition (
34
+ const Scalar start,
35
+ const Scalar end,
36
+ const Scalar step,
37
+ Tensor& out) {
38
+ // Check the type consistency between scalar end and tensor out.
39
+ // They should be in floating point or integer simultaneously.
40
+ #define CHECK_FLOAT_TENSOR (ctype, dtype ) \
41
+ case ScalarType::dtype: \
42
+ ET_CHECK_MSG ( \
43
+ end.isFloatingPoint (), \
44
+ " end should have same type as out.dtype, but get \
45
+ non-floating point end and a floating point out tensor" ); \
46
+ break ;
47
+
48
+ #define CHECK_INT_TENSOR (ctype, dtype ) \
49
+ case ScalarType::dtype: \
50
+ ET_CHECK_MSG ( \
51
+ end.isIntegral (true ), \
52
+ " end should have same type as out, \
53
+ but get non-int end and a int out tensor" ); \
54
+ break ;
55
+
56
+ switch (out.scalar_type ()) {
57
+ ET_FORALL_FLOAT_TYPES (CHECK_FLOAT_TENSOR);
58
+ ET_FORALL_INT_TYPES_AND (Bool, CHECK_INT_TENSOR);
59
+ default :
60
+ ET_CHECK_MSG (
61
+ false ,
62
+ " out tensor should be in floating point or int dtype, but get %hhd" ,
63
+ out.scalar_type ());
64
+ }
65
+
66
+ #undef CHECK_FLOAT_TENSOR
67
+ #undef CHECK_INT_TENSOR
68
+
69
+ ET_CHECK_MSG (
70
+ out.sizes ().size () == 1 ,
71
+ " out should be a 1-d tensor, but got a %zu-d tensor" ,
72
+ out.sizes ().size ());
73
+
74
+ // Check if out size matches end.
75
+
76
+ // Set includeBool = false here because the following extract_scalar for int
77
+ // use includeBool = False. Have deal with boolean type separately.
78
+ if (start.isIntegral (false )) {
79
+ check_scalar_value<int64_t >(start);
80
+ } else if (start.isFloatingPoint ()) {
81
+ check_scalar_value<double >(start);
82
+ } else if (start.isBoolean ()) {
83
+ check_scalar_value<bool >(start);
84
+ } else {
85
+ ET_CHECK_MSG (
86
+ false ,
87
+ " Unexepcted type of start. Should be floating point or int type" );
88
+ }
89
+
90
+ if (end.isIntegral (false )) {
91
+ check_scalar_value<int64_t >(end);
92
+ } else if (end.isFloatingPoint ()) {
93
+ check_scalar_value<double >(end);
94
+ } else if (end.isBoolean ()) {
95
+ check_scalar_value<bool >(end);
96
+ } else {
97
+ ET_CHECK_MSG (
98
+ false , " Unexepcted type of end. Should be floating point or int type" );
99
+ }
100
+
101
+ if (step.isIntegral (false )) {
102
+ check_scalar_value<int64_t >(step);
103
+ } else if (step.isFloatingPoint ()) {
104
+ check_scalar_value<double >(step);
105
+ } else if (step.isBoolean ()) {
106
+ check_scalar_value<bool >(step);
107
+ } else {
108
+ ET_CHECK_MSG (
109
+ false , " Unexepcted type of step. Should be floating point or int type" );
110
+ }
111
+ };
112
+
18
113
template <class CTYPE >
19
114
void check_end (const Scalar end) {
20
115
CTYPE end_v = 0 ;
@@ -95,6 +190,19 @@ Tensor& set_arange_value(const size_t out_length, Tensor& out) {
95
190
return out;
96
191
}
97
192
193
+ template <class CTYPE >
194
+ Tensor& set_arange_value (
195
+ const double start,
196
+ const int64_t out_length,
197
+ const double step,
198
+ Tensor& out) {
199
+ auto out_data = out.data_ptr <CTYPE>();
200
+ for (int64_t i = 0 ; i < out_length; i++) {
201
+ out_data[i] = start + i * step;
202
+ }
203
+ return out;
204
+ }
205
+
98
206
} // namespace
99
207
100
208
/*
@@ -139,6 +247,77 @@ Tensor& arange_out(RuntimeContext& context, const Scalar& end, Tensor& out) {
139
247
return out;
140
248
}
141
249
250
+ Tensor& arange_start_out (
251
+ RuntimeContext& context,
252
+ const Scalar& start,
253
+ const Scalar& end,
254
+ const Scalar& step,
255
+ Tensor& out) {
256
+ (void )context;
257
+ check_precondition (start, end, step, out);
258
+
259
+ double d_start;
260
+ if (start.isIntegral (false )) {
261
+ d_start = extract_scalar_to_double<int64_t >(start);
262
+ } else if (start.isFloatingPoint ()) {
263
+ d_start = extract_scalar_to_double<double >(start);
264
+ } else if (start.isBoolean ()) {
265
+ d_start = extract_scalar_to_double<bool >(start);
266
+ } else {
267
+ ET_CHECK_MSG (false , " Unhandled scalar type" );
268
+ }
269
+
270
+ double d_end;
271
+ if (end.isIntegral (false )) {
272
+ d_end = extract_scalar_to_double<int64_t >(end);
273
+ } else if (end.isFloatingPoint ()) {
274
+ d_end = extract_scalar_to_double<double >(end);
275
+ } else if (end.isBoolean ()) {
276
+ d_end = extract_scalar_to_double<bool >(end);
277
+ } else {
278
+ ET_CHECK_MSG (false , " Unhandled scalar type" );
279
+ }
280
+
281
+ double d_step = 0 ;
282
+ if (step.isIntegral (false )) {
283
+ d_step = extract_scalar_to_double<int64_t >(step);
284
+ } else if (step.isFloatingPoint ()) {
285
+ d_step = extract_scalar_to_double<double >(step);
286
+ } else if (step.isBoolean ()) {
287
+ d_step = extract_scalar_to_double<bool >(step);
288
+ } else {
289
+ ET_CHECK_MSG (false , " Unhandled scalar type" );
290
+ }
291
+
292
+ ET_CHECK_MSG (
293
+ (d_step > 0 && (d_end >= d_start)) || (d_step < 0 && (d_end <= d_start)),
294
+ " upper bound and larger bound inconsistent with step sign" );
295
+
296
+ double size_d = (d_end - d_start) / d_step;
297
+ int64_t size = static_cast <int64_t >(std::ceil (size_d));
298
+
299
+ Tensor::SizesType out_target_length = static_cast <Tensor::SizesType>(size);
300
+ Error status = resize_tensor (out, {&out_target_length, 1 });
301
+ ET_CHECK_MSG (status == Error::Ok, " resize_tensor fails\n " );
302
+
303
+ #define SET_START_ARANGE_VALUE_TO_TENSOR (ctype, dtype ) \
304
+ case ScalarType::dtype: \
305
+ out = set_arange_value<ctype>(d_start, size, d_step, out); \
306
+ break ;
307
+
308
+ switch (out.scalar_type ()) {
309
+ ET_FORALL_REAL_TYPES_AND (Bool, SET_START_ARANGE_VALUE_TO_TENSOR)
310
+ default :
311
+ ET_CHECK_MSG (
312
+ false ,
313
+ " out tensor should be in floating point or int dtype, but get %hhd" ,
314
+ out.scalar_type ());
315
+ }
316
+ #undef SET_START_ARANGE_VALUE_TO_TENSOR
317
+
318
+ return out;
319
+ }
320
+
142
321
} // namespace native
143
322
} // namespace executor
144
323
} // namespace torch
0 commit comments