Skip to content

Add tests for gb.ss.context #488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ repos:
- id: isort
# Let's keep `pyupgrade` even though `ruff --fix` probably does most of it
- repo: https://github.com/asottile/pyupgrade
rev: v3.9.0
rev: v3.10.1
hooks:
- id: pyupgrade
args: [--py38-plus]
Expand All @@ -66,19 +66,19 @@ repos:
- id: black
- id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.280
rev: v0.0.284
hooks:
- id: ruff
args: [--fix-only, --show-fixes]
# Let's keep `flake8` even though `ruff` does much of the same.
# `flake8-bugbear` and `flake8-simplify` have caught things missed by `ruff`.
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 6.1.0
hooks:
- id: flake8
additional_dependencies: &flake8_dependencies
# These versions need updated manually
- flake8==6.0.0
- flake8==6.1.0
- flake8-bugbear==23.7.10
- flake8-simplify==0.20.0
- repo: https://github.com/asottile/yesqa
Expand All @@ -94,11 +94,11 @@ repos:
additional_dependencies: [tomli]
files: ^(graphblas|docs)/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.280
rev: v0.0.284
hooks:
- id: ruff
- repo: https://github.com/sphinx-contrib/sphinx-lint
rev: v0.6.7
rev: v0.6.8
hooks:
- id: sphinx-lint
args: [--enable, all, "--disable=line-too-long,leaked-markup"]
Expand Down
4 changes: 2 additions & 2 deletions graphblas/core/operator/agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def __init__(
@property
def types(self):
if self._types is None:
if type(self._semiring) is str:
if isinstance(self._semiring, str):
self._semiring = semiring.from_string(self._semiring)
if type(self._types_orig[0]) is str: # pragma: no branch
if isinstance(self._types_orig[0], str): # pragma: no branch
self._types_orig[0] = semiring.from_string(self._types_orig[0])
self._types = _get_types(
self._types_orig, None if self._initval_orig is None else self._initdtype
Expand Down
2 changes: 1 addition & 1 deletion graphblas/core/operator/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def monoid(self):

@property
def commutes_to(self):
if type(self._commutes_to) is str:
if isinstance(self._commutes_to, str):
self._commutes_to = BinaryOp._find(self._commutes_to)
return self._commutes_to

Expand Down
2 changes: 1 addition & 1 deletion graphblas/core/operator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _from_string(string, module, mapping, example):
)
if base in mapping:
op = mapping[base]
if type(op) is str:
if isinstance(op, str):
op = mapping[base] = module.from_string(op)
elif hasattr(module, base):
op = getattr(module, base)
Expand Down
1 change: 1 addition & 0 deletions graphblas/core/ss/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def disengage(self):

def __enter__(self):
self.engage()
return self

def __exit__(self, exc_type, exc, exc_tb):
self.disengage()
Expand Down
2 changes: 1 addition & 1 deletion graphblas/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def libget(name):
try:
return getattr(lib, name)
except AttributeError:
if name[-4:] not in {"FC32", "FC64", "error"}:
if name[-4:] not in {"FC32", "FC64", "rror"}:
raise
ext_name = f"GxB_{name[4:]}"
try:
Expand Down
2 changes: 1 addition & 1 deletion graphblas/monoid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __dir__():
def __getattr__(key):
if key in _delayed:
func, kwargs = _delayed.pop(key)
if type(kwargs["binaryop"]) is str:
if isinstance(kwargs["binaryop"], str):
from ..binary import from_string

kwargs["binaryop"] = from_string(kwargs["binaryop"])
Expand Down
5 changes: 2 additions & 3 deletions graphblas/monoid/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@
if (
_config.get("mapnumpy")
or _has_numba
and type(_numba.njit(lambda x, y: _np.fmax(x, y))(1, 2)) # pragma: no branch (numba)
is not float
and not isinstance(_numba.njit(lambda x, y: _np.fmax(x, y))(1, 2), float) # pragma: no branch
):
# Incorrect behavior was introduced in numba 0.56.2 and numpy 1.23
# See: https://github.com/numba/numba/issues/8478
Expand Down Expand Up @@ -170,7 +169,7 @@ def __dir__():
def __getattr__(name):
if name in _delayed:
func, kwargs = _delayed.pop(name)
if type(kwargs["binaryop"]) is str:
if isinstance(kwargs["binaryop"], str):
from ..binary import from_string

kwargs["binaryop"] = from_string(kwargs["binaryop"])
Expand Down
4 changes: 2 additions & 2 deletions graphblas/semiring/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def __getattr__(key):
return rv
if key in _delayed:
func, kwargs = _delayed.pop(key)
if type(kwargs["binaryop"]) is str:
if isinstance(kwargs["binaryop"], str):
from ..binary import from_string

kwargs["binaryop"] = from_string(kwargs["binaryop"])
if type(kwargs["monoid"]) is str:
if isinstance(kwargs["monoid"], str):
from ..monoid import from_string

kwargs["monoid"] = from_string(kwargs["monoid"])
Expand Down
4 changes: 2 additions & 2 deletions graphblas/semiring/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,11 @@ def __getattr__(name):

if name in _delayed:
func, kwargs = _delayed.pop(name)
if type(kwargs["binaryop"]) is str:
if isinstance(kwargs["binaryop"], str):
from ..binary import from_string

kwargs["binaryop"] = from_string(kwargs["binaryop"])
if type(kwargs["monoid"]) is str:
if isinstance(kwargs["monoid"], str):
from ..monoid import from_string

kwargs["monoid"] = from_string(kwargs["monoid"])
Expand Down
2 changes: 1 addition & 1 deletion graphblas/tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3891,7 +3891,7 @@ def test_get(A):
assert compute(A.T.get(0, 1)) is None
assert A.T.get(1, 0) == 2
assert A.get(0, 1, "mittens") == 2
assert type(compute(A.get(0, 1))) is int
assert isinstance(compute(A.get(0, 1)), int)
with pytest.raises(ValueError, match="Bad row, col"):
# Not yet supported
A.get(0, [0, 1])
Expand Down
8 changes: 4 additions & 4 deletions graphblas/tests/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ def test_equal(s):

def test_casting(s):
assert int(s) == 5
assert type(int(s)) is int
assert isinstance(int(s), int)
assert float(s) == 5.0
assert type(float(s)) is float
assert isinstance(float(s), float)
assert range(s) == range(5)
with pytest.raises(AttributeError, match="Scalar .* only .*__index__.*integral"):
range(s.dup(float))
assert complex(s) == complex(5)
assert type(complex(s)) is complex
assert isinstance(complex(s), complex)


def test_truthy(s):
Expand Down Expand Up @@ -580,7 +580,7 @@ def test_record_from_dict():
def test_get(s):
assert s.get() == 5
assert s.get("mittens") == 5
assert type(compute(s.get())) is int
assert isinstance(compute(s.get()), int)
s.clear()
assert compute(s.get()) is None
assert s.get("mittens") == "mittens"
Expand Down
60 changes: 60 additions & 0 deletions graphblas/tests/test_ss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import graphblas as gb
from graphblas import Matrix, Vector, backend
from graphblas.exceptions import InvalidValue

if backend != "suitesparse":
pytest.skip("gb.ss and A.ss only available with suitesparse backend", allow_module_level=True)
Expand Down Expand Up @@ -234,3 +235,62 @@ def test_global_config():
with pytest.raises(ValueError, match="Wrong number"):
config["memory_pool"] = [1, 2]
assert "format" in repr(config)


@pytest.mark.skipif("gb.core.ss._IS_SSGB7")
def test_context():
context = gb.ss.Context()
prev = dict(context)
context["chunk"] += 1
context["nthreads"] += 1
assert context["chunk"] == prev["chunk"] + 1
assert context["nthreads"] == prev["nthreads"] + 1
context2 = gb.ss.Context(stack=True)
assert context2 == context
context3 = gb.ss.Context(stack=False)
assert context3 == prev
context4 = gb.ss.Context(
chunk=context["chunk"] + 1, nthreads=context["nthreads"] + 1, stack=False
)
assert context4["chunk"] == context["chunk"] + 1
assert context4["nthreads"] == context["nthreads"] + 1
assert context == context.dup()
assert context4 == context.dup(chunk=context["chunk"] + 1, nthreads=context["nthreads"] + 1)
assert context.dup(gpu_id=-1)["gpu_id"] == -1

context.engage()
assert gb.core.ss.context.threadlocal.context is context
with gb.ss.Context(nthreads=1) as ctx:
assert gb.core.ss.context.threadlocal.context is ctx
v = Vector(int, 5)
v(nthreads=2) << v + v
assert gb.core.ss.context.threadlocal.context is ctx
assert gb.core.ss.context.threadlocal.context is context
with pytest.raises(InvalidValue):
# Wait, why does this raise?!
ctx.disengage()
assert gb.core.ss.context.threadlocal.context is context
context.disengage()
assert gb.core.ss.context.threadlocal.context is gb.core.ss.context.global_context
assert context._prev_context is None

# hackery
gb.core.ss.context.threadlocal.context = context
context.disengage()
context.disengage()
context.disengage()
assert gb.core.ss.context.threadlocal.context is gb.core.ss.context.global_context

# Actually engaged, but not set in threadlocal
context._engage()
assert gb.core.ss.context.threadlocal.context is gb.core.ss.context.global_context
context.disengage()

context.engage()
context._engage()
assert gb.core.ss.context.threadlocal.context is context
context.disengage()

context._context = context # This is allowed to work with config
with pytest.raises(AttributeError, match="_context"):
context._context = ctx # This is not
2 changes: 1 addition & 1 deletion graphblas/tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2440,7 +2440,7 @@ def test_get(v):
assert v.get(0, "mittens") == "mittens"
assert v.get(1) == 1
assert v.get(1, "mittens") == 1
assert type(compute(v.get(1))) is int
assert isinstance(compute(v.get(1)), int)
with pytest.raises(ValueError, match="Bad index in Vector.get"):
# Not yet supported
v.get([0, 1])
Expand Down
4 changes: 2 additions & 2 deletions scripts/check_versions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# Use, adjust, copy/paste, etc. as necessary to answer your questions.
# This may be helpful when updating dependency versions in CI.
# Tip: add `--json` for more information.
conda search 'numpy[channel=conda-forge]>=1.25.1'
conda search 'numpy[channel=conda-forge]>=1.25.2'
conda search 'pandas[channel=conda-forge]>=2.0.3'
conda search 'scipy[channel=conda-forge]>=1.11.1'
conda search 'networkx[channel=conda-forge]>=3.1'
conda search 'awkward[channel=conda-forge]>=2.3.1'
conda search 'awkward[channel=conda-forge]>=2.3.2'
conda search 'sparse[channel=conda-forge]>=0.14.0'
conda search 'fast_matrix_market[channel=conda-forge]>=1.7.2'
conda search 'numba[channel=conda-forge]>=0.57.1'
Expand Down