Skip to content

linear_congruential_engine: add using more precision to prevent overflow #81583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 82 additions & 21 deletions libcxx/include/__random/linear_congruential_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,60 @@ _LIBCPP_PUSH_MACROS

_LIBCPP_BEGIN_NAMESPACE_STD

enum __lce_alg_type {
_LCE_Full,
_LCE_Part,
_LCE_Schrage,
_LCE_Promote,
};

template <unsigned long long __a,
unsigned long long __c,
unsigned long long __m,
unsigned long long _Mp,
bool _MightOverflow = (__a != 0 && __m != 0 && __m - 1 > (_Mp - __c) / __a),
bool _OverflowOK = ((__m & (__m - 1)) == 0ull), // m = 2^n
bool _SchrageOK = (__a != 0 && __m != 0 && __m % __a <= __m / __a)> // r <= q
bool _HasOverflow = (__a != 0ull && (__m & (__m - 1ull)) != 0ull), // a != 0, m != 0, m != 2^n
bool _Full = (!_HasOverflow || __m - 1ull <= (_Mp - __c) / __a), // (a * x + c) % m works
bool _Part = (!_HasOverflow || __m - 1ull <= _Mp / __a), // (a * x) % m works
bool _Schrage = (_HasOverflow && __m % __a <= __m / __a)> // r <= q
struct __lce_alg_picker {
static_assert(!_MightOverflow || _OverflowOK || _SchrageOK,
"The current values of a, c, and m cannot generate a number "
"within bounds of linear_congruential_engine.");

static _LIBCPP_CONSTEXPR const bool __use_schrage = _MightOverflow && !_OverflowOK && _SchrageOK;
static _LIBCPP_CONSTEXPR const __lce_alg_type __mode =
_Full ? _LCE_Full
: _Part ? _LCE_Part
: _Schrage ? _LCE_Schrage
: _LCE_Promote;

#ifdef _LIBCPP_HAS_NO_INT128
static_assert(_Mp != (unsigned long long)(-1) || _Full || _Part || _Schrage,
"The current values for a, c, and m are not currently supported on platforms without __int128");
#endif
};

template <unsigned long long __a,
unsigned long long __c,
unsigned long long __m,
unsigned long long _Mp,
bool _UseSchrage = __lce_alg_picker<__a, __c, __m, _Mp>::__use_schrage>
__lce_alg_type _Mode = __lce_alg_picker<__a, __c, __m, _Mp>::__mode>
struct __lce_ta;

// 64

#ifndef _LIBCPP_HAS_NO_INT128
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
struct __lce_ta<_Ap, _Cp, _Mp, (unsigned long long)(-1), _LCE_Promote> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __xp) {
__extension__ using __calc_type = unsigned __int128;
const __calc_type __a = static_cast<__calc_type>(_Ap);
const __calc_type __c = static_cast<__calc_type>(_Cp);
const __calc_type __m = static_cast<__calc_type>(_Mp);
const __calc_type __x = static_cast<__calc_type>(__xp);
return static_cast<result_type>((__a * __x + __c) % __m);
}
};
#endif

template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
struct __lce_ta<__a, __c, __m, (unsigned long long)(-1), _LCE_Schrage> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
// Schrage's algorithm
Expand All @@ -66,7 +94,7 @@ struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
};

template <unsigned long long __a, unsigned long long __m>
struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
struct __lce_ta<__a, 0ull, __m, (unsigned long long)(-1), _LCE_Schrage> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
// Schrage's algorithm
Expand All @@ -80,21 +108,40 @@ struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
};

template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), false> {
struct __lce_ta<__a, __c, __m, (unsigned long long)(-1), _LCE_Part> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
// Use (((a*x) % m) + c) % m
__x = (__a * __x) % __m;
__x += __c - (__x >= __m - __c) * __m;
return __x;
}
};

template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
struct __lce_ta<__a, __c, __m, (unsigned long long)(-1), _LCE_Full> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return (__a * __x + __c) % __m; }
};

template <unsigned long long __a, unsigned long long __c>
struct __lce_ta<__a, __c, 0, (unsigned long long)(~0), false> {
struct __lce_ta<__a, __c, 0ull, (unsigned long long)(-1), _LCE_Full> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return __a * __x + __c; }
};

// 32

template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
struct __lce_ta<__a, __c, __m, unsigned(-1), _LCE_Promote> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
return static_cast<result_type>(__lce_ta<__a, __c, __m, (unsigned long long)(-1)>::next(__x));
}
};

template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(-1), _LCE_Schrage> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
Expand All @@ -112,7 +159,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
};

template <unsigned long long _Ap, unsigned long long _Mp>
struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
struct __lce_ta<_Ap, 0ull, _Mp, unsigned(-1), _LCE_Schrage> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
Expand All @@ -128,7 +175,21 @@ struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
};

template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(-1), _LCE_Part> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
const result_type __c = static_cast<result_type>(_Cp);
const result_type __m = static_cast<result_type>(_Mp);
// Use (((a*x) % m) + c) % m
__x = (__a * __x) % __m;
__x += __c - (__x >= __m - __c) * __m;
return __x;
}
};

template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(-1), _LCE_Full> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
Expand All @@ -139,7 +200,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
};

template <unsigned long long _Ap, unsigned long long _Cp>
struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
struct __lce_ta<_Ap, _Cp, 0ull, unsigned(-1), _LCE_Full> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
Expand All @@ -150,11 +211,11 @@ struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {

// 16

template <unsigned long long __a, unsigned long long __c, unsigned long long __m, bool __b>
struct __lce_ta<__a, __c, __m, (unsigned short)(~0), __b> {
template <unsigned long long __a, unsigned long long __c, unsigned long long __m, __lce_alg_type __mode>
struct __lce_ta<__a, __c, __m, (unsigned short)(-1), __mode> {
typedef unsigned short result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
return static_cast<result_type>(__lce_ta<__a, __c, __m, unsigned(~0)>::next(__x));
return static_cast<result_type>(__lce_ta<__a, __c, __m, unsigned(-1)>::next(__x));
}
};

Expand All @@ -178,7 +239,7 @@ class _LIBCPP_TEMPLATE_VIS linear_congruential_engine {
private:
result_type __x_;

static _LIBCPP_CONSTEXPR const result_type _Mp = result_type(~0);
static _LIBCPP_CONSTEXPR const result_type _Mp = result_type(-1);

static_assert(__m == 0 || __a < __m, "linear_congruential_engine invalid parameters");
static_assert(__m == 0 || __c < __m, "linear_congruential_engine invalid parameters");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ int main(int, char**)

// m might overflow. The overflow is not OK and result will be in bounds
// so we should use Schrage's algorithm
typedef std::linear_congruential_engine<T, (1ull << 32), 0, (1ull << 63) + 1> E2;
typedef std::linear_congruential_engine<T, (1ull << 32), 0, (1ull << 63) + 1ull> E2;
E2 e2;
// make sure Schrage's algorithm is used (it would be 0s after the first otherwise)
assert(e2() == (1ull << 32));
assert(e2() == (1ull << 63) - 1ull);
assert(e2() == (1ull << 63) - (1ull << 33) + 1ull);
assert(e2() == (1ull << 63) - 0x1ffffffffull);
// make sure result is in bounds
assert(e2() < (1ull << 63) + 1);
assert(e2() < (1ull << 63) + 1);
Expand All @@ -56,29 +56,62 @@ int main(int, char**)
typedef std::linear_congruential_engine<T, 0x18000001ull, 0x12347ull, (3ull << 56)> E3;
E3 e3;
// make sure Schrage's algorithm is used
assert(e3() == 402727752ull);
assert(e3() == 162159612030764687ull);
assert(e3() == 108176466184989142ull);
assert(e3() == 0x18012348ull);
assert(e3() == 0x2401b4ed802468full);
assert(e3() == 0x18051ec400369d6ull);
// make sure result is in bounds
assert(e3() < (3ull << 56));
assert(e3() < (3ull << 56));
assert(e3() < (3ull << 56));
assert(e3() < (3ull << 56));
assert(e3() < (3ull << 56));

// m will not overflow so we should not use Schrage's algorithm
typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E4;
// 32-bit case:
// m might overflow. The overflow is not OK, result will be in bounds,
// and Schrage's algorithm is incompatible here. Need to use 64 bit arithmetic.
typedef std::linear_congruential_engine<unsigned, 0x10009u, 0u, 0x7fffffffu> E4;
E4 e4;
// make sure enough precision is used
assert(e4() == 0x10009u);
assert(e4() == 0x120053u);
assert(e4() == 0xf5030fu);
// make sure result is in bounds
assert(e4() < 0x7fffffffu);
assert(e4() < 0x7fffffffu);
assert(e4() < 0x7fffffffu);
assert(e4() < 0x7fffffffu);
assert(e4() < 0x7fffffffu);

#ifndef _LIBCPP_HAS_NO_INT128
// m might overflow. The overflow is not OK, result will be in bounds,
// and Schrage's algorithm is incompatible here. Need to use 128 bit arithmetic.
typedef std::linear_congruential_engine<T, 0x100000001ull, 0ull, (1ull << 61) - 1ull> E5;
E5 e5;
// make sure enough precision is used
assert(e5() == 0x100000001ull);
assert(e5() == 0x200000009ull);
assert(e5() == 0xb00000019ull);
// make sure result is in bounds
assert(e5() < (1ull << 61) - 1ull);
assert(e5() < (1ull << 61) - 1ull);
assert(e5() < (1ull << 61) - 1ull);
assert(e5() < (1ull << 61) - 1ull);
assert(e5() < (1ull << 61) - 1ull);
#endif

// m will not overflow so we should not use Schrage's algorithm
typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E6;
E6 e6;
// make sure the correct algorithm was used
assert(e4() == 2ull);
assert(e4() == 3ull);
assert(e4() == 4ull);
assert(e6() == 2ull);
assert(e6() == 3ull);
assert(e6() == 4ull);
// make sure result is in bounds
assert(e4() < (1ull << 48));
assert(e4() < (1ull << 48));
assert(e4() < (1ull << 48));
assert(e4() < (1ull << 48));
assert(e4() < (1ull << 48));
assert(e6() < (1ull << 48));
assert(e6() < (1ull << 48));
assert(e6() < (1ull << 48));
assert(e6() < (1ull << 48));
assert(e6() < (1ull << 48));

return 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,34 @@ test()
test1<T, A, 0, M>();
test1<T, A, M - 2, M>();
test1<T, A, M - 1, M>();
}

template <class T>
void test_ext() {
const T M(static_cast<T>(-1));

/*
// Cases where m is odd and m % a > m / a (not implemented)
// Cases where m is odd and m % a > m / a
test1<T, M - 2, 0, M>();
test1<T, M - 2, M - 2, M>();
test1<T, M - 2, M - 1, M>();
test1<T, M - 1, 0, M>();
test1<T, M - 1, M - 2, M>();
test1<T, M - 1, M - 1, M>();
*/
}

int main(int, char**)
{
test<unsigned short>();
test_ext<unsigned short>();
test<unsigned int>();
test_ext<unsigned int>();
test<unsigned long>();
test_ext<unsigned long>();
test<unsigned long long>();
// This isn't implemented on platforms without __int128
#ifndef _LIBCPP_HAS_NO_INT128
test_ext<unsigned long long>();
#endif

return 0;
return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,34 @@ test()
test1<T, A, 0, M>();
test1<T, A, M - 2, M>();
test1<T, A, M - 1, M>();
}

template <class T>
void test_ext() {
const T M(static_cast<T>(-1));

/*
// Cases where m is odd and m % a > m / a (not implemented)
// Cases where m is odd and m % a > m / a
test1<T, M - 2, 0, M>();
test1<T, M - 2, M - 2, M>();
test1<T, M - 2, M - 1, M>();
test1<T, M - 1, 0, M>();
test1<T, M - 1, M - 2, M>();
test1<T, M - 1, M - 1, M>();
*/
}

int main(int, char**)
{
test<unsigned short>();
test_ext<unsigned short>();
test<unsigned int>();
test_ext<unsigned int>();
test<unsigned long>();
test_ext<unsigned long>();
test<unsigned long long>();
// This isn't implemented on platforms without __int128
#ifndef _LIBCPP_HAS_NO_INT128
test_ext<unsigned long long>();
#endif

return 0;
return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,34 @@ test()
test1<T, A, 0, M>();
test1<T, A, M - 2, M>();
test1<T, A, M - 1, M>();
}

template <class T>
void test_ext() {
const T M(static_cast<T>(-1));

/*
// Cases where m is odd and m % a > m / a (not implemented)
// Cases where m is odd and m % a > m / a
test1<T, M - 2, 0, M>();
test1<T, M - 2, M - 2, M>();
test1<T, M - 2, M - 1, M>();
test1<T, M - 1, 0, M>();
test1<T, M - 1, M - 2, M>();
test1<T, M - 1, M - 1, M>();
*/
}

int main(int, char**)
{
test<unsigned short>();
test_ext<unsigned short>();
test<unsigned int>();
test_ext<unsigned int>();
test<unsigned long>();
test_ext<unsigned long>();
test<unsigned long long>();
// This isn't implemented on platforms without __int128
#ifndef _LIBCPP_HAS_NO_INT128
test_ext<unsigned long long>();
#endif

return 0;
return 0;
}
Loading