Skip to content

Commit 69461cb

Browse files
committed
slight improvement on "[ExecuTorch] Dramatically improve op_clamp build time (2/2)"
Instead of building `O(|CTYPE_IN| * |CTYPE_MIN| * |CTYPE_MAX|* |CTYPE_OUT|)` kernel code (where |T| means the number of possibilities for type T), we build `O((|CTYPE_IN| + |CTYPE_MIN| + |CTYPE_MAX|) * |CTYPE_OUT|)` kernel code. (Concretely, `ET_SWITCH_REALHB_TYPES` has 9 possibilities, so I estimate that we went from 9**4 = 6561 template instantiations to 9 * 3 * 9 = 243 instantiations, or a 27x reduction.) Differential Revision: [D63681034](https://our.internmc.facebook.com/intern/diff/D63681034/) [ghstack-poisoned]
1 parent a00c741 commit 69461cb

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ ET_NODISCARD bool check_bounds(
6666
return is_valid;
6767
}
6868

69+
template <typename To, typename From>
70+
To load_and_convert(const void* fromPtr) {
71+
return static_cast<To>(*reinterpret_cast<const From*>(fromPtr));
72+
}
6973
} // namespace
7074

7175
Tensor& clamp_out(
@@ -218,24 +222,15 @@ Tensor& clamp_tensor_out(
218222
using ToCtypeOutFn = CTYPE_OUT (*)(const void*);
219223
ToCtypeOutFn in_to_out;
220224
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
221-
in_to_out = [](const void* inPtr) {
222-
return static_cast<CTYPE_OUT>(
223-
*reinterpret_cast<const CTYPE_IN*>(inPtr));
224-
};
225+
in_to_out = load_and_convert<CTYPE_OUT, CTYPE_IN>;
225226
});
226227
ToCtypeOutFn min_to_out;
227228
ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
228-
min_to_out = [](const void* minPtr) {
229-
return static_cast<CTYPE_OUT>(
230-
*reinterpret_cast<const CTYPE_MIN*>(minPtr));
231-
};
229+
min_to_out = load_and_convert<CTYPE_OUT, CTYPE_MIN>;
232230
});
233231
ToCtypeOutFn max_to_out;
234232
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
235-
max_to_out = [](const void* maxPtr) {
236-
return static_cast<CTYPE_OUT>(
237-
*reinterpret_cast<const CTYPE_MAX*>(maxPtr));
238-
};
233+
max_to_out = load_and_convert<CTYPE_OUT, CTYPE_MAX>;
239234
});
240235

241236
apply_ternary_elementwise_fn<CTYPE_OUT>(

0 commit comments

Comments
 (0)