diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 97bf22889..b1d264509 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,7 +46,7 @@ repos: # We can probably remove `isort` if we come to trust `ruff --fix`, # but we'll need to figure out the configuration to do this in `ruff` - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.1 hooks: - id: isort # Let's keep `pyupgrade` even though `ruff --fix` probably does most of it @@ -61,12 +61,12 @@ repos: - id: auto-walrus args: [--line-length, "100"] - repo: https://github.com/psf/black - rev: 23.10.1 + rev: 23.12.0 hooks: - id: black - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.4 + rev: v0.1.7 hooks: - id: ruff args: [--fix-only, --show-fixes] @@ -79,7 +79,7 @@ repos: additional_dependencies: &flake8_dependencies # These versions need updated manually - flake8==6.1.0 - - flake8-bugbear==23.9.16 + - flake8-bugbear==23.12.2 - flake8-simplify==0.21.0 - repo: https://github.com/asottile/yesqa rev: v1.5.0 @@ -94,11 +94,11 @@ repos: additional_dependencies: [tomli] files: ^(graphblas|docs)/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.4 + rev: v0.1.7 hooks: - id: ruff - repo: https://github.com/sphinx-contrib/sphinx-lint - rev: v0.8.1 + rev: v0.9.1 hooks: - id: sphinx-lint args: [--enable, all, "--disable=line-too-long,leaked-markup"] diff --git a/graphblas/core/operator/utils.py b/graphblas/core/operator/utils.py index cd0b82d3c..543df793e 100644 --- a/graphblas/core/operator/utils.py +++ b/graphblas/core/operator/utils.py @@ -75,6 +75,9 @@ def get_typed_op(op, dtype, dtype2=None, *, is_left_scalar=False, is_right_scala from .agg import Aggregator, TypedAggregator if isinstance(op, Aggregator): + # agg._any_dtype basically serves the same purpose as op._custom_dtype + if op._any_dtype is not None and op._any_dtype is not True: + return op[op._any_dtype] return op[dtype] if isinstance(op, TypedAggregator): return op diff --git a/graphblas/tests/test_scalar.py b/graphblas/tests/test_scalar.py index aeb19e170..3c7bffa9a 100644 --- a/graphblas/tests/test_scalar.py +++ b/graphblas/tests/test_scalar.py @@ -250,7 +250,7 @@ def test_update(s): def test_not_hashable(s): with pytest.raises(TypeError, match="unhashable type"): - {s} + _ = {s} with pytest.raises(TypeError, match="unhashable type"): hash(s) diff --git a/graphblas/tests/test_vector.py b/graphblas/tests/test_vector.py index 1c9a8d38c..8a2cd0824 100644 --- a/graphblas/tests/test_vector.py +++ b/graphblas/tests/test_vector.py @@ -948,6 +948,21 @@ def test_reduce_agg(v): assert s.is_empty +def test_reduce_agg_count_is_int64(v): + """Aggregators that count should default to INT64 return dtype.""" + assert v.dtype == dtypes.INT64 + res = v.reduce(agg.count).new() + assert res.dtype == dtypes.INT64 + assert res == 4 + res = v.dup(dtypes.INT8).reduce(agg.count).new() + assert res.dtype == dtypes.INT64 + assert res == 4 + # Allow return dtype to be specified + res = v.dup(dtypes.INT8).reduce(agg.count[dtypes.INT16]).new() + assert res.dtype == dtypes.INT16 + assert res == 4 + + @pytest.mark.skipif("not suitesparse") def test_reduce_agg_argminmax(v): assert v.reduce(agg.ss.argmin).new() == 6 diff --git a/scripts/check_versions.sh b/scripts/check_versions.sh index d197f2af2..db786b190 100755 --- a/scripts/check_versions.sh +++ b/scripts/check_versions.sh @@ -3,7 +3,7 @@ # Use, adjust, copy/paste, etc. as necessary to answer your questions. # This may be helpful when updating dependency versions in CI. # Tip: add `--json` for more information. -conda search 'flake8-bugbear[channel=conda-forge]>=23.9.16' +conda search 'flake8-bugbear[channel=conda-forge]>=23.12.2' conda search 'flake8-simplify[channel=conda-forge]>=0.21.0' conda search 'numpy[channel=conda-forge]>=1.26.0' conda search 'pandas[channel=conda-forge]>=2.1.2'