Skip to content

Commit 3b2cea5

Browse files
authored
Merge pull request RustPython#1589 from Writtic/writtic/module_remainder
Update remainder module of math
2 parents 40c8661 + ed075cf commit 3b2cea5

File tree

2 files changed

+144
-5
lines changed

2 files changed

+144
-5
lines changed

tests/snippets/stdlib_math.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,4 +265,95 @@ def assertAllNotClose(examples, *args, **kwargs):
265265
assert math.fmod(3.0, NINF) == 3.0
266266
assert math.fmod(-3.0, NINF) == -3.0
267267
assert math.fmod(0.0, 3.0) == 0.0
268-
assert math.fmod(0.0, NINF) == 0.0
268+
assert math.fmod(0.0, NINF) == 0.0
269+
270+
"""
271+
TODO: math.remainder was added to CPython in 3.7 and RustPython CI runs on 3.6.
272+
So put the tests of math.remainder in a comment for now.
273+
https://github.com/RustPython/RustPython/pull/1589#issuecomment-551424940
274+
"""
275+
276+
# testcases = [
277+
# # Remainders modulo 1, showing the ties-to-even behaviour.
278+
# '-4.0 1 -0.0',
279+
# '-3.8 1 0.8',
280+
# '-3.0 1 -0.0',
281+
# '-2.8 1 -0.8',
282+
# '-2.0 1 -0.0',
283+
# '-1.8 1 0.8',
284+
# '-1.0 1 -0.0',
285+
# '-0.8 1 -0.8',
286+
# '-0.0 1 -0.0',
287+
# ' 0.0 1 0.0',
288+
# ' 0.8 1 0.8',
289+
# ' 1.0 1 0.0',
290+
# ' 1.8 1 -0.8',
291+
# ' 2.0 1 0.0',
292+
# ' 2.8 1 0.8',
293+
# ' 3.0 1 0.0',
294+
# ' 3.8 1 -0.8',
295+
# ' 4.0 1 0.0',
296+
297+
# # Reductions modulo 2*pi
298+
# '0x0.0p+0 0x1.921fb54442d18p+2 0x0.0p+0',
299+
# '0x1.921fb54442d18p+0 0x1.921fb54442d18p+2 0x1.921fb54442d18p+0',
300+
# '0x1.921fb54442d17p+1 0x1.921fb54442d18p+2 0x1.921fb54442d17p+1',
301+
# '0x1.921fb54442d18p+1 0x1.921fb54442d18p+2 0x1.921fb54442d18p+1',
302+
# '0x1.921fb54442d19p+1 0x1.921fb54442d18p+2 -0x1.921fb54442d17p+1',
303+
# '0x1.921fb54442d17p+2 0x1.921fb54442d18p+2 -0x0.0000000000001p+2',
304+
# '0x1.921fb54442d18p+2 0x1.921fb54442d18p+2 0x0p0',
305+
# '0x1.921fb54442d19p+2 0x1.921fb54442d18p+2 0x0.0000000000001p+2',
306+
# '0x1.2d97c7f3321d1p+3 0x1.921fb54442d18p+2 0x1.921fb54442d14p+1',
307+
# '0x1.2d97c7f3321d2p+3 0x1.921fb54442d18p+2 -0x1.921fb54442d18p+1',
308+
# '0x1.2d97c7f3321d3p+3 0x1.921fb54442d18p+2 -0x1.921fb54442d14p+1',
309+
# '0x1.921fb54442d17p+3 0x1.921fb54442d18p+2 -0x0.0000000000001p+3',
310+
# '0x1.921fb54442d18p+3 0x1.921fb54442d18p+2 0x0p0',
311+
# '0x1.921fb54442d19p+3 0x1.921fb54442d18p+2 0x0.0000000000001p+3',
312+
# '0x1.f6a7a2955385dp+3 0x1.921fb54442d18p+2 0x1.921fb54442d14p+1',
313+
# '0x1.f6a7a2955385ep+3 0x1.921fb54442d18p+2 0x1.921fb54442d18p+1',
314+
# '0x1.f6a7a2955385fp+3 0x1.921fb54442d18p+2 -0x1.921fb54442d14p+1',
315+
# '0x1.1475cc9eedf00p+5 0x1.921fb54442d18p+2 0x1.921fb54442d10p+1',
316+
# '0x1.1475cc9eedf01p+5 0x1.921fb54442d18p+2 -0x1.921fb54442d10p+1',
317+
318+
# # Symmetry with respect to signs.
319+
# ' 1 0.c 0.4',
320+
# '-1 0.c -0.4',
321+
# ' 1 -0.c 0.4',
322+
# '-1 -0.c -0.4',
323+
# ' 1.4 0.c -0.4',
324+
# '-1.4 0.c 0.4',
325+
# ' 1.4 -0.c -0.4',
326+
# '-1.4 -0.c 0.4',
327+
328+
# # Huge modulus, to check that the underlying algorithm doesn't
329+
# # rely on 2.0 * modulus being representable.
330+
# '0x1.dp+1023 0x1.4p+1023 0x0.9p+1023',
331+
# '0x1.ep+1023 0x1.4p+1023 -0x0.ap+1023',
332+
# '0x1.fp+1023 0x1.4p+1023 -0x0.9p+1023',
333+
# ]
334+
335+
# for case in testcases:
336+
# x_hex, y_hex, expected_hex = case.split()
337+
# # print(x_hex, y_hex, expected_hex)
338+
# x = float.fromhex(x_hex)
339+
# y = float.fromhex(y_hex)
340+
# expected = float.fromhex(expected_hex)
341+
# actual = math.remainder(x, y)
342+
# # Cheap way of checking that the floats are
343+
# # as identical as we need them to be.
344+
# assert actual.hex() == expected.hex()
345+
# # self.assertEqual(actual.hex(), expected.hex())
346+
347+
348+
# # Test tiny subnormal modulus: there's potential for
349+
# # getting the implementation wrong here (for example,
350+
# # by assuming that modulus/2 is exactly representable).
351+
# tiny = float.fromhex('1p-1074') # min +ve subnormal
352+
# for n in range(-25, 25):
353+
# if n == 0:
354+
# continue
355+
# y = n * tiny
356+
# for m in range(100):
357+
# x = m * tiny
358+
# actual = math.remainder(x, y)
359+
# actual = math.remainder(-x, y)

vm/src/stdlib/math.rs

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,20 +275,67 @@ fn math_modf(x: IntoPyFloat, _vm: &VirtualMachine) -> (f64, f64) {
275275
(x.fract(), x.trunc())
276276
}
277277

278+
fn fmod(x: f64, y: f64) -> f64 {
279+
if y.is_infinite() && x.is_finite() {
280+
return x;
281+
}
282+
283+
x % y
284+
}
285+
278286
fn math_fmod(x: IntoPyFloat, y: IntoPyFloat, vm: &VirtualMachine) -> PyResult<f64> {
279287
let x = x.to_f64();
280288
let y = y.to_f64();
281-
if y.is_infinite() && x.is_finite() {
282-
return Ok(x);
283-
}
284-
let r = x % y;
289+
290+
let r = fmod(x, y);
291+
285292
if r.is_nan() && !x.is_nan() && !y.is_nan() {
286293
return Err(vm.new_value_error("math domain error".to_string()));
287294
}
288295

289296
Ok(r)
290297
}
291298

299+
fn math_remainder(x: IntoPyFloat, y: IntoPyFloat, vm: &VirtualMachine) -> PyResult<f64> {
300+
let x = x.to_f64();
301+
let y = y.to_f64();
302+
if x.is_finite() && y.is_finite() {
303+
if y == 0.0 {
304+
return Ok(std::f64::NAN);
305+
}
306+
307+
let absx = x.abs();
308+
let absy = y.abs();
309+
let modulus = absx % absy;
310+
311+
let c = absy - modulus;
312+
let r;
313+
if modulus < c {
314+
r = modulus;
315+
} else if modulus > c {
316+
r = -c;
317+
} else {
318+
r = modulus - 2.0 * fmod(0.5 * (absx - modulus), absy);
319+
}
320+
321+
return Ok(1.0_f64.copysign(x) * r);
322+
}
323+
324+
if x.is_nan() {
325+
return Ok(x);
326+
}
327+
if y.is_nan() {
328+
return Ok(y);
329+
}
330+
if x.is_infinite() {
331+
return Ok(std::f64::NAN);
332+
}
333+
if y.is_infinite() {
334+
return Err(vm.new_value_error("math domain error".to_string()));
335+
}
336+
Ok(x)
337+
}
338+
292339
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
293340
let ctx = &vm.ctx;
294341

@@ -342,6 +389,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
342389
"ldexp" => ctx.new_rustfunc(math_ldexp),
343390
"modf" => ctx.new_rustfunc(math_modf),
344391
"fmod" => ctx.new_rustfunc(math_fmod),
392+
"remainder" => ctx.new_rustfunc(math_remainder),
345393

346394
// Rounding functions:
347395
"trunc" => ctx.new_rustfunc(math_trunc),

0 commit comments

Comments
 (0)