Skip to content

BUG: Fixed bugs in join_by and added tests #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions numpy/lib/recfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,13 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
(nb1, nb2) = (len(r1), len(r2))
(r1names, r2names) = (r1.dtype.names, r2.dtype.names)

# Check the names for collision
if (set.intersection(set(r1names),set(r2names)).difference(key) and
not (r1postfix or r2postfix)):
msg = "r1 and r2 contain common names, r1postfix and r2postfix "
msg += "can't be empty"
raise ValueError(msg)

# Make temporary arrays of just the keys
r1k = drop_fields(r1, [n for n in r1names if n not in key])
r2k = drop_fields(r2, [n for n in r2names if n not in key])
Expand Down Expand Up @@ -937,7 +944,7 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
name = desc[0]
# Have we seen the current name already ?
if name in names:
nameidx = names.index(name)
nameidx = ndtype.index(desc)
current = ndtype[nameidx]
# The current field is part of the key: take the largest dtype
if name in key:
Expand All @@ -960,15 +967,15 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
names = output.dtype.names
for f in r1names:
selected = s1[f]
if f not in names:
if f not in names or (f in r2names and not r2postfix and not f in key):
f += r1postfix
current = output[f]
current[:r1cmn] = selected[:r1cmn]
if jointype in ('outer', 'leftouter'):
current[cmn:cmn + r1spc] = selected[r1cmn:]
for f in r2names:
selected = s2[f]
if f not in names:
if f not in names or (f in r1names and not r1postfix and f not in key):
f += r2postfix
current = output[f]
current[:r2cmn] = selected[:r2cmn]
Expand Down
67 changes: 64 additions & 3 deletions numpy/lib/tests/test_recfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def test_zip_descr(self):
test = zip_descr((x, x), flatten=False)
assert_equal(test,
np.dtype([('', int), ('', int)]))
# Std & flexible-dtype
# Std & flexible-dtype
test = zip_descr((x, z), flatten=True)
assert_equal(test,
np.dtype([('', int), ('A', '|S3'), ('B', float)]))
test = zip_descr((x, z), flatten=False)
assert_equal(test,
np.dtype([('', int),
('', [('A', '|S3'), ('B', float)])]))
# Standard & nested dtype
# Standard & nested dtype
test = zip_descr((x, w), flatten=True)
assert_equal(test,
np.dtype([('', int),
Expand Down Expand Up @@ -259,7 +259,7 @@ def test_standard(self):
control = np.array([(1, 10), (2, 20), (-1, 30)],
dtype=[('f0', int), ('f1', int)])
assert_equal(test, control)
#
#
test = merge_arrays((x, y), usemask=True)
control = ma.array([(1, 10), (2, 20), (-1, 30)],
mask=[(0, 0), (0, 0), (1, 0)],
Expand Down Expand Up @@ -615,6 +615,67 @@ def test_leftouter_join(self):
dtype=[('a', int), ('b', int), ('c', int), ('d', int)])


class TestJoinBy2(TestCase):
@classmethod
def setUp(cls):
cls.a = np.array(zip(np.arange(10), np.arange(50, 60),
np.arange(100, 110)),
dtype=[('a', int), ('b', int), ('c', int)])
cls.b = np.array(zip(np.arange(10), np.arange(65, 75),
np.arange(100, 110)),
dtype=[('a', int), ('b', int), ('d', int)])

def test_no_r1postfix(self):
"Basic test of join_by"
a, b = self.a, self.b

test = join_by('a', a, b, r1postfix='', r2postfix='2', jointype='inner')
control = np.array([(0, 50, 65, 100, 100), (1, 51, 66, 101, 101),
(2, 52, 67, 102, 102), (3, 53, 68, 103, 103),
(4, 54, 69, 104, 104), (5, 55, 70, 105, 105),
(6, 56, 71, 106, 106), (7, 57, 72, 107, 107),
(8, 58, 73, 108, 108), (9, 59, 74, 109, 109)],
dtype=[('a', int), ('b', int), ('b2', int),
('c', int), ('d', int)])
assert_equal(test, control)


def test_no_postfix(self):
self.assertRaises(ValueError, join_by, 'a', self.a, self.b, r1postfix='', r2postfix='')

def test_no_r2postfix(self):
"Basic test of join_by"
a, b = self.a, self.b

test = join_by('a', a, b, r1postfix='1', r2postfix='', jointype='inner')
control = np.array([(0, 50, 65, 100, 100), (1, 51, 66, 101, 101),
(2, 52, 67, 102, 102), (3, 53, 68, 103, 103),
(4, 54, 69, 104, 104), (5, 55, 70, 105, 105),
(6, 56, 71, 106, 106), (7, 57, 72, 107, 107),
(8, 58, 73, 108, 108), (9, 59, 74, 109, 109)],
dtype=[('a', int), ('b1', int), ('b', int),
('c', int), ('d', int)])
assert_equal(test, control)

def test_two_keys_two_vars(self):
a = np.array(zip(np.tile([10,11],5),np.repeat(np.arange(5),2),
np.arange(50, 60), np.arange(10,20)),
dtype=[('k', int), ('a', int), ('b', int),('c',int)])

b = np.array(zip(np.tile([10,11],5),np.repeat(np.arange(5),2),
np.arange(65, 75), np.arange(0,10)),
dtype=[('k', int), ('a', int), ('b', int), ('c',int)])

control = np.array([(10, 0, 50, 65, 10, 0), (11, 0, 51, 66, 11, 1),
(10, 1, 52, 67, 12, 2), (11, 1, 53, 68, 13, 3),
(10, 2, 54, 69, 14, 4), (11, 2, 55, 70, 15, 5),
(10, 3, 56, 71, 16, 6), (11, 3, 57, 72, 17, 7),
(10, 4, 58, 73, 18, 8), (11, 4, 59, 74, 19, 9)],
dtype=[('k', '<i8'), ('a', '<i8'), ('b1', '<i8'),
('b2', '<i8'), ('c1', '<i8'), ('c2', '<i8')])
test = join_by(['a','k'], a, b, r1postfix='1', r2postfix='2', jointype='inner')
assert_equal(test, control)



if __name__ == '__main__':
Expand Down