Skip to content

Commit 43a7ba9

Browse files
committed
Compute TWO_PRODUCT() with FMA
1 parent 3f700be commit 43a7ba9

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

Lib/test/test_math.py

+86
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,92 @@ def run(func, *args):
13691369
args,
13701370
)
13711371

1372+
@requires_IEEE_754
1373+
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
1374+
"sumprod() accuracy not guaranteed on machines with double rounding")
1375+
@support.cpython_only # Other implementations may choose a different algorithm
1376+
@support.requires_resource('cpu')
1377+
def test_sumprod_extended_precision_accuracy(self):
1378+
sumprod = math.sumprod
1379+
import operator
1380+
from fractions import Fraction
1381+
from itertools import starmap
1382+
from collections import namedtuple
1383+
from math import log2, exp2, log10, fabs
1384+
from random import choices, uniform, shuffle
1385+
from statistics import median
1386+
from functools import partial
1387+
from pprint import pp
1388+
1389+
DotExample = namedtuple('DotExample', ('x', 'y', 'target_sumprod', 'condition'))
1390+
1391+
def DotExact(x, y):
1392+
vec1 = map(Fraction, x)
1393+
vec2 = map(Fraction, y)
1394+
return sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))
1395+
1396+
def Condition(x, y):
1397+
return 2.0 * DotExact(map(abs, x), map(abs, y)) / abs(DotExact(x, y))
1398+
1399+
def linspace(lo, hi, n):
1400+
width = (hi - lo) / (n - 1)
1401+
return [lo + width * i for i in range(n)]
1402+
1403+
def GenDot(n, c=1e12):
1404+
""" Algorithm 6.1 (GenDot) works as follows. The condition number (5.7) of
1405+
the dot product xT y is proportional to the degree of cancellation. In
1406+
order to achieve a prescribed cancellation, we generate the first half of
1407+
the vectors x and y randomly within a large exponent range. This range is
1408+
chosen according to the anticipated condition number. The second half of x
1409+
and y is then constructed choosing xi randomly with decreasing exponent,
1410+
and calculating yi such that some cancellation occurs. Finally, we permute
1411+
the vectors x, y randomly and calculate the achieved condition number.
1412+
"""
1413+
1414+
assert n >= 6
1415+
n2 = n // 2
1416+
x = [0.0] * n
1417+
y = [0.0] * n
1418+
b = log2(c)
1419+
1420+
# First half with exponents from 0 to |_b/2_| and random ints in between
1421+
e = choices(range(int(b/2)), k=n2)
1422+
e[0] = int(b / 2) + 1
1423+
e[-1] = 0.0
1424+
1425+
x[:n2] = [uniform(-1.0, 1.0) * exp2(p) for p in e]
1426+
y[:n2] = [uniform(-1.0, 1.0) * exp2(p) for p in e]
1427+
1428+
# Second half
1429+
e = list(map(round, linspace(b/2, 0.0 , n-n2)))
1430+
for i in range(n2, n):
1431+
x[i] = uniform(-1.0, 1.0) * exp2(e[i - n2])
1432+
y[i] = (uniform(-1.0, 1.0) * exp2(e[i - n2]) - DotExact(x, y)) / x[i]
1433+
1434+
# Shuffle
1435+
pairs = list(zip(x, y))
1436+
shuffle(pairs)
1437+
x, y = zip(*pairs)
1438+
1439+
return DotExample(x, y, DotExact(x, y), Condition(x, y))
1440+
1441+
def RelativeError(res, ex):
1442+
x, y, target_sumprod, condition = ex
1443+
n = DotExact(list(x) + [-res], list(y) + [1])
1444+
return fabs(n / target_sumprod)
1445+
1446+
def Trial(dotfunc, c=10e7, n=10):
1447+
ex = GenDot(10, c)
1448+
res = dotfunc(ex.x, ex.y)
1449+
return RelativeError(res, ex)
1450+
1451+
times = 1000 # Number of trials
1452+
n = 20 # Length of vectors
1453+
c = 1e30 # Target condition number
1454+
1455+
relative_err = median(Trial(sumprod, c=c, n=n) for i in range(times))
1456+
self.assertLess(relative_err, 1e-16)
1457+
13721458
def testModf(self):
13731459
self.assertRaises(TypeError, math.modf)
13741460

0 commit comments

Comments
 (0)