Skip to content

Commit afffe5a

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Clean up binary arithmetic ops
Summary: Refactor `add`, `div,` `sub`, and `mul` to use the `ET_SWITCH_FOR` macros. Reviewed By: kimishpatel Differential Revision: D46990101 fbshipit-source-id: 48249a869f3272fcf6803101e940ab45198eac72
1 parent bf91690 commit afffe5a

File tree

4 files changed

+134
-447
lines changed

4 files changed

+134
-447
lines changed

kernels/portable/cpu/op_add.cpp

Lines changed: 32 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
23
#include <executorch/kernels/kernel_includes.h>
34
#include <executorch/kernels/portable/cpu/scalar_utils.h>
45
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
@@ -8,145 +9,45 @@ namespace torch {
89
namespace executor {
910
namespace native {
1011

11-
using Tensor = exec_aten::Tensor;
12-
using ScalarType = exec_aten::ScalarType;
13-
using Scalar = exec_aten::Scalar;
14-
15-
namespace {
16-
17-
template <typename CTYPE_A, typename CTYPE_B, typename CTYPE_OUT>
18-
void add_tensors_impl(
19-
const Tensor& a,
20-
const Tensor& b,
21-
const Scalar& alpha,
22-
Tensor& out) {
23-
// Alpha multiplication is performed in double to maximize precision
24-
double alpha_val = 0;
25-
bool ok = utils::extract_scalar(alpha, &alpha_val);
26-
ET_CHECK_MSG(ok, "Invalid alpha value: wrong type or out of range");
27-
28-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
29-
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
30-
CTYPE_OUT a_casted = static_cast<CTYPE_OUT>(val_a);
31-
32-
if (alpha_val == 1.0f) {
33-
CTYPE_OUT b_casted = static_cast<CTYPE_OUT>(val_b);
34-
return a_casted + b_casted;
35-
}
36-
37-
double b_casted = static_cast<double>(val_b);
38-
return a_casted + static_cast<CTYPE_OUT>(alpha_val * b_casted);
39-
},
40-
a,
41-
b,
42-
out);
43-
}
44-
45-
template <typename CTYPE_A, typename CTYPE_B>
46-
void add_tensors_switch_out(
47-
const Tensor& a,
48-
const Tensor& b,
49-
const Scalar& alpha,
50-
Tensor& out) {
51-
#define ADD_TENSORS_SWITCH_OUT_CASE(ctype, dtype) \
52-
case ScalarType::dtype: \
53-
add_tensors_impl<CTYPE_A, CTYPE_B, ctype>(a, b, alpha, out); \
54-
break;
55-
56-
switch (out.scalar_type()) {
57-
ET_FORALL_REAL_TYPES_AND(Bool, ADD_TENSORS_SWITCH_OUT_CASE)
58-
default:
59-
ET_CHECK_MSG(false, "Unhandled dtype %hhd for out", out.scalar_type());
60-
}
61-
62-
#undef ADD_TENSORS_SWITCH_OUT_CASE
63-
}
64-
65-
template <typename CTYPE_A>
66-
void add_tensors_switch_b(
67-
const Tensor& a,
68-
const Tensor& b,
69-
const Scalar& alpha,
70-
Tensor& out) {
71-
#define ADD_TENSORS_SWITCH_B_CASE(ctype, dtype) \
72-
case ScalarType::dtype: \
73-
add_tensors_switch_out<CTYPE_A, ctype>(a, b, alpha, out); \
74-
break;
75-
76-
switch (b.scalar_type()) {
77-
ET_FORALL_REAL_TYPES_AND(Bool, ADD_TENSORS_SWITCH_B_CASE)
78-
default:
79-
ET_CHECK_MSG(false, "Unhandled dtype %hhd for b", b.scalar_type());
80-
}
81-
82-
#undef ADD_TENSORS_SWITCH_B_CASE
83-
}
84-
85-
void add_tensors_switch_a(
86-
const Tensor& a,
87-
const Tensor& b,
88-
const Scalar& alpha,
89-
Tensor& out) {
90-
#define ADD_TENSORS_SWITCH_A_CASE(ctype, dtype) \
91-
case ScalarType::dtype: \
92-
add_tensors_switch_b<ctype>(a, b, alpha, out); \
93-
break;
94-
95-
switch (a.scalar_type()) {
96-
ET_FORALL_REAL_TYPES_AND(Bool, ADD_TENSORS_SWITCH_A_CASE)
97-
default:
98-
ET_CHECK_MSG(false, "Unhandled dtype %hhd for a", a.scalar_type());
99-
}
100-
101-
#undef ADD_TENSORS_SWITCH_A_CASE
102-
}
103-
104-
void check_input_dtypes(
105-
const Tensor& a,
106-
const Tensor& b,
107-
const Scalar& alpha,
108-
Tensor& out) {
109-
// If either input is floating point, the output must also be floating point
110-
if (isFloatingType(a.scalar_type()) || isFloatingType(b.scalar_type())) {
111-
ET_CHECK_MSG(
112-
isFloatingType(out.scalar_type()),
113-
"output must be a floating point type if either input is a floating point type.");
114-
}
115-
// Bool output is only allowed if both inputs are bool
116-
if (out.scalar_type() == ScalarType::Bool) {
117-
ET_CHECK_MSG(
118-
a.scalar_type() == ScalarType::Bool &&
119-
b.scalar_type() == ScalarType::Bool,
120-
"both inputs must be bool type for output to be bool");
121-
}
122-
123-
// If both inputs are integral or bool types, then alpha must also be an
124-
// integral type
125-
if (isIntegralType(a.scalar_type(), true) &&
126-
isIntegralType(b.scalar_type(), true)) {
127-
ET_CHECK_MSG(
128-
alpha.isIntegral(true),
129-
"alpha must be an integral type if both inputs are integral types");
130-
}
131-
}
132-
133-
} // namespace
134-
13512
Tensor& add_out(
136-
RuntimeContext& context,
13+
RuntimeContext& ctx,
13714
const Tensor& a,
13815
const Tensor& b,
13916
const Scalar& alpha,
14017
Tensor& out) {
141-
(void)context;
18+
(void)ctx;
14219

143-
// Determine output size and resize for dynamic shapes
14420
resize_to_broadcast_target_size(a, b, out);
14521

146-
// Check arguments
147-
check_input_dtypes(a, b, alpha, out);
148-
149-
add_tensors_switch_a(a, b, alpha, out);
22+
ScalarType a_type = a.scalar_type();
23+
ScalarType b_type = b.scalar_type();
24+
ScalarType common_type = promoteTypes(a_type, b_type);
25+
ScalarType out_type = out.scalar_type();
26+
27+
ET_CHECK(canCast(common_type, out_type));
28+
29+
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "add", CTYPE_A, [&]() {
30+
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "add", CTYPE_B, [&]() {
31+
ET_SWITCH_REAL_TYPES_AND(Bool, common_type, ctx, "add", CTYPE_IN, [&]() {
32+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "add", CTYPE_OUT, [&]() {
33+
CTYPE_IN alpha_val;
34+
ET_EXTRACT_SCALAR(alpha, alpha_val);
35+
36+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
37+
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
38+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
39+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
40+
CTYPE_IN value = a_casted + alpha_val * b_casted;
41+
42+
return static_cast<CTYPE_OUT>(value);
43+
},
44+
a,
45+
b,
46+
out);
47+
});
48+
});
49+
});
50+
});
15051

15152
return out;
15253
}

kernels/portable/cpu/op_div.cpp

Lines changed: 42 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -11,96 +11,58 @@ namespace torch {
1111
namespace executor {
1212
namespace native {
1313

14-
using Tensor = exec_aten::Tensor;
15-
using ScalarType = exec_aten::ScalarType;
16-
1714
namespace {
1815

19-
template <typename CTYPE_A, typename CTYPE_B, typename CTYPE_OUT>
20-
void div_tensors_impl(const Tensor& a, const Tensor& b, Tensor& out) {
21-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
22-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
23-
// Perform math in double for all types to maximize precision
24-
double dividend = static_cast<double>(val_a);
25-
double divisor = static_cast<double>(val_b);
26-
double value = dividend / divisor;
27-
28-
return static_cast<CTYPE_OUT>(value);
29-
},
30-
a,
31-
b,
32-
out);
33-
}
34-
35-
template <typename CTYPE_A, typename CTYPE_B>
36-
void div_tensors_switch_out(const Tensor& a, const Tensor& b, Tensor& out) {
37-
#define DIV_TENSORS_SWITCH_OUT_CASE(ctype, dtype) \
38-
case ScalarType::dtype: \
39-
div_tensors_impl<CTYPE_A, CTYPE_B, ctype>(a, b, out); \
40-
break;
41-
42-
switch (out.scalar_type()) {
43-
ET_FORALL_FLOAT_TYPES(DIV_TENSORS_SWITCH_OUT_CASE)
44-
default:
45-
ET_CHECK_MSG(false, "Unhandled dtype %hhd for out", out.scalar_type());
16+
ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) {
17+
ET_CHECK(
18+
!isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type));
19+
ET_CHECK(
20+
!isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type));
21+
22+
if (isFloatingType(a_type) && isFloatingType(b_type)) {
23+
return promoteTypes(a_type, b_type);
24+
} else if (isFloatingType(a_type)) {
25+
return a_type;
26+
} else if (isFloatingType(b_type)) {
27+
return b_type;
4628
}
47-
48-
#undef DIV_TENSORS_SWITCH_OUT_CASE
49-
}
50-
51-
template <typename CTYPE_A>
52-
void div_tensors_switch_b(const Tensor& a, const Tensor& b, Tensor& out) {
53-
#define DIV_TENSORS_SWITCH_B_CASE(ctype, dtype) \
54-
case ScalarType::dtype: \
55-
div_tensors_switch_out<CTYPE_A, ctype>(a, b, out); \
56-
break;
57-
58-
switch (b.scalar_type()) {
59-
ET_FORALL_REAL_TYPES_AND(Bool, DIV_TENSORS_SWITCH_B_CASE)
60-
default:
61-
ET_CHECK_MSG(false, "Unhandled dtype %hhd for b", b.scalar_type());
62-
}
63-
64-
#undef DIV_TENSORS_SWITCH_B_CASE
65-
}
66-
67-
void div_tensors_switch_a(const Tensor& a, const Tensor& b, Tensor& out) {
68-
#define DIV_TENSORS_SWITCH_A_CASE(ctype, dtype) \
69-
case ScalarType::dtype: \
70-
div_tensors_switch_b<ctype>(a, b, out); \
71-
break;
72-
73-
switch (a.scalar_type()) {
74-
ET_FORALL_REAL_TYPES_AND(Bool, DIV_TENSORS_SWITCH_A_CASE)
75-
default:
76-
ET_CHECK_MSG(false, "Unhandled dtype %hhd for a", a.scalar_type());
77-
}
78-
79-
#undef DIV_TENSORS_SWITCH_A_CASE
80-
}
81-
82-
void check_input_dtypes(const Tensor& a, const Tensor& b, Tensor& out) {
83-
ET_CHECK_MSG(
84-
isFloatingType(out.scalar_type()),
85-
"output must be a floating point type.");
29+
return ScalarType::Float;
8630
}
8731

8832
} // namespace
8933

90-
Tensor& div_out(
91-
RuntimeContext& context,
92-
const Tensor& a,
93-
const Tensor& b,
94-
Tensor& out) {
95-
(void)context;
34+
Tensor&
35+
div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
36+
(void)ctx;
9637

97-
// Determine output size and resize for dynamic shapes
9838
resize_to_broadcast_target_size(a, b, out);
9939

100-
// Check arguments
101-
check_input_dtypes(a, b, out);
102-
103-
div_tensors_switch_a(a, b, out);
40+
ScalarType a_type = a.scalar_type();
41+
ScalarType b_type = b.scalar_type();
42+
ScalarType common_type = get_compute_type(a_type, b_type);
43+
ScalarType out_type = out.scalar_type();
44+
45+
ET_CHECK(canCast(common_type, out_type));
46+
47+
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div", CTYPE_A, [&]() {
48+
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div", CTYPE_B, [&]() {
49+
ET_SWITCH_FLOAT_TYPES(common_type, ctx, "div", CTYPE_IN, [&]() {
50+
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "div", CTYPE_OUT, [&]() {
51+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
52+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
53+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
54+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
55+
CTYPE_IN value = a_casted / b_casted;
56+
57+
return static_cast<CTYPE_OUT>(value);
58+
},
59+
a,
60+
b,
61+
out);
62+
});
63+
});
64+
});
65+
});
10466

10567
return out;
10668
}

0 commit comments

Comments
 (0)