diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index d3b6119f49fd..48f2f6a58fa5 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -4,6 +4,7 @@ :Contains: ediff1d, unique, + crosstab, intersect1d, setxor1d, in1d, @@ -21,7 +22,7 @@ To do: Optionally return indices analogously to unique for all functions. -:Author: Robert Cimrman +:Author: Robert Cimrman and others """ from __future__ import division, absolute_import, print_function @@ -31,7 +32,7 @@ __all__ = [ 'ediff1d', 'intersect1d', 'setxor1d', 'union1d', 'setdiff1d', 'unique', - 'in1d' + 'crosstab', 'in1d' ] @@ -212,6 +213,106 @@ def unique(ar, return_index=False, return_inverse=False, return_counts=False): ret += (np.diff(idx),) return ret + +def crosstab(*args): + """ + Create a table of counts of the unique tuples in ``zip(*args)``. + + When ``len(args) > 1``, the array computed by this function is + often referred to as a contingency table [1]_. + + The arguments must be sequences with the same length. + The second return value is an integer array with ``len(args)`` + dimensions; its shape is ``(n0, n1, ...)``, where ``nk`` is the number + of unique elements in ``args[k]``. + + Parameters + ---------- + args : sequences + Sequences whose unique aligned elements are to be counted. + + Return Value + ------------ + unique_elements : tuple of numpy.ndarrays. + Tuple of length ``len(args)`` containing the arrays of unique elements + in each argument. + count : numpy.ndarray + Counts of the unique elements in ``zip(*args)``, stored in an array. + Also known as a *contingency table* when ``len(args) > 1``. + + See Also + -------- + unique + + References + ---------- + .. [1] "Contingency table", http://en.wikipedia.org/wiki/Contingency_table + + Examples + -------- + Apply `crosstab` to a single argument: + + >>> a = ['A', 'B', 'A', 'A', 'B', 'B', 'A', 'A', 'B'] + >>> vals, counts = crosstab(a) + >>> vals[0] + array(['A', 'B'], + dtype='|S1') + >>> counts + array([5, 4]) + + Note that this case--a single argument--is also handled by `np.unique`: + + >>> unique(a, return_counts=True) + (array(['A', 'B'], + dtype='|S1'), array([5, 4])) + + Include a second argument: + + >>> x = ['X', 'X', 'X', 'Y', 'Z', 'Z', 'Y', 'Y', 'Z'] + >>> (avals, xvals), counts = crosstab(a, x) + >>> avals.tolist() + ['A', 'B'] + >>> xvals.tolist() + ['X', 'Y', 'Z'] + >>> counts + array([[2, 3, 0], + [1, 0, 3]]) + + Higher dimensional contingency tables can be created. + + >>> p = [0, 0, 0, 0, 1, 1, 1, 0, 0] + >>> (avals, xvals, pvals), counts = crosstab(a, x, p) + >>> counts + array([[[2, 0], + [2, 1], + [0, 0]], + + [[1, 0], + [0, 0], + [1, 2]]]) + >>> counts.shape + (2, 3, 2) + + """ + if len(args) == 0: + raise TypeError("crosstab() requires at least one argument.") + + if not all(len(a) == len(args[0]) for a in args[1:]): + raise ValueError("All arguments must have the same length.") + + # Call np.unique with return_inverse=True on each argument. + unique_elements, inverses = zip(*[np.unique(a, return_inverse=True) + for a in args]) + + # Count the occurrences of the unique tuples by applying np.add.at + # to the inverses. + shape = [len(u) for u in unique_elements] + count = np.zeros(shape, dtype=np.intp) + np.add.at(count, inverses, 1) + + return unique_elements, count + + def intersect1d(ar1, ar2, assume_unique=False): """ Find the intersection of two arrays. diff --git a/numpy/lib/tests/test_arraysetops.py b/numpy/lib/tests/test_arraysetops.py index 39196f4bc752..7255cee9af96 100644 --- a/numpy/lib/tests/test_arraysetops.py +++ b/numpy/lib/tests/test_arraysetops.py @@ -8,7 +8,8 @@ run_module_suite, TestCase, assert_array_equal ) from numpy.lib.arraysetops import ( - ediff1d, intersect1d, setxor1d, union1d, setdiff1d, unique, in1d + ediff1d, intersect1d, setxor1d, union1d, setdiff1d, unique, in1d, + crosstab ) @@ -302,6 +303,19 @@ def test_manyways(self): c2 = setdiff1d(aux2, aux1) assert_array_equal(c1, c2) + def test_crosstab(self): + a = np.array([2, 2, 2, 3, 3, 3, 3]) + b = np.array([7, 9, 8, 9, 7, 9, 7]) + items, counts = crosstab(a, b) + assert_array_equal(items[0], [2, 3]) + assert_array_equal(items[1], [7, 8, 9]) + assert_array_equal(counts, [[1, 1, 1], [2, 0, 2]]) + + # Edge case: single empty sequence. + xvals, counts = crosstab([]) + assert_array_equal(xvals[0], []) + assert_array_equal(counts, []) + if __name__ == "__main__": run_module_suite()