Skip to content

Commit 87db0c0

Browse files
[libc] Add bigint casting between word types (#111914)
Previously you could cast between bigints with different numbers of bits, but only if they had the same underlying type. This patch adds the ability to cast between bigints with different underlying types, which is needed for #110894
1 parent 7e72e5b commit 87db0c0

File tree

2 files changed

+229
-12
lines changed

2 files changed

+229
-12
lines changed

libc/src/__support/big_int.h

Lines changed: 88 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "src/__support/CPP/limits.h"
1515
#include "src/__support/CPP/optional.h"
1616
#include "src/__support/CPP/type_traits.h"
17-
#include "src/__support/macros/attributes.h" // LIBC_INLINE
17+
#include "src/__support/macros/attributes.h" // LIBC_INLINE
1818
#include "src/__support/macros/config.h"
1919
#include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
2020
#include "src/__support/macros/properties/compiler.h" // LIBC_COMPILER_IS_CLANG
@@ -361,17 +361,94 @@ struct BigInt {
361361

362362
LIBC_INLINE constexpr BigInt(const BigInt &other) = default;
363363

364-
template <size_t OtherBits, bool OtherSigned>
364+
template <size_t OtherBits, bool OtherSigned, typename OtherWordType>
365365
LIBC_INLINE constexpr BigInt(
366-
const BigInt<OtherBits, OtherSigned, WordType> &other) {
367-
if (OtherBits >= Bits) { // truncate
368-
for (size_t i = 0; i < WORD_COUNT; ++i)
369-
val[i] = other[i];
370-
} else { // zero or sign extend
371-
size_t i = 0;
372-
for (; i < OtherBits / WORD_SIZE; ++i)
373-
val[i] = other[i];
374-
extend(i, Signed && other.is_neg());
366+
const BigInt<OtherBits, OtherSigned, OtherWordType> &other) {
367+
using BigIntOther = BigInt<OtherBits, OtherSigned, OtherWordType>;
368+
const bool should_sign_extend = Signed && other.is_neg();
369+
370+
static_assert(!(Bits == OtherBits && WORD_SIZE != BigIntOther::WORD_SIZE) &&
371+
"This is currently untested for casting between bigints with "
372+
"the same bit width but different word sizes.");
373+
374+
if constexpr (BigIntOther::WORD_SIZE < WORD_SIZE) {
375+
// OtherWordType is smaller
376+
constexpr size_t WORD_SIZE_RATIO = WORD_SIZE / BigIntOther::WORD_SIZE;
377+
static_assert(
378+
(WORD_SIZE % BigIntOther::WORD_SIZE) == 0 &&
379+
"Word types must be multiples of each other for correct conversion.");
380+
if constexpr (OtherBits >= Bits) { // truncate
381+
// for each big word
382+
for (size_t i = 0; i < WORD_COUNT; ++i) {
383+
WordType cur_word = 0;
384+
// combine WORD_SIZE_RATIO small words into a big word
385+
for (size_t j = 0; j < WORD_SIZE_RATIO; ++j)
386+
cur_word |= static_cast<WordType>(other[(i * WORD_SIZE_RATIO) + j])
387+
<< (BigIntOther::WORD_SIZE * j);
388+
389+
val[i] = cur_word;
390+
}
391+
} else { // zero or sign extend
392+
size_t i = 0;
393+
WordType cur_word = 0;
394+
// for each small word
395+
for (; i < BigIntOther::WORD_COUNT; ++i) {
396+
// combine WORD_SIZE_RATIO small words into a big word
397+
cur_word |= static_cast<WordType>(other[i])
398+
<< (BigIntOther::WORD_SIZE * (i % WORD_SIZE_RATIO));
399+
// if we've completed a big word, copy it into place and reset
400+
if ((i % WORD_SIZE_RATIO) == WORD_SIZE_RATIO - 1) {
401+
val[i / WORD_SIZE_RATIO] = cur_word;
402+
cur_word = 0;
403+
}
404+
}
405+
// Pretend there are extra words of the correct sign extension as needed
406+
407+
const WordType extension_bits =
408+
should_sign_extend ? cpp::numeric_limits<WordType>::max()
409+
: cpp::numeric_limits<WordType>::min();
410+
if ((i % WORD_SIZE_RATIO) != 0) {
411+
cur_word |= static_cast<WordType>(extension_bits)
412+
<< (BigIntOther::WORD_SIZE * (i % WORD_SIZE_RATIO));
413+
}
414+
// Copy the last word into place.
415+
val[(i / WORD_SIZE_RATIO)] = cur_word;
416+
extend((i / WORD_SIZE_RATIO) + 1, should_sign_extend);
417+
}
418+
} else if constexpr (BigIntOther::WORD_SIZE == WORD_SIZE) {
419+
if constexpr (OtherBits >= Bits) { // truncate
420+
for (size_t i = 0; i < WORD_COUNT; ++i)
421+
val[i] = other[i];
422+
} else { // zero or sign extend
423+
size_t i = 0;
424+
for (; i < BigIntOther::WORD_COUNT; ++i)
425+
val[i] = other[i];
426+
extend(i, should_sign_extend);
427+
}
428+
} else {
429+
// OtherWordType is bigger.
430+
constexpr size_t WORD_SIZE_RATIO = BigIntOther::WORD_SIZE / WORD_SIZE;
431+
static_assert(
432+
(BigIntOther::WORD_SIZE % WORD_SIZE) == 0 &&
433+
"Word types must be multiples of each other for correct conversion.");
434+
if constexpr (OtherBits >= Bits) { // truncate
435+
// for each small word
436+
for (size_t i = 0; i < WORD_COUNT; ++i) {
437+
// split each big word into WORD_SIZE_RATIO small words
438+
val[i] = static_cast<WordType>(other[i / WORD_SIZE_RATIO] >>
439+
((i % WORD_SIZE_RATIO) * WORD_SIZE));
440+
}
441+
} else { // zero or sign extend
442+
size_t i = 0;
443+
// for each big word
444+
for (; i < BigIntOther::WORD_COUNT; ++i) {
445+
// split each big word into WORD_SIZE_RATIO small words
446+
for (size_t j = 0; j < WORD_SIZE_RATIO; ++j)
447+
val[(i * WORD_SIZE_RATIO) + j] =
448+
static_cast<WordType>(other[i] >> (j * WORD_SIZE));
449+
}
450+
extend(i * WORD_SIZE_RATIO, should_sign_extend);
451+
}
375452
}
376453
}
377454

libc/test/src/__support/big_int_test.cpp

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#include "src/__support/CPP/optional.h"
1010
#include "src/__support/big_int.h"
11-
#include "src/__support/integer_literals.h" // parse_unsigned_bigint
11+
#include "src/__support/integer_literals.h" // parse_unsigned_bigint
1212
#include "src/__support/macros/config.h"
1313
#include "src/__support/macros/properties/types.h" // LIBC_TYPES_HAS_INT128
1414

@@ -208,6 +208,7 @@ TYPED_TEST(LlvmLibcUIntClassTest, CountBits, Types) {
208208
}
209209

210210
using LL_UInt16 = UInt<16>;
211+
using LL_UInt32 = UInt<32>;
211212
using LL_UInt64 = UInt<64>;
212213
// We want to test UInt<128> explicitly. So, for
213214
// convenience, we use a sugar which does not conflict with the UInt128 type
@@ -927,4 +928,143 @@ TEST(LlvmLibcUIntClassTest, OtherWordTypeTests) {
927928
ASSERT_EQ(static_cast<int>(a >> 64), 1);
928929
}
929930

931+
TEST(LlvmLibcUIntClassTest, OtherWordTypeCastTests) {
932+
using LL_UInt96 = BigInt<96, false, uint32_t>;
933+
934+
LL_UInt96 a({123, 456, 789});
935+
936+
ASSERT_EQ(static_cast<int>(a), 123);
937+
ASSERT_EQ(static_cast<int>(a >> 32), 456);
938+
ASSERT_EQ(static_cast<int>(a >> 64), 789);
939+
940+
// Bigger word with more bits to smaller word with less bits.
941+
LL_UInt128 b(a);
942+
943+
ASSERT_EQ(static_cast<int>(b), 123);
944+
ASSERT_EQ(static_cast<int>(b >> 32), 456);
945+
ASSERT_EQ(static_cast<int>(b >> 64), 789);
946+
ASSERT_EQ(static_cast<int>(b >> 96), 0);
947+
948+
b = (b << 32) + 987;
949+
950+
ASSERT_EQ(static_cast<int>(b), 987);
951+
ASSERT_EQ(static_cast<int>(b >> 32), 123);
952+
ASSERT_EQ(static_cast<int>(b >> 64), 456);
953+
ASSERT_EQ(static_cast<int>(b >> 96), 789);
954+
955+
// Smaller word with less bits to bigger word with more bits.
956+
LL_UInt96 c(b);
957+
958+
ASSERT_EQ(static_cast<int>(c), 987);
959+
ASSERT_EQ(static_cast<int>(c >> 32), 123);
960+
ASSERT_EQ(static_cast<int>(c >> 64), 456);
961+
962+
// Smaller word with more bits to bigger word with less bits
963+
LL_UInt64 d(c);
964+
965+
ASSERT_EQ(static_cast<int>(d), 987);
966+
ASSERT_EQ(static_cast<int>(d >> 32), 123);
967+
968+
// Bigger word with less bits to smaller word with more bits
969+
970+
LL_UInt96 e(d);
971+
972+
ASSERT_EQ(static_cast<int>(e), 987);
973+
ASSERT_EQ(static_cast<int>(e >> 32), 123);
974+
975+
e = (e << 32) + 654;
976+
977+
ASSERT_EQ(static_cast<int>(e), 654);
978+
ASSERT_EQ(static_cast<int>(e >> 32), 987);
979+
ASSERT_EQ(static_cast<int>(e >> 64), 123);
980+
}
981+
982+
TEST(LlvmLibcUIntClassTest, SignedOtherWordTypeCastTests) {
983+
using LL_Int64 = BigInt<64, true, uint64_t>;
984+
using LL_Int96 = BigInt<96, true, uint32_t>;
985+
986+
LL_Int64 zero_64(0);
987+
LL_Int96 zero_96(0);
988+
LL_Int192 zero_192(0);
989+
990+
LL_Int96 plus_a({0x1234, 0x5678, 0x9ABC});
991+
992+
ASSERT_EQ(static_cast<int>(plus_a), 0x1234);
993+
ASSERT_EQ(static_cast<int>(plus_a >> 32), 0x5678);
994+
ASSERT_EQ(static_cast<int>(plus_a >> 64), 0x9ABC);
995+
996+
LL_Int96 minus_a(-plus_a);
997+
998+
// The reason that the numbers are inverted and not negated is that we're
999+
// using two's complement. To negate a two's complement number you flip the
1000+
// bits and add 1, so minus_a is {~0x1234, ~0x5678, ~0x9ABC} + {1,0,0}.
1001+
ASSERT_EQ(static_cast<int>(minus_a), (~0x1234) + 1);
1002+
ASSERT_EQ(static_cast<int>(minus_a >> 32), ~0x5678);
1003+
ASSERT_EQ(static_cast<int>(minus_a >> 64), ~0x9ABC);
1004+
1005+
ASSERT_TRUE(plus_a + minus_a == zero_96);
1006+
1007+
// 192 so there's an extra block to get sign extended to
1008+
LL_Int192 bigger_plus_a(plus_a);
1009+
1010+
ASSERT_EQ(static_cast<int>(bigger_plus_a), 0x1234);
1011+
ASSERT_EQ(static_cast<int>(bigger_plus_a >> 32), 0x5678);
1012+
ASSERT_EQ(static_cast<int>(bigger_plus_a >> 64), 0x9ABC);
1013+
ASSERT_EQ(static_cast<int>(bigger_plus_a >> 96), 0);
1014+
ASSERT_EQ(static_cast<int>(bigger_plus_a >> 128), 0);
1015+
ASSERT_EQ(static_cast<int>(bigger_plus_a >> 160), 0);
1016+
1017+
LL_Int192 bigger_minus_a(minus_a);
1018+
1019+
ASSERT_EQ(static_cast<int>(bigger_minus_a), (~0x1234) + 1);
1020+
ASSERT_EQ(static_cast<int>(bigger_minus_a >> 32), ~0x5678);
1021+
ASSERT_EQ(static_cast<int>(bigger_minus_a >> 64), ~0x9ABC);
1022+
ASSERT_EQ(static_cast<int>(bigger_minus_a >> 96), ~0);
1023+
ASSERT_EQ(static_cast<int>(bigger_minus_a >> 128), ~0);
1024+
ASSERT_EQ(static_cast<int>(bigger_minus_a >> 160), ~0);
1025+
1026+
ASSERT_TRUE(bigger_plus_a + bigger_minus_a == zero_192);
1027+
1028+
LL_Int64 smaller_plus_a(plus_a);
1029+
1030+
ASSERT_EQ(static_cast<int>(smaller_plus_a), 0x1234);
1031+
ASSERT_EQ(static_cast<int>(smaller_plus_a >> 32), 0x5678);
1032+
1033+
LL_Int64 smaller_minus_a(minus_a);
1034+
1035+
ASSERT_EQ(static_cast<int>(smaller_minus_a), (~0x1234) + 1);
1036+
ASSERT_EQ(static_cast<int>(smaller_minus_a >> 32), ~0x5678);
1037+
1038+
ASSERT_TRUE(smaller_plus_a + smaller_minus_a == zero_64);
1039+
1040+
// Also try going from bigger word size to smaller word size
1041+
LL_Int96 smaller_back_plus_a(smaller_plus_a);
1042+
1043+
ASSERT_EQ(static_cast<int>(smaller_back_plus_a), 0x1234);
1044+
ASSERT_EQ(static_cast<int>(smaller_back_plus_a >> 32), 0x5678);
1045+
ASSERT_EQ(static_cast<int>(smaller_back_plus_a >> 64), 0);
1046+
1047+
LL_Int96 smaller_back_minus_a(smaller_minus_a);
1048+
1049+
ASSERT_EQ(static_cast<int>(smaller_back_minus_a), (~0x1234) + 1);
1050+
ASSERT_EQ(static_cast<int>(smaller_back_minus_a >> 32), ~0x5678);
1051+
ASSERT_EQ(static_cast<int>(smaller_back_minus_a >> 64), ~0);
1052+
1053+
ASSERT_TRUE(smaller_back_plus_a + smaller_back_minus_a == zero_96);
1054+
1055+
LL_Int96 bigger_back_plus_a(bigger_plus_a);
1056+
1057+
ASSERT_EQ(static_cast<int>(bigger_back_plus_a), 0x1234);
1058+
ASSERT_EQ(static_cast<int>(bigger_back_plus_a >> 32), 0x5678);
1059+
ASSERT_EQ(static_cast<int>(bigger_back_plus_a >> 64), 0x9ABC);
1060+
1061+
LL_Int96 bigger_back_minus_a(bigger_minus_a);
1062+
1063+
ASSERT_EQ(static_cast<int>(bigger_back_minus_a), (~0x1234) + 1);
1064+
ASSERT_EQ(static_cast<int>(bigger_back_minus_a >> 32), ~0x5678);
1065+
ASSERT_EQ(static_cast<int>(bigger_back_minus_a >> 64), ~0x9ABC);
1066+
1067+
ASSERT_TRUE(bigger_back_plus_a + bigger_back_minus_a == zero_96);
1068+
}
1069+
9301070
} // namespace LIBC_NAMESPACE_DECL

0 commit comments

Comments
 (0)