Skip to content

Commit e78e092

Browse files
committed
[Executorch] Add broadcasting support to optimized op_sub
Summary: This diff builds on top of previous one to add support for limited handling of broadcasting for sub Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 9741b34 Pull Request resolved: #8256
1 parent 69d52c6 commit e78e092

File tree

3 files changed

+122
-104
lines changed

3 files changed

+122
-104
lines changed

kernels/optimized/cpu/op_sub.cpp

Lines changed: 5 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include <executorch/runtime/kernel/kernel_includes.h>
1616
#include <executorch/runtime/platform/assert.h>
1717

18+
#include <executorch/kernels/optimized/cpu/op_add_sub_impl.h>
19+
1820
namespace torch {
1921
namespace executor {
2022
namespace native {
@@ -138,110 +140,9 @@ Tensor& opt_sub_out(
138140
}
139141
}
140142

141-
auto selected_optimized_path = select_optimized_path(a, b, out);
142-
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
143-
// Resize for dynamic shape
144-
auto error = resize_tensor(out, a.sizes());
145-
ET_KERNEL_CHECK_MSG(
146-
ctx,
147-
error == Error::Ok,
148-
InvalidArgument,
149-
out,
150-
"Failed to resize output tensor.");
151-
152-
ET_SWITCH_REAL_TYPES(a_type, ctx, "sub.out", CTYPE, [&]() {
153-
CTYPE alpha_val;
154-
ET_KERNEL_CHECK(
155-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
156-
157-
using Vec = executorch::vec::Vectorized<CTYPE>;
158-
executorch::vec::map2<CTYPE>(
159-
[alpha_val](Vec x, Vec y) { return x - Vec(alpha_val) * y; },
160-
out.mutable_data_ptr<CTYPE>(),
161-
a.const_data_ptr<CTYPE>(),
162-
b.const_data_ptr<CTYPE>(),
163-
out.numel());
164-
});
165-
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
166-
const Tensor* lhs;
167-
const Tensor* rhs;
168-
if (selected_optimized_path ==
169-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
170-
lhs = &b;
171-
rhs = &a;
172-
} else {
173-
// Catch failure to update logic when subing new broadcasting possibility.
174-
ET_DCHECK(
175-
selected_optimized_path ==
176-
ElementwiseOptimizedPath::kBroadcast2dBy1d);
177-
lhs = &a;
178-
rhs = &b;
179-
}
180-
auto error = resize_tensor(out, lhs->sizes());
181-
ET_KERNEL_CHECK_MSG(
182-
ctx,
183-
error == Error::Ok,
184-
InvalidArgument,
185-
out,
186-
"Failed to resize output tensor.");
187-
ET_SWITCH_REAL_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() {
188-
CTYPE alpha_val;
189-
ET_KERNEL_CHECK(
190-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
191-
192-
using Vec = executorch::vec::Vectorized<CTYPE>;
193-
if (selected_optimized_path ==
194-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
195-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
196-
[alpha_val](Vec x, Vec y) { return y - Vec(alpha_val) * x; },
197-
out.mutable_data_ptr<CTYPE>(),
198-
lhs->const_data_ptr<CTYPE>(),
199-
rhs->const_data_ptr<CTYPE>(),
200-
lhs->sizes()[lhs->dim() - 2],
201-
lhs->sizes()[lhs->dim() - 1]);
202-
} else {
203-
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
204-
[alpha_val](Vec x, Vec y) { return x - Vec(alpha_val) * y; },
205-
out.mutable_data_ptr<CTYPE>(),
206-
lhs->const_data_ptr<CTYPE>(),
207-
rhs->const_data_ptr<CTYPE>(),
208-
lhs->sizes()[lhs->dim() - 2],
209-
lhs->sizes()[lhs->dim() - 1]);
210-
}
211-
});
212-
} else {
213-
ScalarType common_type =
214-
promoteTypes(a_type, b_type, /*half_to_float*/ true);
215-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
216-
217-
ET_KERNEL_CHECK(
218-
ctx,
219-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
220-
InvalidArgument,
221-
out);
222-
223-
ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.out", CTYPE_A, [&]() {
224-
ET_SWITCH_REALH_TYPES(b_type, ctx, "sub.out", CTYPE_B, [&]() {
225-
using CTYPE_IN = typename torch::executor::
226-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
227-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
228-
ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() {
229-
CTYPE_IN alpha_val;
230-
ET_KERNEL_CHECK(
231-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
232-
233-
SubInner<
234-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
235-
CTYPE_A,
236-
CTYPE_B,
237-
CTYPE_IN,
238-
CTYPE_OUT>::run(a, b, alpha_val, out);
239-
});
240-
});
241-
});
242-
}
243-
244-
return out;
143+
static constexpr const char op_name[] = "sub.out";
144+
return torch::executor::kernels::impl::opt_add_sub_out_impl<true, op_name>(
145+
ctx, a, b, alpha, out);
245146
}
246147

247148
Tensor& opt_sub_scalar_out(

kernels/optimized/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ _OPTIMIZED_ATEN_OPS = (
9595
name = "op_sub",
9696
deps = [
9797
":binary_ops",
98+
":add_sub_impl",
9899
"//executorch/kernels/portable/cpu:scalar_utils",
99100
"//executorch/kernels/portable/cpu/util:broadcast_util",
100101
],

kernels/test/op_sub_test.cpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,109 @@ class OpSubOutTest : public OperatorTest {
9999
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{0.1, 1.2, 3.4, 7.8}));
100100
}
101101

102+
template <ScalarType DTYPE>
103+
void test_broadcast_3D() {
104+
TensorFactory<DTYPE> tf_a;
105+
106+
Tensor a =
107+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
108+
Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7});
109+
110+
// Destination for output of mul.
111+
Tensor out =
112+
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
113+
Tensor expected =
114+
tf_a.make({2, 2, 3}, /*data=*/{-1, -1, -1, 2, 2, 2, 2, 2, 2, 5, 5, 5});
115+
116+
// Check that it matches the expected output.
117+
EXPECT_TENSOR_CLOSE(op_sub_out(a, b, 1.0, out), expected);
118+
// b - a * 1.5 output should be
119+
expected = tf_a.make(
120+
{2, 2, 3},
121+
/*data=*/
122+
{0.5,
123+
0.0,
124+
-0.5,
125+
-4.0,
126+
-4.5,
127+
-5.0,
128+
-5.5,
129+
-6.0,
130+
-6.5,
131+
-10.0,
132+
-10.5,
133+
-11.0});
134+
EXPECT_TENSOR_CLOSE(op_sub_out(b, a, 1.5, out), expected);
135+
}
136+
137+
template <ScalarType DTYPE>
138+
void test_broadcast_4D() {
139+
TensorFactory<DTYPE> tf_a;
140+
141+
Tensor a = tf_a.make(
142+
{2, 2, 3, 5},
143+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
144+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
145+
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
146+
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
147+
Tensor b = tf_a.make(
148+
{2, 1, 3, 5},
149+
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
150+
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
151+
152+
// Destination for output of mul.
153+
Tensor out = tf_a.zeros({2, 2, 3, 5});
154+
Tensor expected = tf_a.make(
155+
{2, 2, 3, 5},
156+
/*data=*/{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
157+
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
158+
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
159+
30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30});
160+
161+
// Check that it matches the expected output.
162+
EXPECT_TENSOR_CLOSE(op_sub_out(a, b, 1.0, out), expected);
163+
expected = tf_a.make(
164+
{2, 2, 3, 5},
165+
/*data=*/{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
166+
0, 0, 0, -15, -15, -15, -15, -15, -15, -15, -15, -15,
167+
-15, -15, -15, -15, -15, -15, -15, -15, -15, -15, -15, -15,
168+
-15, -15, -15, -15, -15, -15, -15, -15, -15, -30, -30, -30,
169+
-30, -30, -30, -30, -30, -30, -30, -30, -30, -30, -30, -30});
170+
EXPECT_TENSOR_CLOSE(op_sub_out(b, a, 1.0, out), expected);
171+
172+
b = tf_a.make(
173+
{2, 2, 1, 5}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
174+
11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
175+
out = tf_a.zeros({2, 2, 3, 5});
176+
expected = tf_a.make(
177+
{2, 2, 3, 5},
178+
/*data=*/{0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 10, 10, 10, 10, 10,
179+
10, 10, 10, 10, 10, 15, 15, 15, 15, 15, 20, 20, 20, 20, 20,
180+
20, 20, 20, 20, 20, 25, 25, 25, 25, 25, 30, 30, 30, 30, 30,
181+
30, 30, 30, 30, 30, 35, 35, 35, 35, 35, 40, 40, 40, 40, 40});
182+
183+
// Check that it matches the expected output.
184+
EXPECT_TENSOR_CLOSE(op_sub_out(a, b, 1.0, out), expected);
185+
expected = tf_a.make(
186+
{2, 2, 3, 5},
187+
/*data=*/{-0.5000, -1.0000, -1.5000, -2.0000, -2.5000,
188+
-8.0000, -8.5000, -9.0000, -9.5000, -10.0000,
189+
-15.5000, -16.0000, -16.5000, -17.0000, -17.5000,
190+
191+
-18.0000, -18.5000, -19.0000, -19.5000, -20.0000,
192+
-25.5000, -26.0000, -26.5000, -27.0000, -27.5000,
193+
-33.0000, -33.5000, -34.0000, -34.5000, -35.0000,
194+
195+
-35.5000, -36.0000, -36.5000, -37.0000, -37.5000,
196+
-43.0000, -43.5000, -44.0000, -44.5000, -45.0000,
197+
-50.5000, -51.0000, -51.5000, -52.0000, -52.5000,
198+
199+
-53.0000, -53.5000, -54.0000, -54.5000, -55.0000,
200+
-60.5000, -61.0000, -61.5000, -62.0000, -62.5000,
201+
-68.0000, -68.5000, -69.0000, -69.5000, -70.0000});
202+
EXPECT_TENSOR_CLOSE(op_sub_out(b, a, 1.5, out), expected);
203+
}
204+
102205
void test_sub_enumerate_a_types() {
103206
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
104207
test_sub_enumerate_b_types<ScalarType::dtype>();
@@ -237,6 +340,19 @@ TEST_F(OpSubOutTest, BroadcastScalarRank0Supported) {
237340
EXPECT_TENSOR_EQ(out, ret);
238341
}
239342

343+
TEST_F(OpSubOutTest, BroadcastNDTest) {
344+
// Test 3D tensors
345+
test_broadcast_3D<ScalarType::Float>();
346+
test_broadcast_3D<ScalarType::Half>();
347+
// Sub doesnt yet support BFloat16
348+
// test_broadcast_3D<ScalarType::BFloat16>();
349+
350+
// Test 4D tensors
351+
test_broadcast_4D<ScalarType::Float>();
352+
test_broadcast_4D<ScalarType::Half>();
353+
// test_broadcast_4D<ScalarType::BFloat16>();
354+
}
355+
240356
//
241357
// Death Tests
242358
//

0 commit comments

Comments
 (0)