From a14fa33306e65bb91738ea9106e97bc1780d320e Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Mon, 30 Oct 2023 15:27:41 -0500 Subject: [PATCH] Handle dtypes (esp. UDTs) better in ewise_union --- .pre-commit-config.yaml | 6 +- graphblas/core/infix.py | 123 +++++++++++++++++------------- graphblas/core/matrix.py | 29 ++++--- graphblas/core/operator/monoid.py | 26 +------ graphblas/core/scalar.py | 33 +++++--- graphblas/core/vector.py | 29 ++++--- graphblas/tests/test_matrix.py | 7 ++ graphblas/tests/test_vector.py | 1 + scripts/check_versions.sh | 4 +- 9 files changed, 151 insertions(+), 107 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b2e08e638..3766e2e7c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,12 +61,12 @@ repos: - id: auto-walrus args: [--line-length, "100"] - repo: https://github.com/psf/black - rev: 23.10.0 + rev: 23.10.1 hooks: - id: black - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.1 + rev: v0.1.3 hooks: - id: ruff args: [--fix-only, --show-fixes] @@ -94,7 +94,7 @@ repos: additional_dependencies: [tomli] files: ^(graphblas|docs)/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.1 + rev: v0.1.3 hooks: - id: ruff - repo: https://github.com/sphinx-contrib/sphinx-lint diff --git a/graphblas/core/infix.py b/graphblas/core/infix.py index 88fc52dbe..09b6a6811 100644 --- a/graphblas/core/infix.py +++ b/graphblas/core/infix.py @@ -1,8 +1,9 @@ from .. import backend, binary from ..dtypes import BOOL +from ..exceptions import DimensionMismatch from ..monoid import land, lor from ..semiring import any_pair -from . import automethods, utils +from . import automethods, recorder, utils from .base import _expect_op, _expect_type from .expr import InfixExprBase from .mask import Mask @@ -402,43 +403,62 @@ def __init__(self, left, right, *, nrows, ncols): utils._output_types[MatrixMatMulExpr] = Matrix +def _dummy(obj, obj_type): + with recorder.skip_record: + return output_type(obj)(BOOL, *obj.shape, name="") + + +def _mismatched(left, right, method, op): + # Create dummy expression to raise on incompatible dimensions + getattr(_dummy(left) if isinstance(left, InfixExprBase) else left, method)( + _dummy(right) if isinstance(right, InfixExprBase) else right, op + ) + raise DimensionMismatch # pragma: no cover + + def _ewise_infix_expr(left, right, *, method, within): left_type = output_type(left) right_type = output_type(right) types = {Vector, Matrix, TransposedMatrix} if left_type in types and right_type in types: - # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(left, method)(right, binary.any) - if expr.output_type is Vector: - if method == "ewise_mult": - return VectorEwiseMultExpr(left, right) - return VectorEwiseAddExpr(left, right) + if left_type is Vector: + if right_type is Vector: + if left._size != right._size: + _mismatched(left, right, method, binary.first) + if method == "ewise_mult": + return VectorEwiseMultExpr(left, right) + return VectorEwiseAddExpr(left, right) + if left._size != right._nrows: + _mismatched(left, right, method, binary.first) + elif right_type is Vector: + if left._ncols != right._size: + _mismatched(left, right, method, binary.first) + elif left.shape != right.shape: + _mismatched(left, right, method, binary.first) if method == "ewise_mult": return MatrixEwiseMultExpr(left, right) return MatrixEwiseAddExpr(left, right) + if within == "__or__" and isinstance(right, Mask): return right.__ror__(left) if within == "__and__" and isinstance(right, Mask): return right.__rand__(left) if left_type in types: left._expect_type(right, tuple(types), within=within, argname="right") - elif right_type in types: + if right_type in types: right._expect_type(left, tuple(types), within=within, argname="left") - elif left_type is Scalar: - # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(left, method)(right, binary.any) + if left_type is Scalar: if method == "ewise_mult": return ScalarEwiseMultExpr(left, right) return ScalarEwiseAddExpr(left, right) - elif right_type is Scalar: - # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(right, method)(left, binary.any) + if right_type is Scalar: if method == "ewise_mult": return ScalarEwiseMultExpr(right, left) return ScalarEwiseAddExpr(right, left) - else: # pragma: no cover (sanity) - raise TypeError(f"Bad types for ewise infix: {type(left).__name__}, {type(right).__name__}") + raise TypeError( # pragma: no cover (sanity) + f"Bad types for ewise infix: {type(left).__name__}, {type(right).__name__}" + ) def _matmul_infix_expr(left, right, *, within): @@ -447,54 +467,51 @@ def _matmul_infix_expr(left, right, *, within): if left_type is Vector: if right_type is Matrix or right_type is TransposedMatrix: - method = "vxm" - elif right_type is Vector: - method = "inner" - else: - right = left._expect_type( - right, - (Matrix, TransposedMatrix), - within=within, - argname="right", - ) - elif left_type is Matrix or left_type is TransposedMatrix: + if left._size != right._nrows: + _mismatched(left, right, "vxm", any_pair[BOOL]) + return VectorMatMulExpr(left, right, method_name="vxm", size=right._ncols) if right_type is Vector: - method = "mxv" - elif right_type is Matrix or right_type is TransposedMatrix: - method = "mxm" - else: - right = left._expect_type( - right, - (Vector, Matrix, TransposedMatrix), - within=within, - argname="right", - ) - elif right_type is Vector: - left = right._expect_type( + if left._size != right._size: + _mismatched(left, right, "inner", any_pair[BOOL]) + return ScalarMatMulExpr(left, right) + left._expect_type( + right, + (Matrix, TransposedMatrix, Vector), + within=within, + argname="right", + ) + if left_type is Matrix or left_type is TransposedMatrix: + if right_type is Vector: + if left._ncols != right._size: + _mismatched(left, right, "mxv", any_pair[BOOL]) + return VectorMatMulExpr(left, right, method_name="mxv", size=left._nrows) + if right_type is Matrix or right_type is TransposedMatrix: + if left._ncols != right._nrows: + _mismatched(left, right, "mxm", any_pair[BOOL]) + return MatrixMatMulExpr(left, right, nrows=left._nrows, ncols=right._ncols) + left._expect_type( + right, + (Vector, Matrix, TransposedMatrix), + within=within, + argname="right", + ) + if right_type is Vector: + right._expect_type( left, (Matrix, TransposedMatrix), within=within, argname="left", ) - elif right_type is Matrix or right_type is TransposedMatrix: - left = right._expect_type( + if right_type is Matrix or right_type is TransposedMatrix: + right._expect_type( left, (Vector, Matrix, TransposedMatrix), within=within, argname="left", ) - else: # pragma: no cover (sanity) - raise TypeError( - f"Bad types for matmul infix: {type(left).__name__}, {type(right).__name__}" - ) - - # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(left, method)(right, any_pair[bool]) - if expr.output_type is Vector: - return VectorMatMulExpr(left, right, method_name=method, size=expr._size) - if expr.output_type is Matrix: - return MatrixMatMulExpr(left, right, nrows=expr._nrows, ncols=expr._ncols) - return ScalarMatMulExpr(left, right) + raise TypeError( # pragma: no cover (sanity) + f"Bad types for matmul infix: {type(left).__name__}, {type(right).__name__}" + ) # Import infixmethods, which has side effects diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index aed98f57d..5e1a76720 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -67,13 +67,13 @@ def _m_mult_v(updater, left, right, op): updater << left.mxm(right.diag(name="M_temp"), get_semiring(monoid.any, op)) -def _m_union_m(updater, left, right, left_default, right_default, op, dtype): +def _m_union_m(updater, left, right, left_default, right_default, op): mask = updater.kwargs.get("mask") opts = updater.opts - new_left = left.dup(dtype, clear=True) + new_left = left.dup(op.type, clear=True) new_left(mask=mask, **opts) << binary.second(right, left_default) new_left(mask=mask, **opts) << binary.first(left | new_left) - new_right = right.dup(dtype, clear=True) + new_right = right.dup(op.type2, clear=True) new_right(mask=mask, **opts) << binary.second(left, right_default) new_right(mask=mask, **opts) << binary.first(right | new_right) updater << op(new_left & new_right) @@ -2078,7 +2078,10 @@ def ewise_union(self, other, op, left_default, right_default): other = self._expect_type( other, (Matrix, TransposedMatrix, Vector), within=method_name, argname="other", op=op ) - dtype = self.dtype if self.dtype._is_udt else None + temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + + left_dtype = temp_op.type + dtype = left_dtype if left_dtype._is_udt else None if type(left_default) is not Scalar: try: left = Scalar.from_value( @@ -2095,6 +2098,8 @@ def ewise_union(self, other, op, left_default, right_default): ) else: left = _as_scalar(left_default, dtype, is_cscalar=False) # pragma: is_grbscalar + right_dtype = temp_op.type2 + dtype = right_dtype if right_dtype._is_udt else None if type(right_default) is not Scalar: try: right = Scalar.from_value( @@ -2111,12 +2116,19 @@ def ewise_union(self, other, op, left_default, right_default): ) else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - scalar_dtype = unify(left.dtype, right.dtype) - nonscalar_dtype = unify(self.dtype, other.dtype) - op = get_typed_op(op, scalar_dtype, nonscalar_dtype, is_left_scalar=True, kind="binary") + + op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") + op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") + if op1 is not op2: + left_dtype = unify(op1.type, op2.type, is_right_scalar=True) + right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) + op = get_typed_op(op, left_dtype, right_dtype, kind="binary") + else: + op = op1 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop + expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if other.ndim == 1: # Broadcast rowwise from the right @@ -2146,11 +2158,10 @@ def ewise_union(self, other, op, left_default, right_default): expr_repr=expr_repr, ) else: - dtype = unify(scalar_dtype, nonscalar_dtype, is_left_scalar=True) expr = MatrixExpression( method_name, None, - [self, left, other, right, _m_union_m, (self, other, left, right, op, dtype)], + [self, left, other, right, _m_union_m, (self, other, left, right, op)], expr_repr=expr_repr, nrows=self._nrows, ncols=self._ncols, diff --git a/graphblas/core/operator/monoid.py b/graphblas/core/operator/monoid.py index fc327b4a7..21d2b7cac 100644 --- a/graphblas/core/operator/monoid.py +++ b/graphblas/core/operator/monoid.py @@ -19,10 +19,9 @@ ) from ...exceptions import check_status_carg from .. import ffi, lib -from ..expr import InfixExprBase from ..utils import libget -from .base import OpBase, ParameterizedUdf, TypedOpBase, _call_op, _hasop -from .binary import BinaryOp, ParameterizedBinaryOp +from .base import OpBase, ParameterizedUdf, TypedOpBase, _hasop +from .binary import BinaryOp, ParameterizedBinaryOp, TypedBuiltinBinaryOp ffi_new = ffi.new @@ -36,25 +35,6 @@ def __init__(self, parent, name, type_, return_type, gb_obj, gb_name): super().__init__(parent, name, type_, return_type, gb_obj, gb_name) self._identity = None - def __call__(self, left, right=None, *, left_default=None, right_default=None): - if left_default is not None or right_default is not None: - if ( - left_default is None - or right_default is None - or right is not None - or not isinstance(left, InfixExprBase) - or left.method_name != "ewise_add" - ): - raise TypeError( - "Specifying `left_default` or `right_default` keyword arguments implies " - "performing `ewise_union` operation with infix notation.\n" - "There is only one valid way to do this:\n\n" - f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y " - "are Vectors or Matrices, and left_default and right_default are scalars." - ) - return left.left.ewise_union(left.right, self, left_default, right_default) - return _call_op(self, left, right) - @property def identity(self): if self._identity is None: @@ -84,6 +64,8 @@ def is_idempotent(self): """True if ``monoid(x, x) == x`` for any x.""" return self.parent.is_idempotent + __call__ = TypedBuiltinBinaryOp.__call__ + class TypedUserMonoid(TypedOpBase): __slots__ = "binaryop", "identity" diff --git a/graphblas/core/scalar.py b/graphblas/core/scalar.py index 8a95e1d71..b822bd58a 100644 --- a/graphblas/core/scalar.py +++ b/graphblas/core/scalar.py @@ -30,12 +30,12 @@ def _scalar_index(name): return self -def _s_union_s(updater, left, right, left_default, right_default, op, dtype): +def _s_union_s(updater, left, right, left_default, right_default, op): opts = updater.opts - new_left = left.dup(dtype, clear=True) + new_left = left.dup(op.type, clear=True) new_left(**opts) << binary.second(right, left_default) new_left(**opts) << binary.first(left | new_left) - new_right = right.dup(dtype, clear=True) + new_right = right.dup(op.type2, clear=True) new_right(**opts) << binary.second(left, right_default) new_right(**opts) << binary.first(right | new_right) updater << op(new_left & new_right) @@ -742,7 +742,8 @@ def ewise_union(self, other, op, left_default, right_default): c << binary.div(a | b, left_default=1, right_default=1) """ method_name = "ewise_union" - dtype = self.dtype if self.dtype._is_udt else None + right_dtype = self.dtype + dtype = right_dtype if right_dtype._is_udt else None if type(other) is not Scalar: try: other = Scalar.from_value(other, dtype, is_cscalar=False, name="") @@ -755,6 +756,13 @@ def ewise_union(self, other, op, left_default, right_default): extra_message="Literal scalars also accepted.", op=op, ) + else: + other = _as_scalar(other, dtype, is_cscalar=False) # pragma: is_grbscalar + + temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + + left_dtype = temp_op.type + dtype = left_dtype if left_dtype._is_udt else None if type(left_default) is not Scalar: try: left = Scalar.from_value( @@ -771,6 +779,8 @@ def ewise_union(self, other, op, left_default, right_default): ) else: left = _as_scalar(left_default, dtype, is_cscalar=False) # pragma: is_grbscalar + right_dtype = temp_op.type2 + dtype = right_dtype if right_dtype._is_udt else None if type(right_default) is not Scalar: try: right = Scalar.from_value( @@ -787,9 +797,15 @@ def ewise_union(self, other, op, left_default, right_default): ) else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - defaults_dtype = unify(left.dtype, right.dtype) - args_dtype = unify(self.dtype, other.dtype) - op = get_typed_op(op, defaults_dtype, args_dtype, kind="binary") + + op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") + op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") + if op1 is not op2: + left_dtype = unify(op1.type, op2.type, is_right_scalar=True) + right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) + op = get_typed_op(op, left_dtype, right_dtype, kind="binary") + else: + op = op1 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop @@ -805,11 +821,10 @@ def ewise_union(self, other, op, left_default, right_default): scalar_as_vector=True, ) else: - dtype = unify(defaults_dtype, args_dtype) expr = ScalarExpression( method_name, None, - [self, left, other, right, _s_union_s, (self, other, left, right, op, dtype)], + [self, left, other, right, _s_union_s, (self, other, left, right, op)], op=op, expr_repr=expr_repr, is_cscalar=False, diff --git a/graphblas/core/vector.py b/graphblas/core/vector.py index cd5b992ba..9d19d80da 100644 --- a/graphblas/core/vector.py +++ b/graphblas/core/vector.py @@ -61,13 +61,13 @@ def _v_union_m(updater, left, right, left_default, right_default, op): updater << temp.ewise_union(right, op, left_default=left_default, right_default=right_default) -def _v_union_v(updater, left, right, left_default, right_default, op, dtype): +def _v_union_v(updater, left, right, left_default, right_default, op): mask = updater.kwargs.get("mask") opts = updater.opts - new_left = left.dup(dtype, clear=True) + new_left = left.dup(op.type, clear=True) new_left(mask=mask, **opts) << binary.second(right, left_default) new_left(mask=mask, **opts) << binary.first(left | new_left) - new_right = right.dup(dtype, clear=True) + new_right = right.dup(op.type2, clear=True) new_right(mask=mask, **opts) << binary.second(left, right_default) new_right(mask=mask, **opts) << binary.first(right | new_right) updater << op(new_left & new_right) @@ -1177,7 +1177,10 @@ def ewise_union(self, other, op, left_default, right_default): other = self._expect_type( other, (Vector, Matrix, TransposedMatrix), within=method_name, argname="other", op=op ) - dtype = self.dtype if self.dtype._is_udt else None + temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + + left_dtype = temp_op.type + dtype = left_dtype if left_dtype._is_udt else None if type(left_default) is not Scalar: try: left = Scalar.from_value( @@ -1194,6 +1197,8 @@ def ewise_union(self, other, op, left_default, right_default): ) else: left = _as_scalar(left_default, dtype, is_cscalar=False) # pragma: is_grbscalar + right_dtype = temp_op.type2 + dtype = right_dtype if right_dtype._is_udt else None if type(right_default) is not Scalar: try: right = Scalar.from_value( @@ -1210,12 +1215,19 @@ def ewise_union(self, other, op, left_default, right_default): ) else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - scalar_dtype = unify(left.dtype, right.dtype) - nonscalar_dtype = unify(self.dtype, other.dtype) - op = get_typed_op(op, scalar_dtype, nonscalar_dtype, is_left_scalar=True, kind="binary") + + op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") + op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") + if op1 is not op2: + left_dtype = unify(op1.type, op2.type, is_right_scalar=True) + right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) + op = get_typed_op(op, left_dtype, right_dtype, kind="binary") + else: + op = op1 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop + expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if other.ndim == 2: # Broadcast columnwise from the left @@ -1243,11 +1255,10 @@ def ewise_union(self, other, op, left_default, right_default): expr_repr=expr_repr, ) else: - dtype = unify(scalar_dtype, nonscalar_dtype, is_left_scalar=True) expr = VectorExpression( method_name, None, - [self, left, other, right, _v_union_v, (self, other, left, right, op, dtype)], + [self, left, other, right, _v_union_v, (self, other, left, right, op)], expr_repr=expr_repr, size=self._size, op=op, diff --git a/graphblas/tests/test_matrix.py b/graphblas/tests/test_matrix.py index e08f96b32..3f66e46ef 100644 --- a/graphblas/tests/test_matrix.py +++ b/graphblas/tests/test_matrix.py @@ -2827,7 +2827,10 @@ def test_auto(A, v): "__and__", "__or__", # "kronecker", + "__rand__", + "__ror__", ]: + # print(type(expr).__name__, method) val1 = getattr(expected, method)(expected).new() val2 = getattr(expected, method)(expr) val3 = getattr(expr, method)(expected) @@ -3138,6 +3141,10 @@ def test_ss_reshape(A): def test_autocompute_argument_messages(A, v): with pytest.raises(TypeError, match="autocompute"): A.ewise_mult(A & A) + with pytest.raises(TypeError, match="autocompute"): + A.ewise_mult(binary.plus(A & A)) + with pytest.raises(TypeError, match="autocompute"): + A.ewise_mult(A + A) with pytest.raises(TypeError, match="autocompute"): A.mxv(A @ v) diff --git a/graphblas/tests/test_vector.py b/graphblas/tests/test_vector.py index 2571f288b..b66bc96c9 100644 --- a/graphblas/tests/test_vector.py +++ b/graphblas/tests/test_vector.py @@ -1579,6 +1579,7 @@ def test_auto(v): "__rand__", "__ror__", ]: + # print(type(expr).__name__, method) val1 = getattr(expected, method)(expected).new() val2 = getattr(expected, method)(expr) val3 = getattr(expr, method)(expected) diff --git a/scripts/check_versions.sh b/scripts/check_versions.sh index 7c09bc168..d197f2af2 100755 --- a/scripts/check_versions.sh +++ b/scripts/check_versions.sh @@ -6,9 +6,9 @@ conda search 'flake8-bugbear[channel=conda-forge]>=23.9.16' conda search 'flake8-simplify[channel=conda-forge]>=0.21.0' conda search 'numpy[channel=conda-forge]>=1.26.0' -conda search 'pandas[channel=conda-forge]>=2.1.1' +conda search 'pandas[channel=conda-forge]>=2.1.2' conda search 'scipy[channel=conda-forge]>=1.11.3' -conda search 'networkx[channel=conda-forge]>=3.2' +conda search 'networkx[channel=conda-forge]>=3.2.1' conda search 'awkward[channel=conda-forge]>=2.4.6' conda search 'sparse[channel=conda-forge]>=0.14.0' conda search 'fast_matrix_market[channel=conda-forge]>=1.7.4'