Skip to content

Commit fe0a9c4

Browse files
committed
MathExtras: avoid unnecessarily widening types
Several multi-argument functions unnecessarily widen types beyond the argument types. Template'ize the functions, and use std::common_type_t to avoid this, hence optimizing the functions. While at it, address usage issues raised in #95087. One of the requirements of this patch is to add overflow checks, and one caller in LoopVectorize and one in AMDGPUBaseInfo is manually widened.
1 parent bfd95a0 commit fe0a9c4

File tree

5 files changed

+143
-36
lines changed

5 files changed

+143
-36
lines changed

llvm/include/llvm/Support/MathExtras.h

Lines changed: 102 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,22 @@
2323
#include <type_traits>
2424

2525
namespace llvm {
26+
/// Some template parameter helpers to optimize for bitwidth, for functions that
27+
/// take multiple arguments.
28+
29+
// We can't verify signedness, since callers rely on implicit coercions to
30+
// signed/unsigned.
31+
template <typename T, typename U>
32+
using enableif_int =
33+
std::enable_if_t<std::is_integral_v<T> && std::is_integral_v<U>>;
34+
35+
// Use std::common_type_t to widen only up to the widest argument.
36+
template <typename T, typename U, typename = enableif_int<T, U>>
37+
using common_uint =
38+
std::common_type_t<std::make_unsigned_t<T>, std::make_unsigned_t<U>>;
39+
template <typename T, typename U, typename = enableif_int<T, U>>
40+
using common_sint =
41+
std::common_type_t<std::make_signed_t<T>, std::make_signed_t<U>>;
2642

2743
/// Mathematical constants.
2844
namespace numbers {
@@ -346,7 +362,8 @@ inline unsigned Log2_64_Ceil(uint64_t Value) {
346362

347363
/// A and B are either alignments or offsets. Return the minimum alignment that
348364
/// may be assumed after adding the two together.
349-
constexpr inline uint64_t MinAlign(uint64_t A, uint64_t B) {
365+
template <typename U, typename V, typename T = common_uint<U, V>>
366+
constexpr T MinAlign(U A, V B) {
350367
// The largest power of 2 that divides both A and B.
351368
//
352369
// Replace "-Value" by "1+~Value" in the following commented code to avoid
@@ -375,7 +392,7 @@ inline uint64_t PowerOf2Ceil(uint64_t A) {
375392
return UINT64_C(1) << Log2_64_Ceil(A);
376393
}
377394

378-
/// Returns the next integer (mod 2**64) that is greater than or equal to
395+
/// Returns the next integer (mod 2**nbits) that is greater than or equal to
379396
/// \p Value and is a multiple of \p Align. \p Align must be non-zero.
380397
///
381398
/// Examples:
@@ -386,19 +403,50 @@ inline uint64_t PowerOf2Ceil(uint64_t A) {
386403
/// alignTo(321, 255) = 510
387404
/// \endcode
388405
///
389-
/// May overflow.
390-
inline uint64_t alignTo(uint64_t Value, uint64_t Align) {
406+
/// Will overflow only if result is not representable.
407+
template <typename U, typename V, typename T = common_uint<U, V>>
408+
constexpr T alignTo(U Value, V Align) {
391409
assert(Align != 0u && "Align can't be 0.");
392-
return (Value + Align - 1) / Align * Align;
410+
T Bias = (Value != 0);
411+
T CeilDiv = (Value - Bias) / Align + Bias;
412+
// If Value is negative, wrap will occur in the cast.
413+
if (Value > 0)
414+
assert(CeilDiv <= (std::numeric_limits<T>::max() - 1) / Align &&
415+
"alignTo would overflow");
416+
return CeilDiv * Align;
393417
}
394418

395-
inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
419+
/// Fallback when arguments aren't integral.
420+
constexpr inline uint64_t alignTo(uint64_t Value, uint64_t Align) {
421+
assert(Align != 0u && "Align can't be 0.");
422+
uint64_t Bias = (Value != 0);
423+
uint64_t CeilDiv = (Value - Bias) / Align + Bias;
424+
return CeilDiv * Align;
425+
}
426+
427+
/// Will overflow only if result is not representable.
428+
template <typename U, typename V, typename T = common_uint<U, V>>
429+
constexpr T alignToPowerOf2(U Value, V Align) {
396430
assert(Align != 0 && (Align & (Align - 1)) == 0 &&
397431
"Align must be a power of 2");
398432
// Replace unary minus to avoid compilation error on Windows:
399433
// "unary minus operator applied to unsigned type, result still unsigned"
400-
uint64_t negAlign = (~Align) + 1;
401-
return (Value + Align - 1) & negAlign;
434+
T NegAlign = (~Align) + 1;
435+
// If Value is negative, wrap will occur in the cast.
436+
if (Value > 0)
437+
assert(static_cast<T>(Value) - 1 <= std::numeric_limits<T>::max() - Align &&
438+
"alignToPowerOf2 would overflow");
439+
return (Value - 1 + Align) & NegAlign;
440+
}
441+
442+
/// Fallback when arguments aren't integral.
443+
constexpr inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
444+
assert(Align != 0 && (Align & (Align - 1)) == 0 &&
445+
"Align must be a power of 2");
446+
// Replace unary minus to avoid compilation error on Windows:
447+
// "unary minus operator applied to unsigned type, result still unsigned"
448+
uint64_t NegAlign = (~Align) + 1;
449+
return (Value - 1 + Align) & NegAlign;
402450
}
403451

404452
/// If non-zero \p Skew is specified, the return value will be a minimal integer
@@ -413,72 +461,99 @@ inline uint64_t alignToPowerOf2(uint64_t Value, uint64_t Align) {
413461
/// alignTo(~0LL, 8, 3) = 3
414462
/// alignTo(321, 255, 42) = 552
415463
/// \endcode
416-
inline uint64_t alignTo(uint64_t Value, uint64_t Align, uint64_t Skew) {
464+
template <typename U, typename V, typename W,
465+
typename T = common_uint<common_uint<U, V>, W>>
466+
constexpr T alignTo(U Value, V Align, W Skew) {
417467
assert(Align != 0u && "Align can't be 0.");
418468
Skew %= Align;
419469
return alignTo(Value - Skew, Align) + Skew;
420470
}
421471

422-
/// Returns the next integer (mod 2**64) that is greater than or equal to
472+
/// Returns the next integer (mod 2**nbits) that is greater than or equal to
423473
/// \p Value and is a multiple of \c Align. \c Align must be non-zero.
424-
template <uint64_t Align> constexpr inline uint64_t alignTo(uint64_t Value) {
474+
///
475+
/// Will overflow only if result is not representable.
476+
template <auto Align, typename V, typename T = common_uint<decltype(Align), V>>
477+
constexpr T alignTo(V Value) {
425478
static_assert(Align != 0u, "Align must be non-zero");
426-
return (Value + Align - 1) / Align * Align;
479+
T Bias = (Value != 0);
480+
T CeilDiv = (Value - Bias) / Align + Bias;
481+
// If Value is negative, wrap will occur in the cast.
482+
if (Value > 0)
483+
assert(CeilDiv <= (std::numeric_limits<T>::max() - 1) / Align &&
484+
"alignTo would overflow");
485+
return CeilDiv * Align;
427486
}
428487

429488
/// Returns the integer ceil(Numerator / Denominator). Unsigned version.
430489
/// Guaranteed to never overflow.
431-
inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
490+
template <typename U, typename V, typename T = common_uint<U, V>>
491+
constexpr T divideCeil(U Numerator, V Denominator) {
492+
assert(Denominator && "Division by zero");
493+
T Bias = (Numerator != 0);
494+
return (Numerator - Bias) / Denominator + Bias;
495+
}
496+
497+
/// Fallback when arguments aren't integral.
498+
constexpr inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
432499
assert(Denominator && "Division by zero");
433500
uint64_t Bias = (Numerator != 0);
434501
return (Numerator - Bias) / Denominator + Bias;
435502
}
436503

437504
/// Returns the integer ceil(Numerator / Denominator). Signed version.
438505
/// Guaranteed to never overflow.
439-
inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
506+
template <typename U, typename V, typename T = common_sint<U, V>>
507+
constexpr T divideCeilSigned(U Numerator, V Denominator) {
440508
assert(Denominator && "Division by zero");
441509
if (!Numerator)
442510
return 0;
443511
// C's integer division rounds towards 0.
444-
int64_t Bias = (Denominator >= 0 ? 1 : -1);
512+
T Bias = Denominator >= 0 ? 1 : -1;
445513
bool SameSign = (Numerator >= 0) == (Denominator >= 0);
446514
return SameSign ? (Numerator - Bias) / Denominator + 1
447515
: Numerator / Denominator;
448516
}
449517

450518
/// Returns the integer floor(Numerator / Denominator). Signed version.
451519
/// Guaranteed to never overflow.
452-
inline int64_t divideFloorSigned(int64_t Numerator, int64_t Denominator) {
520+
template <typename U, typename V, typename T = common_sint<U, V>>
521+
constexpr T divideFloorSigned(U Numerator, V Denominator) {
453522
assert(Denominator && "Division by zero");
454523
if (!Numerator)
455524
return 0;
456525
// C's integer division rounds towards 0.
457-
int64_t Bias = Denominator >= 0 ? -1 : 1;
526+
T Bias = Denominator >= 0 ? -1 : 1;
458527
bool SameSign = (Numerator >= 0) == (Denominator >= 0);
459528
return SameSign ? Numerator / Denominator
460529
: (Numerator - Bias) / Denominator - 1;
461530
}
462531

463532
/// Returns the remainder of the Euclidean division of LHS by RHS. Result is
464533
/// always non-negative.
465-
inline int64_t mod(int64_t Numerator, int64_t Denominator) {
534+
template <typename U, typename V, typename T = common_sint<U, V>>
535+
constexpr T mod(U Numerator, V Denominator) {
466536
assert(Denominator >= 1 && "Mod by non-positive number");
467-
int64_t Mod = Numerator % Denominator;
537+
T Mod = Numerator % Denominator;
468538
return Mod < 0 ? Mod + Denominator : Mod;
469539
}
470540

471541
/// Returns (Numerator / Denominator) rounded by round-half-up. Guaranteed to
472542
/// never overflow.
473-
inline uint64_t divideNearest(uint64_t Numerator, uint64_t Denominator) {
543+
template <typename U, typename V, typename T = common_uint<U, V>>
544+
constexpr T divideNearest(U Numerator, V Denominator) {
474545
assert(Denominator && "Division by zero");
475-
uint64_t Mod = Numerator % Denominator;
476-
return (Numerator / Denominator) + (Mod > (Denominator - 1) / 2);
546+
T Mod = Numerator % Denominator;
547+
return (Numerator / Denominator) +
548+
(Mod > (static_cast<T>(Denominator) - 1) / 2);
477549
}
478550

479-
/// Returns the largest uint64_t less than or equal to \p Value and is
480-
/// \p Skew mod \p Align. \p Align must be non-zero
481-
inline uint64_t alignDown(uint64_t Value, uint64_t Align, uint64_t Skew = 0) {
551+
/// Returns the largest unsigned integer less than or equal to \p Value and is
552+
/// \p Skew mod \p Align. \p Align must be non-zero. Guaranteed to never
553+
/// overflow.
554+
template <typename U, typename V, typename W = uint8_t,
555+
typename T = common_uint<common_uint<U, V>, W>>
556+
constexpr T alignDown(U Value, V Align, W Skew = 0) {
482557
assert(Align != 0u && "Align can't be 0.");
483558
Skew %= Align;
484559
return (Value - Skew) / Align * Align + Skew;
@@ -522,8 +597,8 @@ inline int64_t SignExtend64(uint64_t X, unsigned B) {
522597

523598
/// Subtract two unsigned integers, X and Y, of type T and return the absolute
524599
/// value of the result.
525-
template <typename T>
526-
std::enable_if_t<std::is_unsigned_v<T>, T> AbsoluteDifference(T X, T Y) {
600+
template <typename U, typename V, typename T = common_uint<U, V>>
601+
constexpr T AbsoluteDifference(U X, V Y) {
527602
return X > Y ? (X - Y) : (Y - X);
528603
}
529604

llvm/unittests/Support/MathExtrasTest.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,23 +189,55 @@ TEST(MathExtras, AlignTo) {
189189
EXPECT_EQ(8u, alignTo(5, 8));
190190
EXPECT_EQ(24u, alignTo(17, 8));
191191
EXPECT_EQ(0u, alignTo(~0LL, 8));
192-
EXPECT_EQ(static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1,
193-
alignTo(std::numeric_limits<uint32_t>::max(), 2));
192+
EXPECT_EQ(8u, alignTo(5ULL, 8ULL));
193+
EXPECT_EQ(254u,
194+
alignTo(static_cast<uint8_t>(200), static_cast<uint8_t>(127)));
195+
#ifndef NDEBUG
196+
EXPECT_DEATH(alignTo(static_cast<uint8_t>(200), static_cast<uint8_t>(128)),
197+
"alignTo would overflow");
198+
EXPECT_DEATH(alignTo(std::numeric_limits<uint32_t>::max(), 2),
199+
"alignTo would overflow");
200+
#endif
201+
202+
EXPECT_EQ(8u, alignTo<8>(5));
203+
EXPECT_EQ(24u, alignTo<8>(17));
204+
EXPECT_EQ(0u, alignTo<8>(~0LL));
205+
EXPECT_EQ(254u,
206+
alignTo<static_cast<uint8_t>(127)>(static_cast<uint8_t>(200)));
207+
#ifndef NDEBUG
208+
EXPECT_DEATH(alignTo<static_cast<uint8_t>(128)>(static_cast<uint8_t>(200)),
209+
"alignTo would overflow");
210+
EXPECT_DEATH(alignTo<2>(std::numeric_limits<uint32_t>::max()),
211+
"alignTo would overflow");
212+
#endif
194213

195214
EXPECT_EQ(7u, alignTo(5, 8, 7));
196215
EXPECT_EQ(17u, alignTo(17, 8, 1));
197216
EXPECT_EQ(3u, alignTo(~0LL, 8, 3));
198217
EXPECT_EQ(552u, alignTo(321, 255, 42));
199218
EXPECT_EQ(std::numeric_limits<uint32_t>::max(),
200219
alignTo(std::numeric_limits<uint32_t>::max(), 2, 1));
220+
221+
#ifndef NDEBUG
222+
EXPECT_DEATH(alignTo(std::numeric_limits<uint32_t>::max(), 4, 2),
223+
"alignTo would overflow");
224+
#endif
201225
}
202226

203227
TEST(MathExtras, AlignToPowerOf2) {
204228
EXPECT_EQ(8u, alignToPowerOf2(5, 8));
205229
EXPECT_EQ(24u, alignToPowerOf2(17, 8));
206230
EXPECT_EQ(0u, alignToPowerOf2(~0LL, 8));
207-
EXPECT_EQ(static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()) + 1,
208-
alignToPowerOf2(std::numeric_limits<uint32_t>::max(), 2));
231+
EXPECT_EQ(8u, alignToPowerOf2(5ULL, 8ULL));
232+
EXPECT_EQ(240u, alignToPowerOf2(static_cast<uint8_t>(240),
233+
static_cast<uint8_t>(16)));
234+
#ifndef NDEBUG
235+
EXPECT_DEATH(
236+
alignToPowerOf2(static_cast<uint8_t>(200), static_cast<uint8_t>(128)),
237+
"alignToPowerOf2 would overflow");
238+
EXPECT_DEATH(alignToPowerOf2(std::numeric_limits<uint32_t>::max(), 2),
239+
"alignToPowerOf2 would overflow");
240+
#endif
209241
}
210242

211243
TEST(MathExtras, AlignDown) {
@@ -484,6 +516,7 @@ TEST(MathExtras, DivideCeil) {
484516
EXPECT_EQ(divideCeil(3, 1), 3u);
485517
EXPECT_EQ(divideCeil(3, 6), 1u);
486518
EXPECT_EQ(divideCeil(3, 7), 1u);
519+
EXPECT_EQ(divideCeil(3ULL, 7ULL), 1u);
487520
EXPECT_EQ(divideCeil(std::numeric_limits<uint32_t>::max(), 2),
488521
std::numeric_limits<uint32_t>::max() / 2 + 1);
489522
EXPECT_EQ(divideCeil(std::numeric_limits<uint64_t>::max(), 2),

mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
114114
return ShapedType::kDynamic;
115115

116116
assert(dimSize % shardCount == 0);
117-
return llvm::divideCeilSigned(dimSize, shardCount);
117+
return dimSize / shardCount;
118118
}
119119

120120
// Get the size of an unsharded dimension.

mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ void UnrankedMemRefDescriptor::computeSizes(
365365
Value two = createIndexAttrConstant(builder, loc, indexType, 2);
366366
Value indexSize = createIndexAttrConstant(
367367
builder, loc, indexType,
368-
llvm::divideCeilSigned(typeConverter.getIndexTypeBitwidth(), 8));
368+
llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
369369

370370
sizes.reserve(sizes.size() + values.size());
371371
for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
@@ -378,8 +378,7 @@ void UnrankedMemRefDescriptor::computeSizes(
378378
// to data layout) into the unranked descriptor.
379379
Value pointerSize = createIndexAttrConstant(
380380
builder, loc, indexType,
381-
llvm::divideCeilSigned(typeConverter.getPointerBitwidth(addressSpace),
382-
8));
381+
llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
383382
Value doublePointerSize =
384383
builder.create<LLVM::MulOp>(loc, indexType, two, pointerSize);
385384

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ struct MemorySpaceCastOpLowering
971971
resultUnderlyingDesc, resultElemPtrType);
972972

973973
int64_t bytesToSkip =
974-
2 * llvm::divideCeilSigned(
974+
2 * llvm::divideCeil(
975975
getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
976976
Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
977977
loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));

0 commit comments

Comments
 (0)