Skip to content

Commit b8380d5

Browse files
committed
[ExecuTorch] Simplify function pointers for apply_ternary_elementwise_fn
Cleaning up some of the required boilerplate. I updated op_clamp and op_where, but continued to not optimize op_where for size/build time. Ideal usage optimizing for size/build time would look like op_clamp. Differential Revision: [D63790004](https://our.internmc.facebook.com/intern/diff/D63790004/) [ghstack-poisoned]
1 parent c49f48a commit b8380d5

File tree

3 files changed

+50
-47
lines changed

3 files changed

+50
-47
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,6 @@ ET_NODISCARD bool check_bounds(
6666
return is_valid;
6767
}
6868

69-
template <typename To, typename From>
70-
To load_and_convert(const void* fromPtr) {
71-
return static_cast<To>(*reinterpret_cast<const From*>(fromPtr));
72-
}
73-
74-
template <typename To, typename From>
75-
void convert_and_store(From f, void* dst) {
76-
*reinterpret_cast<To*>(dst) = static_cast<To>(f);
77-
}
7869
} // namespace
7970

8071
Tensor& clamp_out(
@@ -221,26 +212,9 @@ Tensor& clamp_tensor_out(
221212

222213
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
223214

224-
constexpr auto name = "clamp.Tensor_out";
215+
static constexpr const char op_name[] = "clamp.Tensor_out";
225216

226-
ET_SWITCH_REALHB_TYPES(common_type, ctx, name, CTYPE_COMMON, [&]() {
227-
using ToCtypeCommonFn = CTYPE_COMMON (*)(const void*);
228-
ToCtypeCommonFn in_to_common;
229-
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
230-
in_to_common = load_and_convert<CTYPE_COMMON, CTYPE_IN>;
231-
});
232-
ToCtypeCommonFn min_to_common;
233-
ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
234-
min_to_common = load_and_convert<CTYPE_COMMON, CTYPE_MIN>;
235-
});
236-
ToCtypeCommonFn max_to_common;
237-
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
238-
max_to_common = load_and_convert<CTYPE_COMMON, CTYPE_MAX>;
239-
});
240-
void (*common_to_out)(CTYPE_COMMON, void*);
241-
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
242-
common_to_out = convert_and_store<CTYPE_OUT, CTYPE_COMMON>;
243-
});
217+
ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
244218
apply_ternary_elementwise_fn<CTYPE_COMMON>(
245219
[has_min, has_max](
246220
const CTYPE_COMMON val_in,
@@ -259,10 +233,10 @@ Tensor& clamp_tensor_out(
259233
min,
260234
max,
261235
out,
262-
in_to_common,
263-
min_to_common,
264-
max_to_common,
265-
common_to_out);
236+
get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(in),
237+
get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(min),
238+
get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(max),
239+
get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(out));
266240
});
267241

268242
return out;

kernels/portable/cpu/op_where.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,21 +56,10 @@ Tensor& where_out(
5656
b,
5757
cond,
5858
out,
59-
[](const void* a_ptr) {
60-
return static_cast<CTYPE_OUT>(
61-
*reinterpret_cast<const CTYPE_A*>(a_ptr));
62-
},
63-
[](const void* b_ptr) {
64-
return static_cast<CTYPE_OUT>(
65-
*reinterpret_cast<const CTYPE_B*>(b_ptr));
66-
},
67-
[](const void* c_ptr) {
68-
return static_cast<CTYPE_OUT>(
69-
*reinterpret_cast<const uint8_t*>(c_ptr));
70-
},
71-
[](CTYPE_OUT result, void* out) {
72-
*reinterpret_cast<CTYPE_OUT*>(out) = result;
73-
});
59+
internal::load_and_convert<CTYPE_OUT, CTYPE_A>,
60+
internal::load_and_convert<CTYPE_OUT, CTYPE_B>,
61+
internal::load_and_convert<CTYPE_OUT, uint8_t>,
62+
internal::convert_and_store<CTYPE_OUT, CTYPE_OUT>);
7463
});
7564
});
7665

kernels/portable/cpu/util/broadcast_util.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,46 @@ size_t linearize_access_indexes(
270270
// Mapping with broadcasting
271271
//
272272

273+
namespace internal {
274+
template <typename To, typename From>
275+
To load_and_convert(const void* fromPtr) {
276+
return static_cast<To>(*reinterpret_cast<const From*>(fromPtr));
277+
}
278+
279+
template <typename To, typename From>
280+
void convert_and_store(From f, void* dst) {
281+
*reinterpret_cast<To*>(dst) = static_cast<To>(f);
282+
}
283+
} // namespace internal
284+
285+
template <typename CTYPE_COMMON>
286+
using load_to_common_fn = CTYPE_COMMON (*)(const void*);
287+
288+
template <typename CTYPE_COMMON, const char* op_name>
289+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
290+
const Tensor& t) {
291+
CTYPE_COMMON (*result)(const void*) = nullptr;
292+
ET_SWITCH_REALHBBF16_TYPES(
293+
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
294+
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
295+
});
296+
return result;
297+
}
298+
299+
template <typename CTYPE_COMMON>
300+
using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*);
301+
302+
template <typename CTYPE_COMMON, const char* op_name>
303+
store_common_to_tensor_fn<CTYPE_COMMON>
304+
get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
305+
void (*result)(CTYPE_COMMON, void*) = nullptr;
306+
ET_SWITCH_REALHBBF16_TYPES(
307+
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
308+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
309+
});
310+
return result;
311+
}
312+
273313
/**
274314
* Useful for binary elementwise operators. For each element of the inputs,
275315
* perform a computation and write to the corresponding element of the output.

0 commit comments

Comments
 (0)