Skip to content

Add docstrings for using SS JIT, and make better #512

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ coverage:
default:
informational: true
changes: false
comment: off
comment:
layout: "header, diff"
behavior: default
github_checks:
annotations: false
ignore:
- graphblas/viz.py
6 changes: 1 addition & 5 deletions .github/workflows/test_and_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ jobs:
use-mamba: true
python-version: ${{ steps.pyver.outputs.selected }}
channels: conda-forge,${{ contains(steps.pyver.outputs.selected, 'pypy') && 'defaults' || 'nodefaults' }}
# mamba does not yet implement strict priority
# channel-priority: ${{ contains(steps.pyver.outputs.selected, 'pypy') && 'flexible' || 'strict' }}
channel-priority: ${{ contains(steps.pyver.outputs.selected, 'pypy') && 'flexible' || 'strict' }}
activate-environment: graphblas
auto-activate-base: false
- name: Setup conda
Expand Down Expand Up @@ -412,9 +411,6 @@ jobs:
if: matrix.slowtask == 'pytest_bizarro'
run: |
# This step uses `black`
if [[ ${{ startsWith(steps.pyver.outputs.selected, '3.12') }} == true ]]; then
pip install black # Latest version of black on conda-forge does not have builds for Python 3.12
fi
coverage run -a -m graphblas.core.automethods
coverage run -a -m graphblas.core.infixmethods
git diff --exit-code
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ repos:
hooks:
- id: ruff
- repo: https://github.com/sphinx-contrib/sphinx-lint
rev: v0.7.0
rev: v0.8.0
hooks:
- id: sphinx-lint
args: [--enable, all, "--disable=line-too-long,leaked-markup"]
Expand Down
17 changes: 12 additions & 5 deletions graphblas/core/operator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,22 @@ def __getitem__(self, type_):
raise KeyError(f"{self.name} does not work with {type_}")
else:
return self._typed_ops[type_]
if not _supports_udfs:
raise KeyError(f"{self.name} does not work with {type_}")
# This is a UDT or is able to operate on UDTs such as `first` any `any`
dtype = lookup_dtype(type_)
return self._compile_udt(dtype, dtype)

def _add(self, op):
self._typed_ops[op.type] = op
self.types[op.type] = op.return_type
def _add(self, op, *, is_jit=False):
if is_jit:
if hasattr(op, "type2") or hasattr(op, "thunk_type"):
dtypes = (op.type, op._type2)
else:
dtypes = op.type
self.types[dtypes] = op.return_type # This is a different use of .types
self._udt_types[dtypes] = op.return_type
self._udt_ops[dtypes] = op
else:
self._typed_ops[op.type] = op
self.types[op.type] = op.return_type

def __delitem__(self, type_):
type_ = lookup_dtype(type_)
Expand Down
9 changes: 6 additions & 3 deletions graphblas/core/operator/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,8 @@ def _compile_udt(self, dtype, dtype2):
if dtypes in self._udt_types:
return self._udt_ops[dtypes]

nt = numba.types
if self.name == "eq" and not self._anonymous:
if self.name == "eq" and not self._anonymous and _has_numba:
nt = numba.types
# assert dtype.np_type == dtype2.np_type
itemsize = dtype.np_type.itemsize
mask = _udt_mask(dtype.np_type)
Expand Down Expand Up @@ -561,7 +561,8 @@ def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba)
# z_ptr[0] = True
z_ptr[0] = (x[mask] == y[mask]).all()

elif self.name == "ne" and not self._anonymous:
elif self.name == "ne" and not self._anonymous and _has_numba:
nt = numba.types
# assert dtype.np_type == dtype2.np_type
itemsize = dtype.np_type.itemsize
mask = _udt_mask(dtype.np_type)
Expand Down Expand Up @@ -597,6 +598,8 @@ def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba)
# z_ptr[0] = False
z_ptr[0] = (x[mask] != y[mask]).any()

elif self._numba_func is None:
raise KeyError(f"{self.name} does not work with {dtypes} types")
else:
numba_func = self._numba_func
sig = (dtype.numba_type, dtype2.numba_type)
Expand Down
7 changes: 7 additions & 0 deletions graphblas/core/operator/indexunary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def __call__(self, val, thunk=None):
thunk = False # most basic form of 0 when unifying dtypes
return _call_op(self, val, right=thunk)

@property
def thunk_type(self):
return self.type if self._type2 is None else self._type2


class TypedUserIndexUnaryOp(TypedOpBase):
__slots__ = ()
Expand All @@ -41,6 +45,7 @@ def orig_func(self):
def _numba_func(self):
return self.parent._numba_func

thunk_type = TypedBuiltinIndexUnaryOp.thunk_type
__call__ = TypedBuiltinIndexUnaryOp.__call__


Expand Down Expand Up @@ -210,6 +215,8 @@ def _compile_udt(self, dtype, dtype2):
dtypes = (dtype, dtype2)
if dtypes in self._udt_types:
return self._udt_ops[dtypes]
if self._numba_func is None:
raise KeyError(f"{self.name} does not work with {dtypes} types")

numba_func = self._numba_func
sig = (dtype.numba_type, UINT64.numba_type, UINT64.numba_type, dtype2.numba_type)
Expand Down
57 changes: 53 additions & 4 deletions graphblas/core/operator/select.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import inspect

from ... import _STANDARD_OPERATOR_NAMES, select
from ...dtypes import BOOL
from ...dtypes import BOOL, UINT64
from ...exceptions import check_status_carg
from .. import _has_numba, ffi, lib
from .base import OpBase, ParameterizedUdf, TypedOpBase, _call_op, _deserialize_parameterized
from .indexunary import IndexUnaryOp
from .indexunary import IndexUnaryOp, TypedBuiltinIndexUnaryOp

if _has_numba:
import numba

from .base import _get_udt_wrapper
ffi_new = ffi.new


class TypedBuiltinSelectOp(TypedOpBase):
Expand All @@ -15,13 +23,15 @@ def __call__(self, val, thunk=None):
thunk = False # most basic form of 0 when unifying dtypes
return _call_op(self, val, thunk=thunk)

thunk_type = TypedBuiltinIndexUnaryOp.thunk_type


class TypedUserSelectOp(TypedOpBase):
__slots__ = ()
opclass = "SelectOp"

def __init__(self, parent, name, type_, return_type, gb_obj):
super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}")
def __init__(self, parent, name, type_, return_type, gb_obj, dtype2=None):
super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2)

@property
def orig_func(self):
Expand All @@ -31,6 +41,7 @@ def orig_func(self):
def _numba_func(self):
return self.parent._numba_func

thunk_type = TypedBuiltinSelectOp.thunk_type
__call__ = TypedBuiltinSelectOp.__call__


Expand Down Expand Up @@ -120,6 +131,44 @@ def _from_indexunary(cls, iop):
obj.types[type_] = op.return_type
return obj

def _compile_udt(self, dtype, dtype2):
if dtype2 is None: # pragma: no cover
dtype2 = dtype
dtypes = (dtype, dtype2)
if dtypes in self._udt_types:
return self._udt_ops[dtypes]
if self._numba_func is None:
raise KeyError(f"{self.name} does not work with {dtypes} types")

# It would be nice if we could reuse compiling done for IndexUnaryOp
numba_func = self._numba_func
sig = (dtype.numba_type, UINT64.numba_type, UINT64.numba_type, dtype2.numba_type)
numba_func.compile(sig) # Should we catch and give additional error message?
select_wrapper, wrapper_sig = _get_udt_wrapper(
numba_func, BOOL, dtype, dtype2, include_indexes=True
)

select_wrapper = numba.cfunc(wrapper_sig, nopython=True)(select_wrapper)
new_select = ffi_new("GrB_IndexUnaryOp*")
check_status_carg(
lib.GrB_IndexUnaryOp_new(
new_select, select_wrapper.cffi, BOOL._carg, dtype._carg, dtype2._carg
),
"IndexUnaryOp",
new_select[0],
)
op = TypedUserSelectOp(
self,
self.name,
dtype,
BOOL,
new_select[0],
dtype2=dtype2,
)
self._udt_types[dtypes] = BOOL
self._udt_ops[dtypes] = op
return op

@classmethod
def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False):
"""Register a SelectOp without registering it in the ``graphblas.select`` namespace.
Expand Down
2 changes: 2 additions & 0 deletions graphblas/core/operator/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def unary_wrapper(z, x):
def _compile_udt(self, dtype, dtype2):
if dtype in self._udt_types:
return self._udt_ops[dtype]
if self._numba_func is None:
raise KeyError(f"{self.name} does not work with {dtype}")

numba_func = self._numba_func
sig = (dtype.numba_type,)
Expand Down
63 changes: 59 additions & 4 deletions graphblas/core/ss/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,47 @@ def jit_c_definition(self):


def register_new(name, jit_c_definition, left_type, right_type, ret_type):
"""Register a new BinaryOp using the SuiteSparse:GraphBLAS JIT compiler.

This creates a BinaryOp by compiling the C string definition of the function.
It requires a shell call to a C compiler. The resulting operator will be as
fast as if it were built-in to SuiteSparse:GraphBLAS and does not have the
overhead of additional function calls as when using ``gb.binary.register_new``.

This is an advanced feature that requires a C compiler and proper configuration.
Configuration is handled by ``gb.ss.config``; see its docstring for details.
By default, the JIT caches results in ``~/.SuiteSparse/``. For more information,
see the SuiteSparse:GraphBLAS user guide.

Only one type signature may be registered at a time, but repeated calls using
the same name with different input types is allowed.

Parameters
----------
name : str
The name of the operator. This will show up as ``gb.binary.ss.{name}``.
The name may contain periods, ".", which will result in nested objects
such as ``gb.binary.ss.x.y.z`` for name ``"x.y.z"``.
jit_c_definition : str
The C definition as a string of the user-defined function. For example:
``"void absdiff (double *z, double *x, double *y) { (*z) = fabs ((*x) - (*y)) ; }"``.
left_type : dtype
The dtype of the left operand of the binary operator.
right_type : dtype
The dtype of the right operand of the binary operator.
ret_type : dtype
The dtype of the result of the binary operator.

Returns
-------
BinaryOp

See Also
--------
gb.binary.register_new
gb.binary.register_anonymous
gb.unary.ss.register_new
"""
if backend != "suitesparse": # pragma: no cover (safety)
raise RuntimeError(
"`gb.binary.ss.register_new` invalid when not using 'suitesparse' backend"
Expand All @@ -47,9 +88,23 @@ def register_new(name, jit_c_definition, left_type, right_type, ret_type):
right_type = lookup_dtype(right_type)
ret_type = lookup_dtype(ret_type)
name = name if name.startswith("ss.") else f"ss.{name}"
module, funcname = BinaryOp._remove_nesting(name)

rv = BinaryOp(name)
module, funcname = BinaryOp._remove_nesting(name, strict=False)
if hasattr(module, funcname):
rv = getattr(module, funcname)
if not isinstance(rv, BinaryOp):
BinaryOp._remove_nesting(name)
if (
(left_type, right_type) in rv.types
or rv._udt_types is not None
and (left_type, right_type) in rv._udt_types
):
raise TypeError(
f"BinaryOp gb.binary.{name} already defined for "
f"({left_type}, {right_type}) input types"
)
else:
# We use `is_udt=True` to make dtype handling flexible and explicit.
rv = BinaryOp(name, is_udt=True)
gb_obj = ffi_new("GrB_BinaryOp*")
check_status_carg(
lib.GxB_BinaryOp_new(
Expand All @@ -67,6 +122,6 @@ def register_new(name, jit_c_definition, left_type, right_type, ret_type):
op = TypedJitBinaryOp(
rv, funcname, left_type, ret_type, gb_obj[0], jit_c_definition, dtype2=right_type
)
rv._add(op)
rv._add(op, is_jit=True)
setattr(module, funcname, rv)
return rv
Loading