Skip to content

Add A.setdiag(x, k) #493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions .github/workflows/test_and_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,17 @@ jobs:
sparsever=$(python -c 'import random ; print(random.choice(["=0.13", "=0.14", ""]))')
fmmver=$(python -c 'import random ; print(random.choice(["=1.4", "=1.5", "=1.6", "=1.7", ""]))')
if [[ ${{ startsWith(steps.pyver.outputs.selected, '3.9') }} == true ]]; then
npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", "=1.24", "=1.25", ""]))')
npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", "=1.24", "=1.25", "=1.26", ""]))')
spver=$(python -c 'import random ; print(random.choice(["=1.8", "=1.9", "=1.10", "=1.11", ""]))')
pdver=$(python -c 'import random ; print(random.choice(["=1.2", "=1.3", "=1.4", "=1.5", "=2.0", "=2.1", ""]))')
akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", "=2.3", "=2.4", ""]))')
elif [[ ${{ startsWith(steps.pyver.outputs.selected, '3.10') }} == true ]]; then
npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", "=1.24", "=1.25", ""]))')
npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", "=1.24", "=1.25", "=1.26", ""]))')
spver=$(python -c 'import random ; print(random.choice(["=1.8", "=1.9", "=1.10", "=1.11", ""]))')
pdver=$(python -c 'import random ; print(random.choice(["=1.3", "=1.4", "=1.5", "=2.0", "=2.1", ""]))')
akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", "=2.3", "=2.4", ""]))')
else # Python 3.11
npver=$(python -c 'import random ; print(random.choice(["=1.23", "=1.24", "=1.25", ""]))')
npver=$(python -c 'import random ; print(random.choice(["=1.23", "=1.24", "=1.25", "=1.26", ""]))')
spver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=1.11", ""]))')
pdver=$(python -c 'import random ; print(random.choice(["=1.5", "=2.0", "=2.1", ""]))')
akver=$(python -c 'import random ; print(random.choice(["=1.10", "=2.0", "=2.1", "=2.2", "=2.3", "=2.4", ""]))')
Expand All @@ -206,7 +206,7 @@ jobs:
else
psgver=""
fi
if [[ ${npver} == "=1.25" ]] ; then
if [[ ${npver} == "=1.25" || ${npver} == "=1.26" ]] ; then
numbaver=""
if [[ ${spver} == "=1.8" ]] ; then
spver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=1.11", ""]))')
Expand Down Expand Up @@ -243,7 +243,8 @@ jobs:
pdver=""
yamlver=""
fi
elif [[ ${npver} == "=1.25" ]] ; then
elif [[ ${npver} == "=1.25" || ${npver} == "=1.26" ]] ; then
# Don't install numba for unsupported versions of numpy
numba=""
numbaver=NA
sparse=""
Expand Down
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ repos:
- id: isort
# Let's keep `pyupgrade` even though `ruff --fix` probably does most of it
- repo: https://github.com/asottile/pyupgrade
rev: v3.10.1
rev: v3.12.0
hooks:
- id: pyupgrade
args: [--py39-plus]
Expand All @@ -61,12 +61,12 @@ repos:
- id: auto-walrus
args: [--line-length, "100"]
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.9.1
hooks:
- id: black
- id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.287
rev: v0.0.290
hooks:
- id: ruff
args: [--fix-only, --show-fixes]
Expand All @@ -79,7 +79,7 @@ repos:
additional_dependencies: &flake8_dependencies
# These versions need updated manually
- flake8==6.1.0
- flake8-bugbear==23.7.10
- flake8-bugbear==23.9.16
- flake8-simplify==0.20.0
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
Expand All @@ -94,7 +94,7 @@ repos:
additional_dependencies: [tomli]
files: ^(graphblas|docs)/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.287
rev: v0.0.290
hooks:
- id: ruff
- repo: https://github.com/sphinx-contrib/sphinx-lint
Expand Down
113 changes: 113 additions & 0 deletions graphblas/core/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2805,6 +2805,119 @@ def power(self, n, op=semiring.plus_times):
dtype=self.dtype,
)

def setdiag(self, values, k=0, *, mask=None, accum=None, **opts):
"""Set k'th diagonal with a Scalar, Vector, or array.

This is not a built-in GraphBLAS operation. It is implemented as a recipe.

Parameters
----------
values : Vector or list or np.ndarray or scalar
New values to assign to the diagonal. The length of Vector and array
values must match the size of the diagonal being assigned to.
k : int, default=0
Which diagonal or off-diagonal to set. For example, set the elements
``A[i, i+k] = values[i]``. The default, k=0, is the main diagonal.
mask : Mask, optional
Vector or Matrix Mask to control which diagonal elements to set.
If it is Matrix Mask, then only the diagonal is used as the mask.
accum : Monoid or BinaryOp, optional
Operator to use to combine existing diagonal values and new values.
"""
if (K := maybe_integral(k)) is None:
raise TypeError(f"k must be an integer; got bad type: {type(k)}")
k = K
if k < 0:
if (size := min(self._nrows + k, self._ncols)) <= 0 and k <= -self._nrows:
raise IndexError(
f"k={k} is too small; the k'th diagonal is out of range. "
f"Valid k for Matrix with shape {self._nrows}x{self._ncols}: "
f"{-self._nrows} {'<' if self._nrows else '<='} k "
f"{'<' if self._ncols else '<='} {self._ncols}"
)
elif (size := min(self._ncols - k, self._nrows)) <= 0 and k > 0 and k >= self._ncols:
raise IndexError(
f"k={k} is too large; the k'th diagonal is out of range. "
f"Valid k for Matrix with shape {self._nrows}x{self._ncols}: "
f"{-self._nrows} {'<' if self._nrows else '<='} k "
f"{'<' if self._ncols else '<='} {self._ncols}"
)

# Convert `values` to Vector if necessary (i.e., it's scalar or array)
is_scalar = clear_diag = False
if output_type(values) is Vector:
v = values
clear_diag = accum is None and v._nvals != v._size
elif type(values) is Scalar:
is_scalar = True
else:
dtype = self.dtype if self.dtype._is_udt else None
try:
# Try to make it a Scalar
values = Scalar.from_value(values, dtype, is_cscalar=None, name="")
is_scalar = True
except (TypeError, ValueError):
try:
# Else try to make it a numpy array
values, dtype = values_to_numpy_buffer(values, dtype)
except Exception:
self._expect_type(
values,
(Scalar, Vector, np.ndarray),
within="setdiag",
argname="values",
extra_message="Literal scalars also accepted.",
)
else:
v = Vector.from_dense(values, dtype=dtype, **opts)

if is_scalar:
v = Vector.from_scalar(values, size, **opts)
elif v._size != size:
raise DimensionMismatch(
f"Dimensions not compatible for assigning length {v._size} Vector "
f"to {k}'th diagonal of Matrix with shape {self._nrows}x{self._ncols}."
f"The Vector should be size {size}."
)

if mask is not None:
mask = _check_mask(mask)
if mask.parent.ndim == 2:
if mask.parent.shape != self.shape:
raise DimensionMismatch(
"Matrix mask in setdiag is the wrong shape; "
f"expected shape {self._nrows}x{self._ncols}, "
f"got {mask.parent._nrows}x{mask.parent._ncols}"
)
if mask.complement:
mval = type(mask)(mask.parent.diag(k)).new(**opts)
mask = mval.S
M = mval.diag()
else:
M = select.diag(mask.parent, k).new(**opts)
elif mask.parent._size != size:
raise DimensionMismatch(
"Vector mask in setdiag is the wrong length; "
f"expected size {size}, got size {mask.parent._size}."
)
else:
if mask.complement:
mask = mask.new(**opts).S
M = mask.parent.diag()
if M.shape != self.shape:
M.resize(self._nrows, self._ncols)
mask = type(mask)(M)

if clear_diag:
self(mask=mask, **opts) << select.offdiag(self, k)

Diag = v.diag(k)
if Diag.shape != self.shape:
Diag.resize(self._nrows, self._ncols)
if mask is None:
mask = Diag.S
self(accum=accum, mask=mask, **opts) << Diag
Comment on lines +2914 to +2919
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For posterity and those who browse this PR, these lines are the main recipe for setting the diagonal of a Matrix with a Vector. The rest of the function mostly deals with handing different input types and giving good error messages when necessary.


##################################
# Extract and Assign index methods
##################################
Expand Down
136 changes: 135 additions & 1 deletion graphblas/tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2940,6 +2940,7 @@ def test_expr_is_like_matrix(A):
"from_scalar",
"from_values",
"resize",
"setdiag",
"update",
}
ignore = {"__sizeof__"}
Expand Down Expand Up @@ -3002,9 +3003,10 @@ def test_index_expr_is_like_matrix(A):
"from_dense",
"from_dicts",
"from_edgelist",
"from_values",
"from_scalar",
"from_values",
"resize",
"setdiag",
}
ignore = {"__sizeof__"}
assert attrs - expr_attrs - ignore == expected, (
Expand Down Expand Up @@ -4393,3 +4395,135 @@ def test_power(A):
B = A[:2, :3].new()
with pytest.raises(DimensionMismatch):
B.power(2)


def test_setdiag():
A = Matrix(int, 2, 3)
A.setdiag(1)
expected = Matrix(int, 2, 3)
expected[0, 0] = 1
expected[1, 1] = 1
assert A.isequal(expected)
A.setdiag(Scalar.from_value(2), 2)
expected[0, 2] = 2
assert A.isequal(expected)
A.setdiag(3, k=-1)
expected[1, 0] = 3
assert A.isequal(expected)
# List (or array) is treated as dense
A.setdiag([10, 20], 1)
expected[0, 1] = 10
expected[1, 2] = 20
assert A.isequal(expected)
# Size 0 diagonals, which does not set anything.
# This could be valid (esp. given a size 0 vector), but let's raise for now.
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag(-1, 3)
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag(-1, -2)
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag([], 3)
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag(Vector(int, 0), -2)
# Now we're definitely out of bounds
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag(-1, 4)
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag(-1, -3)
with pytest.raises(TypeError, match="k must be an integer"):
A.setdiag(-1, 0.5)
with pytest.raises(TypeError, match="Bad type for argument `values` in Matrix.setdiag"):
A.setdiag(object())
with pytest.raises(DimensionMismatch, match="Dimensions not compatible"):
A.setdiag([10, 20, 30], 1)
with pytest.raises(DimensionMismatch, match="Dimensions not compatible"):
A.setdiag([10], 1)

# Special care for dimensions of length 0
A = Matrix(int, 0, 2, name="A")
A.setdiag(0, 0)
A.setdiag(0, 1)
A.setdiag([], 0)
A.setdiag([], 1)
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag(0, -1)
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag([], -1)
A = Matrix(int, 2, 0, name="A")
A.setdiag(0, 0)
A.setdiag(0, -1)
A.setdiag([], 0)
A.setdiag([], -1)
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag(0, 1)
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag([], 1)
A = Matrix(int, 0, 0, name="A")
A.setdiag(0, 0)
A.setdiag([], 0)
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag(0, 1)
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag([], 1)
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag(0, -1)
with pytest.raises(IndexError, match="diagonal is out of range"):
A.setdiag([], -1)

A = Matrix(int, 2, 2, name="A")
expected = Matrix(int, 2, 2, name="expected")
v = Vector(int, 2, name="v")
Vector(int, 2)
v[0] = 1
A.setdiag(v)
expected[0, 0] = 1
assert A.isequal(expected)
A.setdiag(v, accum=binary.plus)
expected[0, 0] = 2
assert A.isequal(expected)
A.setdiag(10, mask=v.S)
expected[0, 0] = 10
assert A.isequal(expected)
A.setdiag(10, mask=v.S, accum="+")
expected[0, 0] = 20
assert A.isequal(expected)
# Allow mask to be a matrix
A.setdiag(10, mask=A.S, accum="+")
expected[0, 0] = 30
assert A.isequal(expected)
# Test how to clear or not clear missing elements
A.clear()
A.setdiag(99)
A.setdiag(v)
expected[0, 0] = 1
assert A.isequal(expected)
A.setdiag(99)
A.setdiag(v, accum="second")
expected[1, 1] = 99
assert A.isequal(expected)
A.setdiag(99)
A.setdiag(v, mask=v.S)
assert A.isequal(expected)

# We handle complemented masks!
A.clear()
expected.clear()
A.setdiag(42, mask=~v.S)
expected[1, 1] = 42
assert A.isequal(expected)
A.setdiag(7, mask=~A.V)
expected[0, 0] = 7
assert A.isequal(expected)

with pytest.raises(DimensionMismatch, match="Matrix mask in setdiag is the wrong "):
A.setdiag(9, mask=Matrix(int, 3, 3).S)
with pytest.raises(DimensionMismatch, match="Vector mask in setdiag is the wrong "):
A.setdiag(10, mask=Vector(int, 3).S)

A.clear()
A.resize(2, 3)
expected.clear()
expected.resize(2, 3)
A.setdiag(30, mask=v.S)
expected[0, 0] = 30
assert A.isequal(expected)
5 changes: 5 additions & 0 deletions graphblas/tests/test_ssjit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pathlib
import sys

import numpy as np
Expand Down Expand Up @@ -82,6 +83,10 @@ def _setup_jit():
gb.ss.config["jit_c_libraries"] = ""
gb.ss.config["jit_c_cmake_libs"] = ""

if not pathlib.Path(gb.ss.config["jit_c_compiler_name"]).exists():
# Can't use the JIT if we don't have a compiler!
gb.ss.config["jit_c_control"] = "off"


@pytest.fixture
def v():
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ ignore-words-list = "coo,ba"
# https://github.com/charliermarsh/ruff/
line-length = 100
target-version = "py39"
unfixable = [
"F841" # unused-variable (Note: can leave useless expression)
]
select = [
# Have we enabled too many checks that they'll become a nuisance? We'll see...
"F", # pyflakes
Expand Down
Loading