diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index 8b9b4b678..e922ead87 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -1,6 +1,6 @@ import itertools import warnings -from collections.abc import Sequence +from collections.abc import Mapping, Sequence import numpy as np @@ -12,7 +12,14 @@ from .descriptor import lookup as descriptor_lookup from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater from .mask import Mask, StructuralMask, ValueMask -from .operator import UNKNOWN_OPCLASS, find_opclass, get_semiring, get_typed_op, op_from_string +from .operator import ( + UNKNOWN_OPCLASS, + _dict_to_func, + find_opclass, + get_semiring, + get_typed_op, + op_from_string, +) from .scalar import ( _COMPLETE, _MATERIALIZE, @@ -2279,7 +2286,11 @@ def apply(self, op, right=None, *, left=None): right = False # most basic form of 0 when unifying dtypes if left is not None: raise TypeError("Do not pass `left` when applying IndexUnaryOp") - + elif opclass == UNKNOWN_OPCLASS and isinstance(op, Mapping): + if left is not None: + raise TypeError("Do not pass `left` when applying a Mapping") + op = _dict_to_func(op, right) + right = None if left is None and right is None: op = get_typed_op(op, self.dtype, kind="unary") self._expect_op( diff --git a/graphblas/core/operator.py b/graphblas/core/operator.py index eca7c9d75..78bf37cc9 100644 --- a/graphblas/core/operator.py +++ b/graphblas/core/operator.py @@ -3597,3 +3597,25 @@ def aggregator_from_string(string): from .. import agg # noqa: E402 isort:skip agg.from_string = aggregator_from_string + + +def _dict_to_func(d, default): + # This probably doesn't work on UDTs, and we could probably be smarter with dtypes + if default is None: + default = False + keys, vals = zip(*d.items()) + keys = np.array(keys) + lookup_dtype(keys.dtype) + vals = np.array(vals) + lookup_dtype(vals.dtype) + p = np.argsort(keys) + keys = keys[p] + vals = vals[p] + + def func(x): + i = np.searchsorted(keys, x) + if i < keys.size and keys[i] == x: + return vals[i] + return default + + return func diff --git a/graphblas/core/vector.py b/graphblas/core/vector.py index dd183d856..0ba8cec11 100644 --- a/graphblas/core/vector.py +++ b/graphblas/core/vector.py @@ -1,5 +1,6 @@ import itertools import warnings +from collections.abc import Mapping import numpy as np @@ -11,7 +12,14 @@ from .descriptor import lookup as descriptor_lookup from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater from .mask import Mask, StructuralMask, ValueMask -from .operator import UNKNOWN_OPCLASS, find_opclass, get_semiring, get_typed_op, op_from_string +from .operator import ( + UNKNOWN_OPCLASS, + _dict_to_func, + find_opclass, + get_semiring, + get_typed_op, + op_from_string, +) from .scalar import ( _COMPLETE, _MATERIALIZE, @@ -1315,7 +1323,11 @@ def apply(self, op, right=None, *, left=None): right = False # most basic form of 0 when unifying dtypes if left is not None: raise TypeError("Do not pass `left` when applying IndexUnaryOp") - + elif opclass == UNKNOWN_OPCLASS and isinstance(op, Mapping): + if left is not None: + raise TypeError("Do not pass `left` when applying a Mapping") + op = _dict_to_func(op, right) + right = None if left is None and right is None: op = get_typed_op(op, self.dtype, kind="unary") self._expect_op( diff --git a/graphblas/tests/test_matrix.py b/graphblas/tests/test_matrix.py index 40676f71a..160217235 100644 --- a/graphblas/tests/test_matrix.py +++ b/graphblas/tests/test_matrix.py @@ -1232,6 +1232,23 @@ def test_apply_indexunary(A): A.apply(select.valueeq, left=s3) +def test_apply_dict(): + rows = [0, 0, 0, 0] + cols = [1, 3, 4, 6] + vals = [1, 1, 2, 0] + V = Matrix.from_coo(rows, cols, vals) + # Use right as default + W1 = V.apply({1: 10, 2: 20}, 100).new() + expected = Matrix.from_coo(rows, cols, [10, 10, 20, 100]) + assert W1.isequal(expected) + # Default is 0 if unspecified + W2 = V.apply({0: 10, 2: 20}).new() + expected = Matrix.from_coo(rows, cols, [0, 0, 20, 10]) + assert W2.isequal(expected) + with pytest.raises(TypeError, match="left"): + V.apply({0: 10, 2: 20}, left=999) + + def test_select(A): A3 = Matrix.from_coo([0, 3, 3, 6], [3, 0, 2, 4], [3, 3, 3, 3], nrows=7, ncols=7) w1 = A.select(select.valueeq, 3).new() diff --git a/graphblas/tests/test_vector.py b/graphblas/tests/test_vector.py index 8505313e4..7da626b3b 100644 --- a/graphblas/tests/test_vector.py +++ b/graphblas/tests/test_vector.py @@ -724,6 +724,29 @@ def test_apply_indexunary(v): v.apply(indexunary.valueeq, left=s2) +def test_apply_dict(v): + # Use right as default + w1 = v.apply({1: 10, 2: 20}, 100).new() + expected = Vector.from_coo([1, 3, 4, 6], [10, 10, 20, 100]) + assert w1.isequal(expected) + # Default is 0 if unspecified + w2 = v.apply({0: 10, 2: 20}).new() + expected = Vector.from_coo([1, 3, 4, 6], [0, 0, 20, 10]) + assert w2.isequal(expected) + # Scalar default can up-cast dtype + w3 = v.apply({1: 10, 2: 20}, 0.5).new() + expected = Vector.from_coo([1, 3, 4, 6], [10, 10, 20, 0.5]) + assert w3.isequal(expected) + with pytest.raises(TypeError, match="left"): + v.apply({0: 10, 2: 20}, left=999) + with pytest.raises(ValueError, match="Unknown dtype"): + v.apply({0: 10, 2: object()}) + import numba + + with pytest.raises(numba.TypingError): # TODO: this error and message should be better + v.apply({0: 10, 2: 20}, object()) + + def test_select(v): result = Vector.from_coo([1, 3], [1, 1], size=7) w1 = v.select(select.valueeq, 1).new()