Skip to content

Commit dbabcf3

Browse files
committed
[ExecuTorch] Rework apply_ternary_elementwise_fn to allow fixing op_clamp build time (1/2)
This diff is a simple refactor that shouldn't materially change the generated code -- we still create a "kernel" for each combination of input1/input2/input3/output dtypes. The following diff will use this to improve op_clamp (but not op_where, because one of the input types is fixed, so it's not as bad a build time outlier) build time. Differential Revision: [D63681033](https://our.internmc.facebook.com/intern/diff/D63681033/) ghstack-source-id: 245613706 Pull Request resolved: #5783
1 parent 944bd67 commit dbabcf3

File tree

3 files changed

+62
-30
lines changed

3 files changed

+62
-30
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -218,30 +218,36 @@ Tensor& clamp_tensor_out(
218218
ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
219219
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
220220
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
221-
apply_ternary_elementwise_fn<
222-
CTYPE_IN,
223-
CTYPE_MIN,
224-
CTYPE_MAX,
225-
CTYPE_OUT>(
221+
apply_ternary_elementwise_fn<CTYPE_OUT>(
226222
[has_min, has_max](
227-
const CTYPE_IN val_in,
228-
const CTYPE_MIN val_min,
229-
const CTYPE_MAX val_max) {
230-
CTYPE_OUT val_out = static_cast<CTYPE_OUT>(val_in);
223+
const CTYPE_OUT val_in,
224+
const CTYPE_OUT val_min,
225+
const CTYPE_OUT val_max) {
226+
CTYPE_OUT val_out = val_in;
231227
if (has_min) {
232-
val_out = utils::max_override(
233-
val_out, static_cast<CTYPE_OUT>(val_min));
228+
val_out = utils::max_override(val_out, val_min);
234229
}
235230
if (has_max) {
236-
val_out = utils::min_override(
237-
val_out, static_cast<CTYPE_OUT>(val_max));
231+
val_out = utils::min_override(val_out, val_max);
238232
}
239233
return val_out;
240234
},
241235
in,
242236
min,
243237
max,
244-
out);
238+
out,
239+
[](const void* inPtr) {
240+
return static_cast<CTYPE_OUT>(
241+
*reinterpret_cast<const CTYPE_IN*>(inPtr));
242+
},
243+
[](const void* minPtr) {
244+
return static_cast<CTYPE_OUT>(
245+
*reinterpret_cast<const CTYPE_MIN*>(minPtr));
246+
},
247+
[](const void* maxPtr) {
248+
return static_cast<CTYPE_OUT>(
249+
*reinterpret_cast<const CTYPE_MAX*>(maxPtr));
250+
});
245251
});
246252
});
247253
});

kernels/portable/cpu/op_where.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,26 @@ Tensor& where_out(
4848
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
4949
using CTYPE_OUT =
5050
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
51-
apply_ternary_elementwise_fn<CTYPE_A, CTYPE_B, uint8_t, CTYPE_OUT>(
52-
[](const CTYPE_A val_a, const CTYPE_B val_b, const uint8_t val_c) {
53-
CTYPE_OUT a_casted = static_cast<CTYPE_OUT>(val_a);
54-
CTYPE_OUT b_casted = static_cast<CTYPE_OUT>(val_b);
55-
return val_c ? a_casted : b_casted;
56-
},
51+
apply_ternary_elementwise_fn<CTYPE_OUT>(
52+
[](const CTYPE_OUT val_a,
53+
const CTYPE_OUT val_b,
54+
const CTYPE_OUT val_c) { return val_c ? val_a : val_b; },
5755
a,
5856
b,
5957
cond,
60-
out);
58+
out,
59+
[](const void* aPtr) {
60+
return static_cast<CTYPE_OUT>(
61+
*reinterpret_cast<const CTYPE_A*>(aPtr));
62+
},
63+
[](const void* bPtr) {
64+
return static_cast<CTYPE_OUT>(
65+
*reinterpret_cast<const CTYPE_B*>(bPtr));
66+
},
67+
[](const void* cPtr) {
68+
return static_cast<CTYPE_OUT>(
69+
*reinterpret_cast<const uint8_t*>(cPtr));
70+
});
6171
});
6272
});
6373

kernels/portable/cpu/util/broadcast_util.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -313,28 +313,42 @@ inline void apply_binary_elementwise_fn(
313313
* Useful for ternary elementwise operators. For each element of the inputs,
314314
* perform a computation and write to the corresponding element of the output.
315315
* Tensor broadcasting is applied wherever it is required.
316+
*
317+
* In order to mitigate build time cost (straightforwardly |CTYPE_A| *
318+
* |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
319+
* are passed as CTYPE_OUT and we require conversion functions from
320+
* each input type to the output type be provided. Each conversion
321+
* function must take a void* pointing to an element of the
322+
* corresponding tensor, load that element, and convert it to
323+
* CTYPE_OUT.
316324
*/
317325
template <
318-
typename CTYPE_A,
319-
typename CTYPE_B,
320-
typename CTYPE_C,
321326
typename CTYPE_OUT,
322-
typename Op>
327+
typename Op,
328+
typename AToOutFunc,
329+
typename BToOutFunc,
330+
typename CToOutFunc>
323331
inline void apply_ternary_elementwise_fn(
324332
const Op& compute_fun,
325333
const Tensor& a,
326334
const Tensor& b,
327335
const Tensor& c,
328-
const Tensor& out) {
336+
const Tensor& out,
337+
AToOutFunc aToOut,
338+
BToOutFunc bToOut,
339+
CToOutFunc cToOut) {
329340
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
330341
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
331342
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
332343
const bool any_is_broadcasted =
333344
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
334345

335-
const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
336-
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
337-
const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
346+
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
347+
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
348+
const char* const data_c = reinterpret_cast<const char*>(c.const_data_ptr());
349+
const auto a_element_size = a.element_size();
350+
const auto b_element_size = b.element_size();
351+
const auto c_element_size = c.element_size();
338352
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
339353

340354
for (size_t i = 0; i < out.numel(); ++i) {
@@ -358,7 +372,9 @@ inline void apply_ternary_elementwise_fn(
358372
}
359373

360374
data_out[i] = compute_fun(
361-
data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]);
375+
aToOut(&data_a[a_linear_index * a_element_size]),
376+
bToOut(&data_b[b_linear_index * b_element_size]),
377+
cToOut(&data_c[c_linear_index * c_element_size]));
362378
}
363379
}
364380

0 commit comments

Comments
 (0)