From 871b02bf8bffe75e25425dfd808b1e49a15ff918 Mon Sep 17 00:00:00 2001 From: Skipper Seabold Date: Wed, 29 Jun 2011 22:19:52 -0400 Subject: [PATCH] BUG: Fixed bugs in join_by and added tests --- numpy/lib/recfunctions.py | 13 ++++-- numpy/lib/tests/test_recfunctions.py | 67 ++++++++++++++++++++++++++-- 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py index b3c210fff721..0127df9f9113 100644 --- a/numpy/lib/recfunctions.py +++ b/numpy/lib/recfunctions.py @@ -895,6 +895,13 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', (nb1, nb2) = (len(r1), len(r2)) (r1names, r2names) = (r1.dtype.names, r2.dtype.names) + # Check the names for collision + if (set.intersection(set(r1names),set(r2names)).difference(key) and + not (r1postfix or r2postfix)): + msg = "r1 and r2 contain common names, r1postfix and r2postfix " + msg += "can't be empty" + raise ValueError(msg) + # Make temporary arrays of just the keys r1k = drop_fields(r1, [n for n in r1names if n not in key]) r2k = drop_fields(r2, [n for n in r2names if n not in key]) @@ -937,7 +944,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', name = desc[0] # Have we seen the current name already ? if name in names: - nameidx = names.index(name) + nameidx = ndtype.index(desc) current = ndtype[nameidx] # The current field is part of the key: take the largest dtype if name in key: @@ -960,7 +967,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', names = output.dtype.names for f in r1names: selected = s1[f] - if f not in names: + if f not in names or (f in r2names and not r2postfix and not f in key): f += r1postfix current = output[f] current[:r1cmn] = selected[:r1cmn] @@ -968,7 +975,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', current[cmn:cmn + r1spc] = selected[r1cmn:] for f in r2names: selected = s2[f] - if f not in names: + if f not in names or (f in r1names and not r1postfix and f not in key): f += r2postfix current = output[f] current[:r2cmn] = selected[:r2cmn] diff --git a/numpy/lib/tests/test_recfunctions.py b/numpy/lib/tests/test_recfunctions.py index 57d977814092..c6befa5f6da4 100644 --- a/numpy/lib/tests/test_recfunctions.py +++ b/numpy/lib/tests/test_recfunctions.py @@ -36,7 +36,7 @@ def test_zip_descr(self): test = zip_descr((x, x), flatten=False) assert_equal(test, np.dtype([('', int), ('', int)])) - # Std & flexible-dtype + # Std & flexible-dtype test = zip_descr((x, z), flatten=True) assert_equal(test, np.dtype([('', int), ('A', '|S3'), ('B', float)])) @@ -44,7 +44,7 @@ def test_zip_descr(self): assert_equal(test, np.dtype([('', int), ('', [('A', '|S3'), ('B', float)])])) - # Standard & nested dtype + # Standard & nested dtype test = zip_descr((x, w), flatten=True) assert_equal(test, np.dtype([('', int), @@ -259,7 +259,7 @@ def test_standard(self): control = np.array([(1, 10), (2, 20), (-1, 30)], dtype=[('f0', int), ('f1', int)]) assert_equal(test, control) - # + # test = merge_arrays((x, y), usemask=True) control = ma.array([(1, 10), (2, 20), (-1, 30)], mask=[(0, 0), (0, 0), (1, 0)], @@ -615,6 +615,67 @@ def test_leftouter_join(self): dtype=[('a', int), ('b', int), ('c', int), ('d', int)]) +class TestJoinBy2(TestCase): + @classmethod + def setUp(cls): + cls.a = np.array(zip(np.arange(10), np.arange(50, 60), + np.arange(100, 110)), + dtype=[('a', int), ('b', int), ('c', int)]) + cls.b = np.array(zip(np.arange(10), np.arange(65, 75), + np.arange(100, 110)), + dtype=[('a', int), ('b', int), ('d', int)]) + + def test_no_r1postfix(self): + "Basic test of join_by" + a, b = self.a, self.b + + test = join_by('a', a, b, r1postfix='', r2postfix='2', jointype='inner') + control = np.array([(0, 50, 65, 100, 100), (1, 51, 66, 101, 101), + (2, 52, 67, 102, 102), (3, 53, 68, 103, 103), + (4, 54, 69, 104, 104), (5, 55, 70, 105, 105), + (6, 56, 71, 106, 106), (7, 57, 72, 107, 107), + (8, 58, 73, 108, 108), (9, 59, 74, 109, 109)], + dtype=[('a', int), ('b', int), ('b2', int), + ('c', int), ('d', int)]) + assert_equal(test, control) + + + def test_no_postfix(self): + self.assertRaises(ValueError, join_by, 'a', self.a, self.b, r1postfix='', r2postfix='') + + def test_no_r2postfix(self): + "Basic test of join_by" + a, b = self.a, self.b + + test = join_by('a', a, b, r1postfix='1', r2postfix='', jointype='inner') + control = np.array([(0, 50, 65, 100, 100), (1, 51, 66, 101, 101), + (2, 52, 67, 102, 102), (3, 53, 68, 103, 103), + (4, 54, 69, 104, 104), (5, 55, 70, 105, 105), + (6, 56, 71, 106, 106), (7, 57, 72, 107, 107), + (8, 58, 73, 108, 108), (9, 59, 74, 109, 109)], + dtype=[('a', int), ('b1', int), ('b', int), + ('c', int), ('d', int)]) + assert_equal(test, control) + + def test_two_keys_two_vars(self): + a = np.array(zip(np.tile([10,11],5),np.repeat(np.arange(5),2), + np.arange(50, 60), np.arange(10,20)), + dtype=[('k', int), ('a', int), ('b', int),('c',int)]) + + b = np.array(zip(np.tile([10,11],5),np.repeat(np.arange(5),2), + np.arange(65, 75), np.arange(0,10)), + dtype=[('k', int), ('a', int), ('b', int), ('c',int)]) + + control = np.array([(10, 0, 50, 65, 10, 0), (11, 0, 51, 66, 11, 1), + (10, 1, 52, 67, 12, 2), (11, 1, 53, 68, 13, 3), + (10, 2, 54, 69, 14, 4), (11, 2, 55, 70, 15, 5), + (10, 3, 56, 71, 16, 6), (11, 3, 57, 72, 17, 7), + (10, 4, 58, 73, 18, 8), (11, 4, 59, 74, 19, 9)], + dtype=[('k', '