Skip to content

Commit 6f5cd29

Browse files
committed
[ExecuTorch] Dramatically improve op_clamp build time (2/2)
Pull Request resolved: #5784 Instead of building `O(|CTYPE_IN| * |CTYPE_MIN| * |CTYPE_MAX| * |CTYPE_OUT|)` kernel code (where |T| means the number of possibilities for type T), we build `O((|CTYPE_IN| + |CTYPE_MIN| + |CTYPE_MAX|) * |CTYPE_OUT|)` kernel code. (Concretely, `ET_SWITCH_REALHB_TYPES` has 9 possibilities, so I estimate that we went from 9**4 = 6561 template instantiations to 9 * 3 * 9 = 243 instantiations, or a 27x reduction.) Differential Revision: [D63681034](https://our.internmc.facebook.com/intern/diff/D63681034/) ghstack-source-id: 245741505
1 parent dbabcf3 commit 6f5cd29

File tree

1 file changed

+38
-35
lines changed

1 file changed

+38
-35
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ ET_NODISCARD bool check_bounds(
6666
return is_valid;
6767
}
6868

69+
template <typename To, typename From>
70+
To load_and_convert(const void* fromPtr) {
71+
return static_cast<To>(*reinterpret_cast<const From*>(fromPtr));
72+
}
6973
} // namespace
7074

7175
Tensor& clamp_out(
@@ -214,43 +218,42 @@ Tensor& clamp_tensor_out(
214218

215219
constexpr auto name = "clamp.Tensor_out";
216220

217-
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
221+
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
222+
using ToCtypeOutFn = CTYPE_OUT (*)(const void*);
223+
ToCtypeOutFn in_to_out;
224+
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
225+
in_to_out = load_and_convert<CTYPE_OUT, CTYPE_IN>;
226+
});
227+
ToCtypeOutFn min_to_out;
218228
ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
219-
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
220-
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
221-
apply_ternary_elementwise_fn<CTYPE_OUT>(
222-
[has_min, has_max](
223-
const CTYPE_OUT val_in,
224-
const CTYPE_OUT val_min,
225-
const CTYPE_OUT val_max) {
226-
CTYPE_OUT val_out = val_in;
227-
if (has_min) {
228-
val_out = utils::max_override(val_out, val_min);
229-
}
230-
if (has_max) {
231-
val_out = utils::min_override(val_out, val_max);
232-
}
233-
return val_out;
234-
},
235-
in,
236-
min,
237-
max,
238-
out,
239-
[](const void* inPtr) {
240-
return static_cast<CTYPE_OUT>(
241-
*reinterpret_cast<const CTYPE_IN*>(inPtr));
242-
},
243-
[](const void* minPtr) {
244-
return static_cast<CTYPE_OUT>(
245-
*reinterpret_cast<const CTYPE_MIN*>(minPtr));
246-
},
247-
[](const void* maxPtr) {
248-
return static_cast<CTYPE_OUT>(
249-
*reinterpret_cast<const CTYPE_MAX*>(maxPtr));
250-
});
251-
});
252-
});
229+
min_to_out = load_and_convert<CTYPE_OUT, CTYPE_MIN>;
253230
});
231+
ToCtypeOutFn max_to_out;
232+
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
233+
max_to_out = load_and_convert<CTYPE_OUT, CTYPE_MAX>;
234+
});
235+
236+
apply_ternary_elementwise_fn<CTYPE_OUT>(
237+
[has_min, has_max](
238+
const CTYPE_OUT val_in,
239+
const CTYPE_OUT val_min,
240+
const CTYPE_OUT val_max) {
241+
CTYPE_OUT val_out = val_in;
242+
if (has_min) {
243+
val_out = utils::max_override(val_out, val_min);
244+
}
245+
if (has_max) {
246+
val_out = utils::min_override(val_out, val_max);
247+
}
248+
return val_out;
249+
},
250+
in,
251+
min,
252+
max,
253+
out,
254+
in_to_out,
255+
min_to_out,
256+
max_to_out);
254257
});
255258

256259
return out;

0 commit comments

Comments
 (0)