Skip to content

Commit 7dfdfb2

Browse files
committed
linear_congruential_engine: add using more precision to prevent overflow
1 parent fad3752 commit 7dfdfb2

File tree

6 files changed

+180
-50
lines changed

6 files changed

+180
-50
lines changed

libcxx/include/__random/linear_congruential_engine.h

Lines changed: 75 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,32 +26,56 @@ _LIBCPP_PUSH_MACROS
2626

2727
_LIBCPP_BEGIN_NAMESPACE_STD
2828

29+
enum __lce_alg_type {
30+
_LCE_Full,
31+
_LCE_Part,
32+
_LCE_Schrage,
33+
_LCE_Promote,
34+
};
35+
2936
template <unsigned long long __a,
3037
unsigned long long __c,
3138
unsigned long long __m,
3239
unsigned long long _Mp,
33-
bool _MightOverflow = (__a != 0 && __m != 0 && __m - 1 > (_Mp - __c) / __a),
34-
bool _OverflowOK = ((__m & (__m - 1)) == 0ull), // m = 2^n
35-
bool _SchrageOK = (__a != 0 && __m != 0 && __m % __a <= __m / __a)> // r <= q
40+
bool _HasOverflow = (__a != 0ull && (__m & (__m - 1ull)) != 0ull), // a != 0, m != 0, m != 2^n
41+
bool _Full = (!_HasOverflow || __m - 1ull <= (_Mp - __c) / __a), // (a * x + c) % m works
42+
bool _Part = (!_HasOverflow || __m - 1ull <= _Mp / __a), // (a * x) % m works
43+
bool _Schrage = (_HasOverflow && __m % __a <= __m / __a)> // r <= q
3644
struct __lce_alg_picker {
37-
static_assert(!_MightOverflow || _OverflowOK || _SchrageOK,
38-
"The current values of a, c, and m cannot generate a number "
39-
"within bounds of linear_congruential_engine.");
45+
static _LIBCPP_CONSTEXPR const __lce_alg_type __mode = _Full ? _LCE_Full : _Part ? _LCE_Part : _Schrage ? _LCE_Schrage : _LCE_Promote;
4046

41-
static _LIBCPP_CONSTEXPR const bool __use_schrage = _MightOverflow && !_OverflowOK && _SchrageOK;
47+
#ifdef _LIBCPP_HAS_NO_INT128
48+
static_assert(_Mp != (unsigned long long)(~0) || _Full || _Part || _Schrage,
49+
"The current values for a, c, and m are not currently supported on platforms without __int128");
50+
#endif
4251
};
4352

4453
template <unsigned long long __a,
4554
unsigned long long __c,
4655
unsigned long long __m,
4756
unsigned long long _Mp,
48-
bool _UseSchrage = __lce_alg_picker<__a, __c, __m, _Mp>::__use_schrage>
57+
__lce_alg_type _Mode = __lce_alg_picker<__a, __c, __m, _Mp>::__mode>
4958
struct __lce_ta;
5059

5160
// 64
5261

62+
#ifndef _LIBCPP_HAS_NO_INT128
63+
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
64+
struct __lce_ta<_Ap, _Cp, _Mp, (unsigned long long)(~0), _LCE_Promote> {
65+
typedef unsigned long long result_type;
66+
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __xp) {
67+
__extension__ using __calc_type = unsigned __int128;
68+
const __calc_type __a = static_cast<__calc_type>(_Ap);
69+
const __calc_type __c = static_cast<__calc_type>(_Cp);
70+
const __calc_type __m = static_cast<__calc_type>(_Mp);
71+
const __calc_type __x = static_cast<__calc_type>(__xp);
72+
return static_cast<result_type>((__a * __x + __c) % __m);
73+
}
74+
};
75+
#endif
76+
5377
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
54-
struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
78+
struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), _LCE_Schrage> {
5579
typedef unsigned long long result_type;
5680
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
5781
// Schrage's algorithm
@@ -66,7 +90,7 @@ struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
6690
};
6791

6892
template <unsigned long long __a, unsigned long long __m>
69-
struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
93+
struct __lce_ta<__a, 0ull, __m, (unsigned long long)(~0), _LCE_Schrage> {
7094
typedef unsigned long long result_type;
7195
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
7296
// Schrage's algorithm
@@ -80,21 +104,40 @@ struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
80104
};
81105

82106
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
83-
struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), false> {
107+
struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), _LCE_Part> {
108+
typedef unsigned long long result_type;
109+
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
110+
// Use (((a*x) % m) + c) % m
111+
__x = (__a * __x) % __m;
112+
__x += __c - (__x >= __m - __c) * __m;
113+
return __x;
114+
}
115+
};
116+
117+
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
118+
struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), _LCE_Full> {
84119
typedef unsigned long long result_type;
85120
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return (__a * __x + __c) % __m; }
86121
};
87122

88123
template <unsigned long long __a, unsigned long long __c>
89-
struct __lce_ta<__a, __c, 0, (unsigned long long)(~0), false> {
124+
struct __lce_ta<__a, __c, 0ull, (unsigned long long)(~0), _LCE_Full> {
90125
typedef unsigned long long result_type;
91126
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return __a * __x + __c; }
92127
};
93128

94129
// 32
95130

131+
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
132+
struct __lce_ta<__a, __c, __m, unsigned(~0), _LCE_Promote> {
133+
typedef unsigned result_type;
134+
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
135+
return static_cast<result_type>(__lce_ta<__a, __c, __m, (unsigned long long)(~0)>::next(__x));
136+
}
137+
};
138+
96139
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
97-
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
140+
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), _LCE_Schrage> {
98141
typedef unsigned result_type;
99142
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
100143
const result_type __a = static_cast<result_type>(_Ap);
@@ -112,7 +155,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
112155
};
113156

114157
template <unsigned long long _Ap, unsigned long long _Mp>
115-
struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
158+
struct __lce_ta<_Ap, 0ull, _Mp, unsigned(~0), _LCE_Schrage> {
116159
typedef unsigned result_type;
117160
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
118161
const result_type __a = static_cast<result_type>(_Ap);
@@ -128,7 +171,21 @@ struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
128171
};
129172

130173
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
131-
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
174+
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), _LCE_Part> {
175+
typedef unsigned result_type;
176+
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
177+
const result_type __a = static_cast<result_type>(_Ap);
178+
const result_type __c = static_cast<result_type>(_Cp);
179+
const result_type __m = static_cast<result_type>(_Mp);
180+
// Use (((a*x) % m) + c) % m
181+
__x = (__a * __x) % __m;
182+
__x += __c - (__x >= __m - __c) * __m;
183+
return __x;
184+
}
185+
};
186+
187+
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
188+
struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), _LCE_Full> {
132189
typedef unsigned result_type;
133190
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
134191
const result_type __a = static_cast<result_type>(_Ap);
@@ -139,7 +196,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
139196
};
140197

141198
template <unsigned long long _Ap, unsigned long long _Cp>
142-
struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
199+
struct __lce_ta<_Ap, _Cp, 0ull, unsigned(~0), _LCE_Full> {
143200
typedef unsigned result_type;
144201
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
145202
const result_type __a = static_cast<result_type>(_Ap);
@@ -150,8 +207,8 @@ struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
150207

151208
// 16
152209

153-
template <unsigned long long __a, unsigned long long __c, unsigned long long __m, bool __b>
154-
struct __lce_ta<__a, __c, __m, (unsigned short)(~0), __b> {
210+
template <unsigned long long __a, unsigned long long __c, unsigned long long __m, int __mode>
211+
struct __lce_ta<__a, __c, __m, (unsigned short)(~0), __mode> {
155212
typedef unsigned short result_type;
156213
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
157214
return static_cast<result_type>(__lce_ta<__a, __c, __m, unsigned(~0)>::next(__x));

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ int main(int, char**)
3838

3939
// m might overflow. The overflow is not OK and result will be in bounds
4040
// so we should use Schrage's algorithm
41-
typedef std::linear_congruential_engine<T, (1ull << 32), 0, (1ull << 63) + 1> E2;
41+
typedef std::linear_congruential_engine<T, (1ull << 32), 0, (1ull << 63) + 1ull> E2;
4242
E2 e2;
4343
// make sure Schrage's algorithm is used (it would be 0s after the first otherwise)
4444
assert(e2() == (1ull << 32));
4545
assert(e2() == (1ull << 63) - 1ull);
46-
assert(e2() == (1ull << 63) - (1ull << 33) + 1ull);
46+
assert(e2() == (1ull << 63) - 0x1ffffffffull);
4747
// make sure result is in bounds
4848
assert(e2() < (1ull << 63) + 1);
4949
assert(e2() < (1ull << 63) + 1);
@@ -56,29 +56,62 @@ int main(int, char**)
5656
typedef std::linear_congruential_engine<T, 0x18000001ull, 0x12347ull, (3ull << 56)> E3;
5757
E3 e3;
5858
// make sure Schrage's algorithm is used
59-
assert(e3() == 402727752ull);
60-
assert(e3() == 162159612030764687ull);
61-
assert(e3() == 108176466184989142ull);
59+
assert(e3() == 0x18012348ull);
60+
assert(e3() == 0x2401b4ed802468full);
61+
assert(e3() == 0x18051ec400369d6ull);
6262
// make sure result is in bounds
6363
assert(e3() < (3ull << 56));
6464
assert(e3() < (3ull << 56));
6565
assert(e3() < (3ull << 56));
6666
assert(e3() < (3ull << 56));
6767
assert(e3() < (3ull << 56));
6868

69-
// m will not overflow so we should not use Schrage's algorithm
70-
typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E4;
69+
// 32-bit case:
70+
// m might overflow. The overflow is not OK, result will be in bounds,
71+
// and Schrage's algorithm is incompatible here. Need to use 64 bit arithmetic.
72+
typedef std::linear_congruential_engine<unsigned, 0x10009u, 0u, 0x7fffffffu> E4;
7173
E4 e4;
74+
// make sure enough precision is used
75+
assert(e4() == 0x10009u);
76+
assert(e4() == 0x120053u);
77+
assert(e4() == 0xf5030fu);
78+
// make sure result is in bounds
79+
assert(e4() < 0x7fffffffu);
80+
assert(e4() < 0x7fffffffu);
81+
assert(e4() < 0x7fffffffu);
82+
assert(e4() < 0x7fffffffu);
83+
assert(e4() < 0x7fffffffu);
84+
85+
#ifndef _LIBCPP_HAS_NO_INT128
86+
// m might overflow. The overflow is not OK, result will be in bounds,
87+
// and Schrage's algorithm is incompatible here. Need to use 128 bit arithmetic.
88+
typedef std::linear_congruential_engine<T, 0x100000001ull, 0ull, (1ull << 61) - 1ull> E5;
89+
E5 e5;
90+
// make sure enough precision is used
91+
assert(e5() == 0x100000001ull);
92+
assert(e5() == 0x200000009ull);
93+
assert(e5() == 0xb00000019ull);
94+
// make sure result is in bounds
95+
assert(e5() < (1ull << 61) - 1ull);
96+
assert(e5() < (1ull << 61) - 1ull);
97+
assert(e5() < (1ull << 61) - 1ull);
98+
assert(e5() < (1ull << 61) - 1ull);
99+
assert(e5() < (1ull << 61) - 1ull);
100+
#endif
101+
102+
// m will not overflow so we should not use Schrage's algorithm
103+
typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E6;
104+
E6 e6;
72105
// make sure the correct algorithm was used
73-
assert(e4() == 2ull);
74-
assert(e4() == 3ull);
75-
assert(e4() == 4ull);
106+
assert(e6() == 2ull);
107+
assert(e6() == 3ull);
108+
assert(e6() == 4ull);
76109
// make sure result is in bounds
77-
assert(e4() < (1ull << 48));
78-
assert(e4() < (1ull << 48));
79-
assert(e4() < (1ull << 48));
80-
assert(e4() < (1ull << 48));
81-
assert(e4() < (1ull << 48));
110+
assert(e6() < (1ull << 48));
111+
assert(e6() < (1ull << 48));
112+
assert(e6() < (1ull << 48));
113+
assert(e6() < (1ull << 48));
114+
assert(e6() < (1ull << 48));
82115

83116
return 0;
84-
}
117+
}

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,34 @@ test()
6161
test1<T, A, 0, M>();
6262
test1<T, A, M - 2, M>();
6363
test1<T, A, M - 1, M>();
64+
}
65+
66+
template <class T>
67+
void test_ext() {
68+
const T M(static_cast<T>(-1));
6469

65-
/*
66-
// Cases where m is odd and m % a > m / a (not implemented)
70+
// Cases where m is odd and m % a > m / a
6771
test1<T, M - 2, 0, M>();
6872
test1<T, M - 2, M - 2, M>();
6973
test1<T, M - 2, M - 1, M>();
7074
test1<T, M - 1, 0, M>();
7175
test1<T, M - 1, M - 2, M>();
7276
test1<T, M - 1, M - 1, M>();
73-
*/
7477
}
7578

7679
int main(int, char**)
7780
{
7881
test<unsigned short>();
82+
test_ext<unsigned short>();
7983
test<unsigned int>();
84+
test_ext<unsigned int>();
8085
test<unsigned long>();
86+
test_ext<unsigned long>();
8187
test<unsigned long long>();
88+
// This isn't implemented on platforms without __int128
89+
#ifndef _LIBCPP_HAS_NO_INT128
90+
test_ext<unsigned long long>();
91+
#endif
8292

83-
return 0;
93+
return 0;
8494
}

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,34 @@ test()
6060
test1<T, A, 0, M>();
6161
test1<T, A, M - 2, M>();
6262
test1<T, A, M - 1, M>();
63+
}
64+
65+
template <class T>
66+
void test_ext() {
67+
const T M(static_cast<T>(-1));
6368

64-
/*
65-
// Cases where m is odd and m % a > m / a (not implemented)
69+
// Cases where m is odd and m % a > m / a
6670
test1<T, M - 2, 0, M>();
6771
test1<T, M - 2, M - 2, M>();
6872
test1<T, M - 2, M - 1, M>();
6973
test1<T, M - 1, 0, M>();
7074
test1<T, M - 1, M - 2, M>();
7175
test1<T, M - 1, M - 1, M>();
72-
*/
7376
}
7477

7578
int main(int, char**)
7679
{
7780
test<unsigned short>();
81+
test_ext<unsigned short>();
7882
test<unsigned int>();
83+
test_ext<unsigned int>();
7984
test<unsigned long>();
85+
test_ext<unsigned long>();
8086
test<unsigned long long>();
87+
// This isn't implemented on platforms without __int128
88+
#ifndef _LIBCPP_HAS_NO_INT128
89+
test_ext<unsigned long long>();
90+
#endif
8191

82-
return 0;
92+
return 0;
8393
}

libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,34 @@ test()
5858
test1<T, A, 0, M>();
5959
test1<T, A, M - 2, M>();
6060
test1<T, A, M - 1, M>();
61+
}
62+
63+
template <class T>
64+
void test_ext() {
65+
const T M(static_cast<T>(-1));
6166

62-
/*
63-
// Cases where m is odd and m % a > m / a (not implemented)
67+
// Cases where m is odd and m % a > m / a
6468
test1<T, M - 2, 0, M>();
6569
test1<T, M - 2, M - 2, M>();
6670
test1<T, M - 2, M - 1, M>();
6771
test1<T, M - 1, 0, M>();
6872
test1<T, M - 1, M - 2, M>();
6973
test1<T, M - 1, M - 1, M>();
70-
*/
7174
}
7275

7376
int main(int, char**)
7477
{
7578
test<unsigned short>();
79+
test_ext<unsigned short>();
7680
test<unsigned int>();
81+
test_ext<unsigned int>();
7782
test<unsigned long>();
83+
test_ext<unsigned long>();
7884
test<unsigned long long>();
85+
// This isn't implemented on platforms without __int128
86+
#ifndef _LIBCPP_HAS_NO_INT128
87+
test_ext<unsigned long long>();
88+
#endif
7989

80-
return 0;
90+
return 0;
8191
}

0 commit comments

Comments
 (0)