Skip to content
Merged
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ install:
- pip install -e .

script:
- pytest
- pytest --runslow

notifications:
email: false
7 changes: 6 additions & 1 deletion grblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

_init_params = None
_SPECIAL_ATTRS = ["lib", "ffi", "Matrix", "Vector", "Scalar",
"base", "exceptions", "matrix", "ops", "scalar", "vector"
"base", "exceptions", "matrix", "ops", "scalar", "vector",
"unary", "binary", "monoid", "semiring"]


Expand Down Expand Up @@ -76,3 +76,8 @@ def _init(backend, blocking, automatic=False):
ops.BinaryOp._initialize()
ops.Monoid._initialize()
ops.Semiring._initialize()

from .unary import numpy # noqa
from .binary import numpy # noqa
from .monoid import numpy # noqa
from .semiring import numpy # noqa
File renamed without changes.
70 changes: 70 additions & 0 deletions grblas/binary/numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
""" Create UDFs of numpy functions supported by numba.

See list of numpy ufuncs supported by numpy here:

https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#math-operations

"""
import numpy as np
from .. import ops

_binary_names = {
# Math operations
'add',
'subtract',
'multiply',
'divide',
'logaddexp',
'logaddexp2',
'true_divide',
'floor_divide',
'power',
'remainder',
'mod',
'fmod',
'gcd',
'lcm',

# Trigonometric functions
'arctan2',
'hypot',

# Bit-twiddling functions
'bitwise_and',
'bitwise_or',
'bitwise_xor',
'left_shift',
'right_shift',

# Comparison functions
'greater',
'greater_equal',
'less',
'less_equal',
'not_equal',
'equal',
'logical_and',
'logical_or',
'logical_xor',
'maximum',
'minimum',
'fmax',
'fmin',

# Floating functions
'copysign',
'nextafter',
'ldexp',
}


def __dir__():
return list(_binary_names)


def __getattr__(name):
if name not in _binary_names:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
numpy_func = getattr(np, name)
ops.BinaryOp.register_new(f'numpy.{name}', lambda x, y: numpy_func(x, y))
return globals()[name]
52 changes: 30 additions & 22 deletions grblas/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from . import lib
import numpy as np
import numba
from . import lib


class DataType:
Expand All @@ -15,9 +16,6 @@ def __init__(self, name, gb_type, c_type, numba_type):
def __repr__(self):
return self.name

def __hash__(self):
return hash((self.name, self.c_type))

def __eq__(self, other):
if isinstance(other, DataType):
return self.gb_type == other.gb_type
Expand Down Expand Up @@ -54,28 +52,34 @@ def from_pytype(cls, pytype):

# Used for testing user-defined functions
_sample_values = {
BOOL: True,
INT8: -3,
UINT8: 3,
INT16: -3,
UINT16: 3,
INT32: -3,
UINT32: 3,
INT64: -3,
UINT64: 3,
FP32: 3.14,
FP64: 3.14
INT8.name: np.int8(1),
UINT8.name: np.uint8(1),
INT16.name: np.int16(1),
UINT16.name: np.uint16(1),
INT32.name: np.int32(1),
UINT32.name: np.uint32(1),
INT64.name: np.int64(1),
UINT64.name: np.uint64(1),
FP32.name: np.float32(0.5),
FP64.name: np.float64(0.5),
BOOL.name: np.bool_(True),
}

# Create register to easily lookup types by name, gb_type, or c_type
_registry = {}
for x in _sample_values:
_registry[x.name] = x
_registry[x.gb_type] = x
_registry[x.c_type] = x
_registry[x.numba_type] = x
_registry[x.numba_type.name] = x
del x
for dtype in [BOOL, INT8, UINT8, INT16, UINT16, INT32, UINT32, INT64, UINT64, FP32, FP64]:
_registry[dtype.name] = dtype
_registry[dtype.gb_type] = dtype
_registry[dtype.c_type] = dtype
_registry[dtype.numba_type] = dtype
_registry[dtype.numba_type.name] = dtype
val = _sample_values[dtype.name]
_registry[val.dtype] = dtype
_registry[val.dtype.name] = dtype
# Upcast numpy float16 to float32
_registry[np.dtype(np.float16)] = FP32
_registry['float16'] = FP32

# Add some common Python types as lookup keys
_registry[int] = DataType.from_pytype(int)
_registry[float] = DataType.from_pytype(float)
Expand All @@ -92,6 +96,10 @@ def lookup(key):
if hasattr(key, 'name'):
return _registry[key.name]
else:
try:
return lookup(np.dtype(key))
except Exception:
pass
raise


Expand Down
27 changes: 7 additions & 20 deletions grblas/matrix.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from functools import partial
from .base import lib, ffi, GbContainer, GbDelayed
from .vector import Vector
from .vector import Vector, _generate_isclose
from .scalar import Scalar
from .ops import BinaryOp, find_opclass, find_return_type, reify_op
from . import dtypes, unary, binary, monoid, semiring
from . import dtypes, binary, monoid, semiring
from .exceptions import check_status, is_error, NoValue


Expand Down Expand Up @@ -57,7 +57,7 @@ def isclose(self, other, *, rel_tol=1e-7, abs_tol=0.0, check_dtype=False):
"""
Check for approximate equality (including same size and empty values)
If `check_dtype` is True, also checks that dtypes match
Closeness check is equivalent to `abs(a-b) <= max(rtol * max(abs(a), abs(b)), atol)`
Closeness check is equivalent to `abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol)`
"""
if not isinstance(other, Matrix):
return False
Expand All @@ -69,28 +69,15 @@ def isclose(self, other, *, rel_tol=1e-7, abs_tol=0.0, check_dtype=False):
return False
if self.nvals != other.nvals:
return False
if check_dtype:
common_dtype = self.dtype
else:
common_dtype = dtypes.unify(self.dtype, other.dtype)

matches = Matrix.new(bool, self.nrows, self.ncols)
tmp1 = self.apply(unary.abs).new(dtype=common_dtype)
tmp2 = other.apply(unary.abs).new(dtype=common_dtype)
tmp1 << tmp1.ewise_mult(tmp2, monoid.max)
isclose = _generate_isclose(rel_tol, abs_tol)
matches = self.ewise_mult(other, isclose).new(dtype=bool)
# ewise_mult performs intersection, so nvals will indicate mismatched empty values
if tmp1.nvals != self.nvals:
if matches.nvals != self.nvals:
return False
tmp1[:, :](mask=tmp1.S, accum=binary.times) << rel_tol
tmp1[:, :](mask=tmp1.S, accum=binary.max) << abs_tol
tmp2 << self.ewise_mult(other, binary.minus)
tmp2 << tmp2.apply(unary.abs)
matches << tmp2.ewise_mult(tmp1, binary.le[common_dtype])

# Check if all results are True
result = Scalar.new(bool)
result << matches.reduce_scalar(monoid.land)
return result.value
return matches.reduce_scalar(monoid.land).value

def __len__(self):
return self.nvals
Expand Down
File renamed without changes.
100 changes: 100 additions & 0 deletions grblas/monoid/numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
""" Create UDFs of numpy functions supported by numba.

See list of numpy ufuncs supported by numpy here:

https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#math-operations

"""
import numpy as np
from .. import ops, binary

_float_dtypes = {'FP32', 'FP64'}
_int_dtypes = {'INT8', 'UINT8', 'INT16', 'UINT16', 'INT32', 'UINT32', 'INT64', 'UINT64'}
_bool_int_dtypes = _int_dtypes | {'BOOL'}

_monoid_identities = {
# Math operations
'add': 0,
'multiply': 1,
'logaddexp': dict.fromkeys(_float_dtypes, -np.inf),
'logaddexp2': dict.fromkeys(_float_dtypes, -np.inf),
'gcd': dict.fromkeys(_int_dtypes, 0),

# Trigonometric functions
'hypot': dict.fromkeys(_float_dtypes, 0.),

# Bit-twiddling functions
'bitwise_and': {dtype: True if dtype == 'BOOL' else -1 for dtype in _bool_int_dtypes},
'bitwise_or': dict.fromkeys(_bool_int_dtypes, 0),
'bitwise_xor': dict.fromkeys(_bool_int_dtypes, 0),

# Comparison functions
# 'equal': {'BOOL': True}, # Not yet supported
# 'logical_and': {'BOOL': True}, # Not yet supported
# 'logical_or': {'BOOL': True}, # Not yet supported
# 'logical_xor': {'BOOL': False}, # Not yet supported
'maximum': {
'BOOL': False,
'INT8': np.iinfo(np.int8).min,
'UINT8': 0,
'INT16': np.iinfo(np.int16).min,
'UINT16': 0,
'INT32': np.iinfo(np.int32).min,
'UINT32': 0,
'INT64': np.iinfo(np.int64).min,
'UINT64': 0,
'FP32': -np.inf,
'FP64': -np.inf,
},
'minimum': {
'BOOL': True,
'INT8': np.iinfo(np.int8).max,
'UINT8': np.iinfo(np.uint8).max,
'INT16': np.iinfo(np.int16).max,
'UINT16': np.iinfo(np.uint16).max,
'INT32': np.iinfo(np.int32).max,
'UINT32': np.iinfo(np.uint32).max,
'INT64': np.iinfo(np.int64).max,
'UINT64': np.iinfo(np.uint64).max,
'FP32': np.inf,
'FP64': np.inf,
},
'fmax': {
'BOOL': False,
'INT8': np.iinfo(np.int8).min,
'UINT8': 0,
'INT16': np.iinfo(np.int8).min,
'UINT16': 0,
'INT32': np.iinfo(np.int8).min,
'UINT32': 0,
'INT64': np.iinfo(np.int8).min,
'UINT64': 0,
'FP32': -np.inf, # or np.nan?
'FP64': -np.inf, # or np.nan?
},
'fmin': {
'BOOL': True,
'INT8': np.iinfo(np.int8).max,
'UINT8': np.iinfo(np.uint8).max,
'INT16': np.iinfo(np.int16).max,
'UINT16': np.iinfo(np.uint16).max,
'INT32': np.iinfo(np.int32).max,
'UINT32': np.iinfo(np.uint32).max,
'INT64': np.iinfo(np.int64).max,
'UINT64': np.iinfo(np.uint64).max,
'FP32': np.inf, # or np.nan?
'FP64': np.inf, # or np.nan?
},
}


def __dir__():
return list(_monoid_identities)


def __getattr__(name):
if name not in _monoid_identities:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
func = getattr(binary.numpy, name)
ops.Monoid.register_new(f'numpy.{name}', func, _monoid_identities[name])
return globals()[name]
Loading