Skip to content

Commit aa8a93c

Browse files
swolchokfacebook-github-bot
authored andcommitted
Dramatically improve op_clamp build time (#5784)
Summary: Pull Request resolved: #5784 Instead of building `O(|CTYPE_IN| * |CTYPE_MIN| * |CTYPE_MAX| * |CTYPE_OUT|)` kernel code (where |T| means the number of possibilities for type T), we build `O((|CTYPE_IN| + |CTYPE_MIN| + |CTYPE_MAX| + |CTYPE_COMMON|) * |CTYPE_OUT|)` kernel code. (Concretely, `ET_SWITCH_REALHB_TYPES` has 9 possibilities, so I estimate that we went from 9**4 = 6561 template instantiations to 9 * 4 * 9 = 324 instantiations, or a 20x reduction.) ghstack-source-id: 246132260 Reviewed By: manuelcandales Differential Revision: D63681034 fbshipit-source-id: 41891b8fbdfa0d6f6126342febf59e6aae5b4876
1 parent 9c3ebfe commit aa8a93c

File tree

3 files changed

+98
-53
lines changed

3 files changed

+98
-53
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,15 @@ ET_NODISCARD bool check_bounds(
6666
return is_valid;
6767
}
6868

69+
template <typename To, typename From>
70+
To load_and_convert(const void* fromPtr) {
71+
return static_cast<To>(*reinterpret_cast<const From*>(fromPtr));
72+
}
73+
74+
template <typename To, typename From>
75+
void convert_and_store(From f, void* dst) {
76+
*reinterpret_cast<To*>(dst) = static_cast<To>(f);
77+
}
6978
} // namespace
7079

7180
Tensor& clamp_out(
@@ -214,41 +223,46 @@ Tensor& clamp_tensor_out(
214223

215224
constexpr auto name = "clamp.Tensor_out";
216225

217-
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
226+
ET_SWITCH_REALHB_TYPES(common_type, ctx, name, CTYPE_COMMON, [&]() {
227+
using ToCtypeCommonFn = CTYPE_COMMON (*)(const void*);
228+
ToCtypeCommonFn in_to_common;
229+
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
230+
in_to_common = load_and_convert<CTYPE_COMMON, CTYPE_IN>;
231+
});
232+
ToCtypeCommonFn min_to_common;
218233
ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
219-
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
220-
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
221-
using CTYPE_MINMAX = typename torch::executor::
222-
promote_types<CTYPE_MIN, CTYPE_MAX>::type;
223-
using CTYPE = typename torch::executor::
224-
promote_types<CTYPE_IN, CTYPE_MINMAX>::type;
225-
apply_ternary_elementwise_fn<
226-
CTYPE_IN,
227-
CTYPE_MIN,
228-
CTYPE_MAX,
229-
CTYPE_OUT>(
230-
[has_min, has_max](
231-
const CTYPE_IN val_in,
232-
const CTYPE_MIN val_min,
233-
const CTYPE_MAX val_max) {
234-
CTYPE val_out = static_cast<CTYPE>(val_in);
235-
if (has_min) {
236-
val_out =
237-
utils::max_override(val_out, static_cast<CTYPE>(val_min));
238-
}
239-
if (has_max) {
240-
val_out =
241-
utils::min_override(val_out, static_cast<CTYPE>(val_max));
242-
}
243-
return static_cast<CTYPE_OUT>(val_out);
244-
},
245-
in,
246-
min,
247-
max,
248-
out);
249-
});
250-
});
234+
min_to_common = load_and_convert<CTYPE_COMMON, CTYPE_MIN>;
235+
});
236+
ToCtypeCommonFn max_to_common;
237+
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
238+
max_to_common = load_and_convert<CTYPE_COMMON, CTYPE_MAX>;
239+
});
240+
void (*common_to_out)(CTYPE_COMMON, void*);
241+
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
242+
common_to_out = convert_and_store<CTYPE_OUT, CTYPE_COMMON>;
251243
});
244+
apply_ternary_elementwise_fn<CTYPE_COMMON>(
245+
[has_min, has_max](
246+
const CTYPE_COMMON val_in,
247+
const CTYPE_COMMON val_min,
248+
const CTYPE_COMMON val_max) {
249+
CTYPE_COMMON val_out = val_in;
250+
if (has_min) {
251+
val_out = utils::max_override(val_out, val_min);
252+
}
253+
if (has_max) {
254+
val_out = utils::min_override(val_out, val_max);
255+
}
256+
return val_out;
257+
},
258+
in,
259+
min,
260+
max,
261+
out,
262+
in_to_common,
263+
min_to_common,
264+
max_to_common,
265+
common_to_out);
252266
});
253267

254268
return out;

kernels/portable/cpu/op_where.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,29 @@ Tensor& where_out(
4848
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
4949
using CTYPE_OUT =
5050
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
51-
apply_ternary_elementwise_fn<CTYPE_A, CTYPE_B, uint8_t, CTYPE_OUT>(
52-
[](const CTYPE_A val_a, const CTYPE_B val_b, const uint8_t val_c) {
53-
CTYPE_OUT a_casted = static_cast<CTYPE_OUT>(val_a);
54-
CTYPE_OUT b_casted = static_cast<CTYPE_OUT>(val_b);
55-
return val_c ? a_casted : b_casted;
56-
},
51+
apply_ternary_elementwise_fn<CTYPE_OUT>(
52+
[](const CTYPE_OUT val_a,
53+
const CTYPE_OUT val_b,
54+
const CTYPE_OUT val_c) { return val_c ? val_a : val_b; },
5755
a,
5856
b,
5957
cond,
60-
out);
58+
out,
59+
[](const void* a_ptr) {
60+
return static_cast<CTYPE_OUT>(
61+
*reinterpret_cast<const CTYPE_A*>(a_ptr));
62+
},
63+
[](const void* b_ptr) {
64+
return static_cast<CTYPE_OUT>(
65+
*reinterpret_cast<const CTYPE_B*>(b_ptr));
66+
},
67+
[](const void* c_ptr) {
68+
return static_cast<CTYPE_OUT>(
69+
*reinterpret_cast<const uint8_t*>(c_ptr));
70+
},
71+
[](CTYPE_OUT result, void* out) {
72+
*reinterpret_cast<CTYPE_OUT*>(out) = result;
73+
});
6174
});
6275
});
6376

kernels/portable/cpu/util/broadcast_util.h

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -313,29 +313,44 @@ inline void apply_binary_elementwise_fn(
313313
* Useful for ternary elementwise operators. For each element of the inputs,
314314
* perform a computation and write to the corresponding element of the output.
315315
* Tensor broadcasting is applied wherever it is required.
316+
*
317+
* In order to mitigate build time cost (straightforwardly |CTYPE_A| *
318+
* |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
319+
* are passed as CTYPE_COMMON. We require compute_fun to return
320+
* CTYPE_COMMON, and we require loading conversion functions from each
321+
* input type to CTYPE_COMMON and a storing conversion from
322+
* CTYPE_COMMON to CTYPE_OUT be provided. Each conversion function
323+
* must take a void* pointing to an element of the corresponding
324+
* tensor, load that element, and convert it to CTYPE_COMMON. The
325+
* storing conversion function must have the signature
326+
* void(CTYPE_COMMON, void*), convert the given element to CTYPE_OUT,
327+
* and store it to the given location.
316328
*/
317-
template <
318-
typename CTYPE_A,
319-
typename CTYPE_B,
320-
typename CTYPE_C,
321-
typename CTYPE_OUT,
322-
typename Op>
329+
template <typename CTYPE_COMMON, typename Op>
323330
inline void apply_ternary_elementwise_fn(
324331
const Op& compute_fun,
325332
const Tensor& a,
326333
const Tensor& b,
327334
const Tensor& c,
328-
const Tensor& out) {
335+
const Tensor& out,
336+
CTYPE_COMMON (*load_a_to_common)(const void*),
337+
CTYPE_COMMON (*load_b_to_common)(const void*),
338+
CTYPE_COMMON (*load_c_to_common)(const void*),
339+
void (*store_common_to_out)(CTYPE_COMMON, void*)) {
329340
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
330341
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
331342
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
332343
const bool any_is_broadcasted =
333344
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
334345

335-
const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
336-
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
337-
const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
338-
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
346+
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
347+
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
348+
const char* const data_c = reinterpret_cast<const char*>(c.const_data_ptr());
349+
const auto a_element_size = a.element_size();
350+
const auto b_element_size = b.element_size();
351+
const auto c_element_size = c.element_size();
352+
const auto out_element_size = out.element_size();
353+
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
339354

340355
for (size_t i = 0; i < out.numel(); ++i) {
341356
size_t a_linear_index = i;
@@ -357,8 +372,11 @@ inline void apply_ternary_elementwise_fn(
357372
}
358373
}
359374

360-
data_out[i] = compute_fun(
361-
data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]);
375+
auto result = compute_fun(
376+
load_a_to_common(&data_a[a_linear_index * a_element_size]),
377+
load_b_to_common(&data_b[b_linear_index * b_element_size]),
378+
load_c_to_common(&data_c[c_linear_index * c_element_size]));
379+
store_common_to_out(result, &data_out[i * out_element_size]);
362380
}
363381
}
364382

0 commit comments

Comments
 (0)