Skip to content

Commit 8a85ced

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Dtype compliance: arange
Reviewed By: kirklandsign Differential Revision: D48423293 fbshipit-source-id: 2da22be4310376b982159fea28d93b83efe577cc
1 parent 87cf2bd commit 8a85ced

File tree

2 files changed

+89
-377
lines changed

2 files changed

+89
-377
lines changed

kernels/portable/cpu/op_arange.cpp

Lines changed: 75 additions & 283 deletions
Original file line numberDiff line numberDiff line change
@@ -18,238 +18,43 @@ namespace torch {
1818
namespace executor {
1919
namespace native {
2020

21-
using exec_aten::Tensor;
22-
23-
namespace {
24-
25-
template <class CTYPE>
26-
void check_scalar_value(const Scalar val) {
27-
CTYPE _v = 0;
28-
bool ok = utils::extract_scalar(val, &_v);
29-
ET_CHECK_MSG(ok, "Invalid alpha value: wrong type or out of range");
30-
}
31-
32-
template <class CTYPE>
33-
double extract_scalar_to_double(const Scalar val) {
34-
CTYPE v = 0;
35-
bool ok = utils::extract_scalar(val, &v);
36-
ET_CHECK_MSG(ok, "Invalid end value: wrong type or out of range");
37-
return static_cast<double>(v);
38-
}
39-
40-
void check_precondition(
41-
const Scalar start,
42-
const Scalar end,
43-
const Scalar step,
44-
Tensor& out) {
45-
// Check the type consistency between scalar end and tensor out.
46-
// They should be in floating point or integer simultaneously.
47-
#define CHECK_FLOAT_TENSOR(ctype, dtype) \
48-
case ScalarType::dtype: \
49-
ET_CHECK_MSG( \
50-
end.isFloatingPoint(), \
51-
"end should have same type as out.dtype, but get \
52-
non-floating point end and a floating point out tensor"); \
53-
break;
54-
55-
#define CHECK_INT_TENSOR(ctype, dtype) \
56-
case ScalarType::dtype: \
57-
ET_CHECK_MSG( \
58-
end.isIntegral(true), \
59-
"end should have same type as out, \
60-
but get non-int end and a int out tensor"); \
61-
break;
62-
63-
switch (out.scalar_type()) {
64-
ET_FORALL_FLOAT_TYPES(CHECK_FLOAT_TENSOR);
65-
ET_FORALL_INT_TYPES_AND(Bool, CHECK_INT_TENSOR);
66-
default:
67-
ET_CHECK_MSG(
68-
false,
69-
"out tensor should be in floating point or int dtype, but get %hhd",
70-
out.scalar_type());
71-
}
72-
73-
#undef CHECK_FLOAT_TENSOR
74-
#undef CHECK_INT_TENSOR
75-
76-
ET_CHECK_MSG(
77-
out.sizes().size() == 1,
78-
"out should be a 1-d tensor, but got a %zu-d tensor",
79-
out.sizes().size());
80-
81-
// Check if out size matches end.
82-
83-
// Set includeBool = false here because the following extract_scalar for int
84-
// use includeBool = False. Have deal with boolean type separately.
85-
if (start.isIntegral(false)) {
86-
check_scalar_value<int64_t>(start);
87-
} else if (start.isFloatingPoint()) {
88-
check_scalar_value<double>(start);
89-
} else if (start.isBoolean()) {
90-
check_scalar_value<bool>(start);
91-
} else {
92-
ET_CHECK_MSG(
93-
false,
94-
"Unexepcted type of start. Should be floating point or int type");
95-
}
96-
97-
if (end.isIntegral(false)) {
98-
check_scalar_value<int64_t>(end);
99-
} else if (end.isFloatingPoint()) {
100-
check_scalar_value<double>(end);
101-
} else if (end.isBoolean()) {
102-
check_scalar_value<bool>(end);
103-
} else {
104-
ET_CHECK_MSG(
105-
false, "Unexepcted type of end. Should be floating point or int type");
106-
}
107-
108-
if (step.isIntegral(false)) {
109-
check_scalar_value<int64_t>(step);
110-
} else if (step.isFloatingPoint()) {
111-
check_scalar_value<double>(step);
112-
} else if (step.isBoolean()) {
113-
check_scalar_value<bool>(step);
114-
} else {
115-
ET_CHECK_MSG(
116-
false, "Unexepcted type of step. Should be floating point or int type");
117-
}
118-
};
119-
120-
template <class CTYPE>
121-
void check_end(const Scalar end) {
122-
CTYPE end_v = 0;
123-
bool ok = utils::extract_scalar(end, &end_v);
124-
ET_CHECK_MSG(ok, "Invalid alpha value: wrong type or out of range");
125-
ET_CHECK_MSG(end_v >= 0, "end shall be larger than or equal to 0\n");
126-
}
127-
128-
// end here is non-negative scalar, so we can floor it by casting it to int.
129-
template <class CTYPE>
130-
int64_t floor_scalar_to_nearest_int(const Scalar end) {
131-
CTYPE end_v = 0;
132-
bool ok = utils::extract_scalar(end, &end_v);
133-
ET_CHECK_MSG(end_v >= 0, "Input end should be non-negative.");
134-
ET_CHECK_MSG(ok, "Invalid end value: wrong type or out of range");
135-
return static_cast<int64_t>(end_v);
136-
}
137-
138-
void check_precondition(const Scalar end, Tensor& out) {
139-
// Check the type consistency between scalar end and tensor out.
140-
// They should be in floating point or integer simultaneously.
141-
#define CHECK_FLOAT_TENSOR(ctype, dtype) \
142-
case ScalarType::dtype: \
143-
ET_CHECK_MSG( \
144-
end.isFloatingPoint(), \
145-
"end should have same type as out.dtype, but get \
146-
non-floating point end and a floating point out tensor"); \
147-
break;
148-
149-
#define CHECK_INT_TENSOR(ctype, dtype) \
150-
case ScalarType::dtype: \
151-
ET_CHECK_MSG( \
152-
end.isIntegral(true), \
153-
"end should have same type as out, \
154-
but get non-int end and a int out tensor"); \
155-
break;
156-
157-
switch (out.scalar_type()) {
158-
ET_FORALL_FLOAT_TYPES(CHECK_FLOAT_TENSOR);
159-
ET_FORALL_INT_TYPES_AND(Bool, CHECK_INT_TENSOR);
160-
default:
161-
ET_CHECK_MSG(
162-
false,
163-
"out tensor should be in floating point or int dtype, but get %hhd",
164-
out.scalar_type());
165-
}
166-
167-
#undef CHECK_FLOAT_TENSOR
168-
#undef CHECK_INT_TENSOR
169-
170-
ET_CHECK_MSG(
171-
out.sizes().size() == 1,
172-
"out should be a 1-d tensor, but got a %zu-d tensor",
173-
out.sizes().size());
174-
175-
// Check if out size matches end.
176-
177-
// Set includeBool = false here because the following extract_scalar for int
178-
// use includeBool = False. Have deal with boolean type separately.
179-
if (end.isIntegral(false)) {
180-
check_end<int64_t>(end);
181-
} else if (end.isFloatingPoint()) {
182-
check_end<double>(end);
183-
} else if (end.isBoolean()) {
184-
check_end<bool>(end);
185-
} else {
186-
ET_CHECK_MSG(
187-
false, "Unexepcted type of end. Should be floating point or int type");
188-
}
189-
};
190-
191-
template <class CTYPE>
192-
Tensor& set_arange_value(const size_t out_length, Tensor& out) {
193-
auto out_data = out.mutable_data_ptr<CTYPE>();
194-
for (size_t i = 0; i < out_length; i++) {
195-
out_data[i] = static_cast<CTYPE>(i);
196-
}
197-
return out;
198-
}
199-
200-
template <class CTYPE>
201-
Tensor& set_arange_value(
202-
const double start,
203-
const int64_t out_length,
204-
const double step,
205-
Tensor& out) {
206-
auto out_data = out.mutable_data_ptr<CTYPE>();
207-
for (int64_t i = 0; i < out_length; i++) {
208-
out_data[i] = start + i * step;
209-
}
210-
return out;
211-
}
212-
213-
} // namespace
214-
215-
/*
216-
* Fill out tensor using arange(0, end)
217-
*
218-
* arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!)
219-
*/
22021
Tensor& arange_out(RuntimeContext& ctx, const Scalar& end, Tensor& out) {
221-
check_precondition(end, out);
222-
223-
int64_t end_floor = 0;
224-
if (end.isIntegral(false)) {
225-
end_floor = floor_scalar_to_nearest_int<int64_t>(end);
226-
} else if (end.isFloatingPoint()) {
227-
end_floor = floor_scalar_to_nearest_int<double>(end);
228-
} else if (end.isBoolean()) {
229-
end_floor = floor_scalar_to_nearest_int<bool>(end);
230-
} else {
231-
ET_CHECK_MSG(false, "Unhandled scalar type");
232-
}
233-
234-
Tensor::SizesType out_target_length =
235-
static_cast<Tensor::SizesType>(end_floor);
236-
Error status = resize_tensor(out, {&out_target_length, 1});
237-
ET_CHECK_MSG(status == Error::Ok, "resize_tensor fails\n");
238-
239-
#define SET_ARANGE_VALUE_TO_TENSOR(ctype, dtype) \
240-
case ScalarType::dtype: \
241-
out = set_arange_value<ctype>(end_floor, out); \
242-
break;
243-
244-
switch (out.scalar_type()) {
245-
ET_FORALL_REAL_TYPES_AND(Bool, SET_ARANGE_VALUE_TO_TENSOR)
246-
default:
247-
ET_CHECK_MSG(
248-
false,
249-
"out tensor should be in floating point or int dtype, but get %hhd",
250-
out.scalar_type());
251-
}
252-
#undef SET_ARANGE_VALUE_TO_TENSOR
22+
ET_KERNEL_CHECK_MSG(
23+
ctx,
24+
out.dim() == 1,
25+
InvalidArgument,
26+
out,
27+
"out should be a 1-d tensor, but got a %zu-d tensor",
28+
out.dim());
29+
30+
ScalarType end_type = utils::get_scalar_dtype(end);
31+
32+
double end_val = 0;
33+
ET_SWITCH_SCALAR_OBJ_TYPES(end_type, ctx, __func__, CTYPE_END, [&]() {
34+
CTYPE_END end_v;
35+
ET_EXTRACT_SCALAR(end, end_v);
36+
ET_KERNEL_CHECK_MSG(
37+
ctx,
38+
end_v >= 0,
39+
InvalidArgument,
40+
out,
41+
"Input end should be non-negative.");
42+
end_val = static_cast<double>(end_v);
43+
});
44+
45+
size_t size = static_cast<size_t>(std::ceil(end_val));
46+
47+
Tensor::SizesType out_length = static_cast<Tensor::SizesType>(size);
48+
Error status = resize_tensor(out, {&out_length, 1});
49+
ET_KERNEL_CHECK_MSG(
50+
ctx, status == Error::Ok, InvalidArgument, out, "resize_tensor fails");
51+
52+
ET_SWITCH_REAL_TYPES(out.scalar_type(), ctx, __func__, CTYPE, [&]() {
53+
auto out_data = out.mutable_data_ptr<CTYPE>();
54+
for (size_t i = 0; i < size; i++) {
55+
out_data[i] = static_cast<CTYPE>(i);
56+
}
57+
});
25358

25459
return out;
25560
}
@@ -261,66 +66,53 @@ Tensor& arange_start_out(
26166
const Scalar& step,
26267
Tensor& out) {
26368
(void)ctx;
264-
check_precondition(start, end, step, out);
26569

266-
double d_start;
267-
if (start.isIntegral(false)) {
268-
d_start = extract_scalar_to_double<int64_t>(start);
269-
} else if (start.isFloatingPoint()) {
270-
d_start = extract_scalar_to_double<double>(start);
271-
} else if (start.isBoolean()) {
272-
d_start = extract_scalar_to_double<bool>(start);
273-
} else {
274-
ET_CHECK_MSG(false, "Unhandled scalar type");
275-
}
70+
ScalarType start_type = utils::get_scalar_dtype(start);
71+
ScalarType end_type = utils::get_scalar_dtype(end);
72+
ScalarType step_type = utils::get_scalar_dtype(step);
27673

277-
double d_end;
278-
if (end.isIntegral(false)) {
279-
d_end = extract_scalar_to_double<int64_t>(end);
280-
} else if (end.isFloatingPoint()) {
281-
d_end = extract_scalar_to_double<double>(end);
282-
} else if (end.isBoolean()) {
283-
d_end = extract_scalar_to_double<bool>(end);
284-
} else {
285-
ET_CHECK_MSG(false, "Unhandled scalar type");
286-
}
74+
double d_start = 0;
75+
ET_SWITCH_SCALAR_OBJ_TYPES(start_type, ctx, __func__, CTYPE_END, [&]() {
76+
CTYPE_END start_v;
77+
ET_EXTRACT_SCALAR(start, start_v);
78+
d_start = static_cast<double>(start_v);
79+
});
28780

288-
double d_step = 0;
289-
if (step.isIntegral(false)) {
290-
d_step = extract_scalar_to_double<int64_t>(step);
291-
} else if (step.isFloatingPoint()) {
292-
d_step = extract_scalar_to_double<double>(step);
293-
} else if (step.isBoolean()) {
294-
d_step = extract_scalar_to_double<bool>(step);
295-
} else {
296-
ET_CHECK_MSG(false, "Unhandled scalar type");
297-
}
81+
double d_end = 0;
82+
ET_SWITCH_SCALAR_OBJ_TYPES(end_type, ctx, __func__, CTYPE_END, [&]() {
83+
CTYPE_END end_v;
84+
ET_EXTRACT_SCALAR(end, end_v);
85+
d_end = static_cast<double>(end_v);
86+
});
29887

299-
ET_CHECK_MSG(
88+
double d_step = 0;
89+
ET_SWITCH_SCALAR_OBJ_TYPES(step_type, ctx, __func__, CTYPE_END, [&]() {
90+
CTYPE_END step_v;
91+
ET_EXTRACT_SCALAR(step, step_v);
92+
d_step = static_cast<double>(step_v);
93+
});
94+
95+
ET_KERNEL_CHECK_MSG(
96+
ctx,
30097
(d_step > 0 && (d_end >= d_start)) || (d_step < 0 && (d_end <= d_start)),
98+
InvalidArgument,
99+
out,
301100
"upper bound and larger bound inconsistent with step sign");
302101

303102
double size_d = (d_end - d_start) / d_step;
304-
int64_t size = static_cast<int64_t>(std::ceil(size_d));
305-
306-
Tensor::SizesType out_target_length = static_cast<Tensor::SizesType>(size);
307-
Error status = resize_tensor(out, {&out_target_length, 1});
308-
ET_CHECK_MSG(status == Error::Ok, "resize_tensor fails\n");
309-
310-
#define SET_START_ARANGE_VALUE_TO_TENSOR(ctype, dtype) \
311-
case ScalarType::dtype: \
312-
out = set_arange_value<ctype>(d_start, size, d_step, out); \
313-
break;
314-
315-
switch (out.scalar_type()) {
316-
ET_FORALL_REAL_TYPES_AND(Bool, SET_START_ARANGE_VALUE_TO_TENSOR)
317-
default:
318-
ET_CHECK_MSG(
319-
false,
320-
"out tensor should be in floating point or int dtype, but get %hhd",
321-
out.scalar_type());
322-
}
323-
#undef SET_START_ARANGE_VALUE_TO_TENSOR
103+
size_t size = static_cast<size_t>(std::ceil(size_d));
104+
105+
Tensor::SizesType out_length = static_cast<Tensor::SizesType>(size);
106+
Error status = resize_tensor(out, {&out_length, 1});
107+
ET_KERNEL_CHECK_MSG(
108+
ctx, status == Error::Ok, InvalidArgument, out, "resize_tensor fails");
109+
110+
ET_SWITCH_REAL_TYPES(out.scalar_type(), ctx, __func__, CTYPE, [&]() {
111+
auto out_data = out.mutable_data_ptr<CTYPE>();
112+
for (size_t i = 0; i < size; i++) {
113+
out_data[i] = convert<CTYPE, double>(d_start + i * d_step);
114+
}
115+
});
324116

325117
return out;
326118
}

0 commit comments

Comments
 (0)