Skip to content

Commit 41e6962

Browse files
authored
linear_congruential_engine: add using more precision to prevent overflow (#81583)
This PR is a followup to #81080. This PR makes two major changes to how the LCG operation is computed: The first is that I added an additional case where `ax + c` might overflow the intermediate variable, but `ax` by itself won't. In this case, it's much better to use `(ax mod m) + c mod m` than the previous behavior of falling back to Schrage's algorithm. The addition modulo is done in the same way as when using Schrage's algorithm (i.e. `x += c - (x >= m - c)*m`), but the multiplication modulo is calculated directly, which is faster. The second is that I added handling for the case where the `ax` intermediate might overflow, but Schrage's algorithm doesn't apply (i.e. r > q). In this case, the only real option is to increase the precision of the intermediate values. The good news is that - for `x`, `a`, and `c` being n-bit values - `ax + c` will never overflow a 2n-bit intermediary, meaning this promotion can only happen once, and will always be able to use the simplest implementation. This is already the case for 16-bit LCGs, as libcxx chooses to compute them with 32-bit intermediate values. For 32-bit LCGs, I simply added code similar to the 16-bit case to use the existing 64-bit implementations. Lastly, for 64-bit LCGs, I wrote a case that calculates it using `unsigned __int128` if it is available to use. While this implementation covers a *lot* of the missing cases from #81080, this still won't compile **every** possible `linear_congruential_engine`. Specifically, if `a`, `c`, and `m` are chosen such that it needs 128-bit integers, but the platform doesn't support `__int128` (eg. 32-bit x86), then it will fail to compile. However, this is a fairly rare case to see actually used, and libcxx would be in good company with this, as [libstdc++ also fails to compile under these circumstances](https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87744). Fixing **this** gap would require even **more** work of further complexity, so that would probably be best handled by a different PR (I'll put more details on what that PR would entail in a comment).
1 parent 82c320c commit 41e6962

File tree

6 files changed

+187
-53
lines changed

6 files changed

+187
-53
lines changed

libcxx/include/__random/linear_congruential_engine.h

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,32 +26,60 @@ _LIBCPP_PUSH_MACROS
2626

2727
_LIBCPP_BEGIN_NAMESPACE_STD
2828

29+
enum __lce_alg_type {
30+
_LCE_Full,
31+
_LCE_Part,
32+
_LCE_Schrage,
33+
_LCE_Promote,
34+
};
35+
2936
template <unsigned long long __a,
3037
unsigned long long __c,
3138
unsigned long long __m,
3239
unsigned long long _Mp,
33-
bool _MightOverflow = (__a != 0 && __m != 0 && __m - 1 > (_Mp - __c) / __a),
34-
bool _OverflowOK = ((__m & (__m - 1)) == 0ull), // m = 2^n
35-
bool _SchrageOK = (__a != 0 && __m != 0 && __m % __a <= __m / __a)> // r <= q
40+
bool _HasOverflow = (__a != 0ull && (__m & (__m - 1ull)) != 0ull), // a != 0, m != 0, m != 2^n
41+
bool _Full = (!_HasOverflow || __m - 1ull <= (_Mp - __c) / __a), // (a * x + c) % m works
42+
bool _Part = (!_HasOverflow || __m - 1ull <= _Mp / __a), // (a * x) % m works
43+
bool _Schrage = (_HasOverflow && __m % __a <= __m / __a)> // r <= q
3644
struct __lce_alg_picker {
37-
static_assert(!_MightOverflow || _OverflowOK || _SchrageOK,
38-
"The current values of a, c, and m cannot generate a number "
39-
"within bounds of linear_congruential_engine.");
40-
41-
static _LIBCPP_CONSTEXPR const bool __use_schrage = _MightOverflow && !_OverflowOK && _SchrageOK;
45+
static _LIBCPP_CONSTEXPR const __lce_alg_type __mode =
46+
_Full ? _LCE_Full
47+
: _Part ? _LCE_Part
48+
: _Schrage ? _LCE_Schrage
49+
: _LCE_Promote;
50+
51+
#ifdef _LIBCPP_HAS_NO_INT128
52+
static_assert(_Mp != (unsigned long long)(-1) || _Full || _Part || _Schrage,
53+
"The current values for a, c, and m are not currently supported on platforms without __int128");
54+
#endif
4255
};
4356

4457
template <unsigned long long __a,
4558
unsigned long long __c,
4659
unsigned long long __m,
4760
unsigned long long _Mp,
48-
bool _UseSchrage = __lce_alg_picker<__a, __c, __m, _Mp>::__use_schrage>
61+
__lce_alg_type _Mode = __lce_alg_picker<__a, __c, __m, _Mp>::__mode>
4962
struct __lce_ta;
5063

5164
// 64
5265

66+
#ifndef _LIBCPP_HAS_NO_INT128
67+
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
68+
struct __lce_ta<_Ap, _Cp, _Mp, (unsigned long long)(-1), _LCE_Promote> {
69+
typedef unsigned long long result_type;
70+
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __xp) {
71+
__extension__ using __calc_type = unsigned __int128;
72+
const __calc_type __a = static_cast<__calc_type>(_Ap);
73+
const __calc_type __c = static_cast<__calc_type>(_Cp);
74+
const __calc_type __m = static_cast<__calc_type>(_Mp);
75+
const __calc_type __x = static_cast<__calc_type>(__xp);
76+
return static_cast<result_type>((__a * __x + __c) % __m);
77+
}
78+
};
79+
#endif
80+
5381
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
54-
struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
82+
struct __lce_ta<__a, __c, __m, (unsigned long long)(-1), _LCE_Schrage> {
5583
typedef unsigned long long result_type;
5684
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
5785
// Schrage's algorithm
@@ -66,7 +94,7 @@ struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
6694
};
6795

6896
template <unsigned long long __a, unsigned long long __m>
69-
struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
97+
struct __lce_ta<__a, 0ull, __m, (unsigned long long)(-1), _LCE_Schrage> {
7098
typedef unsigned long long result_type;
7199
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
72100
// Schrage's algorithm
@@ -80,21 +108,40 @@ struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
80108
};
81109

82110
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
83-
struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), false> {
111+
struct __lce_ta<__a, __c, __m, (unsigned long long)(-1), _LCE_Part> {
112+
typedef unsigned long long result_type;
113+
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
114+
// Use (((a*x) % m) + c) % m
115+
__x = (__a * __x) % __m;
116+
__x += __c - (__x >= __m - __c) * __m;
117+
return __x;
118+
}
119+
};
120+
121+
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
122+
struct __lce_ta<__a, __c, __m, (unsigned long long)(-1), _LCE_Full> {
84123
typedef unsigned long long result_type;
85124
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return (__a * __x + __c) % __m; }
86125
};
87126

88127
template <unsigned long long __a, unsigned long long __c>
89-
struct __lce_ta<__a, __c, 0, (unsigned long long)(~0), false> {
128+
struct __lce_ta<__a, __c, 0ull, (unsigned long long)(-1), _LCE_Full> {
90129
typedef unsigned long long result_type;
91130
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return __a * __x + __c; }
92131
};
93132

94133
// 32
95134

135+
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
136+
struct __lce_ta<__a, __c, __m, unsigned(-1), _LCE_Promote> {
137+
typedef unsigned result_type;
138+
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
139+
return static_cast<result_type>(__lce_ta<__a, __c, __m, (unsigned long long)(-1)>::next(__x));
140+
}
141+
};
142+
96143
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
97-
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
144+
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(-1), _LCE_Schrage> {
98145
typedef unsigned result_type;
99146
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
100147
const result_type __a = static_cast<result_type>(_Ap);
@@ -112,7 +159,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
112159
};
113160

114161
template <unsigned long long _Ap, unsigned long long _Mp>
115-
struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
162+
struct __lce_ta<_Ap, 0ull, _Mp, unsigned(-1), _LCE_Schrage> {
116163
typedef unsigned result_type;
117164
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
118165
const result_type __a = static_cast<result_type>(_Ap);
@@ -128,7 +175,21 @@ struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
128175
};
129176

130177
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
131-
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
178+
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(-1), _LCE_Part> {
179+
typedef unsigned result_type;
180+
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
181+
const result_type __a = static_cast<result_type>(_Ap);
182+
const result_type __c = static_cast<result_type>(_Cp);
183+
const result_type __m = static_cast<result_type>(_Mp);
184+
// Use (((a*x) % m) + c) % m
185+
__x = (__a * __x) % __m;
186+
__x += __c - (__x >= __m - __c) * __m;
187+
return __x;
188+
}
189+
};
190+
191+
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
192+
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(-1), _LCE_Full> {
132193
typedef unsigned result_type;
133194
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
134195
const result_type __a = static_cast<result_type>(_Ap);
@@ -139,7 +200,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
139200
};
140201

141202
template <unsigned long long _Ap, unsigned long long _Cp>
142-
struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
203+
struct __lce_ta<_Ap, _Cp, 0ull, unsigned(-1), _LCE_Full> {
143204
typedef unsigned result_type;
144205
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
145206
const result_type __a = static_cast<result_type>(_Ap);
@@ -150,11 +211,11 @@ struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
150211

151212
// 16
152213

153-
template <unsigned long long __a, unsigned long long __c, unsigned long long __m, bool __b>
154-
struct __lce_ta<__a, __c, __m, (unsigned short)(~0), __b> {
214+
template <unsigned long long __a, unsigned long long __c, unsigned long long __m, __lce_alg_type __mode>
215+
struct __lce_ta<__a, __c, __m, (unsigned short)(-1), __mode> {
155216
typedef unsigned short result_type;
156217
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
157-
return static_cast<result_type>(__lce_ta<__a, __c, __m, unsigned(~0)>::next(__x));
218+
return static_cast<result_type>(__lce_ta<__a, __c, __m, unsigned(-1)>::next(__x));
158219
}
159220
};
160221

@@ -178,7 +239,7 @@ class _LIBCPP_TEMPLATE_VIS linear_congruential_engine {
178239
private:
179240
result_type __x_;
180241

181-
static _LIBCPP_CONSTEXPR const result_type _Mp = result_type(~0);
242+
static _LIBCPP_CONSTEXPR const result_type _Mp = result_type(-1);
182243

183244
static_assert(__m == 0 || __a < __m, "linear_congruential_engine invalid parameters");
184245
static_assert(__m == 0 || __c < __m, "linear_congruential_engine invalid parameters");

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ int main(int, char**)
3838

3939
// m might overflow. The overflow is not OK and result will be in bounds
4040
// so we should use Schrage's algorithm
41-
typedef std::linear_congruential_engine<T, (1ull << 32), 0, (1ull << 63) + 1> E2;
41+
typedef std::linear_congruential_engine<T, (1ull << 32), 0, (1ull << 63) + 1ull> E2;
4242
E2 e2;
4343
// make sure Schrage's algorithm is used (it would be 0s after the first otherwise)
4444
assert(e2() == (1ull << 32));
4545
assert(e2() == (1ull << 63) - 1ull);
46-
assert(e2() == (1ull << 63) - (1ull << 33) + 1ull);
46+
assert(e2() == (1ull << 63) - 0x1ffffffffull);
4747
// make sure result is in bounds
4848
assert(e2() < (1ull << 63) + 1);
4949
assert(e2() < (1ull << 63) + 1);
@@ -56,29 +56,62 @@ int main(int, char**)
5656
typedef std::linear_congruential_engine<T, 0x18000001ull, 0x12347ull, (3ull << 56)> E3;
5757
E3 e3;
5858
// make sure Schrage's algorithm is used
59-
assert(e3() == 402727752ull);
60-
assert(e3() == 162159612030764687ull);
61-
assert(e3() == 108176466184989142ull);
59+
assert(e3() == 0x18012348ull);
60+
assert(e3() == 0x2401b4ed802468full);
61+
assert(e3() == 0x18051ec400369d6ull);
6262
// make sure result is in bounds
6363
assert(e3() < (3ull << 56));
6464
assert(e3() < (3ull << 56));
6565
assert(e3() < (3ull << 56));
6666
assert(e3() < (3ull << 56));
6767
assert(e3() < (3ull << 56));
6868

69-
// m will not overflow so we should not use Schrage's algorithm
70-
typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E4;
69+
// 32-bit case:
70+
// m might overflow. The overflow is not OK, result will be in bounds,
71+
// and Schrage's algorithm is incompatible here. Need to use 64 bit arithmetic.
72+
typedef std::linear_congruential_engine<unsigned, 0x10009u, 0u, 0x7fffffffu> E4;
7173
E4 e4;
74+
// make sure enough precision is used
75+
assert(e4() == 0x10009u);
76+
assert(e4() == 0x120053u);
77+
assert(e4() == 0xf5030fu);
78+
// make sure result is in bounds
79+
assert(e4() < 0x7fffffffu);
80+
assert(e4() < 0x7fffffffu);
81+
assert(e4() < 0x7fffffffu);
82+
assert(e4() < 0x7fffffffu);
83+
assert(e4() < 0x7fffffffu);
84+
85+
#ifndef _LIBCPP_HAS_NO_INT128
86+
// m might overflow. The overflow is not OK, result will be in bounds,
87+
// and Schrage's algorithm is incompatible here. Need to use 128 bit arithmetic.
88+
typedef std::linear_congruential_engine<T, 0x100000001ull, 0ull, (1ull << 61) - 1ull> E5;
89+
E5 e5;
90+
// make sure enough precision is used
91+
assert(e5() == 0x100000001ull);
92+
assert(e5() == 0x200000009ull);
93+
assert(e5() == 0xb00000019ull);
94+
// make sure result is in bounds
95+
assert(e5() < (1ull << 61) - 1ull);
96+
assert(e5() < (1ull << 61) - 1ull);
97+
assert(e5() < (1ull << 61) - 1ull);
98+
assert(e5() < (1ull << 61) - 1ull);
99+
assert(e5() < (1ull << 61) - 1ull);
100+
#endif
101+
102+
// m will not overflow so we should not use Schrage's algorithm
103+
typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E6;
104+
E6 e6;
72105
// make sure the correct algorithm was used
73-
assert(e4() == 2ull);
74-
assert(e4() == 3ull);
75-
assert(e4() == 4ull);
106+
assert(e6() == 2ull);
107+
assert(e6() == 3ull);
108+
assert(e6() == 4ull);
76109
// make sure result is in bounds
77-
assert(e4() < (1ull << 48));
78-
assert(e4() < (1ull << 48));
79-
assert(e4() < (1ull << 48));
80-
assert(e4() < (1ull << 48));
81-
assert(e4() < (1ull << 48));
110+
assert(e6() < (1ull << 48));
111+
assert(e6() < (1ull << 48));
112+
assert(e6() < (1ull << 48));
113+
assert(e6() < (1ull << 48));
114+
assert(e6() < (1ull << 48));
82115

83116
return 0;
84-
}
117+
}

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,34 @@ test()
6161
test1<T, A, 0, M>();
6262
test1<T, A, M - 2, M>();
6363
test1<T, A, M - 1, M>();
64+
}
65+
66+
template <class T>
67+
void test_ext() {
68+
const T M(static_cast<T>(-1));
6469

65-
/*
66-
// Cases where m is odd and m % a > m / a (not implemented)
70+
// Cases where m is odd and m % a > m / a
6771
test1<T, M - 2, 0, M>();
6872
test1<T, M - 2, M - 2, M>();
6973
test1<T, M - 2, M - 1, M>();
7074
test1<T, M - 1, 0, M>();
7175
test1<T, M - 1, M - 2, M>();
7276
test1<T, M - 1, M - 1, M>();
73-
*/
7477
}
7578

7679
int main(int, char**)
7780
{
7881
test<unsigned short>();
82+
test_ext<unsigned short>();
7983
test<unsigned int>();
84+
test_ext<unsigned int>();
8085
test<unsigned long>();
86+
test_ext<unsigned long>();
8187
test<unsigned long long>();
88+
// This isn't implemented on platforms without __int128
89+
#ifndef _LIBCPP_HAS_NO_INT128
90+
test_ext<unsigned long long>();
91+
#endif
8292

83-
return 0;
93+
return 0;
8494
}

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,34 @@ test()
6060
test1<T, A, 0, M>();
6161
test1<T, A, M - 2, M>();
6262
test1<T, A, M - 1, M>();
63+
}
64+
65+
template <class T>
66+
void test_ext() {
67+
const T M(static_cast<T>(-1));
6368

64-
/*
65-
// Cases where m is odd and m % a > m / a (not implemented)
69+
// Cases where m is odd and m % a > m / a
6670
test1<T, M - 2, 0, M>();
6771
test1<T, M - 2, M - 2, M>();
6872
test1<T, M - 2, M - 1, M>();
6973
test1<T, M - 1, 0, M>();
7074
test1<T, M - 1, M - 2, M>();
7175
test1<T, M - 1, M - 1, M>();
72-
*/
7376
}
7477

7578
int main(int, char**)
7679
{
7780
test<unsigned short>();
81+
test_ext<unsigned short>();
7882
test<unsigned int>();
83+
test_ext<unsigned int>();
7984
test<unsigned long>();
85+
test_ext<unsigned long>();
8086
test<unsigned long long>();
87+
// This isn't implemented on platforms without __int128
88+
#ifndef _LIBCPP_HAS_NO_INT128
89+
test_ext<unsigned long long>();
90+
#endif
8191

82-
return 0;
92+
return 0;
8393
}

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,34 @@ test()
5858
test1<T, A, 0, M>();
5959
test1<T, A, M - 2, M>();
6060
test1<T, A, M - 1, M>();
61+
}
62+
63+
template <class T>
64+
void test_ext() {
65+
const T M(static_cast<T>(-1));
6166

62-
/*
63-
// Cases where m is odd and m % a > m / a (not implemented)
67+
// Cases where m is odd and m % a > m / a
6468
test1<T, M - 2, 0, M>();
6569
test1<T, M - 2, M - 2, M>();
6670
test1<T, M - 2, M - 1, M>();
6771
test1<T, M - 1, 0, M>();
6872
test1<T, M - 1, M - 2, M>();
6973
test1<T, M - 1, M - 1, M>();
70-
*/
7174
}
7275

7376
int main(int, char**)
7477
{
7578
test<unsigned short>();
79+
test_ext<unsigned short>();
7680
test<unsigned int>();
81+
test_ext<unsigned int>();
7782
test<unsigned long>();
83+
test_ext<unsigned long>();
7884
test<unsigned long long>();
85+
// This isn't implemented on platforms without __int128
86+
#ifndef _LIBCPP_HAS_NO_INT128
87+
test_ext<unsigned long long>();
88+
#endif
7989

80-
return 0;
90+
return 0;
8191
}

0 commit comments

Comments
 (0)