Skip to content

Commit 84483aa

Browse files
authored
GH-100485: Add extended accuracy test. Switch to faster fma() based variant. GH-101383)
1 parent db757f0 commit 84483aa

File tree

2 files changed

+100
-36
lines changed

2 files changed

+100
-36
lines changed

Lib/test/test_math.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,89 @@ def run(func, *args):
13691369
args,
13701370
)
13711371

1372+
@requires_IEEE_754
1373+
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
1374+
"sumprod() accuracy not guaranteed on machines with double rounding")
1375+
@support.cpython_only # Other implementations may choose a different algorithm
1376+
@support.requires_resource('cpu')
1377+
def test_sumprod_extended_precision_accuracy(self):
1378+
import operator
1379+
from fractions import Fraction
1380+
from itertools import starmap
1381+
from collections import namedtuple
1382+
from math import log2, exp2, fabs
1383+
from random import choices, uniform, shuffle
1384+
from statistics import median
1385+
1386+
DotExample = namedtuple('DotExample', ('x', 'y', 'target_sumprod', 'condition'))
1387+
1388+
def DotExact(x, y):
1389+
vec1 = map(Fraction, x)
1390+
vec2 = map(Fraction, y)
1391+
return sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))
1392+
1393+
def Condition(x, y):
1394+
return 2.0 * DotExact(map(abs, x), map(abs, y)) / abs(DotExact(x, y))
1395+
1396+
def linspace(lo, hi, n):
1397+
width = (hi - lo) / (n - 1)
1398+
return [lo + width * i for i in range(n)]
1399+
1400+
def GenDot(n, c):
1401+
""" Algorithm 6.1 (GenDot) works as follows. The condition number (5.7) of
1402+
the dot product xT y is proportional to the degree of cancellation. In
1403+
order to achieve a prescribed cancellation, we generate the first half of
1404+
the vectors x and y randomly within a large exponent range. This range is
1405+
chosen according to the anticipated condition number. The second half of x
1406+
and y is then constructed choosing xi randomly with decreasing exponent,
1407+
and calculating yi such that some cancellation occurs. Finally, we permute
1408+
the vectors x, y randomly and calculate the achieved condition number.
1409+
"""
1410+
1411+
assert n >= 6
1412+
n2 = n // 2
1413+
x = [0.0] * n
1414+
y = [0.0] * n
1415+
b = log2(c)
1416+
1417+
# First half with exponents from 0 to |_b/2_| and random ints in between
1418+
e = choices(range(int(b/2)), k=n2)
1419+
e[0] = int(b / 2) + 1
1420+
e[-1] = 0.0
1421+
1422+
x[:n2] = [uniform(-1.0, 1.0) * exp2(p) for p in e]
1423+
y[:n2] = [uniform(-1.0, 1.0) * exp2(p) for p in e]
1424+
1425+
# Second half
1426+
e = list(map(round, linspace(b/2, 0.0 , n-n2)))
1427+
for i in range(n2, n):
1428+
x[i] = uniform(-1.0, 1.0) * exp2(e[i - n2])
1429+
y[i] = (uniform(-1.0, 1.0) * exp2(e[i - n2]) - DotExact(x, y)) / x[i]
1430+
1431+
# Shuffle
1432+
pairs = list(zip(x, y))
1433+
shuffle(pairs)
1434+
x, y = zip(*pairs)
1435+
1436+
return DotExample(x, y, DotExact(x, y), Condition(x, y))
1437+
1438+
def RelativeError(res, ex):
1439+
x, y, target_sumprod, condition = ex
1440+
n = DotExact(list(x) + [-res], list(y) + [1])
1441+
return fabs(n / target_sumprod)
1442+
1443+
def Trial(dotfunc, c, n):
1444+
ex = GenDot(10, c)
1445+
res = dotfunc(ex.x, ex.y)
1446+
return RelativeError(res, ex)
1447+
1448+
times = 1000 # Number of trials
1449+
n = 20 # Length of vectors
1450+
c = 1e30 # Target condition number
1451+
1452+
relative_err = median(Trial(math.sumprod, c, n) for i in range(times))
1453+
self.assertLess(relative_err, 1e-16)
1454+
13721455
def testModf(self):
13731456
self.assertRaises(TypeError, math.modf)
13741457

Modules/mathmodule.c

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2832,12 +2832,7 @@ long_add_would_overflow(long a, long b)
28322832
}
28332833

28342834
/*
2835-
Double and triple length extended precision floating point arithmetic
2836-
based on:
2837-
2838-
A Floating-Point Technique for Extending the Available Precision
2839-
by T. J. Dekker
2840-
https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
2835+
Double and triple length extended precision algorithms from:
28412836
28422837
Accurate Sum and Dot Product
28432838
by Takeshi Ogita, Siegfried M. Rump, and Shin’Ichi Oishi
@@ -2848,58 +2843,44 @@ based on:
28482843

28492844
typedef struct{ double hi; double lo; } DoubleLength;
28502845

2851-
static inline DoubleLength
2852-
twosum(double a, double b)
2846+
static DoubleLength
2847+
dl_sum(double a, double b)
28532848
{
2854-
// Rump Algorithm 3.1 Error-free transformation of the sum
2849+
/* Algorithm 3.1 Error-free transformation of the sum */
28552850
double x = a + b;
28562851
double z = x - a;
28572852
double y = (a - (x - z)) + (b - z);
28582853
return (DoubleLength) {x, y};
28592854
}
28602855

2861-
static inline DoubleLength
2862-
dl_split(double x) {
2863-
// Rump Algorithm 3.2 Error-free splitting of a floating point number
2864-
// Dekker (5.5) and (5.6).
2865-
double t = x * 134217729.0; // Veltkamp constant = 2.0 ** 27 + 1
2866-
double hi = t - (t - x);
2867-
double lo = x - hi;
2868-
return (DoubleLength) {hi, lo};
2869-
}
2870-
2871-
static inline DoubleLength
2856+
static DoubleLength
28722857
dl_mul(double x, double y)
28732858
{
2874-
// Dekker (5.12) and mul12()
2875-
DoubleLength xx = dl_split(x);
2876-
DoubleLength yy = dl_split(y);
2877-
double p = xx.hi * yy.hi;
2878-
double q = xx.hi * yy.lo + xx.lo * yy.hi;
2879-
double z = p + q;
2880-
double zz = p - z + q + xx.lo * yy.lo;
2859+
/* Algorithm 3.5. Error-free transformation of a product */
2860+
double z = x * y;
2861+
double zz = fma(x, y, -z);
28812862
return (DoubleLength) {z, zz};
28822863
}
28832864

28842865
typedef struct { double hi; double lo; double tiny; } TripleLength;
28852866

28862867
static const TripleLength tl_zero = {0.0, 0.0, 0.0};
28872868

2888-
static inline TripleLength
2889-
tl_fma(TripleLength total, double x, double y)
2869+
static TripleLength
2870+
tl_fma(double x, double y, TripleLength total)
28902871
{
2891-
// Rump Algorithm 5.10 with K=3 and using SumKVert
2872+
/* Algorithm 5.10 with SumKVert for K=3 */
28922873
DoubleLength pr = dl_mul(x, y);
2893-
DoubleLength sm = twosum(total.hi, pr.hi);
2894-
DoubleLength r1 = twosum(total.lo, pr.lo);
2895-
DoubleLength r2 = twosum(r1.hi, sm.lo);
2874+
DoubleLength sm = dl_sum(total.hi, pr.hi);
2875+
DoubleLength r1 = dl_sum(total.lo, pr.lo);
2876+
DoubleLength r2 = dl_sum(r1.hi, sm.lo);
28962877
return (TripleLength) {sm.hi, r2.hi, total.tiny + r1.lo + r2.lo};
28972878
}
28982879

2899-
static inline double
2880+
static double
29002881
tl_to_d(TripleLength total)
29012882
{
2902-
DoubleLength last = twosum(total.lo, total.hi);
2883+
DoubleLength last = dl_sum(total.lo, total.hi);
29032884
return total.tiny + last.lo + last.hi;
29042885
}
29052886

@@ -3066,7 +3047,7 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q)
30663047
} else {
30673048
goto finalize_flt_path;
30683049
}
3069-
TripleLength new_flt_total = tl_fma(flt_total, flt_p, flt_q);
3050+
TripleLength new_flt_total = tl_fma(flt_p, flt_q, flt_total);
30703051
if (isfinite(new_flt_total.hi)) {
30713052
flt_total = new_flt_total;
30723053
flt_total_in_use = true;

0 commit comments

Comments
 (0)