Skip to content

Commit 014aa0a

Browse files
al-buscopybara-github
authored andcommitted
Fixing the behavior of assertDictAlmostEqual
PiperOrigin-RevId: 740501314
1 parent 57ea862 commit 014aa0a

File tree

2 files changed

+49
-11
lines changed

2 files changed

+49
-11
lines changed

absl/testing/absltest.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import io
3232
import itertools
3333
import json
34+
import numbers
3435
import os
3536
import random
3637
import re
@@ -1715,7 +1716,9 @@ def assertDictAlmostEqual(
17151716

17161717
# Almost equality with preset places and delta.
17171718
def almost_equal_compare(a_value, b_value):
1718-
if isinstance(a_value, float) or isinstance(b_value, float):
1719+
if isinstance(a_value, numbers.Number) and isinstance(
1720+
b_value, numbers.Number
1721+
):
17191722
try:
17201723
# assertAlmostEqual should be called with at most one of `places`
17211724
# and `delta`. However, it's okay for assertMappingEqual to pass
@@ -1730,7 +1733,14 @@ def almost_equal_compare(a_value, b_value):
17301733
# pytype: enable=wrong-keyword-args
17311734
except self.failureException as err:
17321735
return False, err
1733-
return True, None
1736+
return True, None
1737+
else:
1738+
# Fall back to regular equality check if the values are not numbers.
1739+
try:
1740+
self.assertEqual(a_value, b_value)
1741+
except self.failureException as err:
1742+
return False, err
1743+
return True, None
17341744

17351745
if delta is not None and places is not None:
17361746
raise ValueError('specify delta or places not both\n')

absl/testing/tests/absltest_test.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -624,15 +624,7 @@ def test_assert_set_equal(self):
624624
self.assertRaises(AssertionError, self.assertSetEqual, set1, set2)
625625

626626
@parameterized.named_parameters(
627-
dict(testcase_name='empty', a={}, b={}),
628627
dict(testcase_name='equal_float', a={'a': 1.01}, b={'a': 1.01}),
629-
dict(testcase_name='int_and_float', a={'a': 0}, b={'a': 0.000_000_01}),
630-
dict(testcase_name='float_and_int', a={'a': 0.000_000_01}, b={'a': 0}),
631-
dict(
632-
testcase_name='mixed_elements',
633-
a={'a': 'A', 'b': 1, 'c': 0.999_999_99},
634-
b={'a': 'A', 'b': 1, 'c': 1},
635-
),
636628
dict(
637629
testcase_name='float_artifacts',
638630
a={'a': 0.15000000000000002},
@@ -644,7 +636,7 @@ def test_assert_set_equal(self):
644636
b={'a': 1.000_000_01, 'b': 1.999_999_99},
645637
),
646638
)
647-
def test_assert_dict_almost_equal(self, a, b):
639+
def test_assert_dict_almost_equal_on_floats(self, a, b):
648640
self.assertDictAlmostEqual(a, b)
649641

650642
@parameterized.named_parameters(
@@ -702,6 +694,42 @@ def test_assert_dict_almost_equal_fails_with_tolerance(
702694
with self.assertRaises(self.failureException):
703695
self.assertDictAlmostEqual(a, b, places=places, delta=delta)
704696

697+
@parameterized.named_parameters(
698+
dict(testcase_name='empty', a={}, b={}),
699+
dict(testcase_name='int_and_float', a={'a': 0}, b={'a': 0.000_000_01}),
700+
dict(testcase_name='float_and_int', a={'a': 1e-8}, b={'a': 0}),
701+
dict(testcase_name='float_and_bool', a={'a': 1.0}, b={'a': True}),
702+
dict(testcase_name='complex_and_float', a={'a': 1j * 1j}, b={'a': -1.0}),
703+
dict(
704+
testcase_name='mixed_elements',
705+
a={'a': 'A', 'b': 1, 'c': False, 'd': 0.999_999_99, 'e': 1j},
706+
b={'a': 'A', 'b': 1, 'c': False, 'd': 1, 'e': 0.999_999_99j},
707+
),
708+
dict(testcase_name='list_values', a={'a': [0.1]}, b={'a': [0.1]}),
709+
dict(testcase_name='dict_values', a={'a': {'b': 1}}, b={'a': {'b': 1}}),
710+
)
711+
def test_assert_dict_almost_equal_non_float_comparison_succeeds(self, a, b):
712+
self.assertDictAlmostEqual(a, b)
713+
714+
@parameterized.named_parameters(
715+
dict(testcase_name='different_strings', a={'a': 'b'}, b={'a': 'c'}),
716+
dict(testcase_name='string_and_float', a={'a': 'b'}, b={'a': 0.1}),
717+
dict(testcase_name='complex_and_float', a={'a': 1j}, b={'a': 1.0}),
718+
dict(
719+
testcase_name='different_bools',
720+
a={'a': False},
721+
b={'a': True},
722+
),
723+
dict(
724+
testcase_name='no_recursion',
725+
a={'a': {'b': 0}},
726+
b={'a': {'b': 1e-8}},
727+
),
728+
)
729+
def test_assert_dict_almost_equal_non_float_comparison_fails(self, a, b):
730+
with self.assertRaises(self.failureException):
731+
self.assertDictAlmostEqual(a, b)
732+
705733
def test_assert_dict_almost_equal_assertion_message(self):
706734
with self.assertRaises(AssertionError) as e:
707735
self.assertDictAlmostEqual({'a': 0.6}, {'a': 1.0}, delta=0.1)

0 commit comments

Comments
 (0)