@@ -22,45 +22,74 @@ using ScalarType = exec_aten::ScalarType;
22
22
23
23
namespace {
24
24
25
+ // NOTE: we bake ArrayRef iterators being pointers into the return
26
+ // type here because we assume that iterators are portable across
27
+ // ArrayRef copies.
28
+ const Tensor::SizesType* arrayref_begin_ignoring_leading_1s (
29
+ ArrayRef<Tensor::SizesType> arr) {
30
+ return std::find_if (
31
+ arr.begin (), arr.end (), [](Tensor::SizesType x) { return x != 1 ; });
32
+ }
33
+
25
34
bool sizes_match_ignoring_leading_1s (
26
35
ArrayRef<Tensor::SizesType> lhs,
27
36
ArrayRef<Tensor::SizesType> rhs) {
28
- auto lhs_begin = lhs. begin ( );
37
+ auto lhs_begin = arrayref_begin_ignoring_leading_1s (lhs );
29
38
auto lhs_end = lhs.end ();
30
- while (lhs_begin != lhs_end && *lhs_begin == 1 ) {
31
- ++lhs_begin;
32
- }
33
39
34
- auto rhs_begin = rhs. begin ( );
40
+ auto rhs_begin = arrayref_begin_ignoring_leading_1s (rhs );
35
41
auto rhs_end = rhs.end ();
36
- while (rhs_begin != rhs_end && *rhs_begin == 1 ) {
37
- ++rhs_begin;
38
- }
39
42
40
43
return ((lhs_end - lhs_begin) == (rhs_end - rhs_begin)) &&
41
44
std::equal (lhs_begin, lhs_end, rhs_begin);
42
45
}
43
46
44
47
// Move to generic util as this is applicable to all binary ops
45
- bool can_use_optimized_path (
46
- const Tensor& a,
47
- const Tensor& b,
48
- const Tensor& out) {
48
+ enum class ElementwiseOptimizedPath {
49
+ kNone ,
50
+ kTreatAs1d ,
51
+ kBroadcast2dBy1d ,
52
+ kBroadcast2dBy1dReverseArguments ,
53
+ };
54
+
55
+ ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path (
56
+ const Tensor& lhs,
57
+ const Tensor& rhs) {
58
+ auto lhs_begin = arrayref_begin_ignoring_leading_1s (lhs.sizes ());
59
+ auto lhs_end = lhs.sizes ().end ();
60
+
61
+ auto rhs_begin = arrayref_begin_ignoring_leading_1s (rhs.sizes ());
62
+ auto rhs_end = rhs.sizes ().end ();
63
+
64
+ const auto lhs_size = lhs_end - lhs_begin;
65
+ const auto rhs_size = rhs_end - rhs_begin;
66
+ if (lhs_size == 2 && rhs_size == 1 && lhs_begin[1 ] == rhs_begin[0 ]) {
67
+ return ElementwiseOptimizedPath::kBroadcast2dBy1d ;
68
+ }
69
+
70
+ if (lhs_size == 1 && rhs_size == 2 && rhs_begin[1 ] == lhs_begin[0 ]) {
71
+ return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ;
72
+ }
73
+
74
+ return ElementwiseOptimizedPath::kNone ;
75
+ }
76
+
77
+ ElementwiseOptimizedPath
78
+ select_optimized_path (const Tensor& a, const Tensor& b, const Tensor& out) {
49
79
ScalarType a_type = a.scalar_type ();
50
80
ScalarType b_type = b.scalar_type ();
51
81
ScalarType out_type = out.scalar_type ();
52
82
53
- bool can_use_optimized_path = true ;
54
- can_use_optimized_path =
55
- can_use_optimized_path && ((a_type == b_type) && (a_type == out_type));
56
- can_use_optimized_path = can_use_optimized_path &&
57
- (a_type != ScalarType::Half && b_type != ScalarType::Half);
58
- can_use_optimized_path = can_use_optimized_path &&
59
- (a.sizes ().equals (b.sizes ()) ||
60
- (a.numel () == b.numel () &&
61
- (a.numel () == out.numel () ||
62
- sizes_match_ignoring_leading_1s (a.sizes (), b.sizes ()))));
63
- return can_use_optimized_path;
83
+ if (a_type != b_type || a_type != out_type || a_type == ScalarType::Half) {
84
+ return ElementwiseOptimizedPath::kNone ;
85
+ }
86
+ if (a.sizes ().equals (b.sizes ()) ||
87
+ (a.numel () == b.numel () &&
88
+ (a.numel () == out.numel () ||
89
+ sizes_match_ignoring_leading_1s (a.sizes (), b.sizes ())))) {
90
+ return ElementwiseOptimizedPath::kTreatAs1d ;
91
+ }
92
+ return select_broadcast_2d_by_1d_optimized_path (a, b);
64
93
}
65
94
66
95
template <
@@ -147,7 +176,8 @@ Tensor& opt_mul_out(
147
176
return opt_mul_out (ctx, b, a, out);
148
177
}
149
178
150
- if (can_use_optimized_path (a, b, out)) {
179
+ auto selected_optimized_path = select_optimized_path (a, b, out);
180
+ if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d ) {
151
181
// Resize for dynamic shape
152
182
auto error = resize_tensor (out, a.sizes ());
153
183
ET_KERNEL_CHECK_MSG (
@@ -166,6 +196,38 @@ Tensor& opt_mul_out(
166
196
b.const_data_ptr <CTYPE>(),
167
197
out.numel ());
168
198
});
199
+ } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone ) {
200
+ const Tensor* lhs;
201
+ const Tensor* rhs;
202
+ if (selected_optimized_path ==
203
+ ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ) {
204
+ lhs = &b;
205
+ rhs = &a;
206
+ } else {
207
+ // Catch failure to update logic when adding new broadcasting possibility.
208
+ ET_DCHECK (
209
+ selected_optimized_path ==
210
+ ElementwiseOptimizedPath::kBroadcast2dBy1d );
211
+ lhs = &a;
212
+ rhs = &b;
213
+ }
214
+ auto error = resize_tensor (out, lhs->sizes ());
215
+ ET_KERNEL_CHECK_MSG (
216
+ ctx,
217
+ error == Error::Ok,
218
+ InvalidArgument,
219
+ out,
220
+ " Failed to resize output tensor." );
221
+ ET_SWITCH_REALB_TYPES (out_type, ctx, " mul.out" , CTYPE, [&]() {
222
+ using Vec = executorch::vec::Vectorized<CTYPE>;
223
+ executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
224
+ [](Vec x, Vec y) { return x * y; },
225
+ out.mutable_data_ptr <CTYPE>(),
226
+ lhs->const_data_ptr <CTYPE>(),
227
+ rhs->const_data_ptr <CTYPE>(),
228
+ lhs->sizes ()[lhs->dim () - 2 ],
229
+ lhs->sizes ()[lhs->dim () - 1 ]);
230
+ });
169
231
} else {
170
232
ScalarType common_type =
171
233
promoteTypes (a_type, b_type, /* half_to_float*/ true );
0 commit comments