Skip to content

Commit b6134e0

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add scalar variants for add, sub, mul, div
Reviewed By: SS-JIA Differential Revision: D47037034 fbshipit-source-id: 5e1b2aed031dd99dcae6bb7ff3310bcc02160f26
1 parent 83fb692 commit b6134e0

File tree

10 files changed

+309
-3
lines changed

10 files changed

+309
-3
lines changed

kernels/aten/functions.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
- op: add.out
3131

32+
- op: add.Scalar_out
33+
3234
- op: addmm.out
3335

3436
- op: amax.out
@@ -104,6 +106,8 @@
104106

105107
- op: div.out
106108

109+
- op: div.Scalar_out
110+
107111
- op: embedding.out
108112

109113
- op: empty.out
@@ -208,6 +212,8 @@
208212

209213
- op: mul.out
210214

215+
- op: mul.Scalar_out
216+
211217
- op: native_batch_norm.out
212218

213219
- op: native_layer_norm.out
@@ -284,6 +290,8 @@
284290

285291
- op: sub.out
286292

293+
- op: sub.Scalar_out
294+
287295
- op: sum.IntList_out
288296

289297
- op: t_copy.out

kernels/portable/cpu/op_add.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) Meta Platforms, Inc. and affiliates.
22

3+
#include <executorch/core/Assert.h>
34
#include <executorch/kernels/kernel_includes.h>
45
#include <executorch/kernels/portable/cpu/scalar_utils.h>
56
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
@@ -52,6 +53,52 @@ Tensor& add_out(
5253
return out;
5354
}
5455

56+
Tensor& add_scalar_out(
57+
RuntimeContext& ctx,
58+
const Tensor& a,
59+
const Scalar& b,
60+
const Scalar& alpha,
61+
Tensor& out) {
62+
(void)ctx;
63+
64+
// Resize for dynamic shape
65+
auto error = resize_tensor(out, a.sizes());
66+
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
67+
68+
ScalarType a_type = a.scalar_type();
69+
ScalarType b_type = utils::get_scalar_dtype(b);
70+
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
71+
ScalarType out_type = out.scalar_type();
72+
73+
ET_CHECK(common_type == out_type);
74+
75+
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "add", CTYPE_A, [&]() {
76+
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "add", CTYPE_B, [&]() {
77+
ET_SWITCH_REAL_TYPES_AND(Bool, common_type, ctx, "add", CTYPE_IN, [&]() {
78+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "add", CTYPE_OUT, [&]() {
79+
CTYPE_B b_val;
80+
ET_EXTRACT_SCALAR(b, b_val);
81+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
82+
CTYPE_IN alpha_val;
83+
ET_EXTRACT_SCALAR(alpha, alpha_val);
84+
85+
apply_unary_map_fn(
86+
[b_casted, alpha_val](const CTYPE_A val_a) {
87+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
88+
CTYPE_IN value = a_casted + alpha_val * b_casted;
89+
return static_cast<CTYPE_OUT>(value);
90+
},
91+
a.const_data_ptr<CTYPE_A>(),
92+
out.mutable_data_ptr<CTYPE_OUT>(),
93+
out.numel());
94+
});
95+
});
96+
});
97+
});
98+
99+
return out;
100+
}
101+
55102
} // namespace native
56103
} // namespace executor
57104
} // namespace torch

kernels/portable/cpu/op_div.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
#include <executorch/core/Assert.h>
44
#include <executorch/kernels/kernel_includes.h>
5+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
56
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
67
#include <executorch/kernels/portable/cpu/util/functional_util.h>
7-
#include <cmath>
8-
#include <type_traits>
98

109
namespace torch {
1110
namespace executor {
@@ -67,6 +66,49 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
6766
return out;
6867
}
6968

69+
Tensor& div_scalar_out(
70+
RuntimeContext& ctx,
71+
const Tensor& a,
72+
const Scalar& b,
73+
Tensor& out) {
74+
(void)ctx;
75+
76+
// Resize for dynamic shape
77+
auto error = resize_tensor(out, a.sizes());
78+
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
79+
80+
ScalarType a_type = a.scalar_type();
81+
ScalarType b_type = utils::get_scalar_dtype(b);
82+
ScalarType common_type = isFloatingType(a_type) ? a_type : ScalarType::Float;
83+
ScalarType out_type = out.scalar_type();
84+
85+
ET_CHECK(common_type == out_type);
86+
87+
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div", CTYPE_A, [&]() {
88+
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div", CTYPE_B, [&]() {
89+
ET_SWITCH_REAL_TYPES(common_type, ctx, "div", CTYPE_IN, [&]() {
90+
ET_SWITCH_REAL_TYPES(out_type, ctx, "div", CTYPE_OUT, [&]() {
91+
CTYPE_B b_val;
92+
ET_EXTRACT_SCALAR(b, b_val);
93+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
94+
95+
apply_unary_map_fn(
96+
[b_casted](const CTYPE_A val_a) {
97+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
98+
CTYPE_IN value = a_casted / b_casted;
99+
return static_cast<CTYPE_OUT>(value);
100+
},
101+
a.const_data_ptr<CTYPE_A>(),
102+
out.mutable_data_ptr<CTYPE_OUT>(),
103+
out.numel());
104+
});
105+
});
106+
});
107+
});
108+
109+
return out;
110+
}
111+
70112
} // namespace native
71113
} // namespace executor
72114
} // namespace torch

kernels/portable/cpu/op_mul.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <executorch/core/Assert.h>
44
#include <executorch/kernels/kernel_includes.h>
5+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
56
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
67
#include <executorch/kernels/portable/cpu/util/functional_util.h>
78

@@ -45,6 +46,49 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
4546
return out;
4647
}
4748

49+
Tensor& mul_scalar_out(
50+
RuntimeContext& ctx,
51+
const Tensor& a,
52+
const Scalar& b,
53+
Tensor& out) {
54+
(void)ctx;
55+
56+
// Resize for dynamic shape
57+
auto error = resize_tensor(out, a.sizes());
58+
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
59+
60+
ScalarType a_type = a.scalar_type();
61+
ScalarType b_type = utils::get_scalar_dtype(b);
62+
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
63+
ScalarType out_type = out.scalar_type();
64+
65+
ET_CHECK(common_type == out_type);
66+
67+
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "mul", CTYPE_A, [&]() {
68+
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "mul", CTYPE_B, [&]() {
69+
ET_SWITCH_REAL_TYPES_AND(Bool, common_type, ctx, "mul", CTYPE_IN, [&]() {
70+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "mul", CTYPE_OUT, [&]() {
71+
CTYPE_B b_val;
72+
ET_EXTRACT_SCALAR(b, b_val);
73+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
74+
75+
apply_unary_map_fn(
76+
[b_casted](const CTYPE_A val_a) {
77+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
78+
CTYPE_IN value = a_casted * b_casted;
79+
return static_cast<CTYPE_OUT>(value);
80+
},
81+
a.const_data_ptr<CTYPE_A>(),
82+
out.mutable_data_ptr<CTYPE_OUT>(),
83+
out.numel());
84+
});
85+
});
86+
});
87+
});
88+
89+
return out;
90+
}
91+
4892
} // namespace native
4993
} // namespace executor
5094
} // namespace torch

kernels/portable/cpu/op_sub.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
#include <executorch/core/Assert.h>
24
#include <executorch/kernels/kernel_includes.h>
35
#include <executorch/kernels/portable/cpu/scalar_utils.h>
46
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
@@ -51,6 +53,52 @@ Tensor& sub_out(
5153
return out;
5254
}
5355

56+
Tensor& sub_scalar_out(
57+
RuntimeContext& ctx,
58+
const Tensor& a,
59+
const Scalar& b,
60+
const Scalar& alpha,
61+
Tensor& out) {
62+
(void)ctx;
63+
64+
// Resize for dynamic shape
65+
auto error = resize_tensor(out, a.sizes());
66+
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
67+
68+
ScalarType a_type = a.scalar_type();
69+
ScalarType b_type = utils::get_scalar_dtype(b);
70+
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
71+
ScalarType out_type = out.scalar_type();
72+
73+
ET_CHECK(common_type == out_type);
74+
75+
ET_SWITCH_REAL_TYPES(a_type, ctx, "sub", CTYPE_A, [&]() {
76+
ET_SWITCH_REAL_TYPES(b_type, ctx, "sub", CTYPE_B, [&]() {
77+
ET_SWITCH_REAL_TYPES(common_type, ctx, "sub", CTYPE_IN, [&]() {
78+
ET_SWITCH_REAL_TYPES(out_type, ctx, "sub", CTYPE_OUT, [&]() {
79+
CTYPE_B b_val;
80+
ET_EXTRACT_SCALAR(b, b_val);
81+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
82+
CTYPE_IN alpha_val;
83+
ET_EXTRACT_SCALAR(alpha, alpha_val);
84+
85+
apply_unary_map_fn(
86+
[b_casted, alpha_val](const CTYPE_A val_a) {
87+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
88+
CTYPE_IN value = a_casted - alpha_val * b_casted;
89+
return static_cast<CTYPE_OUT>(value);
90+
},
91+
a.const_data_ptr<CTYPE_A>(),
92+
out.mutable_data_ptr<CTYPE_OUT>(),
93+
out.numel());
94+
});
95+
});
96+
});
97+
});
98+
99+
return out;
100+
}
101+
54102
} // namespace native
55103
} // namespace executor
56104
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@
5252
- arg_meta: null
5353
kernel_name: torch::executor::add_out
5454

55+
- op: add.Scalar_out
56+
kernels:
57+
- arg_meta: null
58+
kernel_name: torch::executor::add_scalar_out
59+
5560
- op: addmm.out
5661
kernels:
5762
- arg_meta: null
@@ -213,6 +218,11 @@
213218
- arg_meta: null
214219
kernel_name: torch::executor::div_out
215220

221+
- op: div.Scalar_out
222+
kernels:
223+
- arg_meta: null
224+
kernel_name: torch::executor::div_scalar_out
225+
216226
- op: embedding.out
217227
kernels:
218228
- arg_meta: null
@@ -436,6 +446,11 @@
436446
- arg_meta: null
437447
kernel_name: torch::executor::mul_out
438448

449+
- op: mul.Scalar_out
450+
kernels:
451+
- arg_meta: null
452+
kernel_name: torch::executor::mul_scalar_out
453+
439454
- op: native_layer_norm.out
440455
kernels:
441456
- arg_meta: null
@@ -596,6 +611,11 @@
596611
- arg_meta: null
597612
kernel_name: torch::executor::sub_out
598613

614+
- op: sub.Scalar_out
615+
kernels:
616+
- arg_meta: null
617+
kernel_name: torch::executor::sub_scalar_out
618+
599619
- op: sum.IntList_out
600620
kernels:
601621
- arg_meta: null

kernels/test/op_add_test.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ Tensor& add_out(
2727
return torch::executor::aten::add_outf(context, self, other, alpha, out);
2828
}
2929

30+
Tensor& add_scalar_out(
31+
const Tensor& self,
32+
const Scalar& other,
33+
const Scalar& alpha,
34+
Tensor& out) {
35+
exec_aten::RuntimeContext context{};
36+
return torch::executor::aten::add_outf(context, self, other, alpha, out);
37+
}
38+
3039
template <ScalarType DTYPE_A, ScalarType DTYPE_B, ScalarType DTYPE_OUT>
3140
void test_add() {
3241
TensorFactory<DTYPE_A> tf_a;
@@ -512,3 +521,16 @@ TEST(OpAddOutKernelTest, DynamicShapeUnbound) {
512521
Tensor ret = add_out(x, y, 1, out);
513522
EXPECT_TENSOR_CLOSE(out, expected_result);
514523
}
524+
525+
TEST(OpAddScalarOutKernelTest, SanityCheck) {
526+
TensorFactory<ScalarType::Int> tf;
527+
528+
const std::vector<int32_t> sizes = {2, 2};
529+
530+
Tensor out = tf.zeros(sizes);
531+
532+
add_scalar_out(tf.make(sizes, {1, 2, 4, 8}), true, /*alpha=*/2, out);
533+
534+
// Check that it matches the expected output.
535+
EXPECT_TENSOR_EQ(out, tf.make(sizes, {3, 4, 6, 10}));
536+
}

0 commit comments

Comments
 (0)