diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3766e2e7c..97bf22889 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -66,7 +66,7 @@ repos: - id: black - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.3 + rev: v0.1.4 hooks: - id: ruff args: [--fix-only, --show-fixes] @@ -94,7 +94,7 @@ repos: additional_dependencies: [tomli] files: ^(graphblas|docs)/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.3 + rev: v0.1.4 hooks: - id: ruff - repo: https://github.com/sphinx-contrib/sphinx-lint diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index 5e1a76720..3e5f95f0c 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -94,6 +94,10 @@ def _reposition(updater, indices, chunk): def _power(updater, A, n, op): opts = updater.opts + if n == 0: + v = Vector.from_scalar(op.binaryop.monoid.identity, A._nrows, A.dtype, name="v_diag") + updater << v.diag(name="M_diag") + return if n == 1: updater << A return @@ -2773,7 +2777,11 @@ def power(self, n, op=semiring.plus_times): Parameters ---------- n : int - The exponent must be a positive integer. + The exponent must be a nonnegative integer. If n=0, the result will be a diagonal + matrix with values equal to the identity of the semiring's binary operator. + For example, ``plus_times`` will have diagonal values of 1, which is the + identity of ``times``. The binary operator must be associated with a monoid + when n=0 so the identity can be determined; otherwise, ValueError is raised. op : :class:`~graphblas.core.operator.Semiring` Semiring used in the computation @@ -2801,11 +2809,17 @@ def power(self, n, op=semiring.plus_times): if self._nrows != self._ncols: raise DimensionMismatch(f"power only works for square Matrix; shape is {self.shape}") if (N := maybe_integral(n)) is None: - raise TypeError(f"n must be a positive integer; got bad type: {type(n)}") - if N <= 0: - raise ValueError(f"n must be a positive integer; got: {N}") + raise TypeError(f"n must be a nonnegative integer; got bad type: {type(n)}") + if N < 0: + raise ValueError(f"n must be a nonnegative integer; got: {N}") op = get_typed_op(op, self.dtype, kind="semiring") self._expect_op(op, "Semiring", within=method_name, argname="op") + if N == 0 and op.binaryop.monoid is None: + raise ValueError( + f"Binary operator of {op} semiring does not have a monoid with an identity. " + "When n=0, the result is a diagonal matrix with values equal to the " + "identity of the binaryop, so the binaryop must be associated with a monoid." + ) return MatrixExpression( "power", None, diff --git a/graphblas/tests/test_matrix.py b/graphblas/tests/test_matrix.py index 3f66e46ef..b62f6dc26 100644 --- a/graphblas/tests/test_matrix.py +++ b/graphblas/tests/test_matrix.py @@ -4391,14 +4391,22 @@ def test_power(A): result = A.power(i, semiring.min_plus).new() assert result.isequal(expected) expected << semiring.min_plus(A @ expected) + # n == 0 + result = A.power(0).new() + expected = Vector.from_scalar(1, A.nrows, A.dtype).diag() + assert result.isequal(expected) + result = A.power(0, semiring.plus_min).new() + identity = semiring.plus_min[A.dtype].binaryop.monoid.identity + assert identity != 1 + expected = Vector.from_scalar(identity, A.nrows, A.dtype).diag() + assert result.isequal(expected) # Exceptional - with pytest.raises(TypeError, match="must be a positive integer"): + with pytest.raises(TypeError, match="must be a nonnegative integer"): A.power(1.5) - with pytest.raises(ValueError, match="must be a positive integer"): + with pytest.raises(ValueError, match="must be a nonnegative integer"): A.power(-1) - with pytest.raises(ValueError, match="must be a positive integer"): - # Not implemented yet... could create identity matrix - A.power(0) + with pytest.raises(ValueError, match="binaryop must be associated with a monoid"): + A.power(0, semiring.min_first) B = A[:2, :3].new() with pytest.raises(DimensionMismatch): B.power(2)