Skip to content

GH-100485: Tweaks to sumprod() #100857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Doc/whatsnew/3.12.rst
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ dis
:data:`~dis.hasarg` collection instead.
(Contributed by Irit Katriel in :gh:`94216`.)

math
----

* Added :func:`math.sumprod` for computing a sum of products.
(Contributed by Raymond Hettinger in :gh:`100485`.)

os
--

Expand Down
1 change: 1 addition & 0 deletions Lib/test/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,7 @@ def test_sumprod_accuracy(self):
self.assertEqual(sumprod([0.1] * 20, [True, False] * 10), 1.0)
self.assertEqual(sumprod([1.0, 10E100, 1.0, -10E100], [1.0]*4), 2.0)

@support.requires_resource('cpu')
def test_sumprod_stress(self):
sumprod = math.sumprod
product = itertools.product
Expand Down
59 changes: 34 additions & 25 deletions Modules/mathmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -2832,7 +2832,7 @@ long_add_would_overflow(long a, long b)
}

/*
Double length extended precision floating point arithmetic
Double and triple length extended precision floating point arithmetic
based on ideas from three sources:

Improved Kahan–Babuška algorithm by Arnold Neumaier
Expand All @@ -2845,22 +2845,22 @@ based on ideas from three sources:
Ultimately Fast Accurate Summation by Siegfried M. Rump
https://www.tuhh.de/ti3/paper/rump/Ru08b.pdf

The double length routines allow for quite a bit of instruction
level parallelism. On a 3.22 Ghz Apple M1 Max, the incremental
cost of increasing the input vector size by one is 6.0 nsec.
Double length functions:
* dl_split() exact split of a C double into two half precision components.
* dl_mul() exact multiplication of two C doubles.

dl_zero() returns an extended precision zero
dl_split() exactly splits a double into two half precision components.
dl_add() performs compensated summation to keep a running total.
dl_mul() implements lossless multiplication of doubles.
dl_fma() implements an extended precision fused-multiply-add.
dl_to_d() converts from extended precision to double precision.
Triple length functions and constant:
* tl_zero is a triple length zero for starting or resetting an accumulation.
* tl_add() compensated addition of a C double to a triple length number.
* tl_fma() performs a triple length fused-multiply-add.
* tl_to_d() converts from triple length number back to a C double.

*/

typedef struct{ double hi; double lo; } DoubleLength;
typedef struct{ double hi; double lo; double tiny; } TripleLength;

static const DoubleLength dl_zero = {0.0, 0.0};
static const TripleLength tl_zero = {0.0, 0.0, 0.0};

static inline DoubleLength
twosum(double a, double b)
Expand All @@ -2874,11 +2874,20 @@ twosum(double a, double b)
return (DoubleLength) {s, t};
}

static inline DoubleLength
dl_add(DoubleLength total, double x)
static inline TripleLength
tl_add(TripleLength total, double x)
{
DoubleLength s = twosum(total.hi, x);
return (DoubleLength) {s.hi, total.lo + s.lo};
/* Input: x total.hi total.lo total.tiny
|--- twosum ---|
s.hi s.lo
|--- twosum ---|
t.hi t.lo
|--- single sum ---|
Output: s.hi t.hi tiny
*/
DoubleLength s = twosum(x, total.hi);
DoubleLength t = twosum(s.lo, total.lo);
return (TripleLength) {s.hi, t.hi, t.lo + total.tiny};
}

static inline DoubleLength
Expand All @@ -2902,18 +2911,18 @@ dl_mul(double x, double y)
return (DoubleLength) {z, zz};
}

static inline DoubleLength
dl_fma(DoubleLength total, double p, double q)
static inline TripleLength
tl_fma(TripleLength total, double p, double q)
{
DoubleLength product = dl_mul(p, q);
total = dl_add(total, product.hi);
return dl_add(total, product.lo);
total = tl_add(total, product.hi);
return tl_add(total, product.lo);
}

static inline double
dl_to_d(DoubleLength total)
tl_to_d(TripleLength total)
{
return total.hi + total.lo;
return total.tiny + total.lo + total.hi;
}

/*[clinic input]
Expand Down Expand Up @@ -2944,7 +2953,7 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q)
bool int_path_enabled = true, int_total_in_use = false;
bool flt_path_enabled = true, flt_total_in_use = false;
long int_total = 0;
DoubleLength flt_total = dl_zero;
TripleLength flt_total = tl_zero;

p_it = PyObject_GetIter(p);
if (p_it == NULL) {
Expand Down Expand Up @@ -3079,7 +3088,7 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q)
} else {
goto finalize_flt_path;
}
DoubleLength new_flt_total = dl_fma(flt_total, flt_p, flt_q);
TripleLength new_flt_total = tl_fma(flt_total, flt_p, flt_q);
if (isfinite(new_flt_total.hi)) {
flt_total = new_flt_total;
flt_total_in_use = true;
Expand All @@ -3093,7 +3102,7 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q)
// We're finished, overflowed, have a non-float, or got a non-finite value
flt_path_enabled = false;
if (flt_total_in_use) {
term_i = PyFloat_FromDouble(dl_to_d(flt_total));
term_i = PyFloat_FromDouble(tl_to_d(flt_total));
if (term_i == NULL) {
goto err_exit;
}
Expand All @@ -3104,7 +3113,7 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q)
Py_SETREF(total, new_total);
new_total = NULL;
Py_CLEAR(term_i);
flt_total = dl_zero;
flt_total = tl_zero;
flt_total_in_use = false;
}
}
Expand Down