Skip to content

Commit 42b7be3

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Dtype compliance: where
Reviewed By: digantdesai, SS-JIA Differential Revision: D47738838 fbshipit-source-id: dccbc5e61ef0b86c3b742486df9fd2e6b8a3cb4c
1 parent 842ddf2 commit 42b7be3

File tree

2 files changed

+53488
-126
lines changed

2 files changed

+53488
-126
lines changed

kernels/portable/cpu/op_where.cpp

Lines changed: 37 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -5,116 +5,51 @@
55
namespace torch {
66
namespace executor {
77
namespace native {
8-
using Tensor = exec_aten::Tensor;
9-
using ScalarType = exec_aten::ScalarType;
10-
using Scalar = exec_aten::Scalar;
11-
namespace {
12-
13-
template <typename CTYPE_A, typename CTYPE_B, typename CTYPE_OUT>
14-
void where_tensors_impl(
15-
const Tensor& condition,
16-
const Tensor& a,
17-
const Tensor& b,
18-
Tensor& out) {
19-
apply_ternary_elementwise_fn<CTYPE_A, CTYPE_B, bool, CTYPE_OUT>(
20-
[](const CTYPE_A val_a, const CTYPE_B val_b, const bool val_c) {
21-
CTYPE_OUT a_casted = static_cast<CTYPE_OUT>(val_a);
22-
CTYPE_OUT b_casted = static_cast<CTYPE_OUT>(val_b);
23-
24-
return val_c ? a_casted : b_casted;
25-
},
26-
a,
27-
b,
28-
condition,
29-
out);
30-
}
31-
32-
template <typename CTYPE_A, typename CTYPE_B>
33-
void where_tensors_switch_out(
34-
const Tensor& condition,
35-
const Tensor& a,
36-
const Tensor& b,
37-
Tensor& out) {
38-
#define WHERE_TENSORS_SWITCH_OUT_CASE(ctype, dtype) \
39-
case ScalarType::dtype: \
40-
where_tensors_impl<CTYPE_A, CTYPE_B, ctype>(condition, a, b, out); \
41-
break;
42-
43-
switch (out.scalar_type()) {
44-
ET_FORALL_REAL_TYPES_AND(Bool, WHERE_TENSORS_SWITCH_OUT_CASE)
45-
default:
46-
ET_CHECK_MSG(false, "Unhandled dtype %hhd for out", out.scalar_type());
47-
}
48-
49-
#undef WHERE_TENSORS_SWITCH_OUT_CASE
50-
}
51-
52-
template <typename CTYPE_A>
53-
void where_tensors_switch_b(
54-
const Tensor& condition,
55-
const Tensor& a,
56-
const Tensor& b,
57-
Tensor& out) {
58-
#define WHERE_TENSORS_SWITCH_B_CASE(ctype, dtype) \
59-
case ScalarType::dtype: \
60-
where_tensors_switch_out<CTYPE_A, ctype>(condition, a, b, out); \
61-
break;
62-
63-
switch (b.scalar_type()) {
64-
ET_FORALL_REAL_TYPES_AND(Bool, WHERE_TENSORS_SWITCH_B_CASE)
65-
default:
66-
ET_CHECK_MSG(false, "Unhandled dtype %hhd for b", b.scalar_type());
67-
}
68-
69-
#undef WHERE_TENSORS_SWITCH_B_CASE
70-
}
71-
72-
void where_tensors_switch_a(
73-
const Tensor& condition,
74-
const Tensor& a,
75-
const Tensor& b,
76-
Tensor& out) {
77-
#define WHERE_TENSORS_SWITCH_A_CASE(ctype, dtype) \
78-
case ScalarType::dtype: \
79-
where_tensors_switch_b<ctype>(condition, a, b, out); \
80-
break;
81-
82-
switch (a.scalar_type()) {
83-
ET_FORALL_REAL_TYPES_AND(Bool, WHERE_TENSORS_SWITCH_A_CASE)
84-
default:
85-
ET_CHECK_MSG(false, "Unhandled dtype %hhd for a", a.scalar_type());
86-
}
87-
88-
#undef WHERE_TENSORS_SWITCH_A_CASE
89-
}
90-
91-
void check_input_dtypes(
92-
const Tensor& condition,
93-
const Tensor& a,
94-
const Tensor& b,
95-
Tensor& out) {
96-
ET_CHECK_MSG(
97-
condition.scalar_type() == ScalarType::Bool,
98-
"Condition tensor must be boolean type");
99-
}
100-
101-
} // namespace
1028

1039
Tensor& where_out(
104-
RuntimeContext& context,
105-
const Tensor& condition,
10+
RuntimeContext& ctx,
11+
const Tensor& cond,
10612
const Tensor& a,
10713
const Tensor& b,
10814
Tensor& out) {
109-
(void)context;
15+
(void)ctx;
11016

111-
// Determine output size and resize for dynamic shapes
112-
resize_to_broadcast_target_size(a, b, condition, out);
17+
ScalarType cond_type = cond.scalar_type();
18+
ScalarType a_type = a.scalar_type();
19+
ScalarType b_type = b.scalar_type();
20+
ScalarType common_type = promoteTypes(a_type, b_type);
21+
ScalarType out_type = out.scalar_type();
11322

114-
// Check arguments
115-
check_input_dtypes(condition, a, b, out);
23+
ET_CHECK(common_type == out_type);
11624

117-
where_tensors_switch_a(condition, a, b, out);
25+
// Determine output size and resize for dynamic shapes
26+
resize_to_broadcast_target_size(a, b, cond, out);
27+
28+
ET_SWITCH_TWO_TYPES(Bool, Byte, cond_type, ctx, "where", CTYPE_COND, [&]() {
29+
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "where", CTYPE_A, [&]() {
30+
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "where", CTYPE_B, [&]() {
31+
ET_SWITCH_REAL_TYPES_AND(
32+
Bool, out_type, ctx, "where", CTYPE_OUT, [&]() {
33+
apply_ternary_elementwise_fn<
34+
CTYPE_A,
35+
CTYPE_B,
36+
CTYPE_COND,
37+
CTYPE_OUT>(
38+
[](const CTYPE_A val_a,
39+
const CTYPE_B val_b,
40+
const CTYPE_COND val_c) {
41+
CTYPE_OUT a_casted = static_cast<CTYPE_OUT>(val_a);
42+
CTYPE_OUT b_casted = static_cast<CTYPE_OUT>(val_b);
43+
return val_c ? a_casted : b_casted;
44+
},
45+
a,
46+
b,
47+
cond,
48+
out);
49+
});
50+
});
51+
});
52+
});
11853

11954
return out;
12055
}

0 commit comments

Comments
 (0)