Skip to content

Commit ac64f9e

Browse files
committed
Update
[ghstack-poisoned]
1 parent 6b0e11f commit ac64f9e

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,25 @@ inline void apply_tritensor_elementwise_fn(
199199
}
200200

201201
inline ScalarType get_compute_type(ScalarType& common_type) {
202-
if (common_type == ScalarType::Long) {
203-
return common_type;
204-
}
205-
if (isIntegralType(common_type, /*includeBool=*/true)) {
202+
// Code size optimization: on typical 32-bit or 64-bit CPUs, the ALU should be
203+
// just as good at 32-bit arithmetic as it is at 16-bit or 8-bit
204+
// arithmetic, so don't go out of our way to generate 8-bit or
205+
// 16-bit code.
206+
207+
// Gate above optimization off if we appear to be on some kind of 8-bit or
208+
// 16-bit CPU, which would invalidate our assumption about 32-bit
209+
// math being just as fast.
210+
constexpr bool cpu_appears_to_be_at_least_32_bit = sizeof(void*) >= 4 && sizeof(int) >= 4;
211+
212+
if (cpu_appears_to_be_at_least_32_bit &&
213+
// Don't mess up 64-bit ints.
214+
common_type != ScalarType::Long &&
215+
isIntegralType(common_type, /*includeBool=*/true)) {
206216
return ScalarType::Int;
207217
}
218+
219+
// We compute in float for reduced-precision floating-point types as
220+
// a matter of policy, not size optimization.
208221
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
209222
return ScalarType::Float;
210223
}

0 commit comments

Comments
 (0)