-
Notifications
You must be signed in to change notification settings - Fork 14.3k
linear_congruential_engine: add using more precision to prevent overflow #81583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-libcxx Author: None (LRFLEW) ChangesThis PR is a followup to #81080, and as such includes the commit from that PR. This should either be merged after that one is merged (and this one is rebased), or this one can take the place of that PR. Since the other PR has gotten pretty far in terms of reviews, and this PR is more involved, the other PR should probably be merged before this one. This PR makes two major changes to how the LCG operation is computed: The first is that I added an additional case where The second is that I added handling for the case where the While this implementation covers a lot of the missing cases from #81080, this still won't compile every possible I currently consider this PR to be a WIP because a) I haven't gotten the test cases written properly to avoid failing when __int128 isn't available, and b) I'm not 100% sure about how I've structured / formatted the changes, and may still want to tweak it before merging. I'm making the PR now so I can start getting feedback if anybody has any. Full diff: https://github.com/llvm/llvm-project/pull/81583.diff 6 Files Affected:
diff --git a/libcxx/include/__random/linear_congruential_engine.h b/libcxx/include/__random/linear_congruential_engine.h
index 51f6b248d8f974..e32f0a9d05395a 100644
--- a/libcxx/include/__random/linear_congruential_engine.h
+++ b/libcxx/include/__random/linear_congruential_engine.h
@@ -30,28 +30,45 @@ template <unsigned long long __a,
unsigned long long __c,
unsigned long long __m,
unsigned long long _Mp,
- bool _MightOverflow = (__a != 0 && __m != 0 && __m - 1 > (_Mp - __c) / __a),
- bool _OverflowOK = ((__m | (__m - 1)) > __m), // m = 2^n
- bool _SchrageOK = (__a != 0 && __m != 0 && __m % __a <= __m / __a)> // r <= q
+ bool _HasOverflow = (__a != 0ull && (__m & (__m - 1ull)) != 0ull), // a != 0, m != 0, m != 2^n
+ bool _Full = (!_HasOverflow || __m - 1ull <= (_Mp - __c) / __a), // (a * x + c) % m works
+ bool _Part = (!_HasOverflow || __m - 1ull <= _Mp / __a), // (a * x) % m works
+ bool _Schrage = (_HasOverflow && __m % __a <= __m / __a)> // r <= q
struct __lce_alg_picker {
- static_assert(__a != 0 || __m != 0 || !_MightOverflow || _OverflowOK || _SchrageOK,
- "The current values of a, c, and m cannot generate a number "
- "within bounds of linear_congruential_engine.");
+ static _LIBCPP_CONSTEXPR const int __mode = _Full ? 0 : _Part ? 1 : _Schrage ? 2 : 3;
- static _LIBCPP_CONSTEXPR const bool __use_schrage = _MightOverflow && !_OverflowOK && _SchrageOK;
+#ifndef __SIZEOF_INT128__
+ static_assert(_Mp != (unsigned long long)(~0) || _Full || _Part || _Schrage,
+ "The current values for a, c, and m are not currently supported on platforms without __int128");
+#endif
};
template <unsigned long long __a,
unsigned long long __c,
unsigned long long __m,
unsigned long long _Mp,
- bool _UseSchrage = __lce_alg_picker<__a, __c, __m, _Mp>::__use_schrage>
+ int _Mode = __lce_alg_picker<__a, __c, __m, _Mp>::__mode>
struct __lce_ta;
// 64
+#ifdef __SIZEOF_INT128__
+template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
+struct __lce_ta<_Ap, _Cp, _Mp, (unsigned long long)(~0), 3> {
+ typedef unsigned long long result_type;
+ _LIBCPP_HIDE_FROM_ABI static result_type next(result_type _Xp) {
+ __extension__ typedef unsigned __int128 calc_type;
+ const calc_type __a = static_cast<calc_type>(_Ap);
+ const calc_type __c = static_cast<calc_type>(_Cp);
+ const calc_type __m = static_cast<calc_type>(_Mp);
+ const calc_type __x = static_cast<calc_type>(_Xp);
+ return static_cast<result_type>((__a * __x + __c) % __m);
+ }
+};
+#endif
+
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
-struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
+struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), 2> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
// Schrage's algorithm
@@ -66,7 +83,7 @@ struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), true> {
};
template <unsigned long long __a, unsigned long long __m>
-struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
+struct __lce_ta<__a, 0ull, __m, (unsigned long long)(~0), 2> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
// Schrage's algorithm
@@ -80,21 +97,40 @@ struct __lce_ta<__a, 0, __m, (unsigned long long)(~0), true> {
};
template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
-struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), false> {
+struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), 1> {
+ typedef unsigned long long result_type;
+ _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
+ // Use (((a*x) % m) + c) % m
+ __x = (__a * __x) % __m;
+ __x += __c - (__x >= __m - __c) * __m;
+ return __x;
+ }
+};
+
+template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
+struct __lce_ta<__a, __c, __m, (unsigned long long)(~0), 0> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return (__a * __x + __c) % __m; }
};
template <unsigned long long __a, unsigned long long __c>
-struct __lce_ta<__a, __c, 0, (unsigned long long)(~0), false> {
+struct __lce_ta<__a, __c, 0ull, (unsigned long long)(~0), 0> {
typedef unsigned long long result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) { return __a * __x + __c; }
};
// 32
+template <unsigned long long __a, unsigned long long __c, unsigned long long __m>
+struct __lce_ta<__a, __c, __m, unsigned(~0), 3> {
+ typedef unsigned result_type;
+ _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
+ return static_cast<result_type>(__lce_ta<__a, __c, __m, (unsigned long long)(~0)>::next(__x));
+ }
+};
+
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
-struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
+struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), 2> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
@@ -112,7 +148,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), true> {
};
template <unsigned long long _Ap, unsigned long long _Mp>
-struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
+struct __lce_ta<_Ap, 0ull, _Mp, unsigned(~0), 2> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
@@ -128,7 +164,21 @@ struct __lce_ta<_Ap, 0, _Mp, unsigned(~0), true> {
};
template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
-struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
+struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), 1> {
+ typedef unsigned result_type;
+ _LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
+ const result_type __a = static_cast<result_type>(_Ap);
+ const result_type __c = static_cast<result_type>(_Cp);
+ const result_type __m = static_cast<result_type>(_Mp);
+ // Use (((a*x) % m) + c) % m
+ __x = (__a * __x) % __m;
+ __x += __c - (__x >= __m - __c) * __m;
+ return __x;
+ }
+};
+
+template <unsigned long long _Ap, unsigned long long _Cp, unsigned long long _Mp>
+struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), 0> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
@@ -139,7 +189,7 @@ struct __lce_ta<_Ap, _Cp, _Mp, unsigned(~0), false> {
};
template <unsigned long long _Ap, unsigned long long _Cp>
-struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
+struct __lce_ta<_Ap, _Cp, 0ull, unsigned(~0), 0> {
typedef unsigned result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
const result_type __a = static_cast<result_type>(_Ap);
@@ -150,8 +200,8 @@ struct __lce_ta<_Ap, _Cp, 0, unsigned(~0), false> {
// 16
-template <unsigned long long __a, unsigned long long __c, unsigned long long __m, bool __b>
-struct __lce_ta<__a, __c, __m, (unsigned short)(~0), __b> {
+template <unsigned long long __a, unsigned long long __c, unsigned long long __m, int __mode>
+struct __lce_ta<__a, __c, __m, (unsigned short)(~0), __mode> {
typedef unsigned short result_type;
_LIBCPP_HIDE_FROM_ABI static result_type next(result_type __x) {
return static_cast<result_type>(__lce_ta<__a, __c, __m, unsigned(~0)>::next(__x));
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
index 77b7c570f85a1d..fff93a895f8955 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
@@ -22,48 +22,96 @@ int main(int, char**)
{
typedef unsigned long long T;
- // m might overflow, but the overflow is OK so it shouldn't use schrage's algorithm
- typedef std::linear_congruential_engine<T, 25214903917ull, 1, (1ull<<48)> E1;
+ // m might overflow, but the overflow is OK so it shouldn't use Schrage's algorithm
+ typedef std::linear_congruential_engine<T, 25214903917ull, 1, (1ull << 48)> E1;
E1 e1;
// make sure the right algorithm was used
- assert(e1() == 25214903918);
- assert(e1() == 205774354444503);
- assert(e1() == 158051849450892);
+ assert(e1() == 25214903918ull);
+ assert(e1() == 205774354444503ull);
+ assert(e1() == 158051849450892ull);
// make sure result is in bounds
- assert(e1() < (1ull<<48));
- assert(e1() < (1ull<<48));
- assert(e1() < (1ull<<48));
- assert(e1() < (1ull<<48));
- assert(e1() < (1ull<<48));
+ assert(e1() < (1ull << 48));
+ assert(e1() < (1ull << 48));
+ assert(e1() < (1ull << 48));
+ assert(e1() < (1ull << 48));
+ assert(e1() < (1ull << 48));
// m might overflow. The overflow is not OK and result will be in bounds
- // so we should use shrage's algorithm
- typedef std::linear_congruential_engine<T, (1ull<<2), 0, (1ull<<63) + 1> E2;
+ // so we should use Schrage's algorithm
+ typedef std::linear_congruential_engine<T, 0x100000000ull, 0, (1ull << 63) + 1ull> E2;
E2 e2;
- // make sure shrage's algorithm is used (it would be 0s otherwise)
- assert(e2() == 4);
- assert(e2() == 16);
- assert(e2() == 64);
+ // make sure Schrage's algorithm is used (it would be 0s after the first otherwise)
+ assert(e2() == 0x100000000ull);
+ assert(e2() == (1ull << 63) - 1ull);
+ assert(e2() == (1ull << 63) - 0x1ffffffffull);
// make sure result is in bounds
- assert(e2() < (1ull<<48) + 1);
- assert(e2() < (1ull<<48) + 1);
- assert(e2() < (1ull<<48) + 1);
- assert(e2() < (1ull<<48) + 1);
- assert(e2() < (1ull<<48) + 1);
+ assert(e2() < (1ull << 63) + 1);
+ assert(e2() < (1ull << 63) + 1);
+ assert(e2() < (1ull << 63) + 1);
+ assert(e2() < (1ull << 63) + 1);
+ assert(e2() < (1ull << 63) + 1);
- // m will not overflow so we should not use shrage's algorithm
- typedef std::linear_congruential_engine<T, 1ull, 1, (1ull<<48)> E3;
+ // m might overflow. The overflow is not OK and result will be in bounds
+ // so we should use Schrage's algorithm. m is even
+ typedef std::linear_congruential_engine<T, 0x18000001ull, 0x12347ull, (3ull << 56)> E3;
E3 e3;
+ // make sure Schrage's algorithm is used
+ assert(e3() == 0x18012348ull);
+ assert(e3() == 0x2401b4ed802468full);
+ assert(e3() == 0x18051ec400369d6ull);
+ // make sure result is in bounds
+ assert(e3() < (3ull << 56));
+ assert(e3() < (3ull << 56));
+ assert(e3() < (3ull << 56));
+ assert(e3() < (3ull << 56));
+ assert(e3() < (3ull << 56));
+
+ // 32-bit case:
+ // m might overflow. The overflow is not OK, result will be in bounds,
+ // and Schrage's algorithm is incompatible here. Need to use 64 bit arithmetic.
+ typedef std::linear_congruential_engine<unsigned, 0x10009u, 0u, 0x7fffffffu> E4;
+ E4 e4;
+ // make sure enough precision is used
+ assert(e4() == 0x10009u);
+ assert(e4() == 0x120053u);
+ assert(e4() == 0xf5030fu);
+ // make sure result is in bounds
+ assert(e4() < 0x7fffffffu);
+ assert(e4() < 0x7fffffffu);
+ assert(e4() < 0x7fffffffu);
+ assert(e4() < 0x7fffffffu);
+ assert(e4() < 0x7fffffffu);
+
+#ifdef __SIZEOF_INT128__
+ // m might overflow. The overflow is not OK, result will be in bounds,
+ // and Schrage's algorithm is incompatible here. Need to use 128 bit arithmetic.
+ typedef std::linear_congruential_engine<T, 0x100000001ull, 0ull, (1ull << 61) - 1ull> E5;
+ E5 e5;
+ // make sure enough precision is used
+ assert(e5() == 0x100000001ull);
+ assert(e5() == 0x200000009ull);
+ assert(e5() == 0xb00000019ull);
+ // make sure result is in bounds
+ assert(e5() < (1ull << 61) - 1ull);
+ assert(e5() < (1ull << 61) - 1ull);
+ assert(e5() < (1ull << 61) - 1ull);
+ assert(e5() < (1ull << 61) - 1ull);
+ assert(e5() < (1ull << 61) - 1ull);
+#endif
+
+ // m will not overflow so we should not use Schrage's algorithm
+ typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E6;
+ E6 e6;
// make sure the correct algorithm was used
- assert(e3() == 2);
- assert(e3() == 3);
- assert(e3() == 4);
+ assert(e6() == 2ull);
+ assert(e6() == 3ull);
+ assert(e6() == 4ull);
// make sure result is in bounds
- assert(e3() < (1ull<<48));
- assert(e3() < (1ull<<48));
- assert(e3() < (1ull<<48));
- assert(e3() < (1ull<<48));
- assert(e2() < (1ull<<48));
+ assert(e6() < (1ull << 48));
+ assert(e6() < (1ull << 48));
+ assert(e6() < (1ull << 48));
+ assert(e6() < (1ull << 48));
+ assert(e6() < (1ull << 48));
return 0;
}
\ No newline at end of file
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp
index 12620848626fc8..8f5a861cbff563 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/assign.pass.cpp
@@ -15,6 +15,7 @@
#include <random>
#include <cassert>
+#include <climits>
#include "test_macros.h"
@@ -35,19 +36,39 @@ template <class T>
void
test()
{
- test1<T, 0, 0, 0>();
- test1<T, 0, 1, 2>();
- test1<T, 1, 1, 2>();
- const T M(static_cast<T>(-1));
- test1<T, 0, 0, M>();
- test1<T, 0, M-2, M>();
- test1<T, 0, M-1, M>();
- test1<T, M-2, 0, M>();
- test1<T, M-2, M-2, M>();
- test1<T, M-2, M-1, M>();
- test1<T, M-1, 0, M>();
- test1<T, M-1, M-2, M>();
- test1<T, M-1, M-1, M>();
+ const int W = sizeof(T) * CHAR_BIT;
+ const T M(static_cast<T>(-1));
+ const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
+
+ // Cases where m = 0
+ test1<T, 0, 0, 0>();
+ test1<T, A, 0, 0>();
+ test1<T, 0, 1, 0>();
+ test1<T, A, 1, 0>();
+
+ // Cases where m = 2^n for n < w
+ test1<T, 0, 0, 256>();
+ test1<T, 5, 0, 256>();
+ test1<T, 0, 1, 256>();
+ test1<T, 5, 1, 256>();
+
+ // Cases where m is odd and a = 0
+ test1<T, 0, 0, M>();
+ test1<T, 0, M - 2, M>();
+ test1<T, 0, M - 1, M>();
+
+ // Cases where m is odd and m % a <= m / a (Schrage)
+ test1<T, A, 0, M>();
+ test1<T, A, M - 2, M>();
+ test1<T, A, M - 1, M>();
+
+ // Cases where m is odd and m % a > m / a (not implemented)
+ test1<T, M - 2, 0, M>();
+ test1<T, M - 2, M - 2, M>();
+ test1<T, M - 2, M - 1, M>();
+ test1<T, M - 1, 0, M>();
+ test1<T, M - 1, M - 2, M>();
+ test1<T, M - 1, M - 1, M>();
}
int main(int, char**)
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp
index 5dac0772cb0e94..654352cd13fa8e 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/copy.pass.cpp
@@ -35,19 +35,39 @@ template <class T>
void
test()
{
- test1<T, 0, 0, 0>();
- test1<T, 0, 1, 2>();
- test1<T, 1, 1, 2>();
- const T M(static_cast<T>(-1));
- test1<T, 0, 0, M>();
- test1<T, 0, M-2, M>();
- test1<T, 0, M-1, M>();
- test1<T, M-2, 0, M>();
- test1<T, M-2, M-2, M>();
- test1<T, M-2, M-1, M>();
- test1<T, M-1, 0, M>();
- test1<T, M-1, M-2, M>();
- test1<T, M-1, M-1, M>();
+ const int W = sizeof(T) * CHAR_BIT;
+ const T M(static_cast<T>(-1));
+ const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
+
+ // Cases where m = 0
+ test1<T, 0, 0, 0>();
+ test1<T, A, 0, 0>();
+ test1<T, 0, 1, 0>();
+ test1<T, A, 1, 0>();
+
+ // Cases where m = 2^n for n < w
+ test1<T, 0, 0, 256>();
+ test1<T, 5, 0, 256>();
+ test1<T, 0, 1, 256>();
+ test1<T, 5, 1, 256>();
+
+ // Cases where m is odd and a = 0
+ test1<T, 0, 0, M>();
+ test1<T, 0, M - 2, M>();
+ test1<T, 0, M - 1, M>();
+
+ // Cases where m is odd and m % a <= m / a (Schrage)
+ test1<T, A, 0, M>();
+ test1<T, A, M - 2, M>();
+ test1<T, A, M - 1, M>();
+
+ // Cases where m is odd and m % a > m / a (not implemented)
+ test1<T, M - 2, 0, M>();
+ test1<T, M - 2, M - 2, M>();
+ test1<T, M - 2, M - 1, M>();
+ test1<T, M - 1, 0, M>();
+ test1<T, M - 1, M - 2, M>();
+ test1<T, M - 1, M - 1, M>();
}
int main(int, char**)
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp
index 10bc1d71d8e892..caee6b89571d79 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/default.pass.cpp
@@ -33,19 +33,39 @@ template <class T>
void
test()
{
- test1<T, 0, 0, 0>();
- test1<T, 0, 1, 2>();
- test1<T, 1, 1, 2>();
- const T M(static_cast<T>(-1));
- test1<T, 0, 0, M>();
- test1<T, 0, M-2, M>();
- test1<T, 0, M-1, M>();
- test1<T, M-2, 0, M>();
- test1<T, M-2, M-2, M>();
- test1<T, M-2, M-1, M>();
- test1<T, M-1, 0, M>();
- test1<T, M-1, M-2, M>();
- test1<T, M-1, M-1, M>();
+ const int W = sizeof(T) * CHAR_BIT;
+ const T M(static_cast<T>(-1));
+ const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
+
+ // Cases where m = 0
+ test1<T, 0, 0, 0>();
+ test1<T, A, 0, 0>();
+ test1<T, 0, 1, 0>();
+ test1<T, A, 1, 0>();
+
+ // Cases where m = 2^n for n < w
+ test1<T, 0, 0, 256>();
+ test1<T, 5, 0, 256>();
+ test1<T, 0, 1, 256>();
+ test1<T, 5, 1, 256>();
+
+ // Cases where m is odd and a = 0
+ test1<T, 0, 0, M>();
+ test1<T, 0, M - 2, M>();
+ test1<T, 0, M - 1, M>();
+
+ // Cases where m is odd and m % a <= m / a (Schrage)
+ test1<T, A, 0, M>();
+ test1<T, A, M - 2, M>();
+ test1<T, A, M - 1, M>();
+
+ // Cases where m is odd and m % a > m / a (not implemented)
+ test1<T, M - 2, 0, M>();
+ test1<T, M - 2, M - 2, M>();
+ test1<T, M - 2, M - 1, M>();
+ test1<T, M - 1, 0, M>();
+ test1<T, M - 1, M - 2, M>();
+ test1<T, M - 1, M - 1, M>();
}
int main(int, char**)
diff --git a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp
index d9d47c5d8db46c..1af116e529156f 100644
--- a/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp
+++ b/libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/values.pass.cpp
@@ -66,19 +66,39 @@ template <class T>
void
test()
{
- test1<T, 0, 0, 0>();
- test1<T, 0, 1, 2>();
- test1<T, 1, 1, 2>();
- const T M(static_cast<T>(-1));
- test1<T, 0, 0, M>();
- test1<T, 0, M-2, M>();
- test1<T, 0, M-1, M>();
- test1<T, M-2, 0, M>();
- test1<T, M-2, M-2, M>();
- test1<T, M-2, M-1, M>();
- test1<T, M-1, 0, M>();
- test1<T, M-1, M-2, M>();
- test1<T, M-1, M-1, M>();
+ const int W = sizeof(T) * CHAR_BIT;
+ const T M(static_cast<T>(-1));
+ const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
+
+ // Cases where m = 0
+ test1<T, 0, 0, 0>();
+ test1<T, A, 0, 0>();
+ test1<T, 0, 1, 0>();
+ test1<T, A, 1, 0>();
+
+ // Cases where m = 2^n for n < w
+ test1<T, 0, 0, 256>();
+ test1<T, 5, 0, 256>();
+ test1<T, 0, 1, 256>();
+ test1<T, 5, 1, 256>();
+
+ // Cases where m is odd and a = 0
+ test1<T, 0, 0, M>();
+ test1<T, 0, M - 2, M>();
+ test1<T, 0, M - 1, M>();
+
+ // Cases where m is odd and m % a <= m / a (Schrage)
+ test1<T, A, 0, M>();
+ test1<T, A, M - 2, M>();
+ test1<T, A, M - 1, M>();
+
+ // Cases where m is odd and m % a > m / a (not implemented)
+ test1<T, M - 2, 0, M>();
+ test1<T, M - 2, M - 2, M>();
+ test1<T, M - 2, M - 1, M>();
+ test1<T, M - 1, 0, M>();
+ test1<T, M - 1, M - 2, M>();
+ test1<T, M - 1, M - 1, M>();
}
int main(int, char**)
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Ok, so here's the comment about the missing edge-case. Performing 128-bit arithmetic on platforms without 128-bit integer types (which is mainly platforms with 32-bit registers) can get very complex. It's all very possible, with Art of Programming Volume 2 seeming to be a popular source for the methods to use (IIRC, MSVC's STL references it in the source code it uses to handle this particular case), but it will still be rather involved. From what I've seen, there's kind of three notable options for handling this:
I've experimented with writing some of this myself, so I think I could do it if necessary, but would appreciate some feedback on these options before I make any attempts, and also would appreciate if someone else was willing to take the lead on this. |
That PR was indeed ready, it just needed somebody to merge it. I just merged is so let's rebase this patch. |
c7aa088
to
bd3a3bf
Compare
Ok, I'm happy with this PR now, so I'm gonna remove the Draft status. The CI is giving me issues, but it seems like it's passing all the cases at this point. Pretty much the only thing I'm not 100% happy with is the naming of the new functions in the test code, but it should be fine. Let me know if you have any feedback or suggestions on this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! I mainly took a quick look at the patch. I hope to have time to do a better review soon. (I really need to look at how the algorithm should work in the first place before reviewing.)
libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
Outdated
Show resolved
Hide resolved
2e055e3
to
74657d1
Compare
FYI I've not had time to look at this yet, I still hope to find time for it soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patch! I finally found time to have a good look at it. I think it's mostly good. Can you update the commit message; it mentions the patch is WIP.
I'm happy to accept the patch and have a compile-time error when 128-bit integrals are missing and it's needed. We can add that in a later patch.
libcxx/test/std/numerics/rand/rand.eng/rand.eng.lcong/alg.pass.cpp
Outdated
Show resolved
Hide resolved
638cd4c
to
d9b679b
Compare
7dfdfb2
to
3e2efac
Compare
I addressed some of the review comments, and responded to others with followup questions / suggestions.
I'm not seeing that. The PR had "[WIP]" at one point, but I don't see that in the commit message itself. Could you clarify what exactly you want me to change? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I addressed some of the review comments, and responded to others with followup questions / suggestions.
Can you update the commit message; it mentions the patch is WIP.
I'm not seeing that. The PR had "[WIP]" at one point, but I don't see that in the commit message itself. Could you clarify what exactly you want me to change?
At the end of the commit message
"
I currently consider this PR to be a WIP because a) I haven't gotten the test cases written properly to avoid failing when __int128 isn't available, and b) I'm not 100% sure about how I've structured / formatted the changes, and may still want to tweak it before merging. I'm making the PR now so I can start getting feedback if anybody has any.
"
Other than the last comments the patch LGTM! Thanks for your work and patience.
Double-checking this, I realized I forgot to respond to the last message. That message you're referring to was not in the commit message, but only in the PR description. I went ahead and removed it anyways, though. I've addressed all the review comments, so unless there's anything else, this should be ready to merge. |
FYI When we merge in GitHub, by default, the PR description is the commit message. |
@LRFLEW Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested Please check whether problems have been caused by your change specifically, as How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
…low (llvm#81583) This PR is a followup to llvm#81080. This PR makes two major changes to how the LCG operation is computed: The first is that I added an additional case where `ax + c` might overflow the intermediate variable, but `ax` by itself won't. In this case, it's much better to use `(ax mod m) + c mod m` than the previous behavior of falling back to Schrage's algorithm. The addition modulo is done in the same way as when using Schrage's algorithm (i.e. `x += c - (x >= m - c)*m`), but the multiplication modulo is calculated directly, which is faster. The second is that I added handling for the case where the `ax` intermediate might overflow, but Schrage's algorithm doesn't apply (i.e. r > q). In this case, the only real option is to increase the precision of the intermediate values. The good news is that - for `x`, `a`, and `c` being n-bit values - `ax + c` will never overflow a 2n-bit intermediary, meaning this promotion can only happen once, and will always be able to use the simplest implementation. This is already the case for 16-bit LCGs, as libcxx chooses to compute them with 32-bit intermediate values. For 32-bit LCGs, I simply added code similar to the 16-bit case to use the existing 64-bit implementations. Lastly, for 64-bit LCGs, I wrote a case that calculates it using `unsigned __int128` if it is available to use. While this implementation covers a *lot* of the missing cases from llvm#81080, this still won't compile **every** possible `linear_congruential_engine`. Specifically, if `a`, `c`, and `m` are chosen such that it needs 128-bit integers, but the platform doesn't support `__int128` (eg. 32-bit x86), then it will fail to compile. However, this is a fairly rare case to see actually used, and libcxx would be in good company with this, as [libstdc++ also fails to compile under these circumstances](https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87744). Fixing **this** gap would require even **more** work of further complexity, so that would probably be best handled by a different PR (I'll put more details on what that PR would entail in a comment).
This PR is a followup to #81080,
and as such includes the commit from that PR. This should either be merged after that one is merged (and this one is rebased), or this one can take the place of that PR. Since the other PR has gotten pretty far in terms of reviews, and this PR is more involved, the other PR should probably be merged before this one.This PR makes two major changes to how the LCG operation is computed:
The first is that I added an additional case where
ax + c
might overflow the intermediate variable, butax
by itself won't. In this case, it's much better to use(ax mod m) + c mod m
than the previous behavior of falling back to Schrage's algorithm. The addition modulo is done in the same way as when using Schrage's algorithm (i.e.x += c - (x >= m - c)*m
), but the multiplication modulo is calculated directly, which is faster.The second is that I added handling for the case where the
ax
intermediate might overflow, but Schrage's algorithm doesn't apply (i.e. r > q). In this case, the only real option is to increase the precision of the intermediate values. The good news is that - forx
,a
, andc
being n-bit values -ax + c
will never overflow a 2n-bit intermediary, meaning this promotion can only happen once, and will always be able to use the simplest implementation. This is already the case for 16-bit LCGs, as libcxx chooses to compute them with 32-bit intermediate values. For 32-bit LCGs, I simply added code similar to the 16-bit case to use the existing 64-bit implementations. Lastly, for 64-bit LCGs, I wrote a case that calculates it usingunsigned __int128
if it is available to use.While this implementation covers a lot of the missing cases from #81080, this still won't compile every possible
linear_congruential_engine
. Specifically, ifa
,c
, andm
are chosen such that it needs 128-bit integers, but the platform doesn't support__int128
(eg. 32-bit x86), then it will fail to compile. However, this is a fairly rare case to see actually used, and libcxx would be in good company with this, as libstdc++ also fails to compile under these circumstances. Fixing this gap would require even more work of further complexity, so that would probably be best handled by a different PR (I'll put more details on what that PR would entail in a comment).