From 0e78d7e84a8bfcce58e5426268078b1951ff4354 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Wed, 2 Aug 2023 09:09:24 -0500 Subject: [PATCH 1/4] Add tests for `gb.ss.context` --- .pre-commit-config.yaml | 10 +++--- graphblas/core/ss/context.py | 1 + graphblas/core/utils.py | 2 +- graphblas/tests/test_ss_utils.py | 59 ++++++++++++++++++++++++++++++++ 4 files changed, 66 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fef625a70..b1f93fd36 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: - id: isort # Let's keep `pyupgrade` even though `ruff --fix` probably does most of it - repo: https://github.com/asottile/pyupgrade - rev: v3.9.0 + rev: v3.10.1 hooks: - id: pyupgrade args: [--py38-plus] @@ -66,19 +66,19 @@ repos: - id: black - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.280 + rev: v0.0.282 hooks: - id: ruff args: [--fix-only, --show-fixes] # Let's keep `flake8` even though `ruff` does much of the same. # `flake8-bugbear` and `flake8-simplify` have caught things missed by `ruff`. - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 + rev: 6.1.0 hooks: - id: flake8 additional_dependencies: &flake8_dependencies # These versions need updated manually - - flake8==6.0.0 + - flake8==6.1.0 - flake8-bugbear==23.7.10 - flake8-simplify==0.20.0 - repo: https://github.com/asottile/yesqa @@ -94,7 +94,7 @@ repos: additional_dependencies: [tomli] files: ^(graphblas|docs)/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.280 + rev: v0.0.282 hooks: - id: ruff - repo: https://github.com/sphinx-contrib/sphinx-lint diff --git a/graphblas/core/ss/context.py b/graphblas/core/ss/context.py index 9b48bcaa4..f93d1ec1c 100644 --- a/graphblas/core/ss/context.py +++ b/graphblas/core/ss/context.py @@ -111,6 +111,7 @@ def disengage(self): def __enter__(self): self.engage() + return self def __exit__(self, exc_type, exc, exc_tb): self.disengage() diff --git a/graphblas/core/utils.py b/graphblas/core/utils.py index 7bb1a1fb0..42fcf0685 100644 --- a/graphblas/core/utils.py +++ b/graphblas/core/utils.py @@ -11,7 +11,7 @@ def libget(name): try: return getattr(lib, name) except AttributeError: - if name[-4:] not in {"FC32", "FC64", "error"}: + if name[-4:] not in {"FC32", "FC64", "rror"}: raise ext_name = f"GxB_{name[4:]}" try: diff --git a/graphblas/tests/test_ss_utils.py b/graphblas/tests/test_ss_utils.py index 12c8c6329..6b15246e1 100644 --- a/graphblas/tests/test_ss_utils.py +++ b/graphblas/tests/test_ss_utils.py @@ -4,6 +4,7 @@ import graphblas as gb from graphblas import Matrix, Vector, backend +from graphblas.exceptions import InvalidValue if backend != "suitesparse": pytest.skip("gb.ss and A.ss only available with suitesparse backend", allow_module_level=True) @@ -234,3 +235,61 @@ def test_global_config(): with pytest.raises(ValueError, match="Wrong number"): config["memory_pool"] = [1, 2] assert "format" in repr(config) + + +def test_context(): + context = gb.ss.Context() + prev = dict(context) + context["chunk"] += 1 + context["nthreads"] += 1 + assert context["chunk"] == prev["chunk"] + 1 + assert context["nthreads"] == prev["nthreads"] + 1 + context2 = gb.ss.Context(stack=True) + assert context2 == context + context3 = gb.ss.Context(stack=False) + assert context3 == prev + context4 = gb.ss.Context( + chunk=context["chunk"] + 1, nthreads=context["nthreads"] + 1, stack=False + ) + assert context4["chunk"] == context["chunk"] + 1 + assert context4["nthreads"] == context["nthreads"] + 1 + assert context == context.dup() + assert context4 == context.dup(chunk=context["chunk"] + 1, nthreads=context["nthreads"] + 1) + assert context.dup(gpu_id=-1)["gpu_id"] == -1 + + context.engage() + assert gb.core.ss.context.threadlocal.context is context + with gb.ss.Context(nthreads=1) as ctx: + assert gb.core.ss.context.threadlocal.context is ctx + v = Vector(int, 5) + v(nthreads=2) << v + v + assert gb.core.ss.context.threadlocal.context is ctx + assert gb.core.ss.context.threadlocal.context is context + with pytest.raises(InvalidValue): + # Wait, why does this raise?! + ctx.disengage() + assert gb.core.ss.context.threadlocal.context is context + context.disengage() + assert gb.core.ss.context.threadlocal.context is gb.core.ss.context.global_context + assert context._prev_context is None + + # hackery + gb.core.ss.context.threadlocal.context = context + context.disengage() + context.disengage() + context.disengage() + assert gb.core.ss.context.threadlocal.context is gb.core.ss.context.global_context + + # Actually engaged, but not set in threadlocal + context._engage() + assert gb.core.ss.context.threadlocal.context is gb.core.ss.context.global_context + context.disengage() + + context.engage() + context._engage() + assert gb.core.ss.context.threadlocal.context is context + context.disengage() + + context._context = context # This is allowed to work with config + with pytest.raises(AttributeError, match="_context"): + context._context = ctx # This is not From 13a9e9ee359e93b790768807c5ac03a1c1a7c755 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Wed, 2 Aug 2023 09:54:38 -0500 Subject: [PATCH 2/4] oops --- graphblas/tests/test_ss_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/graphblas/tests/test_ss_utils.py b/graphblas/tests/test_ss_utils.py index 6b15246e1..dd50eacbd 100644 --- a/graphblas/tests/test_ss_utils.py +++ b/graphblas/tests/test_ss_utils.py @@ -237,6 +237,7 @@ def test_global_config(): assert "format" in repr(config) +@pytest.mark.skipif("not gb.core.ss._IS_SSGB7") def test_context(): context = gb.ss.Context() prev = dict(context) From 20184e0bde439fe77aa97c368630535aca59b56f Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Wed, 2 Aug 2023 18:35:14 -0500 Subject: [PATCH 3/4] haha oops --- graphblas/tests/test_ss_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphblas/tests/test_ss_utils.py b/graphblas/tests/test_ss_utils.py index dd50eacbd..81abe5804 100644 --- a/graphblas/tests/test_ss_utils.py +++ b/graphblas/tests/test_ss_utils.py @@ -237,7 +237,7 @@ def test_global_config(): assert "format" in repr(config) -@pytest.mark.skipif("not gb.core.ss._IS_SSGB7") +@pytest.mark.skipif("gb.core.ss._IS_SSGB7") def test_context(): context = gb.ss.Context() prev = dict(context) From 55f20f6f023108af43a0ccbabdb0bf863350d351 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Tue, 15 Aug 2023 10:54:14 -0500 Subject: [PATCH 4/4] bump ruff --- .pre-commit-config.yaml | 6 +++--- graphblas/core/operator/agg.py | 4 ++-- graphblas/core/operator/binary.py | 2 +- graphblas/core/operator/utils.py | 2 +- graphblas/monoid/__init__.py | 2 +- graphblas/monoid/numpy.py | 5 ++--- graphblas/semiring/__init__.py | 4 ++-- graphblas/semiring/numpy.py | 4 ++-- graphblas/tests/test_matrix.py | 2 +- graphblas/tests/test_scalar.py | 8 ++++---- graphblas/tests/test_vector.py | 2 +- scripts/check_versions.sh | 4 ++-- 12 files changed, 22 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1f93fd36..1e1eb502e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -66,7 +66,7 @@ repos: - id: black - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.282 + rev: v0.0.284 hooks: - id: ruff args: [--fix-only, --show-fixes] @@ -94,11 +94,11 @@ repos: additional_dependencies: [tomli] files: ^(graphblas|docs)/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.282 + rev: v0.0.284 hooks: - id: ruff - repo: https://github.com/sphinx-contrib/sphinx-lint - rev: v0.6.7 + rev: v0.6.8 hooks: - id: sphinx-lint args: [--enable, all, "--disable=line-too-long,leaked-markup"] diff --git a/graphblas/core/operator/agg.py b/graphblas/core/operator/agg.py index 09d644c32..6b463a8a6 100644 --- a/graphblas/core/operator/agg.py +++ b/graphblas/core/operator/agg.py @@ -76,9 +76,9 @@ def __init__( @property def types(self): if self._types is None: - if type(self._semiring) is str: + if isinstance(self._semiring, str): self._semiring = semiring.from_string(self._semiring) - if type(self._types_orig[0]) is str: # pragma: no branch + if isinstance(self._types_orig[0], str): # pragma: no branch self._types_orig[0] = semiring.from_string(self._types_orig[0]) self._types = _get_types( self._types_orig, None if self._initval_orig is None else self._initdtype diff --git a/graphblas/core/operator/binary.py b/graphblas/core/operator/binary.py index 88191c39b..77a686868 100644 --- a/graphblas/core/operator/binary.py +++ b/graphblas/core/operator/binary.py @@ -200,7 +200,7 @@ def monoid(self): @property def commutes_to(self): - if type(self._commutes_to) is str: + if isinstance(self._commutes_to, str): self._commutes_to = BinaryOp._find(self._commutes_to) return self._commutes_to diff --git a/graphblas/core/operator/utils.py b/graphblas/core/operator/utils.py index 00bc86cea..00df31db8 100644 --- a/graphblas/core/operator/utils.py +++ b/graphblas/core/operator/utils.py @@ -340,7 +340,7 @@ def _from_string(string, module, mapping, example): ) if base in mapping: op = mapping[base] - if type(op) is str: + if isinstance(op, str): op = mapping[base] = module.from_string(op) elif hasattr(module, base): op = getattr(module, base) diff --git a/graphblas/monoid/__init__.py b/graphblas/monoid/__init__.py index ed028c5d9..027fc0afe 100644 --- a/graphblas/monoid/__init__.py +++ b/graphblas/monoid/__init__.py @@ -10,7 +10,7 @@ def __dir__(): def __getattr__(key): if key in _delayed: func, kwargs = _delayed.pop(key) - if type(kwargs["binaryop"]) is str: + if isinstance(kwargs["binaryop"], str): from ..binary import from_string kwargs["binaryop"] = from_string(kwargs["binaryop"]) diff --git a/graphblas/monoid/numpy.py b/graphblas/monoid/numpy.py index f46d57143..5f6895e5d 100644 --- a/graphblas/monoid/numpy.py +++ b/graphblas/monoid/numpy.py @@ -90,8 +90,7 @@ if ( _config.get("mapnumpy") or _has_numba - and type(_numba.njit(lambda x, y: _np.fmax(x, y))(1, 2)) # pragma: no branch (numba) - is not float + and not isinstance(_numba.njit(lambda x, y: _np.fmax(x, y))(1, 2), float) # pragma: no branch ): # Incorrect behavior was introduced in numba 0.56.2 and numpy 1.23 # See: https://github.com/numba/numba/issues/8478 @@ -170,7 +169,7 @@ def __dir__(): def __getattr__(name): if name in _delayed: func, kwargs = _delayed.pop(name) - if type(kwargs["binaryop"]) is str: + if isinstance(kwargs["binaryop"], str): from ..binary import from_string kwargs["binaryop"] = from_string(kwargs["binaryop"]) diff --git a/graphblas/semiring/__init__.py b/graphblas/semiring/__init__.py index 538136406..95a44261a 100644 --- a/graphblas/semiring/__init__.py +++ b/graphblas/semiring/__init__.py @@ -46,11 +46,11 @@ def __getattr__(key): return rv if key in _delayed: func, kwargs = _delayed.pop(key) - if type(kwargs["binaryop"]) is str: + if isinstance(kwargs["binaryop"], str): from ..binary import from_string kwargs["binaryop"] = from_string(kwargs["binaryop"]) - if type(kwargs["monoid"]) is str: + if isinstance(kwargs["monoid"], str): from ..monoid import from_string kwargs["monoid"] = from_string(kwargs["monoid"]) diff --git a/graphblas/semiring/numpy.py b/graphblas/semiring/numpy.py index 3a59090cc..97b90874b 100644 --- a/graphblas/semiring/numpy.py +++ b/graphblas/semiring/numpy.py @@ -151,11 +151,11 @@ def __getattr__(name): if name in _delayed: func, kwargs = _delayed.pop(name) - if type(kwargs["binaryop"]) is str: + if isinstance(kwargs["binaryop"], str): from ..binary import from_string kwargs["binaryop"] = from_string(kwargs["binaryop"]) - if type(kwargs["monoid"]) is str: + if isinstance(kwargs["monoid"], str): from ..monoid import from_string kwargs["monoid"] = from_string(kwargs["monoid"]) diff --git a/graphblas/tests/test_matrix.py b/graphblas/tests/test_matrix.py index 80a66a524..2746f1314 100644 --- a/graphblas/tests/test_matrix.py +++ b/graphblas/tests/test_matrix.py @@ -3891,7 +3891,7 @@ def test_get(A): assert compute(A.T.get(0, 1)) is None assert A.T.get(1, 0) == 2 assert A.get(0, 1, "mittens") == 2 - assert type(compute(A.get(0, 1))) is int + assert isinstance(compute(A.get(0, 1)), int) with pytest.raises(ValueError, match="Bad row, col"): # Not yet supported A.get(0, [0, 1]) diff --git a/graphblas/tests/test_scalar.py b/graphblas/tests/test_scalar.py index cf4c6fd41..ba9903169 100644 --- a/graphblas/tests/test_scalar.py +++ b/graphblas/tests/test_scalar.py @@ -128,14 +128,14 @@ def test_equal(s): def test_casting(s): assert int(s) == 5 - assert type(int(s)) is int + assert isinstance(int(s), int) assert float(s) == 5.0 - assert type(float(s)) is float + assert isinstance(float(s), float) assert range(s) == range(5) with pytest.raises(AttributeError, match="Scalar .* only .*__index__.*integral"): range(s.dup(float)) assert complex(s) == complex(5) - assert type(complex(s)) is complex + assert isinstance(complex(s), complex) def test_truthy(s): @@ -580,7 +580,7 @@ def test_record_from_dict(): def test_get(s): assert s.get() == 5 assert s.get("mittens") == 5 - assert type(compute(s.get())) is int + assert isinstance(compute(s.get()), int) s.clear() assert compute(s.get()) is None assert s.get("mittens") == "mittens" diff --git a/graphblas/tests/test_vector.py b/graphblas/tests/test_vector.py index e321d3e9b..2571f288b 100644 --- a/graphblas/tests/test_vector.py +++ b/graphblas/tests/test_vector.py @@ -2440,7 +2440,7 @@ def test_get(v): assert v.get(0, "mittens") == "mittens" assert v.get(1) == 1 assert v.get(1, "mittens") == 1 - assert type(compute(v.get(1))) is int + assert isinstance(compute(v.get(1)), int) with pytest.raises(ValueError, match="Bad index in Vector.get"): # Not yet supported v.get([0, 1]) diff --git a/scripts/check_versions.sh b/scripts/check_versions.sh index ffa440c22..bb9059158 100755 --- a/scripts/check_versions.sh +++ b/scripts/check_versions.sh @@ -3,11 +3,11 @@ # Use, adjust, copy/paste, etc. as necessary to answer your questions. # This may be helpful when updating dependency versions in CI. # Tip: add `--json` for more information. -conda search 'numpy[channel=conda-forge]>=1.25.1' +conda search 'numpy[channel=conda-forge]>=1.25.2' conda search 'pandas[channel=conda-forge]>=2.0.3' conda search 'scipy[channel=conda-forge]>=1.11.1' conda search 'networkx[channel=conda-forge]>=3.1' -conda search 'awkward[channel=conda-forge]>=2.3.1' +conda search 'awkward[channel=conda-forge]>=2.3.2' conda search 'sparse[channel=conda-forge]>=0.14.0' conda search 'fast_matrix_market[channel=conda-forge]>=1.7.2' conda search 'numba[channel=conda-forge]>=0.57.1'