Skip to content

Commit af3e8ac

Browse files
committed
[ExecuTorch] Dramatically improve op_clamp build time (2/2)
Instead of building `O(|CTYPE_IN| * |CTYPE_MIN| * |CTYPE_MAX| * |CTYPE_OUT|)` kernel code (where |T| means the number of possibilities for type T), we build `O((|CTYPE_IN| + |CTYPE_MIN| + |CTYPE_MAX|) * |CTYPE_OUT|)` kernel code. (Concretely, `ET_SWITCH_REALHB_TYPES` has 9 possibilities, so I estimate that we went from 9**4 = 6561 template instantiations to 9 * 3 * 9 = 243 kernels, or a 27x reduction.) Differential Revision: [D63681034](https://our.internmc.facebook.com/intern/diff/D63681034/) ghstack-source-id: 245613707 Pull Request resolved: #5784
1 parent dbabcf3 commit af3e8ac

File tree

1 file changed

+43
-35
lines changed

1 file changed

+43
-35
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -214,43 +214,51 @@ Tensor& clamp_tensor_out(
214214

215215
constexpr auto name = "clamp.Tensor_out";
216216

217-
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
217+
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
218+
using ToCtypeOutFn = CTYPE_OUT (*)(const void*);
219+
ToCtypeOutFn in_to_out;
220+
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() {
221+
in_to_out = [](const void* inPtr) {
222+
return static_cast<CTYPE_OUT>(
223+
*reinterpret_cast<const CTYPE_IN*>(inPtr));
224+
};
225+
});
226+
ToCtypeOutFn min_to_out;
218227
ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() {
219-
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
220-
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
221-
apply_ternary_elementwise_fn<CTYPE_OUT>(
222-
[has_min, has_max](
223-
const CTYPE_OUT val_in,
224-
const CTYPE_OUT val_min,
225-
const CTYPE_OUT val_max) {
226-
CTYPE_OUT val_out = val_in;
227-
if (has_min) {
228-
val_out = utils::max_override(val_out, val_min);
229-
}
230-
if (has_max) {
231-
val_out = utils::min_override(val_out, val_max);
232-
}
233-
return val_out;
234-
},
235-
in,
236-
min,
237-
max,
238-
out,
239-
[](const void* inPtr) {
240-
return static_cast<CTYPE_OUT>(
241-
*reinterpret_cast<const CTYPE_IN*>(inPtr));
242-
},
243-
[](const void* minPtr) {
244-
return static_cast<CTYPE_OUT>(
245-
*reinterpret_cast<const CTYPE_MIN*>(minPtr));
246-
},
247-
[](const void* maxPtr) {
248-
return static_cast<CTYPE_OUT>(
249-
*reinterpret_cast<const CTYPE_MAX*>(maxPtr));
250-
});
251-
});
252-
});
228+
min_to_out = [](const void* minPtr) {
229+
return static_cast<CTYPE_OUT>(
230+
*reinterpret_cast<const CTYPE_MIN*>(minPtr));
231+
};
253232
});
233+
ToCtypeOutFn max_to_out;
234+
ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() {
235+
max_to_out = [](const void* maxPtr) {
236+
return static_cast<CTYPE_OUT>(
237+
*reinterpret_cast<const CTYPE_MAX*>(maxPtr));
238+
};
239+
});
240+
241+
apply_ternary_elementwise_fn<CTYPE_OUT>(
242+
[has_min, has_max](
243+
const CTYPE_OUT val_in,
244+
const CTYPE_OUT val_min,
245+
const CTYPE_OUT val_max) {
246+
CTYPE_OUT val_out = val_in;
247+
if (has_min) {
248+
val_out = utils::max_override(val_out, val_min);
249+
}
250+
if (has_max) {
251+
val_out = utils::min_override(val_out, val_max);
252+
}
253+
return val_out;
254+
},
255+
in,
256+
min,
257+
max,
258+
out,
259+
in_to_out,
260+
min_to_out,
261+
max_to_out);
254262
});
255263

256264
return out;

0 commit comments

Comments
 (0)