Skip to content

Handle dtypes (esp. UDTs) better in ewise_union #517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ repos:
- id: auto-walrus
args: [--line-length, "100"]
- repo: https://github.com/psf/black
rev: 23.10.0
rev: 23.10.1
hooks:
- id: black
- id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.1
rev: v0.1.3
hooks:
- id: ruff
args: [--fix-only, --show-fixes]
Expand Down Expand Up @@ -94,7 +94,7 @@ repos:
additional_dependencies: [tomli]
files: ^(graphblas|docs)/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.1
rev: v0.1.3
hooks:
- id: ruff
- repo: https://github.com/sphinx-contrib/sphinx-lint
Expand Down
123 changes: 70 additions & 53 deletions graphblas/core/infix.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .. import backend, binary
from ..dtypes import BOOL
from ..exceptions import DimensionMismatch
from ..monoid import land, lor
from ..semiring import any_pair
from . import automethods, utils
from . import automethods, recorder, utils
from .base import _expect_op, _expect_type
from .expr import InfixExprBase
from .mask import Mask
Expand Down Expand Up @@ -402,43 +403,62 @@ def __init__(self, left, right, *, nrows, ncols):
utils._output_types[MatrixMatMulExpr] = Matrix


def _dummy(obj, obj_type):
with recorder.skip_record:
return output_type(obj)(BOOL, *obj.shape, name="")


def _mismatched(left, right, method, op):
# Create dummy expression to raise on incompatible dimensions
getattr(_dummy(left) if isinstance(left, InfixExprBase) else left, method)(
_dummy(right) if isinstance(right, InfixExprBase) else right, op
)
raise DimensionMismatch # pragma: no cover


def _ewise_infix_expr(left, right, *, method, within):
left_type = output_type(left)
right_type = output_type(right)

types = {Vector, Matrix, TransposedMatrix}
if left_type in types and right_type in types:
# Create dummy expression to check compatibility of dimensions, etc.
expr = getattr(left, method)(right, binary.any)
if expr.output_type is Vector:
if method == "ewise_mult":
return VectorEwiseMultExpr(left, right)
return VectorEwiseAddExpr(left, right)
if left_type is Vector:
if right_type is Vector:
if left._size != right._size:
_mismatched(left, right, method, binary.first)
if method == "ewise_mult":
return VectorEwiseMultExpr(left, right)
return VectorEwiseAddExpr(left, right)
if left._size != right._nrows:
_mismatched(left, right, method, binary.first)
elif right_type is Vector:
if left._ncols != right._size:
_mismatched(left, right, method, binary.first)
elif left.shape != right.shape:
_mismatched(left, right, method, binary.first)
if method == "ewise_mult":
return MatrixEwiseMultExpr(left, right)
return MatrixEwiseAddExpr(left, right)

if within == "__or__" and isinstance(right, Mask):
return right.__ror__(left)
if within == "__and__" and isinstance(right, Mask):
return right.__rand__(left)
if left_type in types:
left._expect_type(right, tuple(types), within=within, argname="right")
elif right_type in types:
if right_type in types:
right._expect_type(left, tuple(types), within=within, argname="left")
elif left_type is Scalar:
# Create dummy expression to check compatibility of dimensions, etc.
expr = getattr(left, method)(right, binary.any)
if left_type is Scalar:
if method == "ewise_mult":
return ScalarEwiseMultExpr(left, right)
return ScalarEwiseAddExpr(left, right)
elif right_type is Scalar:
# Create dummy expression to check compatibility of dimensions, etc.
expr = getattr(right, method)(left, binary.any)
if right_type is Scalar:
if method == "ewise_mult":
return ScalarEwiseMultExpr(right, left)
return ScalarEwiseAddExpr(right, left)
else: # pragma: no cover (sanity)
raise TypeError(f"Bad types for ewise infix: {type(left).__name__}, {type(right).__name__}")
raise TypeError( # pragma: no cover (sanity)
f"Bad types for ewise infix: {type(left).__name__}, {type(right).__name__}"
)


def _matmul_infix_expr(left, right, *, within):
Expand All @@ -447,54 +467,51 @@ def _matmul_infix_expr(left, right, *, within):

if left_type is Vector:
if right_type is Matrix or right_type is TransposedMatrix:
method = "vxm"
elif right_type is Vector:
method = "inner"
else:
right = left._expect_type(
right,
(Matrix, TransposedMatrix),
within=within,
argname="right",
)
elif left_type is Matrix or left_type is TransposedMatrix:
if left._size != right._nrows:
_mismatched(left, right, "vxm", any_pair[BOOL])
return VectorMatMulExpr(left, right, method_name="vxm", size=right._ncols)
if right_type is Vector:
method = "mxv"
elif right_type is Matrix or right_type is TransposedMatrix:
method = "mxm"
else:
right = left._expect_type(
right,
(Vector, Matrix, TransposedMatrix),
within=within,
argname="right",
)
elif right_type is Vector:
left = right._expect_type(
if left._size != right._size:
_mismatched(left, right, "inner", any_pair[BOOL])
return ScalarMatMulExpr(left, right)
left._expect_type(
right,
(Matrix, TransposedMatrix, Vector),
within=within,
argname="right",
)
if left_type is Matrix or left_type is TransposedMatrix:
if right_type is Vector:
if left._ncols != right._size:
_mismatched(left, right, "mxv", any_pair[BOOL])
return VectorMatMulExpr(left, right, method_name="mxv", size=left._nrows)
if right_type is Matrix or right_type is TransposedMatrix:
if left._ncols != right._nrows:
_mismatched(left, right, "mxm", any_pair[BOOL])
return MatrixMatMulExpr(left, right, nrows=left._nrows, ncols=right._ncols)
left._expect_type(
right,
(Vector, Matrix, TransposedMatrix),
within=within,
argname="right",
)
if right_type is Vector:
right._expect_type(
left,
(Matrix, TransposedMatrix),
within=within,
argname="left",
)
elif right_type is Matrix or right_type is TransposedMatrix:
left = right._expect_type(
if right_type is Matrix or right_type is TransposedMatrix:
right._expect_type(
left,
(Vector, Matrix, TransposedMatrix),
within=within,
argname="left",
)
else: # pragma: no cover (sanity)
raise TypeError(
f"Bad types for matmul infix: {type(left).__name__}, {type(right).__name__}"
)

# Create dummy expression to check compatibility of dimensions, etc.
expr = getattr(left, method)(right, any_pair[bool])
if expr.output_type is Vector:
return VectorMatMulExpr(left, right, method_name=method, size=expr._size)
if expr.output_type is Matrix:
return MatrixMatMulExpr(left, right, nrows=expr._nrows, ncols=expr._ncols)
return ScalarMatMulExpr(left, right)
raise TypeError( # pragma: no cover (sanity)
f"Bad types for matmul infix: {type(left).__name__}, {type(right).__name__}"
)


# Import infixmethods, which has side effects
Expand Down
29 changes: 20 additions & 9 deletions graphblas/core/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def _m_mult_v(updater, left, right, op):
updater << left.mxm(right.diag(name="M_temp"), get_semiring(monoid.any, op))


def _m_union_m(updater, left, right, left_default, right_default, op, dtype):
def _m_union_m(updater, left, right, left_default, right_default, op):
mask = updater.kwargs.get("mask")
opts = updater.opts
new_left = left.dup(dtype, clear=True)
new_left = left.dup(op.type, clear=True)
new_left(mask=mask, **opts) << binary.second(right, left_default)
new_left(mask=mask, **opts) << binary.first(left | new_left)
new_right = right.dup(dtype, clear=True)
new_right = right.dup(op.type2, clear=True)
new_right(mask=mask, **opts) << binary.second(left, right_default)
new_right(mask=mask, **opts) << binary.first(right | new_right)
updater << op(new_left & new_right)
Expand Down Expand Up @@ -2078,7 +2078,10 @@ def ewise_union(self, other, op, left_default, right_default):
other = self._expect_type(
other, (Matrix, TransposedMatrix, Vector), within=method_name, argname="other", op=op
)
dtype = self.dtype if self.dtype._is_udt else None
temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary")

left_dtype = temp_op.type
dtype = left_dtype if left_dtype._is_udt else None
if type(left_default) is not Scalar:
try:
left = Scalar.from_value(
Expand All @@ -2095,6 +2098,8 @@ def ewise_union(self, other, op, left_default, right_default):
)
else:
left = _as_scalar(left_default, dtype, is_cscalar=False) # pragma: is_grbscalar
right_dtype = temp_op.type2
dtype = right_dtype if right_dtype._is_udt else None
if type(right_default) is not Scalar:
try:
right = Scalar.from_value(
Expand All @@ -2111,12 +2116,19 @@ def ewise_union(self, other, op, left_default, right_default):
)
else:
right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar
scalar_dtype = unify(left.dtype, right.dtype)
nonscalar_dtype = unify(self.dtype, other.dtype)
op = get_typed_op(op, scalar_dtype, nonscalar_dtype, is_left_scalar=True, kind="binary")

op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary")
op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary")
if op1 is not op2:
left_dtype = unify(op1.type, op2.type, is_right_scalar=True)
right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True)
op = get_typed_op(op, left_dtype, right_dtype, kind="binary")
else:
op = op1
self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op")
if op.opclass == "Monoid":
op = op.binaryop

expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})"
if other.ndim == 1:
# Broadcast rowwise from the right
Expand Down Expand Up @@ -2146,11 +2158,10 @@ def ewise_union(self, other, op, left_default, right_default):
expr_repr=expr_repr,
)
else:
dtype = unify(scalar_dtype, nonscalar_dtype, is_left_scalar=True)
expr = MatrixExpression(
method_name,
None,
[self, left, other, right, _m_union_m, (self, other, left, right, op, dtype)],
[self, left, other, right, _m_union_m, (self, other, left, right, op)],
expr_repr=expr_repr,
nrows=self._nrows,
ncols=self._ncols,
Expand Down
26 changes: 4 additions & 22 deletions graphblas/core/operator/monoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
)
from ...exceptions import check_status_carg
from .. import ffi, lib
from ..expr import InfixExprBase
from ..utils import libget
from .base import OpBase, ParameterizedUdf, TypedOpBase, _call_op, _hasop
from .binary import BinaryOp, ParameterizedBinaryOp
from .base import OpBase, ParameterizedUdf, TypedOpBase, _hasop
from .binary import BinaryOp, ParameterizedBinaryOp, TypedBuiltinBinaryOp

ffi_new = ffi.new

Expand All @@ -36,25 +35,6 @@ def __init__(self, parent, name, type_, return_type, gb_obj, gb_name):
super().__init__(parent, name, type_, return_type, gb_obj, gb_name)
self._identity = None

def __call__(self, left, right=None, *, left_default=None, right_default=None):
if left_default is not None or right_default is not None:
if (
left_default is None
or right_default is None
or right is not None
or not isinstance(left, InfixExprBase)
or left.method_name != "ewise_add"
):
raise TypeError(
"Specifying `left_default` or `right_default` keyword arguments implies "
"performing `ewise_union` operation with infix notation.\n"
"There is only one valid way to do this:\n\n"
f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y "
"are Vectors or Matrices, and left_default and right_default are scalars."
)
return left.left.ewise_union(left.right, self, left_default, right_default)
return _call_op(self, left, right)

@property
def identity(self):
if self._identity is None:
Expand Down Expand Up @@ -84,6 +64,8 @@ def is_idempotent(self):
"""True if ``monoid(x, x) == x`` for any x."""
return self.parent.is_idempotent

__call__ = TypedBuiltinBinaryOp.__call__


class TypedUserMonoid(TypedOpBase):
__slots__ = "binaryop", "identity"
Expand Down
Loading