1
1
// Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
2
3
#include < executorch/kernels/kernel_includes.h>
3
4
#include < executorch/kernels/portable/cpu/scalar_utils.h>
4
5
#include < executorch/kernels/portable/cpu/util/broadcast_util.h>
@@ -8,145 +9,45 @@ namespace torch {
8
9
namespace executor {
9
10
namespace native {
10
11
11
- using Tensor = exec_aten::Tensor;
12
- using ScalarType = exec_aten::ScalarType;
13
- using Scalar = exec_aten::Scalar;
14
-
15
- namespace {
16
-
17
- template <typename CTYPE_A, typename CTYPE_B, typename CTYPE_OUT>
18
- void add_tensors_impl (
19
- const Tensor& a,
20
- const Tensor& b,
21
- const Scalar& alpha,
22
- Tensor& out) {
23
- // Alpha multiplication is performed in double to maximize precision
24
- double alpha_val = 0 ;
25
- bool ok = utils::extract_scalar (alpha, &alpha_val);
26
- ET_CHECK_MSG (ok, " Invalid alpha value: wrong type or out of range" );
27
-
28
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
29
- [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
30
- CTYPE_OUT a_casted = static_cast <CTYPE_OUT>(val_a);
31
-
32
- if (alpha_val == 1 .0f ) {
33
- CTYPE_OUT b_casted = static_cast <CTYPE_OUT>(val_b);
34
- return a_casted + b_casted;
35
- }
36
-
37
- double b_casted = static_cast <double >(val_b);
38
- return a_casted + static_cast <CTYPE_OUT>(alpha_val * b_casted);
39
- },
40
- a,
41
- b,
42
- out);
43
- }
44
-
45
- template <typename CTYPE_A, typename CTYPE_B>
46
- void add_tensors_switch_out (
47
- const Tensor& a,
48
- const Tensor& b,
49
- const Scalar& alpha,
50
- Tensor& out) {
51
- #define ADD_TENSORS_SWITCH_OUT_CASE (ctype, dtype ) \
52
- case ScalarType::dtype: \
53
- add_tensors_impl<CTYPE_A, CTYPE_B, ctype>(a, b, alpha, out); \
54
- break ;
55
-
56
- switch (out.scalar_type ()) {
57
- ET_FORALL_REAL_TYPES_AND (Bool, ADD_TENSORS_SWITCH_OUT_CASE)
58
- default :
59
- ET_CHECK_MSG (false , " Unhandled dtype %hhd for out" , out.scalar_type ());
60
- }
61
-
62
- #undef ADD_TENSORS_SWITCH_OUT_CASE
63
- }
64
-
65
- template <typename CTYPE_A>
66
- void add_tensors_switch_b (
67
- const Tensor& a,
68
- const Tensor& b,
69
- const Scalar& alpha,
70
- Tensor& out) {
71
- #define ADD_TENSORS_SWITCH_B_CASE (ctype, dtype ) \
72
- case ScalarType::dtype: \
73
- add_tensors_switch_out<CTYPE_A, ctype>(a, b, alpha, out); \
74
- break ;
75
-
76
- switch (b.scalar_type ()) {
77
- ET_FORALL_REAL_TYPES_AND (Bool, ADD_TENSORS_SWITCH_B_CASE)
78
- default :
79
- ET_CHECK_MSG (false , " Unhandled dtype %hhd for b" , b.scalar_type ());
80
- }
81
-
82
- #undef ADD_TENSORS_SWITCH_B_CASE
83
- }
84
-
85
- void add_tensors_switch_a (
86
- const Tensor& a,
87
- const Tensor& b,
88
- const Scalar& alpha,
89
- Tensor& out) {
90
- #define ADD_TENSORS_SWITCH_A_CASE (ctype, dtype ) \
91
- case ScalarType::dtype: \
92
- add_tensors_switch_b<ctype>(a, b, alpha, out); \
93
- break ;
94
-
95
- switch (a.scalar_type ()) {
96
- ET_FORALL_REAL_TYPES_AND (Bool, ADD_TENSORS_SWITCH_A_CASE)
97
- default :
98
- ET_CHECK_MSG (false , " Unhandled dtype %hhd for a" , a.scalar_type ());
99
- }
100
-
101
- #undef ADD_TENSORS_SWITCH_A_CASE
102
- }
103
-
104
- void check_input_dtypes (
105
- const Tensor& a,
106
- const Tensor& b,
107
- const Scalar& alpha,
108
- Tensor& out) {
109
- // If either input is floating point, the output must also be floating point
110
- if (isFloatingType (a.scalar_type ()) || isFloatingType (b.scalar_type ())) {
111
- ET_CHECK_MSG (
112
- isFloatingType (out.scalar_type ()),
113
- " output must be a floating point type if either input is a floating point type." );
114
- }
115
- // Bool output is only allowed if both inputs are bool
116
- if (out.scalar_type () == ScalarType::Bool) {
117
- ET_CHECK_MSG (
118
- a.scalar_type () == ScalarType::Bool &&
119
- b.scalar_type () == ScalarType::Bool,
120
- " both inputs must be bool type for output to be bool" );
121
- }
122
-
123
- // If both inputs are integral or bool types, then alpha must also be an
124
- // integral type
125
- if (isIntegralType (a.scalar_type (), true ) &&
126
- isIntegralType (b.scalar_type (), true )) {
127
- ET_CHECK_MSG (
128
- alpha.isIntegral (true ),
129
- " alpha must be an integral type if both inputs are integral types" );
130
- }
131
- }
132
-
133
- } // namespace
134
-
135
12
Tensor& add_out (
136
- RuntimeContext& context ,
13
+ RuntimeContext& ctx ,
137
14
const Tensor& a,
138
15
const Tensor& b,
139
16
const Scalar& alpha,
140
17
Tensor& out) {
141
- (void )context ;
18
+ (void )ctx ;
142
19
143
- // Determine output size and resize for dynamic shapes
144
20
resize_to_broadcast_target_size (a, b, out);
145
21
146
- // Check arguments
147
- check_input_dtypes (a, b, alpha, out);
148
-
149
- add_tensors_switch_a (a, b, alpha, out);
22
+ ScalarType a_type = a.scalar_type ();
23
+ ScalarType b_type = b.scalar_type ();
24
+ ScalarType common_type = promoteTypes (a_type, b_type);
25
+ ScalarType out_type = out.scalar_type ();
26
+
27
+ ET_CHECK (canCast (common_type, out_type));
28
+
29
+ ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " add" , CTYPE_A, [&]() {
30
+ ET_SWITCH_REAL_TYPES_AND (Bool, b_type, ctx, " add" , CTYPE_B, [&]() {
31
+ ET_SWITCH_REAL_TYPES_AND (Bool, common_type, ctx, " add" , CTYPE_IN, [&]() {
32
+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, " add" , CTYPE_OUT, [&]() {
33
+ CTYPE_IN alpha_val;
34
+ ET_EXTRACT_SCALAR (alpha, alpha_val);
35
+
36
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
37
+ [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
38
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
39
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
40
+ CTYPE_IN value = a_casted + alpha_val * b_casted;
41
+
42
+ return static_cast <CTYPE_OUT>(value);
43
+ },
44
+ a,
45
+ b,
46
+ out);
47
+ });
48
+ });
49
+ });
50
+ });
150
51
151
52
return out;
152
53
}
0 commit comments