diff --git a/quantities/quantity.py b/quantities/quantity.py index e6046fd..fa26354 100644 --- a/quantities/quantity.py +++ b/quantities/quantity.py @@ -99,16 +99,6 @@ def g(self, other, *args): return f(self, other, *args) return g -def wrap_comparison(f): - @wraps(f) - def g(self, other): - if isinstance(other, Quantity): - if other._dimensionality != self._dimensionality: - other = other.rescale(self._dimensionality) - other = other.magnitude - return f(self, other) - return g - class Quantity(np.ndarray): @@ -425,14 +415,12 @@ def __setitem__(self, key, value): self.magnitude[key] = value @with_doc(np.ndarray.__lt__) - @wrap_comparison def __lt__(self, other): - return self.magnitude < other + return (self - other).magnitude < 0 @with_doc(np.ndarray.__le__) - @wrap_comparison def __le__(self, other): - return self.magnitude <= other + return (self - other).magnitude <= 0 @with_doc(np.ndarray.__eq__) def __eq__(self, other): @@ -440,8 +428,15 @@ def __eq__(self, other): try: other = other.rescale(self._dimensionality).magnitude except ValueError: - return np.zeros(self.shape, '?') - return self.magnitude == other + return np.logical_and(self.magnitude != other.magnitude, False) + return self.magnitude == other + else: + try: + pln = self.rescale(unit_registry['dimensionless']).magnitude + except ValueError: + return np.logical_and(self.magnitude == other, False) + return pln == other + @with_doc(np.ndarray.__ne__) def __ne__(self, other): @@ -449,18 +444,22 @@ def __ne__(self, other): try: other = other.rescale(self._dimensionality).magnitude except ValueError: - return np.ones(self.shape, '?') - return self.magnitude != other - + return np.logical_or(self.magnitude != other.magnitude, True) + return self.magnitude != other + else: + try: + pln = self.rescale(unit_registry['dimensionless']).magnitude + except ValueError: + return np.logical_or(self.magnitude != other, True) + return pln != other + @with_doc(np.ndarray.__ge__) - @wrap_comparison def __ge__(self, other): - return self.magnitude >= other + return (self - other).magnitude >= 0 @with_doc(np.ndarray.__gt__) - @wrap_comparison def __gt__(self, other): - return self.magnitude > other + return (self - other).magnitude > 0 #I don't think this implementation is particularly efficient, #perhaps there is something better diff --git a/quantities/tests/test_comparison.py b/quantities/tests/test_comparison.py index 523dc3b..59866f8 100644 --- a/quantities/tests/test_comparison.py +++ b/quantities/tests/test_comparison.py @@ -64,7 +64,7 @@ def test_array_equality(self): ) self.assertQuantityEqual( [1, 2, 3, 4]*pq.J == [1, 22, 3, 44], - [1, 0, 1, 0] + [0, 0, 0, 0] ) def test_array_inequality(self): @@ -78,7 +78,7 @@ def test_array_inequality(self): ) self.assertQuantityEqual( [1, 2, 3, 4]*pq.J != [1, 22, 3, 44], - [0, 1, 0, 1] + [1, 1, 1, 1] ) def test_quantity_less_than(self): @@ -90,9 +90,11 @@ def test_quantity_less_than(self): [50, 100, 150]*pq.cm < [1, 1, 1]*pq.m, [1, 0, 0] ) - self.assertQuantityEqual( - [1, 2, 33]*pq.J < [1, 22, 3], - [0, 1, 0] + self.assertRaises( + ValueError, + op.lt, + [1, 2, 33]*pq.J, + [1, 22, 3], ) self.assertRaises( ValueError, @@ -110,9 +112,11 @@ def test_quantity_less_than_or_equal(self): [50, 100, 150]*pq.cm <= [1, 1, 1]*pq.m, [1, 1, 0] ) - self.assertQuantityEqual( - [1, 2, 33]*pq.J <= [1, 22, 3], - [1, 1, 0] + self.assertRaises( + ValueError, + op.le, + [1, 2, 33]*pq.J, + [1, 22, 3], ) self.assertRaises( ValueError, @@ -130,9 +134,11 @@ def test_quantity_greater_than_or_equal(self): [50, 100, 150]*pq.cm >= [1, 1, 1]*pq.m, [0, 1, 1] ) - self.assertQuantityEqual( - [1, 2, 33]*pq.J >= [1, 22, 3], - [1, 0, 1] + self.assertRaises( + ValueError, + op.ge, + [1, 2, 33]*pq.J, + [1, 22, 3], ) self.assertRaises( ValueError, @@ -150,9 +156,11 @@ def test_quantity_greater_than(self): [50, 100, 150]*pq.cm > [1, 1, 1]*pq.m, [0, 0, 1] ) - self.assertQuantityEqual( - [1, 2, 33]*pq.J > [1, 22, 3], - [0, 0, 1] + self.assertRaises( + ValueError, + op.gt, + [1, 2, 33]*pq.J, + [1, 22, 3], ) self.assertRaises( ValueError, @@ -160,3 +168,46 @@ def test_quantity_greater_than(self): [1, 2, 33]*pq.J, [1, 22, 3]*pq.kg, ) + + def test_quantity_more_raises(self): + self.assertRaises(ValueError, op.gt, pq.ms, 1) + self.assertRaises(ValueError, op.ge, pq.ms, 1) + self.assertRaises(ValueError, op.lt, pq.ms, 1) + self.assertRaises(ValueError, op.le, pq.ms, 1) + + def test_quantity_more_equal(self): + self.assertEqual(pq.ms == 1, False) + self.assertEqual(pq.ms != 1, True) + self.assertEqual(pq.ms == .001 * pq.m, False) + self.assertEqual(pq.ms != .001 * pq.m, True) + self.assertEqual(pq.ms == .001 * pq.s, True) + self.assertEqual(pq.ms != .001 * pq.s, False) + + def test_quantity_more_compare(self): + self.assertEqual(pq.ms * pq.Hz < 1, True) + self.assertEqual(pq.ms * pq.Hz <= 1, True) + self.assertEqual(pq.ms * pq.Hz > 1, False) + self.assertEqual(pq.ms * pq.Hz >= 1, False) + self.assertEqual(pq.ms * pq.Hz == 0.001, True) + self.assertEqual(pq.ms * pq.Hz != 0.001, False) + self.assertEqual(pq.ms * pq.Hz >= 0.001, True) + self.assertEqual(pq.ms * pq.Hz <= 0.001, True) + + self.assertQuantityEqual([1, 2, 3, 4] * pq.ms * pq.Hz + == [0.001, 0.022, 0.003, 0.044], + [1, 0, 1, 0]) + self.assertQuantityEqual([1, 2, 3, 4] * pq.ms * pq.Hz + != [0.001, 0.022, 0.003, 0.044], + [0, 1, 0, 1]) + self.assertQuantityEqual([1, 2, 33] * pq.ms * pq.Hz + < [0.001, 0.020, 0.003], + [0, 1, 0]) + self.assertQuantityEqual([1, 2, 33] * pq.ms * pq.Hz + <= [0.001, 0.020, 0.003], + [1, 1, 0]) + self.assertQuantityEqual([1, 2, 33] * pq.ms * pq.Hz + >= [0.001, 0.020, 0.003], + [1, 0, 1]) + self.assertQuantityEqual([1, 2, 33] * pq.ms * pq.Hz + > [0.001, 0.020, 0.003], + [0, 0, 1])