Skip to content

Commit 9788626

Browse files
jseaboldcharris
authored andcommitted
BUG: Fixed bugs in join_by and added tests
1 parent 834b5bf commit 9788626

File tree

2 files changed

+74
-6
lines changed

2 files changed

+74
-6
lines changed

numpy/lib/recfunctions.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,13 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
895895
(nb1, nb2) = (len(r1), len(r2))
896896
(r1names, r2names) = (r1.dtype.names, r2.dtype.names)
897897

898+
# Check the names for collision
899+
if (set.intersection(set(r1names),set(r2names)).difference(key) and
900+
not (r1postfix or r2postfix)):
901+
msg = "r1 and r2 contain common names, r1postfix and r2postfix "
902+
msg += "can't be empty"
903+
raise ValueError(msg)
904+
898905
# Make temporary arrays of just the keys
899906
r1k = drop_fields(r1, [n for n in r1names if n not in key])
900907
r2k = drop_fields(r2, [n for n in r2names if n not in key])
@@ -937,7 +944,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
937944
name = desc[0]
938945
# Have we seen the current name already ?
939946
if name in names:
940-
nameidx = names.index(name)
947+
nameidx = ndtype.index(desc)
941948
current = ndtype[nameidx]
942949
# The current field is part of the key: take the largest dtype
943950
if name in key:
@@ -960,15 +967,15 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
960967
names = output.dtype.names
961968
for f in r1names:
962969
selected = s1[f]
963-
if f not in names:
970+
if f not in names or (f in r2names and not r2postfix and not f in key):
964971
f += r1postfix
965972
current = output[f]
966973
current[:r1cmn] = selected[:r1cmn]
967974
if jointype in ('outer', 'leftouter'):
968975
current[cmn:cmn + r1spc] = selected[r1cmn:]
969976
for f in r2names:
970977
selected = s2[f]
971-
if f not in names:
978+
if f not in names or (f in r1names and not r1postfix and f not in key):
972979
f += r2postfix
973980
current = output[f]
974981
current[:r2cmn] = selected[:r2cmn]

numpy/lib/tests/test_recfunctions.py

+64-3
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ def test_zip_descr(self):
3636
test = zip_descr((x, x), flatten=False)
3737
assert_equal(test,
3838
np.dtype([('', int), ('', int)]))
39-
# Std & flexible-dtype
39+
# Std & flexible-dtype
4040
test = zip_descr((x, z), flatten=True)
4141
assert_equal(test,
4242
np.dtype([('', int), ('A', '|S3'), ('B', float)]))
4343
test = zip_descr((x, z), flatten=False)
4444
assert_equal(test,
4545
np.dtype([('', int),
4646
('', [('A', '|S3'), ('B', float)])]))
47-
# Standard & nested dtype
47+
# Standard & nested dtype
4848
test = zip_descr((x, w), flatten=True)
4949
assert_equal(test,
5050
np.dtype([('', int),
@@ -259,7 +259,7 @@ def test_standard(self):
259259
control = np.array([(1, 10), (2, 20), (-1, 30)],
260260
dtype=[('f0', int), ('f1', int)])
261261
assert_equal(test, control)
262-
#
262+
#
263263
test = merge_arrays((x, y), usemask=True)
264264
control = ma.array([(1, 10), (2, 20), (-1, 30)],
265265
mask=[(0, 0), (0, 0), (1, 0)],
@@ -615,6 +615,67 @@ def test_leftouter_join(self):
615615
dtype=[('a', int), ('b', int), ('c', int), ('d', int)])
616616

617617

618+
class TestJoinBy2(TestCase):
619+
@classmethod
620+
def setUp(cls):
621+
cls.a = np.array(zip(np.arange(10), np.arange(50, 60),
622+
np.arange(100, 110)),
623+
dtype=[('a', int), ('b', int), ('c', int)])
624+
cls.b = np.array(zip(np.arange(10), np.arange(65, 75),
625+
np.arange(100, 110)),
626+
dtype=[('a', int), ('b', int), ('d', int)])
627+
628+
def test_no_r1postfix(self):
629+
"Basic test of join_by"
630+
a, b = self.a, self.b
631+
632+
test = join_by('a', a, b, r1postfix='', r2postfix='2', jointype='inner')
633+
control = np.array([(0, 50, 65, 100, 100), (1, 51, 66, 101, 101),
634+
(2, 52, 67, 102, 102), (3, 53, 68, 103, 103),
635+
(4, 54, 69, 104, 104), (5, 55, 70, 105, 105),
636+
(6, 56, 71, 106, 106), (7, 57, 72, 107, 107),
637+
(8, 58, 73, 108, 108), (9, 59, 74, 109, 109)],
638+
dtype=[('a', int), ('b', int), ('b2', int),
639+
('c', int), ('d', int)])
640+
assert_equal(test, control)
641+
642+
643+
def test_no_postfix(self):
644+
self.assertRaises(ValueError, join_by, 'a', self.a, self.b, r1postfix='', r2postfix='')
645+
646+
def test_no_r2postfix(self):
647+
"Basic test of join_by"
648+
a, b = self.a, self.b
649+
650+
test = join_by('a', a, b, r1postfix='1', r2postfix='', jointype='inner')
651+
control = np.array([(0, 50, 65, 100, 100), (1, 51, 66, 101, 101),
652+
(2, 52, 67, 102, 102), (3, 53, 68, 103, 103),
653+
(4, 54, 69, 104, 104), (5, 55, 70, 105, 105),
654+
(6, 56, 71, 106, 106), (7, 57, 72, 107, 107),
655+
(8, 58, 73, 108, 108), (9, 59, 74, 109, 109)],
656+
dtype=[('a', int), ('b1', int), ('b', int),
657+
('c', int), ('d', int)])
658+
assert_equal(test, control)
659+
660+
def test_two_keys_two_vars(self):
661+
a = np.array(zip(np.tile([10,11],5),np.repeat(np.arange(5),2),
662+
np.arange(50, 60), np.arange(10,20)),
663+
dtype=[('k', int), ('a', int), ('b', int),('c',int)])
664+
665+
b = np.array(zip(np.tile([10,11],5),np.repeat(np.arange(5),2),
666+
np.arange(65, 75), np.arange(0,10)),
667+
dtype=[('k', int), ('a', int), ('b', int), ('c',int)])
668+
669+
control = np.array([(10, 0, 50, 65, 10, 0), (11, 0, 51, 66, 11, 1),
670+
(10, 1, 52, 67, 12, 2), (11, 1, 53, 68, 13, 3),
671+
(10, 2, 54, 69, 14, 4), (11, 2, 55, 70, 15, 5),
672+
(10, 3, 56, 71, 16, 6), (11, 3, 57, 72, 17, 7),
673+
(10, 4, 58, 73, 18, 8), (11, 4, 59, 74, 19, 9)],
674+
dtype=[('k', '<i8'), ('a', '<i8'), ('b1', '<i8'),
675+
('b2', '<i8'), ('c1', '<i8'), ('c2', '<i8')])
676+
test = join_by(['a','k'], a, b, r1postfix='1', r2postfix='2', jointype='inner')
677+
assert_equal(test, control)
678+
618679

619680

620681
if __name__ == '__main__':

0 commit comments

Comments
 (0)