Skip to content

Map values according to a dict #257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions graphblas/core/matrix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
import warnings
from collections.abc import Sequence
from collections.abc import Mapping, Sequence

import numpy as np

Expand All @@ -12,7 +12,14 @@
from .descriptor import lookup as descriptor_lookup
from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater
from .mask import Mask, StructuralMask, ValueMask
from .operator import UNKNOWN_OPCLASS, find_opclass, get_semiring, get_typed_op, op_from_string
from .operator import (
UNKNOWN_OPCLASS,
_dict_to_func,
find_opclass,
get_semiring,
get_typed_op,
op_from_string,
)
from .scalar import (
_COMPLETE,
_MATERIALIZE,
Expand Down Expand Up @@ -2279,7 +2286,11 @@ def apply(self, op, right=None, *, left=None):
right = False # most basic form of 0 when unifying dtypes
if left is not None:
raise TypeError("Do not pass `left` when applying IndexUnaryOp")

elif opclass == UNKNOWN_OPCLASS and isinstance(op, Mapping):
if left is not None:
raise TypeError("Do not pass `left` when applying a Mapping")
op = _dict_to_func(op, right)
right = None
if left is None and right is None:
op = get_typed_op(op, self.dtype, kind="unary")
self._expect_op(
Expand Down
22 changes: 22 additions & 0 deletions graphblas/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3597,3 +3597,25 @@ def aggregator_from_string(string):
from .. import agg # noqa: E402 isort:skip

agg.from_string = aggregator_from_string


def _dict_to_func(d, default):
# This probably doesn't work on UDTs, and we could probably be smarter with dtypes
if default is None:
default = False
keys, vals = zip(*d.items())
keys = np.array(keys)
lookup_dtype(keys.dtype)
vals = np.array(vals)
lookup_dtype(vals.dtype)
p = np.argsort(keys)
keys = keys[p]
vals = vals[p]

def func(x):
i = np.searchsorted(keys, x)
if i < keys.size and keys[i] == x:
return vals[i]
return default

return func
16 changes: 14 additions & 2 deletions graphblas/core/vector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import warnings
from collections.abc import Mapping

import numpy as np

Expand All @@ -11,7 +12,14 @@
from .descriptor import lookup as descriptor_lookup
from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater
from .mask import Mask, StructuralMask, ValueMask
from .operator import UNKNOWN_OPCLASS, find_opclass, get_semiring, get_typed_op, op_from_string
from .operator import (
UNKNOWN_OPCLASS,
_dict_to_func,
find_opclass,
get_semiring,
get_typed_op,
op_from_string,
)
from .scalar import (
_COMPLETE,
_MATERIALIZE,
Expand Down Expand Up @@ -1315,7 +1323,11 @@ def apply(self, op, right=None, *, left=None):
right = False # most basic form of 0 when unifying dtypes
if left is not None:
raise TypeError("Do not pass `left` when applying IndexUnaryOp")

elif opclass == UNKNOWN_OPCLASS and isinstance(op, Mapping):
if left is not None:
raise TypeError("Do not pass `left` when applying a Mapping")
op = _dict_to_func(op, right)
right = None
if left is None and right is None:
op = get_typed_op(op, self.dtype, kind="unary")
self._expect_op(
Expand Down
17 changes: 17 additions & 0 deletions graphblas/tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,23 @@ def test_apply_indexunary(A):
A.apply(select.valueeq, left=s3)


def test_apply_dict():
rows = [0, 0, 0, 0]
cols = [1, 3, 4, 6]
vals = [1, 1, 2, 0]
V = Matrix.from_coo(rows, cols, vals)
# Use right as default
W1 = V.apply({1: 10, 2: 20}, 100).new()
expected = Matrix.from_coo(rows, cols, [10, 10, 20, 100])
assert W1.isequal(expected)
# Default is 0 if unspecified
W2 = V.apply({0: 10, 2: 20}).new()
expected = Matrix.from_coo(rows, cols, [0, 0, 20, 10])
assert W2.isequal(expected)
with pytest.raises(TypeError, match="left"):
V.apply({0: 10, 2: 20}, left=999)


def test_select(A):
A3 = Matrix.from_coo([0, 3, 3, 6], [3, 0, 2, 4], [3, 3, 3, 3], nrows=7, ncols=7)
w1 = A.select(select.valueeq, 3).new()
Expand Down
23 changes: 23 additions & 0 deletions graphblas/tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,29 @@ def test_apply_indexunary(v):
v.apply(indexunary.valueeq, left=s2)


def test_apply_dict(v):
# Use right as default
w1 = v.apply({1: 10, 2: 20}, 100).new()
expected = Vector.from_coo([1, 3, 4, 6], [10, 10, 20, 100])
assert w1.isequal(expected)
# Default is 0 if unspecified
w2 = v.apply({0: 10, 2: 20}).new()
expected = Vector.from_coo([1, 3, 4, 6], [0, 0, 20, 10])
assert w2.isequal(expected)
# Scalar default can up-cast dtype
w3 = v.apply({1: 10, 2: 20}, 0.5).new()
expected = Vector.from_coo([1, 3, 4, 6], [10, 10, 20, 0.5])
assert w3.isequal(expected)
with pytest.raises(TypeError, match="left"):
v.apply({0: 10, 2: 20}, left=999)
with pytest.raises(ValueError, match="Unknown dtype"):
v.apply({0: 10, 2: object()})
import numba

with pytest.raises(numba.TypingError): # TODO: this error and message should be better
v.apply({0: 10, 2: 20}, object())


def test_select(v):
result = Vector.from_coo([1, 3], [1, 1], size=7)
w1 = v.select(select.valueeq, 1).new()
Expand Down