From dcaecbafb586d6b48dd2ca7779f6036deb6b947b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Dahlgren?= Date: Sun, 28 Jan 2018 17:56:38 +0100 Subject: [PATCH 1/2] Change comparison of Quantity (to fix gh-146) --- quantities/quantity.py | 51 +++++++++++++++-------------- quantities/tests/test_comparison.py | 40 +++++++++++++--------- quantities/tests/test_umath.py | 16 ++++----- quantities/umath.py | 2 +- 4 files changed, 58 insertions(+), 51 deletions(-) diff --git a/quantities/quantity.py b/quantities/quantity.py index c3887514..ff3908a7 100644 --- a/quantities/quantity.py +++ b/quantities/quantity.py @@ -102,11 +102,7 @@ def g(self, other, *args): def wrap_comparison(f): @wraps(f) def g(self, other): - if isinstance(other, Quantity): - if other._dimensionality != self._dimensionality: - other = other.rescale(self._dimensionality) - other = other.magnitude - return f(self, other) + return f(self, self._magnitude_of_other(other)) return g @@ -375,43 +371,48 @@ def __setitem__(self, key, value): value = value.rescale(self._dimensionality) self.magnitude[key] = value + def _magnitude_of_other(self, other): + if not isinstance(other, Quantity): + other = other * unit_registry['dimensionless'] + return other.rescale(self._dimensionality).magnitude + @with_doc(np.ndarray.__lt__) @wrap_comparison - def __lt__(self, other): - return self.magnitude < other + def __lt__(self, other_mag): + return self.magnitude < other_mag @with_doc(np.ndarray.__le__) @wrap_comparison - def __le__(self, other): - return self.magnitude <= other + def __le__(self, other_mag): + return self.magnitude <= other_mag @with_doc(np.ndarray.__eq__) def __eq__(self, other): - if isinstance(other, Quantity): - try: - other = other.rescale(self._dimensionality).magnitude - except ValueError: - return np.zeros(self.shape, '?') - return self.magnitude == other + try: + other_mag = self._magnitude_of_other(other) + except ValueError: + return np.zeros(self.shape, '?') + else: + return self.magnitude == other_mag @with_doc(np.ndarray.__ne__) def __ne__(self, other): - if isinstance(other, Quantity): - try: - other = other.rescale(self._dimensionality).magnitude - except ValueError: - return np.ones(self.shape, '?') - return self.magnitude != other + try: + other_mag = self._magnitude_of_other(other) + except ValueError: + return np.ones(self.shape, '?') + else: + return self.magnitude != other_mag @with_doc(np.ndarray.__ge__) @wrap_comparison - def __ge__(self, other): - return self.magnitude >= other + def __ge__(self, other_mag): + return self.magnitude >= other_mag @with_doc(np.ndarray.__gt__) @wrap_comparison - def __gt__(self, other): - return self.magnitude > other + def __gt__(self, other_mag): + return self.magnitude > other_mag #I don't think this implementation is particularly efficient, #perhaps there is something better diff --git a/quantities/tests/test_comparison.py b/quantities/tests/test_comparison.py index 8485aac7..43b26906 100644 --- a/quantities/tests/test_comparison.py +++ b/quantities/tests/test_comparison.py @@ -19,6 +19,8 @@ def test_scalar_equality(self): self.assertEqual(pq.J == 2*pq.kg*pq.m**2/pq.s**2, [False]) self.assertEqual(pq.J == pq.kg, [False]) + self.assertTrue(1e3*pq.m == pq.km) + self.assertFalse(1e3*pq.m == [1e3]) def test_scalar_inequality(self): self.assertEqual(pq.J != pq.erg, [True]) @@ -66,7 +68,7 @@ def test_array_equality(self): ) self.assertQuantityEqual( [1, 2, 3, 4]*pq.J == [1, 22, 3, 44], - [1, 0, 1, 0] + [0, 0, 0, 0] ) def test_array_inequality(self): @@ -80,7 +82,7 @@ def test_array_inequality(self): ) self.assertQuantityEqual( [1, 2, 3, 4]*pq.J != [1, 22, 3, 44], - [0, 1, 0, 1] + [1, 1, 1, 1] ) def test_quantity_less_than(self): @@ -92,9 +94,11 @@ def test_quantity_less_than(self): [50, 100, 150]*pq.cm < [1, 1, 1]*pq.m, [1, 0, 0] ) - self.assertQuantityEqual( - [1, 2, 33]*pq.J < [1, 22, 3], - [0, 1, 0] + self.assertRaises( + ValueError, + op.lt, + [1, 2, 33]*pq.J, + [1, 22, 3] ) self.assertRaises( ValueError, @@ -112,9 +116,11 @@ def test_quantity_less_than_or_equal(self): [50, 100, 150]*pq.cm <= [1, 1, 1]*pq.m, [1, 1, 0] ) - self.assertQuantityEqual( - [1, 2, 33]*pq.J <= [1, 22, 3], - [1, 1, 0] + self.assertRaises( + ValueError, + op.le, + [1, 2, 33]*pq.J, + [1, 22, 3] ) self.assertRaises( ValueError, @@ -132,9 +138,11 @@ def test_quantity_greater_than_or_equal(self): [50, 100, 150]*pq.cm >= [1, 1, 1]*pq.m, [0, 1, 1] ) - self.assertQuantityEqual( - [1, 2, 33]*pq.J >= [1, 22, 3], - [1, 0, 1] + self.assertRaises( + ValueError, + op.ge, + [1, 2, 33]*pq.J, + [1, 22, 3] ) self.assertRaises( ValueError, @@ -152,13 +160,15 @@ def test_quantity_greater_than(self): [50, 100, 150]*pq.cm > [1, 1, 1]*pq.m, [0, 0, 1] ) - self.assertQuantityEqual( - [1, 2, 33]*pq.J > [1, 22, 3], - [0, 0, 1] + self.assertRaises( + ValueError, + op.gt, + [1, 2, 33]*pq.J, + [1, 22, 3] ) self.assertRaises( ValueError, op.gt, [1, 2, 33]*pq.J, - [1, 22, 3]*pq.kg, + [1, 22, 3]*pq.kg ) diff --git a/quantities/tests/test_umath.py b/quantities/tests/test_umath.py index 5a37c616..d626ff49 100644 --- a/quantities/tests/test_umath.py +++ b/quantities/tests/test_umath.py @@ -151,17 +151,13 @@ def test_ceil(self): [-1., -1., -0., 1., 2., 2., 2.] * pq.m ) - @unittest.expectedFailure def test_fix(self): - try: - self.assertQuantityEqual(np.fix(3.14 * pq.degF), 3.0 * pq.degF) - self.assertQuantityEqual(np.fix(3.0 * pq.degF), 3.0 * pq.degF) - self.assertQuantityEqual( - np.fix([2.1, 2.9, -2.1, -2.9] * pq.degF), - [2., 2., -2., -2.] * pq.degF - ) - except ValueError as e: - raise self.failureException(e) + self.assertQuantityEqual(np.fix(3.14 * pq.degF), 3.0 * pq.degF) + self.assertQuantityEqual(np.fix(3.0 * pq.degF), 3.0 * pq.degF) + self.assertQuantityEqual( + np.fix([2.1, 2.9, -2.1, -2.9] * pq.degF), + [2., 2., -2., -2.] * pq.degF + ) def test_exp(self): self.assertQuantityEqual(np.exp(1*pq.dimensionless), np.e) diff --git a/quantities/umath.py b/quantities/umath.py index 5c8019e4..b4a71f17 100644 --- a/quantities/umath.py +++ b/quantities/umath.py @@ -81,7 +81,7 @@ def gradient(f, *varargs): for i,j in zip( ret, varargsQuantities)]) @with_doc(np.cross) -def cross (a, b , axisa=-1, axisb=-1, axisc=-1, axis=None): +def cross(a, b , axisa=-1, axisb=-1, axisc=-1, axis=None): if not (isinstance(a, Quantity) and isinstance(b, Quantity)): return np.cross(a, b, axisa, axisb, axisc, axis) From f6a5a261c7ec135bc96b70083108910fd7658906 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Dahlgren?= Date: Mon, 29 Jan 2018 10:52:13 +0100 Subject: [PATCH 2/2] Re-add xfail marker for test_fix --- quantities/tests/test_umath.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/quantities/tests/test_umath.py b/quantities/tests/test_umath.py index d626ff49..63691173 100644 --- a/quantities/tests/test_umath.py +++ b/quantities/tests/test_umath.py @@ -151,13 +151,17 @@ def test_ceil(self): [-1., -1., -0., 1., 2., 2., 2.] * pq.m ) + @unittest.expectedFailure def test_fix(self): - self.assertQuantityEqual(np.fix(3.14 * pq.degF), 3.0 * pq.degF) - self.assertQuantityEqual(np.fix(3.0 * pq.degF), 3.0 * pq.degF) - self.assertQuantityEqual( - np.fix([2.1, 2.9, -2.1, -2.9] * pq.degF), - [2., 2., -2., -2.] * pq.degF - ) + try: + self.assertQuantityEqual(np.fix(3.14 * pq.degF), 3.0 * pq.degF) + self.assertQuantityEqual(np.fix(3.0 * pq.degF), 3.0 * pq.degF) + self.assertQuantityEqual( + np.fix([2.1, 2.9, -2.1, -2.9] * pq.degF), + [2., 2., -2., -2.] * pq.degF + ) + except ValueError as e: + raise self.failureException(e) def test_exp(self): self.assertQuantityEqual(np.exp(1*pq.dimensionless), np.e)