Skip to content

Commit 8f2f15e

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Merge arange and arange_start files
Summary: The implementations of arange.out and arange.start_out should live in the same .cpp file (op_arange.cpp). Also, the their tests should live in the same .cpp test file (op_arange_test.cpp). Reviewed By: SS-JIA Differential Revision: D47347092 fbshipit-source-id: bc7aaf82310d359df4e7472abfc5fda2b4647756
1 parent 07958b5 commit 8f2f15e

File tree

6 files changed

+412
-464
lines changed

6 files changed

+412
-464
lines changed

kernels/portable/cpu/op_arange.cpp

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,101 @@ using exec_aten::Tensor;
1515

1616
namespace {
1717

18+
template <class CTYPE>
19+
void check_scalar_value(const Scalar val) {
20+
CTYPE _v = 0;
21+
bool ok = utils::extract_scalar(val, &_v);
22+
ET_CHECK_MSG(ok, "Invalid alpha value: wrong type or out of range");
23+
}
24+
25+
template <class CTYPE>
26+
double extract_scalar_to_double(const Scalar val) {
27+
CTYPE v = 0;
28+
bool ok = utils::extract_scalar(val, &v);
29+
ET_CHECK_MSG(ok, "Invalid end value: wrong type or out of range");
30+
return static_cast<double>(v);
31+
}
32+
33+
void check_precondition(
34+
const Scalar start,
35+
const Scalar end,
36+
const Scalar step,
37+
Tensor& out) {
38+
// Check the type consistency between scalar end and tensor out.
39+
// They should be in floating point or integer simultaneously.
40+
#define CHECK_FLOAT_TENSOR(ctype, dtype) \
41+
case ScalarType::dtype: \
42+
ET_CHECK_MSG( \
43+
end.isFloatingPoint(), \
44+
"end should have same type as out.dtype, but get \
45+
non-floating point end and a floating point out tensor"); \
46+
break;
47+
48+
#define CHECK_INT_TENSOR(ctype, dtype) \
49+
case ScalarType::dtype: \
50+
ET_CHECK_MSG( \
51+
end.isIntegral(true), \
52+
"end should have same type as out, \
53+
but get non-int end and a int out tensor"); \
54+
break;
55+
56+
switch (out.scalar_type()) {
57+
ET_FORALL_FLOAT_TYPES(CHECK_FLOAT_TENSOR);
58+
ET_FORALL_INT_TYPES_AND(Bool, CHECK_INT_TENSOR);
59+
default:
60+
ET_CHECK_MSG(
61+
false,
62+
"out tensor should be in floating point or int dtype, but get %hhd",
63+
out.scalar_type());
64+
}
65+
66+
#undef CHECK_FLOAT_TENSOR
67+
#undef CHECK_INT_TENSOR
68+
69+
ET_CHECK_MSG(
70+
out.sizes().size() == 1,
71+
"out should be a 1-d tensor, but got a %zu-d tensor",
72+
out.sizes().size());
73+
74+
// Check if out size matches end.
75+
76+
// Set includeBool = false here because the following extract_scalar for int
77+
// use includeBool = False. Have deal with boolean type separately.
78+
if (start.isIntegral(false)) {
79+
check_scalar_value<int64_t>(start);
80+
} else if (start.isFloatingPoint()) {
81+
check_scalar_value<double>(start);
82+
} else if (start.isBoolean()) {
83+
check_scalar_value<bool>(start);
84+
} else {
85+
ET_CHECK_MSG(
86+
false,
87+
"Unexepcted type of start. Should be floating point or int type");
88+
}
89+
90+
if (end.isIntegral(false)) {
91+
check_scalar_value<int64_t>(end);
92+
} else if (end.isFloatingPoint()) {
93+
check_scalar_value<double>(end);
94+
} else if (end.isBoolean()) {
95+
check_scalar_value<bool>(end);
96+
} else {
97+
ET_CHECK_MSG(
98+
false, "Unexepcted type of end. Should be floating point or int type");
99+
}
100+
101+
if (step.isIntegral(false)) {
102+
check_scalar_value<int64_t>(step);
103+
} else if (step.isFloatingPoint()) {
104+
check_scalar_value<double>(step);
105+
} else if (step.isBoolean()) {
106+
check_scalar_value<bool>(step);
107+
} else {
108+
ET_CHECK_MSG(
109+
false, "Unexepcted type of step. Should be floating point or int type");
110+
}
111+
};
112+
18113
template <class CTYPE>
19114
void check_end(const Scalar end) {
20115
CTYPE end_v = 0;
@@ -95,6 +190,19 @@ Tensor& set_arange_value(const size_t out_length, Tensor& out) {
95190
return out;
96191
}
97192

193+
template <class CTYPE>
194+
Tensor& set_arange_value(
195+
const double start,
196+
const int64_t out_length,
197+
const double step,
198+
Tensor& out) {
199+
auto out_data = out.data_ptr<CTYPE>();
200+
for (int64_t i = 0; i < out_length; i++) {
201+
out_data[i] = start + i * step;
202+
}
203+
return out;
204+
}
205+
98206
} // namespace
99207

100208
/*
@@ -139,6 +247,77 @@ Tensor& arange_out(RuntimeContext& context, const Scalar& end, Tensor& out) {
139247
return out;
140248
}
141249

250+
Tensor& arange_start_out(
251+
RuntimeContext& context,
252+
const Scalar& start,
253+
const Scalar& end,
254+
const Scalar& step,
255+
Tensor& out) {
256+
(void)context;
257+
check_precondition(start, end, step, out);
258+
259+
double d_start;
260+
if (start.isIntegral(false)) {
261+
d_start = extract_scalar_to_double<int64_t>(start);
262+
} else if (start.isFloatingPoint()) {
263+
d_start = extract_scalar_to_double<double>(start);
264+
} else if (start.isBoolean()) {
265+
d_start = extract_scalar_to_double<bool>(start);
266+
} else {
267+
ET_CHECK_MSG(false, "Unhandled scalar type");
268+
}
269+
270+
double d_end;
271+
if (end.isIntegral(false)) {
272+
d_end = extract_scalar_to_double<int64_t>(end);
273+
} else if (end.isFloatingPoint()) {
274+
d_end = extract_scalar_to_double<double>(end);
275+
} else if (end.isBoolean()) {
276+
d_end = extract_scalar_to_double<bool>(end);
277+
} else {
278+
ET_CHECK_MSG(false, "Unhandled scalar type");
279+
}
280+
281+
double d_step = 0;
282+
if (step.isIntegral(false)) {
283+
d_step = extract_scalar_to_double<int64_t>(step);
284+
} else if (step.isFloatingPoint()) {
285+
d_step = extract_scalar_to_double<double>(step);
286+
} else if (step.isBoolean()) {
287+
d_step = extract_scalar_to_double<bool>(step);
288+
} else {
289+
ET_CHECK_MSG(false, "Unhandled scalar type");
290+
}
291+
292+
ET_CHECK_MSG(
293+
(d_step > 0 && (d_end >= d_start)) || (d_step < 0 && (d_end <= d_start)),
294+
"upper bound and larger bound inconsistent with step sign");
295+
296+
double size_d = (d_end - d_start) / d_step;
297+
int64_t size = static_cast<int64_t>(std::ceil(size_d));
298+
299+
Tensor::SizesType out_target_length = static_cast<Tensor::SizesType>(size);
300+
Error status = resize_tensor(out, {&out_target_length, 1});
301+
ET_CHECK_MSG(status == Error::Ok, "resize_tensor fails\n");
302+
303+
#define SET_START_ARANGE_VALUE_TO_TENSOR(ctype, dtype) \
304+
case ScalarType::dtype: \
305+
out = set_arange_value<ctype>(d_start, size, d_step, out); \
306+
break;
307+
308+
switch (out.scalar_type()) {
309+
ET_FORALL_REAL_TYPES_AND(Bool, SET_START_ARANGE_VALUE_TO_TENSOR)
310+
default:
311+
ET_CHECK_MSG(
312+
false,
313+
"out tensor should be in floating point or int dtype, but get %hhd",
314+
out.scalar_type());
315+
}
316+
#undef SET_START_ARANGE_VALUE_TO_TENSOR
317+
318+
return out;
319+
}
320+
142321
} // namespace native
143322
} // namespace executor
144323
} // namespace torch

0 commit comments

Comments
 (0)