Skip to content

Commit 7c3cf08

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add optimized ops add, sub, mul, div
Reviewed By: SS-JIA Differential Revision: D47060957 fbshipit-source-id: 3785ba915ccb3e926ef95c7d437767406780a0ec
1 parent b6134e0 commit 7c3cf08

File tree

9 files changed

+676
-0
lines changed

9 files changed

+676
-0
lines changed

kernels/optimized/cpu/op_add.cpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
#include <executorch/core/Assert.h>
4+
#include <executorch/kernels/kernel_includes.h>
5+
#include <executorch/kernels/optimized/vec/functional.h>
6+
#include <executorch/kernels/optimized/vec/vec.h>
7+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
8+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
9+
10+
namespace torch {
11+
namespace executor {
12+
namespace native {
13+
14+
using Tensor = exec_aten::Tensor;
15+
using ScalarType = exec_aten::ScalarType;
16+
17+
Tensor& opt_add_out(
18+
RuntimeContext& ctx,
19+
const Tensor& a,
20+
const Tensor& b,
21+
const Scalar& alpha,
22+
Tensor& out) {
23+
(void)ctx;
24+
25+
ScalarType a_type = a.scalar_type();
26+
ScalarType b_type = b.scalar_type();
27+
ScalarType out_type = out.scalar_type();
28+
29+
if (a_type == b_type && a_type == out_type && a.sizes().equals(b.sizes())) {
30+
// Resize for dynamic shape
31+
auto error = resize_tensor(out, a.sizes());
32+
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
33+
34+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "add", CTYPE, [&]() {
35+
CTYPE alpha_val;
36+
ET_EXTRACT_SCALAR(alpha, alpha_val);
37+
38+
using Vec = executorch::vec::Vectorized<CTYPE>;
39+
executorch::vec::map2<CTYPE>(
40+
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
41+
out.mutable_data_ptr<CTYPE>(),
42+
a.const_data_ptr<CTYPE>(),
43+
b.const_data_ptr<CTYPE>(),
44+
out.numel());
45+
});
46+
} else {
47+
ScalarType common_type = promoteTypes(a_type, b_type);
48+
ET_CHECK(canCast(common_type, out_type));
49+
50+
resize_to_broadcast_target_size(a, b, out);
51+
52+
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "add", CTYPE_A, [&]() {
53+
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "add", CTYPE_B, [&]() {
54+
ET_SWITCH_REAL_TYPES_AND(
55+
Bool, common_type, ctx, "add", CTYPE_IN, [&]() {
56+
ET_SWITCH_REAL_TYPES_AND(
57+
Bool, out_type, ctx, "add", CTYPE_OUT, [&]() {
58+
CTYPE_IN alpha_val;
59+
ET_EXTRACT_SCALAR(alpha, alpha_val);
60+
61+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
62+
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
63+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
64+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
65+
CTYPE_IN value = a_casted + alpha_val * b_casted;
66+
67+
return static_cast<CTYPE_OUT>(value);
68+
},
69+
a,
70+
b,
71+
out);
72+
});
73+
});
74+
});
75+
});
76+
}
77+
78+
return out;
79+
}
80+
81+
Tensor& opt_add_scalar_out(
82+
RuntimeContext& ctx,
83+
const Tensor& a,
84+
const Scalar& b,
85+
const Scalar& alpha,
86+
Tensor& out) {
87+
(void)ctx;
88+
89+
ScalarType a_type = a.scalar_type();
90+
ScalarType b_type = utils::get_scalar_dtype(b);
91+
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
92+
ScalarType out_type = out.scalar_type();
93+
94+
ET_CHECK(common_type == out_type);
95+
96+
// Resize for dynamic shape
97+
auto error = resize_tensor(out, a.sizes());
98+
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
99+
100+
if (a_type == common_type && a_type == out_type) {
101+
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "add", CTYPE, [&]() {
102+
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "add", CTYPE_B, [&]() {
103+
CTYPE_B b_val;
104+
ET_EXTRACT_SCALAR(b, b_val);
105+
CTYPE b_casted = static_cast<CTYPE>(b_val);
106+
CTYPE alpha_val;
107+
ET_EXTRACT_SCALAR(alpha, alpha_val);
108+
109+
using Vec = executorch::vec::Vectorized<CTYPE>;
110+
executorch::vec::map<CTYPE>(
111+
[alpha_val, b_casted](Vec x) {
112+
return x + Vec(alpha_val * b_casted);
113+
},
114+
out.mutable_data_ptr<CTYPE>(),
115+
a.const_data_ptr<CTYPE>(),
116+
out.numel());
117+
});
118+
});
119+
} else {
120+
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "add", CTYPE_A, [&]() {
121+
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "add", CTYPE_B, [&]() {
122+
ET_SWITCH_REAL_TYPES_AND(
123+
Bool, common_type, ctx, "add", CTYPE_IN, [&]() {
124+
ET_SWITCH_REAL_TYPES_AND(
125+
Bool, out_type, ctx, "add", CTYPE_OUT, [&]() {
126+
CTYPE_B b_val;
127+
ET_EXTRACT_SCALAR(b, b_val);
128+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
129+
CTYPE_IN alpha_val;
130+
ET_EXTRACT_SCALAR(alpha, alpha_val);
131+
132+
const size_t n = a.numel();
133+
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
134+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
135+
for (auto i = 0; i < n; ++i) {
136+
out_data[i] = static_cast<CTYPE_OUT>(
137+
static_cast<CTYPE_IN>(a_data[i]) +
138+
alpha_val * b_casted);
139+
}
140+
});
141+
});
142+
});
143+
});
144+
}
145+
146+
return out;
147+
}
148+
149+
} // namespace native
150+
} // namespace executor
151+
} // namespace torch

kernels/optimized/cpu/op_div.cpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
#include <executorch/core/Assert.h>
4+
#include <executorch/kernels/kernel_includes.h>
5+
#include <executorch/kernels/optimized/vec/functional.h>
6+
#include <executorch/kernels/optimized/vec/vec.h>
7+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
8+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
9+
10+
namespace torch {
11+
namespace executor {
12+
namespace native {
13+
14+
namespace {
15+
16+
ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) {
17+
ET_CHECK(
18+
!isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type));
19+
ET_CHECK(
20+
!isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type));
21+
22+
if (isFloatingType(a_type) && isFloatingType(b_type)) {
23+
return promoteTypes(a_type, b_type);
24+
} else if (isFloatingType(a_type)) {
25+
return a_type;
26+
} else if (isFloatingType(b_type)) {
27+
return b_type;
28+
}
29+
return ScalarType::Float;
30+
}
31+
32+
} // namespace
33+
34+
Tensor& opt_div_out(
35+
RuntimeContext& ctx,
36+
const Tensor& a,
37+
const Tensor& b,
38+
Tensor& out) {
39+
(void)ctx;
40+
41+
ScalarType a_type = a.scalar_type();
42+
ScalarType b_type = b.scalar_type();
43+
ScalarType out_type = out.scalar_type();
44+
45+
if (a_type == b_type && a_type == out_type && a.sizes().equals(b.sizes())) {
46+
// Resize for dynamic shape
47+
auto error = resize_tensor(out, a.sizes());
48+
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
49+
50+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "div", CTYPE, [&]() {
51+
using Vec = executorch::vec::Vectorized<CTYPE>;
52+
executorch::vec::map2<CTYPE>(
53+
[](Vec x, Vec y) { return x / y; },
54+
out.mutable_data_ptr<CTYPE>(),
55+
a.const_data_ptr<CTYPE>(),
56+
b.const_data_ptr<CTYPE>(),
57+
out.numel());
58+
});
59+
} else {
60+
ScalarType common_type = get_compute_type(a_type, b_type);
61+
ET_CHECK(canCast(common_type, out_type));
62+
63+
resize_to_broadcast_target_size(a, b, out);
64+
65+
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div", CTYPE_A, [&]() {
66+
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div", CTYPE_B, [&]() {
67+
ET_SWITCH_REAL_TYPES_AND(
68+
Bool, common_type, ctx, "div", CTYPE_IN, [&]() {
69+
ET_SWITCH_REAL_TYPES_AND(
70+
Bool, out_type, ctx, "div", CTYPE_OUT, [&]() {
71+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
72+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
73+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
74+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
75+
CTYPE_IN value = a_casted / b_casted;
76+
77+
return static_cast<CTYPE_OUT>(value);
78+
},
79+
a,
80+
b,
81+
out);
82+
});
83+
});
84+
});
85+
});
86+
}
87+
88+
return out;
89+
}
90+
91+
Tensor& opt_div_scalar_out(
92+
RuntimeContext& ctx,
93+
const Tensor& a,
94+
const Scalar& b,
95+
Tensor& out) {
96+
(void)ctx;
97+
98+
ScalarType a_type = a.scalar_type();
99+
ScalarType b_type = utils::get_scalar_dtype(b);
100+
ScalarType common_type = isFloatingType(a_type) ? a_type : ScalarType::Float;
101+
ScalarType out_type = out.scalar_type();
102+
103+
ET_CHECK(common_type == out_type);
104+
105+
// Resize for dynamic shape
106+
auto error = resize_tensor(out, a.sizes());
107+
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
108+
109+
if (a_type == common_type && a_type == out_type) {
110+
ET_SWITCH_REAL_TYPES(a_type, ctx, "div", CTYPE, [&]() {
111+
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div", CTYPE_B, [&]() {
112+
CTYPE_B b_val;
113+
ET_EXTRACT_SCALAR(b, b_val);
114+
CTYPE b_casted = static_cast<CTYPE>(b_val);
115+
116+
using Vec = executorch::vec::Vectorized<CTYPE>;
117+
executorch::vec::map<CTYPE>(
118+
[b_casted](Vec x) { return x / Vec(b_casted); },
119+
out.mutable_data_ptr<CTYPE>(),
120+
a.const_data_ptr<CTYPE>(),
121+
out.numel());
122+
});
123+
});
124+
} else {
125+
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div", CTYPE_A, [&]() {
126+
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div", CTYPE_B, [&]() {
127+
ET_SWITCH_REAL_TYPES(common_type, ctx, "div", CTYPE_IN, [&]() {
128+
ET_SWITCH_REAL_TYPES(out_type, ctx, "div", CTYPE_OUT, [&]() {
129+
CTYPE_B b_val;
130+
ET_EXTRACT_SCALAR(b, b_val);
131+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
132+
133+
const size_t n = a.numel();
134+
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
135+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
136+
for (auto i = 0; i < n; ++i) {
137+
out_data[i] = static_cast<CTYPE_OUT>(
138+
static_cast<CTYPE_IN>(a_data[i]) / b_casted);
139+
}
140+
});
141+
});
142+
});
143+
});
144+
}
145+
146+
return out;
147+
}
148+
149+
} // namespace native
150+
} // namespace executor
151+
} // namespace torch

0 commit comments

Comments
 (0)