diff --git a/.codecov.yml b/.codecov.yml index 1720ac027..1894009c1 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,3 +1,16 @@ -comment: off +coverage: + status: + project: + default: + informational: true + patch: + default: + informational: true + changes: false +comment: + layout: "header, diff" + behavior: default +github_checks: + annotations: false ignore: - graphblas/viz.py diff --git a/.github/dependabot.yml b/.github/dependabot.yml index b18fd2935..5ace4600a 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,6 +1,6 @@ version: 2 updates: - - package-ecosystem: 'github-actions' - directory: '/' + - package-ecosystem: "github-actions" + directory: "/" schedule: - interval: 'weekly' + interval: "weekly" diff --git a/.github/workflows/debug.yml b/.github/workflows/debug.yml index 794746f77..6c2b202b1 100644 --- a/.github/workflows/debug.yml +++ b/.github/workflows/debug.yml @@ -5,7 +5,7 @@ on: workflow_dispatch: inputs: debug_enabled: - description: 'Run the build with tmate debugging enabled (https://github.com/marketplace/actions/debugging-with-tmate)' + description: "Run the build with tmate debugging enabled (https://github.com/marketplace/actions/debugging-with-tmate)" required: false default: false @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - pyver: [3.8] + pyver: [3.10] testopts: - "--blocking" # - "--non-blocking --record --runslow" @@ -26,9 +26,10 @@ jobs: # - "conda-forge" steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 + persist-credentials: false - name: Setup conda env run: | source "$CONDA/etc/profile.d/conda.sh" diff --git a/.github/workflows/imports.yml b/.github/workflows/imports.yml index a9863a213..e24d0d4db 100644 --- a/.github/workflows/imports.yml +++ b/.github/workflows/imports.yml @@ -7,29 +7,57 @@ on: - main jobs: - test_imports: + rngs: runs-on: ubuntu-latest - # strategy: - # matrix: - # python-version: ["3.8", "3.9", "3.10"] + outputs: + os: ${{ steps.os.outputs.selected }} + pyver: ${{ steps.pyver.outputs.selected }} steps: - - uses: actions/checkout@v3 + - name: RNG for os + uses: ddradar/choose-random-action@v3.0.0 + id: os + with: + contents: | + ubuntu-latest + macos-latest + windows-latest + weights: | + 1 + 1 + 1 - name: RNG for Python version - uses: ddradar/choose-random-action@v2.0.2 + uses: ddradar/choose-random-action@v3.0.0 id: pyver with: contents: | - 3.8 - 3.9 3.10 + 3.11 + 3.12 + 3.13 weights: | 1 1 1 - - uses: actions/setup-python@v4 + 1 + test_imports: + needs: rngs + runs-on: ${{ needs.rngs.outputs.os }} + # runs-on: ${{ matrix.os }} + # strategy: + # matrix: + # python-version: ["3.10", "3.11", "3.12", "3.13"] + # os: ["ubuntu-latest", "macos-latest", "windows-latest"] + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + persist-credentials: false + - uses: actions/setup-python@v5 with: - python-version: ${{ steps.pyver.outputs.selected }} + python-version: ${{ needs.rngs.outputs.pyver }} # python-version: ${{ matrix.python-version }} - run: python -m pip install --upgrade pip - - run: pip install -e . - - run: ./scripts/test_imports.sh + # - run: pip install --pre suitesparse-graphblas # Use if we need pre-release + - run: pip install -e .[default] + - name: Run test imports + run: ./scripts/test_imports.sh diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5ef2b1033..655a576e5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,10 +1,12 @@ +# Rely on pre-commit.ci instead name: Lint via pre-commit on: - pull_request: - push: - branches-ignore: - - main + workflow_dispatch: + # pull_request: + # push: + # branches-ignore: + # - main permissions: contents: read @@ -14,8 +16,11 @@ jobs: name: pre-commit-hooks runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + persist-credentials: false + - uses: actions/setup-python@v5 with: python-version: "3.10" - - uses: pre-commit/action@v3.0.0 + - uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml index 93a7e31c8..32926c5c8 100644 --- a/.github/workflows/publish_pypi.yml +++ b/.github/workflows/publish_pypi.yml @@ -3,7 +3,7 @@ name: Publish to PyPI on: push: tags: - - '20*' + - "20*" jobs: build_and_deploy: @@ -14,20 +14,21 @@ jobs: shell: bash -l {0} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 + persist-credentials: false - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: - python-version: "3.8" + python-version: "3.10" - name: Install build dependencies run: | python -m pip install --upgrade pip python -m pip install build twine - name: Build wheel and sdist run: python -m build --sdist --wheel - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 with: name: releases path: dist @@ -35,7 +36,7 @@ jobs: - name: Check with twine run: python -m twine check --strict dist/* - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@v1.6.4 + uses: pypa/gh-action-pypi-publish@v1.12.4 with: user: __token__ password: ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/test_and_build.yml b/.github/workflows/test_and_build.yml index 6b36da3bc..af7525928 100644 --- a/.github/workflows/test_and_build.yml +++ b/.github/workflows/test_and_build.yml @@ -17,6 +17,10 @@ on: branches: - main +# concurrency: +# group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} +# cancel-in-progress: true + jobs: rngs: # To achieve consistent coverage, we need a little bit of correlated collaboration. @@ -46,7 +50,7 @@ jobs: backend: ${{ steps.backend.outputs.selected }} steps: - name: RNG for mapnumpy - uses: ddradar/choose-random-action@v2.0.2 + uses: ddradar/choose-random-action@v3.0.0 id: mapnumpy with: contents: | @@ -60,7 +64,7 @@ jobs: 1 1 - name: RNG for backend - uses: ddradar/choose-random-action@v2.0.2 + uses: ddradar/choose-random-action@v3.0.0 id: backend with: contents: | @@ -80,57 +84,67 @@ jobs: run: shell: bash -l {0} strategy: - # To "stress test" in CI, set `fail-fast` to `false` and perhaps add more items to `matrix.slowtask` - fail-fast: true + # To "stress test" in CI, set `fail-fast` to `false` and use `repeat` in matrix below + fail-fast: false # The build matrix is [os]x[slowtask] and then randomly chooses [pyver] and [sourcetype]. # This should ensure we'll have full code coverage (i.e., no chance of getting unlucky), # since we need to run all slow tests on Windows and non-Windoes OSes. matrix: os: ["ubuntu-latest", "macos-latest", "windows-latest"] slowtask: ["pytest_normal", "pytest_bizarro", "notebooks"] + # repeat: [1, 2, 3] # For stress testing + env: + # Wheels on OS X come with an OpenMP that conflicts with OpenMP from conda-forge. + # Setting this is a workaround. + KMP_DUPLICATE_LIB_OK: ${{ contains(matrix.os, 'macos') && 'TRUE' || 'FALSE' }} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 + persist-credentials: false - name: RNG for Python version - uses: ddradar/choose-random-action@v2.0.2 + uses: ddradar/choose-random-action@v3.0.0 id: pyver with: - # We should support major Python versions for at least 36-42 months + # We should support major Python versions for at least 36 months as per SPEC 0 + # We may be able to support pypy if anybody asks for it + # 3.9.16 0_73_pypy contents: | - 3.8 - 3.9 3.10 + 3.11 + 3.12 + 3.13 weights: | 1 1 1 + 1 - name: RNG for source of python-suitesparse-graphblas - uses: ddradar/choose-random-action@v2.0.2 + uses: ddradar/choose-random-action@v3.0.0 id: sourcetype with: - # Set weight to 0 to skip (such as if 'upstream' is known to not work). - # Have slightly higher weight for `conda-forge` for faster CI. + # Weights must be natural numbers, so set weights to very large to skip one + # (such as if 'upstream' is known to not work). contents: | conda-forge wheel source upstream weights: | - 2 + 1 1 1 1 - name: Setup conda - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 + id: setup_conda with: - miniforge-variant: Mambaforge - miniforge-version: latest - use-mamba: true + auto-update-conda: true python-version: ${{ steps.pyver.outputs.selected }} - channels: conda-forge,nodefaults - channel-priority: strict + channels: conda-forge${{ contains(steps.pyver.outputs.selected, 'pypy') && ',defaults' || '' }} + conda-remove-defaults: ${{ contains(steps.pyver.outputs.selected, 'pypy') && 'false' || 'true' }} + channel-priority: ${{ contains(steps.pyver.outputs.selected, 'pypy') && 'flexible' || 'strict' }} activate-environment: graphblas auto-activate-base: false - name: Update env @@ -140,73 +154,243 @@ jobs: # # First let's randomly get versions of dependencies to install. # Consider removing old versions when they become problematic or very old (>=2 years). - nxver=$(python -c 'import random ; print(random.choice(["=2.7", "=2.8", "=3.0", ""]))') - yamlver=$(python -c 'import random ; print(random.choice(["=5.4", "=6.0", ""]))') - sparsever=$(python -c 'import random ; print(random.choice(["=0.12", "=0.13", "=0.14", ""]))') - if [[ ${{ steps.pyver.outputs.selected }} == "3.8" ]]; then - npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", ""]))') - spver=$(python -c 'import random ; print(random.choice(["=1.8", "=1.9", "=1.10", ""]))') - pdver=$(python -c 'import random ; print(random.choice(["=1.2", "=1.3", "=1.4", "=1.5", ""]))') - akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", ""]))') - elif [[ ${{ steps.pyver.outputs.selected }} == "3.9" ]]; then - npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", ""]))') - spver=$(python -c 'import random ; print(random.choice(["=1.8", "=1.9", "=1.10", ""]))') - pdver=$(python -c 'import random ; print(random.choice(["=1.2", "=1.3", "=1.4", "=1.5", ""]))') - akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", ""]))') - elif [[ ${{ steps.pyver.outputs.selected }} == "3.10" ]]; then - npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", ""]))') - spver=$(python -c 'import random ; print(random.choice(["=1.8", "=1.9", "=1.10", ""]))') - pdver=$(python -c 'import random ; print(random.choice(["=1.3", "=1.4", "=1.5", ""]))') - akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", ""]))') - else # Python 3.11 - npver=$(python -c 'import random ; print(random.choice(["=1.23", ""]))') - spver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", ""]))') - pdver=$(python -c 'import random ; print(random.choice(["=1.5", ""]))') - akver=$(python -c 'import random ; print(random.choice(["=1.10", "=2.0.5", "=2.0.6", "=2.0.7", "=2.0.8", ""]))') + + # Randomly choosing versions of dependencies based on Python version works surprisingly well... + if [[ ${{ startsWith(steps.pyver.outputs.selected, '3.10') }} == true ]]; then + nxver=$(python -c 'import random ; print(random.choice(["=2.8", "=3.0", "=3.1", "=3.2", "=3.3", "=3.4", ""]))') + npver=$(python -c 'import random ; print(random.choice(["=1.24", "=1.25", "=1.26", "=2.0", "=2.1", "=2.2", ""]))') + spver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=1.11", "=1.12", "=1.13", "=1.14", "=1.15", ""]))') + pdver=$(python -c 'import random ; print(random.choice(["=1.5", "=2.0", "=2.1", "=2.2", ""]))') + akver=$(python -c 'import random ; print(random.choice(["=1.10", "=2.0", "=2.1", "=2.2", "=2.3", "=2.4", "=2.5", "=2.6", "=2.7", ""]))') + fmmver=$(python -c 'import random ; print(random.choice(["=1.4", "=1.5", "=1.6", "=1.7", ""]))') + yamlver=$(python -c 'import random ; print(random.choice(["=5.4", "=6.0", ""]))') + sparsever=$(python -c 'import random ; print(random.choice(["=0.14", "=0.15", ""]))') + elif [[ ${{ startsWith(steps.pyver.outputs.selected, '3.11') }} == true ]]; then + nxver=$(python -c 'import random ; print(random.choice(["=2.8", "=3.0", "=3.1", "=3.2", "=3.3", "=3.4", ""]))') + npver=$(python -c 'import random ; print(random.choice(["=1.24", "=1.25", "=1.26", "=2.0", "=2.1", "=2.2", ""]))') + spver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=1.11", "=1.12", "=1.13", "=1.14", "=1.15", ""]))') + pdver=$(python -c 'import random ; print(random.choice(["=1.5", "=2.0", "=2.1", "=2.2", ""]))') + akver=$(python -c 'import random ; print(random.choice(["=1.10", "=2.0", "=2.1", "=2.2", "=2.3", "=2.4", "=2.5", "=2.6", "=2.7", ""]))') + fmmver=$(python -c 'import random ; print(random.choice(["=1.4", "=1.5", "=1.6", "=1.7", ""]))') + yamlver=$(python -c 'import random ; print(random.choice(["=5.4", "=6.0", ""]))') + sparsever=$(python -c 'import random ; print(random.choice(["=0.14", "=0.15", ""]))') + elif [[ ${{ startsWith(steps.pyver.outputs.selected, '3.12') }} == true ]]; then + nxver=$(python -c 'import random ; print(random.choice(["=3.2", "=3.3", "=3.4", ""]))') + npver=$(python -c 'import random ; print(random.choice(["=1.26", "=2.0", "=2.1", "=2.2", ""]))') + spver=$(python -c 'import random ; print(random.choice(["=1.11", "=1.12", "=1.13", "=1.14", "=1.15", ""]))') + pdver=$(python -c 'import random ; print(random.choice(["=2.1", "=2.2", ""]))') + akver=$(python -c 'import random ; print(random.choice(["=2.4", "=2.5", "=2.6", "=2.7", ""]))') + fmmver=$(python -c 'import random ; print(random.choice(["=1.7", ""]))') + yamlver=$(python -c 'import random ; print(random.choice(["=6.0", ""]))') + sparsever=$(python -c 'import random ; print(random.choice(["=0.14", "=0.15", ""]))') + else # Python 3.13 + nxver=$(python -c 'import random ; print(random.choice(["=3.4", ""]))') + npver=$(python -c 'import random ; print(random.choice(["=2.1", "=2.2", ""]))') + spver=$(python -c 'import random ; print(random.choice(["=1.14", "=1.15", ""]))') + pdver=$(python -c 'import random ; print(random.choice(["=2.2", ""]))') + akver=$(python -c 'import random ; print(random.choice(["=2.7", ""]))') + fmmver=NA # Not yet supported + yamlver=$(python -c 'import random ; print(random.choice(["=6.0", ""]))') + sparsever=NA # Not yet supported fi - if [[ ${{ steps.sourcetype.outputs.selected }} == "source" || ${{ steps.sourcetype.outputs.selected }} == "upstream" ]]; then + + # But there may be edge cases of incompatibility we need to handle (more handled below) + if [[ ${{ steps.sourcetype.outputs.selected }} == "source" ]]; then # TODO: there are currently issues with some numpy versions when - # installing python-suitesparse-grphblas from source or upstream. + # installing python-suitesparse-grphblas from source. npver="" spver="" pdver="" fi + # We can have a tight coupling with python-suitesparse-graphblas. # That is, we don't need to support versions of it that are two years old. # But, it's still useful for us to test with different versions! - if [[ ${{ steps.sourcetype.outputs.selected}} == "conda-forge" ]] ; then - psgver=$(python -c 'import random ; print(random.choice(["=7.4.0", "=7.4.1", "=7.4.2", "=7.4.3.0", "=7.4.3.1", ""]))') - else + psg="" + if [[ ${{ steps.sourcetype.outputs.selected}} == "upstream" ]] ; then + # Upstream needs to build with numpy 2 psgver="" + if [[ ${{ startsWith(steps.pyver.outputs.selected, '3.13') }} == true ]]; then + npver=$(python -c 'import random ; print(random.choice(["=2.1", "=2.2", ""]))') + else + npver=$(python -c 'import random ; print(random.choice(["=2.0", "=2.1", "=2.2", ""]))') + fi + elif [[ ${{ startsWith(steps.pyver.outputs.selected, '3.13') }} == true ]] ; then + if [[ ${{ steps.sourcetype.outputs.selected}} == "conda-forge" ]] ; then + psgver=$(python -c 'import random ; print(random.choice(["=9.3.1.0", "=9.4.5.0", ""]))') + psg=python-suitesparse-graphblas${psgver} + else + psgver=$(python -c 'import random ; print(random.choice(["==9.3.1.0", "==9.4.5.0", ""]))') + fi + elif [[ ${{ startsWith(steps.pyver.outputs.selected, '3.12') }} == true ]] ; then + if [[ ${{ steps.sourcetype.outputs.selected}} == "conda-forge" ]] ; then + if [[ $npver == =1.* ]] ; then + psgver=$(python -c 'import random ; print(random.choice(["=8.2.0.1", "=8.2.1.0"]))') + else + psgver=$(python -c 'import random ; print(random.choice(["=9.3.1.0", "=9.4.5.0", ""]))') + fi + psg=python-suitesparse-graphblas${psgver} + else + if [[ $npver == =1.* ]] ; then + psgver=$(python -c 'import random ; print(random.choice(["==8.2.0.1", "==8.2.1.0"]))') + else + psgver=$(python -c 'import random ; print(random.choice(["==9.3.1.0", "==9.4.5.0", ""]))') + fi + fi + # python-suitsparse-graphblas support is the same for Python 3.10 and 3.11 + elif [[ ${{ steps.sourcetype.outputs.selected}} == "conda-forge" ]] ; then + if [[ $npver == =1.* ]] ; then + psgver=$(python -c 'import random ; print(random.choice(["=7.4.0", "=7.4.1", "=7.4.2", "=7.4.3.0", "=7.4.3.1", "=7.4.3.2", "=8.0.2.1", "=8.2.0.1", "=8.2.1.0"]))') + else + psgver=$(python -c 'import random ; print(random.choice(["=9.3.1.0", "=9.4.5.0", ""]))') + fi + psg=python-suitesparse-graphblas${psgver} + elif [[ ${{ steps.sourcetype.outputs.selected}} == "wheel" ]] ; then + if [[ $npver == =1.* ]] ; then + psgver=$(python -c 'import random ; print(random.choice(["==7.4.3.2", "==8.0.2.1", "==8.2.0.1", "==8.2.1.0"]))') + else + psgver=$(python -c 'import random ; print(random.choice(["==9.3.1.0", "==9.4.5.0", ""]))') + fi + elif [[ ${{ steps.sourcetype.outputs.selected}} == "source" ]] ; then + # These should be exact versions + if [[ $npver == =1.* ]] ; then + psgver=$(python -c 'import random ; print(random.choice(["==7.4.0.0", "==7.4.1.0", "==7.4.2.0", "==7.4.3.0", "==7.4.3.1", "==7.4.3.2", "==8.0.2.1", "==8.2.0.1", "==8.2.1.0"]))') + else + psgver=$(python -c 'import random ; print(random.choice(["==9.3.1.0", "==9.4.5.0", ""]))') + fi fi - if [[ $npver == "=1.21" ]] ; then - numbaver=$(python -c 'import random ; print(random.choice(["=0.55", "=0.56", ""]))') + + # Numba is tightly coupled to numpy versions + if [[ ${npver} == "=1.26" ]] ; then + numbaver=$(python -c 'import random ; print(random.choice(["=0.58", "=0.59", "=0.60", "=0.61", ""]))') + if [[ ${spver} == "=1.9" ]] ; then + spver=$(python -c 'import random ; print(random.choice(["=1.10", "=1.11", ""]))') + fi + elif [[ ${npver} == "=1.25" ]] ; then + numbaver=$(python -c 'import random ; print(random.choice(["=0.58", "=0.59", "=0.60", "=0.61", ""]))') + elif [[ ${npver} == "=1.24" || ${{ startsWith(steps.pyver.outputs.selected, '3.11') }} == true ]] ; then + numbaver=$(python -c 'import random ; print(random.choice(["=0.57", "=0.58", "=0.59", "=0.60", "=0.61", ""]))') + else + numbaver="" + fi + # Only numba >=0.59 support Python 3.12 + if [[ ${{ startsWith(steps.pyver.outputs.selected, '3.12') }} == true ]] ; then + numbaver=$(python -c 'import random ; print(random.choice(["=0.59", "=0.60", "=0.61", ""]))') + fi + + # Handle NumPy 2 + if [[ $npver != =1.* ]] ; then + # Only pandas >=2.2.2 supports NumPy 2 + pdver=$(python -c 'import random ; print(random.choice(["=2.2", ""]))') + + # Only awkward >=2.6.3 supports NumPy 2 + if [[ ${{ startsWith(steps.pyver.outputs.selected, '3.13') }} == true ]] ; then + akver=$(python -c 'import random ; print(random.choice(["=2.7", ""]))') + else + akver=$(python -c 'import random ; print(random.choice(["=2.6", "=2.7", ""]))') + fi + + # Only scipy >=1.13 supports NumPy 2 + if [[ $spver == "=1.9" || $spver == "=1.10" || $spver == "=1.11" || $spver == "=1.12" ]] ; then + spver="=1.13" + fi + fi + + fmm=fast_matrix_market${fmmver} + awkward=awkward${akver} + + # Don't install numba and sparse for some versions + if [[ ${{ contains(steps.pyver.outputs.selected, 'pypy') || + startsWith(steps.pyver.outputs.selected, '3.14') }} == true || + ( ${{ matrix.slowtask != 'notebooks'}} == true && ( + ( ${{ matrix.os == 'windows-latest' }} == true && $(python -c 'import random ; print(random.random() < .2)') == True ) || + ( ${{ matrix.os == 'windows-latest' }} == false && $(python -c 'import random ; print(random.random() < .4)') == True ))) ]] + then + # Some packages aren't available for pypy or Python 3.13; randomly otherwise (if not running notebooks) + echo "skipping numba" + numba="" + numbaver=NA + sparse="" + sparsever=NA + if [[ ${{ contains(steps.pyver.outputs.selected, 'pypy') }} ]]; then + awkward="" + akver=NA + fmm="" + fmmver=NA + # Be more flexible until we determine what versions are supported by pypy + npver="" + spver="" + pdver="" + yamlver="" + fi + elif [[ ${npver} == =2.* ]] ; then + # Don't install numba for unsupported versions of numpy + numba="" + numbaver=NA + sparse="" + sparsever=NA else - numbaver=$(python -c 'import random ; print(random.choice(["=0.56", ""]))') + numba=numba${numbaver} + sparse=sparse${sparsever} fi - echo "versions: np${npver} sp${spver} pd${pdver} ak${akver} nx${nxver} numba${numbaver} yaml${yamlver} sparse${sparsever} psgver${psgver}" - # Once we have wheels for all OSes, we can delete the last two lines. - mamba install packaging pytest coverage coveralls=3.3.1 pytest-randomly cffi donfig pyyaml${yamlver} sparse${sparsever} \ - pandas${pdver} scipy${spver} numpy${npver} awkward${akver} networkx${nxver} numba${numbaver} \ + # sparse does not yet support Python 3.13 + if [[ ${{ startsWith(steps.pyver.outputs.selected, '3.13') }} == true ]] ; then + sparse="" + sparsever=NA + fi + # fast_matrix_market does not yet support Python 3.13 or osx-arm64 + if [[ ${{ startsWith(steps.pyver.outputs.selected, '3.13') }} == true || + ${{ matrix.os == 'macos-latest' }} == true ]] + then + fmm="" + fmmver=NA + fi + + echo "versions: np${npver} sp${spver} pd${pdver} ak${akver} nx${nxver} numba${numbaver} yaml${yamlver} sparse${sparsever} psg${psgver}" + + set -x # echo on + $(command -v mamba || command -v conda) install -c nodefaults \ + packaging pytest coverage pytest-randomly cffi donfig tomli c-compiler make \ + pyyaml${yamlver} ${sparse} pandas${pdver} scipy${spver} numpy${npver} ${awkward} \ + networkx${nxver} ${numba} ${fmm} ${psg} \ ${{ matrix.slowtask == 'pytest_bizarro' && 'black' || '' }} \ - ${{ matrix.slowtask == 'notebooks' && 'matplotlib nbconvert jupyter "ipython>=7"' || '' }} \ + ${{ matrix.slowtask == 'notebooks' && 'matplotlib nbconvert jupyter "ipython>=7" drawsvg' || '' }} \ ${{ steps.sourcetype.outputs.selected == 'upstream' && 'cython' || '' }} \ - ${{ steps.sourcetype.outputs.selected != 'wheel' && '"graphblas>=7.4.0"' || '' }} \ - ${{ steps.sourcetype.outputs.selected == 'conda-forge' && 'python-suitesparse-graphblas' || '' }}${psgver} \ - ${{ matrix.os != 'ubuntu-latest' && '"graphblas>=7.4.0"' || '' }} \ - ${{ steps.sourcetype.outputs.selected == 'wheel' && matrix.os != 'ubuntu-latest' && 'python-suitesparse-graphblas' || '' }} + ${{ steps.sourcetype.outputs.selected != 'wheel' && '"graphblas>=7.4,<9.5"' || '' }} \ + ${{ contains(steps.pyver.outputs.selected, 'pypy') && 'pypy' || '' }} \ + ${{ matrix.os == 'windows-latest' && 'cmake' || 'm4' }} \ + # ${{ matrix.os != 'windows-latest' && 'pytest-forked' || '' }} # to investigate crashes - name: Build extension module run: | - # We only have wheels for Linux right now - if [[ ${{ steps.sourcetype.outputs.selected }} == "wheel" && ${{ matrix.os }} == "ubuntu-latest" ]]; then - pip install --no-deps suitesparse-graphblas + if [[ ${{ steps.sourcetype.outputs.selected }} == "wheel" ]]; then + # Add --pre if installing a pre-release + pip install --no-deps --only-binary ":all:" suitesparse-graphblas${psgver} + + # Add the below line to the conda install command above if installing from test.pypi.org + # ${{ steps.sourcetype.outputs.selected == 'wheel' && 'setuptools setuptools-git-versioning wheel cython' || '' }} \ + # pip install --no-deps --only-binary ":all:" --index-url https://test.pypi.org/simple/ "suitesparse-graphblas>=7.4.3" elif [[ ${{ steps.sourcetype.outputs.selected }} == "source" ]]; then - pip install --no-deps --no-binary=all suitesparse-graphblas + # Add --pre if installing a pre-release + pip install --no-deps --no-binary suitesparse-graphblas suitesparse-graphblas${psgver} + + # Add the below line to the conda install command above if installing from test.pypi.org + # ${{ steps.sourcetype.outputs.selected == 'source' && 'setuptools setuptools-git-versioning wheel cython' || '' }} \ + # pip install --no-deps --no-build-isolation --no-binary suitesparse-graphblas --index-url https://test.pypi.org/simple/ suitesparse-graphblas==7.4.3.3 elif [[ ${{ steps.sourcetype.outputs.selected }} == "upstream" ]]; then pip install --no-deps git+https://github.com/GraphBLAS/python-suitesparse-graphblas.git@main#egg=suitesparse-graphblas fi pip install --no-deps -e . + - name: python-suitesparse-graphblas tests + run: | + # Don't use our conftest.py ; allow `test_print_jit_config` to fail if it doesn't exist + (cd .. + pytest --pyargs suitesparse_graphblas -s -k test_print_jit_config || true + pytest -v --pyargs suitesparse_graphblas || true) + - name: Print platform and sysconfig variables + run: | + python -c "import platform ; print(platform.uname())" + python -c "import pprint, sysconfig ; pprint.pprint(sysconfig.get_config_vars())" - name: Unit tests run: | A=${{ needs.rngs.outputs.mapnumpy == 'A' || '' }} ; B=${{ needs.rngs.outputs.mapnumpy == 'B' || '' }} @@ -233,8 +417,11 @@ jobs: if [[ $G && $bizarro ]] ; then if [[ $ubuntu ]] ; then echo " $suitesparse" ; elif [[ $windows ]] ; then echo " $vanilla" ; fi ; fi)$( \ if [[ $H && $normal ]] ; then if [[ $macos ]] ; then echo " $vanilla" ; elif [[ $windows ]] ; then echo " $suitesparse" ; fi ; fi)$( \ if [[ $H && $bizarro ]] ; then if [[ $macos ]] ; then echo " $suitesparse" ; elif [[ $windows ]] ; then echo " $vanilla" ; fi ; fi) - echo $args - coverage run -m pytest --color=yes --randomly -v $args \ + echo ${args} + set -x # echo on + # pytest ${{ matrix.os != 'windows-latest' && '--forked' || '' }} \ # to investigate crashes + # --color=yes --randomly -v -s ${args} \ + coverage run -m pytest --color=yes --randomly -v ${args} \ ${{ matrix.slowtask == 'pytest_normal' && '--runslow' || '' }} - name: Unit tests (bizarro scalars) run: | @@ -268,8 +455,11 @@ jobs: if [[ $G && $bizarro ]] ; then if [[ $ubuntu ]] ; then echo " $vanilla" ; elif [[ $windows ]] ; then echo " $suitesparse" ; fi ; fi)$( \ if [[ $H && $normal ]] ; then if [[ $macos ]] ; then echo " $suitesparse" ; elif [[ $windows ]] ; then echo " $vanilla" ; fi ; fi)$( \ if [[ $H && $bizarro ]] ; then if [[ $macos ]] ; then echo " $vanilla" ; elif [[ $windows ]] ; then echo " $suitesparse" ; fi ; fi) - echo $args - coverage run -a -m pytest --color=yes --randomly -v $args \ + echo ${args} + set -x # echo on + # pytest ${{ matrix.os != 'windows-latest' && '--forked' || '' }} \ # to investigate crashes + # --color=yes --randomly -v -s ${args} \ + coverage run -a -m pytest --color=yes --randomly -v ${args} \ ${{ matrix.slowtask == 'pytest_bizarro' && '--runslow' || '' }} git checkout . # Undo changes to scalar default - name: Miscellaneous tests @@ -282,16 +472,26 @@ jobs: echo "from graphblas.agg import count" > script.py coverage run -a script.py echo "from graphblas import agg" > script.py # Does this still cover? - echo "from graphblas.core import agg" >> script.py + echo "from graphblas.core.operator import agg" >> script.py coverage run -a script.py # Tests lazy loading of lib, ffi, and NULL in gb.core echo "from graphblas.core import base" > script.py coverage run -a script.py + # Test another code pathway for loading lib + echo "from graphblas.core import lib" > script.py + coverage run -a script.py rm script.py # Tests whose coverage depend on order of tests :/ # TODO: understand why these are order-dependent and try to fix coverage run -a -m pytest --color=yes -x --no-mapnumpy --runslow -k test_binaryop_attributes_numpy graphblas/tests/test_op.py # coverage run -a -m pytest --color=yes -x --no-mapnumpy -k test_npmonoid graphblas/tests/test_numpyops.py --runslow + - name: More tests for coverage + if: matrix.slowtask == 'notebooks' && matrix.os == 'windows-latest' + run: | + # We use 'notebooks' slow task b/c it should have numba installed + coverage run -a -m pytest --color=yes --runslow --no-mapnumpy -p no:randomly -v -k 'test_commutes or test_bool_doesnt_get_too_large or test_npbinary or test_npmonoid or test_npsemiring' + coverage run -a -m pytest --color=yes --runslow --mapnumpy -p no:randomly -k 'test_bool_doesnt_get_too_large or test_npunary or test_binaryop_monoid_numpy' + coverage run -a -m pytest --color=yes -x --no-mapnumpy --runslow -k test_binaryop_attributes_numpy graphblas/tests/test_op.py - name: Auto-generated code check if: matrix.slowtask == 'pytest_bizarro' run: | @@ -300,31 +500,15 @@ jobs: coverage run -a -m graphblas.core.infixmethods git diff --exit-code - name: Coverage - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - COVERALLS_FLAG_NAME: ${{ matrix.os }}/${{ matrix.slowtask }} - COVERALLS_PARALLEL: true run: | coverage xml coverage report --show-missing - coveralls --service=github - name: codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v5 - name: Notebooks Execution check if: matrix.slowtask == 'notebooks' - run: jupyter nbconvert --to notebook --execute notebooks/*ipynb - - finish: - needs: build_and_test - if: always() - runs-on: ubuntu-latest - steps: - - uses: actions/setup-python@v4 - with: - python-version: "3.10" - - run: python -m pip install --upgrade pip - - run: pip install coveralls - - name: Coveralls Finished - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: coveralls --finish + run: | + # Run notebooks only if numba is installed + if python -c 'import numba' 2> /dev/null ; then + jupyter nbconvert --to notebook --execute notebooks/*ipynb + fi diff --git a/.github/zizmor.yml b/.github/zizmor.yml new file mode 100644 index 000000000..61f32c2e0 --- /dev/null +++ b/.github/zizmor.yml @@ -0,0 +1,16 @@ +rules: + use-trusted-publishing: + # TODO: we should update to use trusted publishing + ignore: + - publish_pypi.yml + excessive-permissions: + # It is probably good practice to use narrow permissions + ignore: + - debug.yml + - imports.yml + - publish_pypi.yml + - test_and_build.yml + template-injection: + # We use templates pretty heavily + ignore: + - test_and_build.yml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0c9c94988..43e28b8fe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,81 +4,145 @@ # To run: `pre-commit run --all-files` # To update: `pre-commit autoupdate` # - &flake8_dependencies below needs updated manually -fail_fast: true +ci: + # See: https://pre-commit.ci/#configuration + autofix_prs: false + autoupdate_schedule: quarterly + autoupdate_commit_msg: "chore: update pre-commit hooks" + autofix_commit_msg: "style: pre-commit fixes" + skip: [pylint, no-commit-to-branch] +fail_fast: false default_language_version: - python: python3 + python: python3 repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v5.0.0 hooks: - id: check-added-large-files + - id: check-case-conflict + - id: check-merge-conflict + # - id: check-symlinks - id: check-ast - id: check-toml - id: check-yaml + - id: check-executables-have-shebangs + - id: check-vcs-permalinks + - id: destroyed-symlinks + - id: detect-private-key - id: debug-statements - id: end-of-file-fixer + exclude_types: [svg] - id: mixed-line-ending - id: trailing-whitespace + - id: name-tests-test + args: ["--pytest-test-first"] - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.12.1 + rev: v0.23 hooks: - id: validate-pyproject name: Validate pyproject.toml - - repo: https://github.com/myint/autoflake - rev: v2.0.1 + # I don't yet trust ruff to do what autoflake does + - repo: https://github.com/PyCQA/autoflake + rev: v2.3.1 hooks: - id: autoflake args: [--in-place] + # We can probably remove `isort` if we come to trust `ruff --fix`, + # but we'll need to figure out the configuration to do this in `ruff` - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 6.0.0 hooks: - id: isort + # Let's keep `pyupgrade` even though `ruff --fix` probably does most of it - repo: https://github.com/asottile/pyupgrade - rev: v3.3.1 + rev: v3.19.1 hooks: - id: pyupgrade - args: [--py38-plus] + args: [--py310-plus] - repo: https://github.com/MarcoGorelli/auto-walrus - rev: v0.2.2 + rev: 0.3.4 hooks: - id: auto-walrus args: [--line-length, "100"] - repo: https://github.com/psf/black - rev: 23.1.0 + rev: 25.1.0 hooks: - id: black - id: black-jupyter + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.6 + hooks: + - id: ruff + args: [--fix-only, --show-fixes] + # Let's keep `flake8` even though `ruff` does much of the same. + # `flake8-bugbear` and `flake8-simplify` have caught things missed by `ruff`. - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 + rev: 7.1.2 hooks: - id: flake8 - additional_dependencies: &flake8_dependencies - # These versions need updated manually - - flake8==6.0.0 - - flake8-comprehensions==3.10.1 - - flake8-bugbear==23.2.13 - - flake8-simplify==0.19.3 - - repo: https://github.com/asottile/yesqa - rev: v1.4.0 - hooks: - - id: yesqa - additional_dependencies: *flake8_dependencies + args: ["--config=.flake8"] + additional_dependencies: + &flake8_dependencies # These versions need updated manually + - flake8==7.1.2 + - flake8-bugbear==24.12.12 + - flake8-simplify==0.21.0 - repo: https://github.com/codespell-project/codespell - rev: v2.2.2 + rev: v2.4.1 hooks: - id: codespell types_or: [python, rst, markdown] additional_dependencies: [tomli] files: ^(graphblas|docs)/ - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.252 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.6 hooks: - id: ruff - repo: https://github.com/sphinx-contrib/sphinx-lint - rev: v0.6.7 + rev: v1.0.0 hooks: - id: sphinx-lint args: [--enable, all, "--disable=line-too-long,leaked-markup"] + # `pyroma` may help keep our package standards up to date if best practices change. + # This is probably a "low value" check though and safe to remove if we want faster pre-commit. + - repo: https://github.com/regebro/pyroma + rev: "4.2" + hooks: + - id: pyroma + args: [-n, "10", .] + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: "v0.10.0.1" + hooks: + - id: shellcheck + - repo: https://github.com/rbubley/mirrors-prettier + rev: v3.5.1 + hooks: + - id: prettier + - repo: https://github.com/ComPWA/taplo-pre-commit + rev: v0.9.3 + hooks: + - id: taplo-format + - repo: https://github.com/rhysd/actionlint + rev: v1.7.7 + hooks: + - id: actionlint + - repo: https://github.com/python-jsonschema/check-jsonschema + rev: 0.31.1 + hooks: + - id: check-dependabot + - id: check-github-workflows + - id: check-readthedocs + - repo: https://github.com/adrienverge/yamllint + rev: v1.35.1 + hooks: + - id: yamllint + - repo: https://github.com/woodruffw/zizmor-pre-commit + rev: v1.3.1 + hooks: + - id: zizmor + - repo: meta + hooks: + - id: check-hooks-apply + - id: check-useless-excludes - repo: local hooks: # Add `--hook-stage manual` to pre-commit command to run (very slow) @@ -92,9 +156,9 @@ repos: args: [graphblas/] pass_filenames: false - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v5.0.0 hooks: - - id: no-commit-to-branch # no commit directly to main + - id: no-commit-to-branch # no commit directly to main # # Maybe: # @@ -111,8 +175,10 @@ repos: # additional_dependencies: [tomli] # # - repo: https://github.com/PyCQA/bandit -# rev: 1.7.4 +# rev: 1.8.2 # hooks: # - id: bandit +# args: ["-c", "pyproject.toml"] +# additional_dependencies: ["bandit[toml]"] # -# blacken-docs, blackdoc mypy, pydocstringformatter, velin, flynt, yamllint +# blacken-docs, blackdoc, mypy, pydocstringformatter, velin, flynt diff --git a/.yamllint.yaml b/.yamllint.yaml new file mode 100644 index 000000000..54e656293 --- /dev/null +++ b/.yamllint.yaml @@ -0,0 +1,6 @@ +--- +extends: default +rules: + document-start: disable + line-length: disable + truthy: disable diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 7cfcb10f9..eebd2c372 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -13,13 +13,13 @@ educational level, family status, culture, or political belief. Examples of unacceptable behavior by participants include: -* The use of sexualized language or imagery -* Personal attacks -* Trolling or insulting/derogatory comments -* Public or private harassment -* Publishing other's private information, such as physical or electronic +- The use of sexualized language or imagery +- Personal attacks +- Trolling or insulting/derogatory comments +- Public or private harassment +- Publishing other's private information, such as physical or electronic addresses, without explicit permission -* Other unethical or unprofessional conduct +- Other unethical or unprofessional conduct Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions @@ -52,12 +52,12 @@ that is deemed necessary and appropriate to the circumstances. Maintainers are obligated to maintain confidentiality with regard to the reporter of an incident. -This Code of Conduct is adapted from the [Numba Code of Conduct][numba], which is based on the [Contributor Covenant][homepage], +This Code of Conduct is adapted from the [Numba Code of Conduct][numba], which is based on the [Contributor Covenant][homepage], version 1.3.0, available at -[http://contributor-covenant.org/version/1/3/0/][version], +[https://contributor-covenant.org/version/1/3/0/][version], and the [Swift Code of Conduct][swift]. [numba]: https://github.com/numba/numba-governance/blob/accepted/code-of-conduct.md -[homepage]: http://contributor-covenant.org -[version]: http://contributor-covenant.org/version/1/3/0/ +[homepage]: https://contributor-covenant.org +[version]: https://contributor-covenant.org/version/1/3/0/ [swift]: https://swift.org/community/#code-of-conduct diff --git a/LICENSE b/LICENSE index 74a8ba6c6..21c605c21 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ Apache License Version 2.0, January 2004 - http://www.apache.org/licenses/ + https://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION @@ -186,13 +186,13 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2020 Anaconda, Inc + Copyright 2020-2023 Anaconda, Inc. and contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/MANIFEST.in b/MANIFEST.in index f3f4b04bb..27cd3f0c4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,8 @@ recursive-include graphblas *.py +prune docs +prune scripts include setup.py +include conftest.py include README.md include LICENSE include MANIFEST.in diff --git a/README.md b/README.md index 34c1c1994..1080314c7 100644 --- a/README.md +++ b/README.md @@ -1,41 +1,71 @@ -# Python-graphblas +![Python-graphblas](https://raw.githubusercontent.com/python-graphblas/python-graphblas/main/docs/_static/img/logo-horizontal-medium-big.svg) +[![Powered by NumFOCUS](https://img.shields.io/badge/powered%20by-NumFOCUS-orange.svg?style=flat&colorA=E1523D&colorB=007D8A)](https://numfocus.org) +[![pyOpenSci](https://tinyurl.com/y22nb8up)](https://github.com/pyOpenSci/software-review/issues/81) +[![Discord](https://img.shields.io/badge/Chat-Discord-Blue?color=5865f2)](https://discord.com/invite/vur45CbwMz) +
[![conda-forge](https://img.shields.io/conda/vn/conda-forge/python-graphblas.svg)](https://anaconda.org/conda-forge/python-graphblas) [![pypi](https://img.shields.io/pypi/v/python-graphblas.svg)](https://pypi.python.org/pypi/python-graphblas/) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/python-graphblas)](https://pypi.python.org/pypi/python-graphblas/) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/python-graphblas/python-graphblas/blob/main/LICENSE) -[![Tests](https://github.com/python-graphblas/python-graphblas/workflows/Tests/badge.svg?branch=main)](https://github.com/python-graphblas/python-graphblas/actions) +
+[![Tests](https://github.com/python-graphblas/python-graphblas/actions/workflows/test_and_build.yml/badge.svg?branch=main)](https://github.com/python-graphblas/python-graphblas/actions) [![Docs](https://readthedocs.org/projects/python-graphblas/badge/?version=latest)](https://python-graphblas.readthedocs.io/en/latest/) -[![Coverage](https://coveralls.io/repos/python-graphblas/python-graphblas/badge.svg?branch=main)](https://coveralls.io/r/python-graphblas/python-graphblas) +[![Coverage](https://codecov.io/gh/python-graphblas/python-graphblas/graph/badge.svg?token=D7HHLDPQ2Q)](https://codecov.io/gh/python-graphblas/python-graphblas) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.7328791.svg)](https://doi.org/10.5281/zenodo.7328791) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/python-graphblas/python-graphblas/HEAD?filepath=notebooks%2FIntro%20to%20GraphBLAS%20%2B%20SSSP%20example.ipynb) -[![Discord](https://img.shields.io/badge/Chat-Discord-blue)](https://discord.com/invite/vur45CbwMz) Python library for GraphBLAS: high-performance sparse linear algebra for scalable graph analytics. +For algorithms, see +[`graphblas-algorithms`](https://github.com/python-graphblas/graphblas-algorithms). - **Documentation:** [https://python-graphblas.readthedocs.io/](https://python-graphblas.readthedocs.io/) + - **FAQ:** [https://python-graphblas.readthedocs.io/en/stable/getting_started/faq.html](https://python-graphblas.readthedocs.io/en/stable/getting_started/faq.html) - **GraphBLAS C API:** [https://graphblas.org/docs/GraphBLAS_API_C_v2.0.0.pdf](https://graphblas.org/docs/GraphBLAS_API_C_v2.0.0.pdf) - **SuiteSparse:GraphBLAS User Guide:** [https://github.com/DrTimothyAldenDavis/GraphBLAS/raw/stable/Doc/GraphBLAS_UserGuide.pdf](https://github.com/DrTimothyAldenDavis/GraphBLAS/raw/stable/Doc/GraphBLAS_UserGuide.pdf) - **Source:** [https://github.com/python-graphblas/python-graphblas](https://github.com/python-graphblas/python-graphblas) - **Bug reports:** [https://github.com/python-graphblas/python-graphblas/issues](https://github.com/python-graphblas/python-graphblas/issues) - **Github discussions:** [https://github.com/python-graphblas/python-graphblas/discussions](https://github.com/python-graphblas/python-graphblas/discussions) -- **Weekly community call:** [https://github.com/python-graphblas/python-graphblas/issues/247](https://github.com/python-graphblas/python-graphblas/issues/247) +- **Weekly community call:** [python-graphblas#247](https://github.com/python-graphblas/python-graphblas/issues/247) or [https://scientific-python.org/calendars/](https://scientific-python.org/calendars/) - **Chat via Discord:** [https://discord.com/invite/vur45CbwMz](https://discord.com/invite/vur45CbwMz) in the [#graphblas channel](https://discord.com/channels/786703927705862175/1024732940233605190) +

+ Directed graph + Adjacency matrix +

+ ## Install + Install the latest version of Python-graphblas via conda: + ``` $ conda install -c conda-forge python-graphblas ``` + or pip: + ``` -$ pip install python-graphblas +$ pip install 'python-graphblas[default]' ``` + This will also install the [SuiteSparse:GraphBLAS](https://github.com/DrTimothyAldenDavis/GraphBLAS) compiled C library. +We currently support the [GraphBLAS C API 2.0 specification](https://graphblas.org/docs/GraphBLAS_API_C_v2.0.0.pdf). + +### Optional Dependencies + +The following are not required by python-graphblas, but may be needed for certain functionality to work. + +- `pandas` – required for nicer `__repr__`; +- `matplotlib` – required for basic plotting of graphs; +- `scipy` – used in `io` module to read/write `scipy.sparse` format; +- `networkx` – used in `io` module to interface with `networkx` graphs; +- `fast-matrix-market` - for faster read/write of Matrix Market files with `gb.io.mmread` and `gb.io.mmwrite`. ## Description + Currently works with [SuiteSparse:GraphBLAS](https://github.com/DrTimothyAldenDavis/GraphBLAS), but the goal is to make it work with all implementations of the GraphBLAS spec. -The approach taken with this library is to follow the C-API specification as closely as possible while making improvements +The approach taken with this library is to follow the C-API 2.0 specification as closely as possible while making improvements allowed with the Python syntax. Because the spec always passes in the output object to be written to, we follow the same, which is very different from the way Python normally operates. In fact, many who are familiar with other Python data libraries (numpy, pandas, etc) will find it strange to not create new objects for every call. @@ -46,10 +76,12 @@ with how Python handles assignment, so instead we (ab)use the left-shift `<<` no assignment. This opens up all kinds of nice possibilities. This is an example of how the mapping works: + ```C // C call GrB_Matrix_mxm(M, mask, GrB_PLUS_INT64, GrB_MIN_PLUS_INT64, A, B, NULL) ``` + ```python # Python call M(mask.V, accum=binary.plus) << A.mxm(B, semiring.min_plus) @@ -67,10 +99,12 @@ is a much better approach, even if it doesn't feel very Pythonic. Descriptor flags are set on the appropriate elements to keep logic close to what it affects. Here is the same call with descriptor bits set. `ttcsr` indicates transpose the first and second matrices, complement the structure of the mask, and do a replacement on the output. + ```C // C call GrB_Matrix_mxm(M, mask, GrB_PLUS_INT64, GrB_MIN_PLUS_INT64, A, B, desc.ttcsr) ``` + ```python # Python call M(~mask.S, accum=binary.plus, replace=True) << A.T.mxm(B.T, semiring.min_plus) @@ -80,16 +114,20 @@ The objects receiving the flag operations (A.T, ~mask, etc) are also delayed obj do no computation, allowing the correct descriptor bits to be set in a single GraphBLAS call. **If no mask or accumulator is used, the call looks like this**: + ```python M << A.mxm(B, semiring.min_plus) ``` + The use of `<<` to indicate updating is actually just syntactic sugar for a real `.update()` method. The above expression could be written as: + ```python M.update(A.mxm(B, semiring.min_plus)) ``` ## Operations + ```python M(mask, accum) << A.mxm(B, semiring) # mxm w(mask, accum) << A.mxv(v, semiring) # mxv @@ -99,14 +137,18 @@ M(mask, accum) << A.ewise_mult(B, binaryop) # eWiseMult M(mask, accum) << A.kronecker(B, binaryop) # kronecker M(mask, accum) << A.T # transpose ``` + ## Extract + ```python M(mask, accum) << A[rows, cols] # rows and cols are a list or a slice w(mask, accum) << A[rows, col_index] # extract column w(mask, accum) << A[row_index, cols] # extract row s = A[row_index, col_index].value # extract single element ``` + ## Assign + ```python M(mask, accum)[rows, cols] << A # rows and cols are a list or a slice M(mask, accum)[rows, col_index] << v # assign column @@ -116,31 +158,42 @@ M[row_index, col_index] << s # assign scalar to single element # (mask and accum not allowed) del M[row_index, col_index] # remove single element ``` + ## Apply + ```python M(mask, accum) << A.apply(unaryop) M(mask, accum) << A.apply(binaryop, left=s) # bind-first M(mask, accum) << A.apply(binaryop, right=s) # bind-second ``` + ## Reduce + ```python v(mask, accum) << A.reduce_rowwise(op) # reduce row-wise v(mask, accum) << A.reduce_columnwise(op) # reduce column-wise s(accum) << A.reduce_scalar(op) s(accum) << v.reduce(op) ``` + ## Creating new Vectors / Matrices + ```python A = Matrix.new(dtype, num_rows, num_cols) # new_type B = A.dup() # dup A = Matrix.from_coo([row_indices], [col_indices], [values]) # build ``` + ## New from delayed + Delayed objects can be used to create a new object using `.new()` method + ```python C = A.mxm(B, semiring).new() ``` + ## Properties + ```python size = v.size # size nrows = M.nrows # nrows @@ -148,23 +201,30 @@ ncols = M.ncols # ncols nvals = M.nvals # nvals rindices, cindices, vals = M.to_coo() # extractTuples ``` + ## Initialization + There is a mechanism to initialize `graphblas` with a context prior to use. This allows for setting the backend to use as well as the blocking/non-blocking mode. If the context is not initialized, a default initialization will be performed automatically. + ```python import graphblas as gb + # Context initialization must happen before any other imports -gb.init('suitesparse', blocking=True) +gb.init("suitesparse", blocking=True) # Now we can import other items from graphblas from graphblas import binary, semiring from graphblas import Matrix, Vector, Scalar ``` + ## Performant User Defined Functions + Python-graphblas requires `numba` which enables compiling user-defined Python functions to native C for use in GraphBLAS. Example customized UnaryOp: + ```python from graphblas import unary @@ -173,22 +233,42 @@ def force_odd_func(x): return x + 1 return x -unary.register_new('force_odd', force_odd_func) +unary.register_new("force_odd", force_odd_func) v = Vector.from_coo([0, 1, 3], [1, 2, 3]) w = v.apply(unary.force_odd).new() w # indexes=[0, 1, 3], values=[1, 3, 3] ``` + Similar methods exist for BinaryOp, Monoid, and Semiring. +## Relation to other network analysis libraries + +Python-graphblas aims to provide an efficient and consistent expression +of graph operations using linear algebra. This allows the development of +high-performance implementations of existing and new graph algorithms +(also see [`graphblas-algorithms`](https://github.com/python-graphblas/graphblas-algorithms)). + +While end-to-end analysis can be done using `python-graphblas`, users +might find that other libraries in the Python ecosystem provide a more +convenient high-level interface for data pre-processing and transformation +(e.g. `pandas`, `scipy.sparse`), visualization (e.g. `networkx`, `igraph`), +interactive exploration and analysis (e.g. `networkx`, `igraph`) or for +algorithms that are not (yet) implemented in `graphblas-algorithms` (e.g. +`networkx`, `igraph`, `scipy.sparse.csgraph`). To facilitate communication with +other libraries, `graphblas.io` contains multiple connectors, see the +following section. + ## Import/Export connectors to the Python ecosystem + `graphblas.io` contains functions for converting to and from: + ```python import graphblas as gb # scipy.sparse matrices A = gb.io.from_scipy_sparse(m) -m = gb.io.to_scipy_sparse(m, format='csr') +m = gb.io.to_scipy_sparse(m, format="csr") # networkx graphs A = gb.io.from_networkx(g) diff --git a/binder/environment.yml b/binder/environment.yml index ef72a4d2b..9548f2126 100644 --- a/binder/environment.yml +++ b/binder/environment.yml @@ -1,10 +1,12 @@ name: graphblas channels: - - conda-forge + - conda-forge dependencies: - - python=3.10 - - python-graphblas - - matplotlib - - networkx - - pandas - - scipy + - python=3.11 + - python-graphblas + - matplotlib + - networkx + - pandas + - scipy + - drawsvg + - cairosvg diff --git a/dev-requirements.txt b/dev-requirements.txt index b84c0e849..a281672ec 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -6,6 +6,7 @@ pyyaml pandas # For I/O awkward +fast_matrix_market networkx scipy sparse @@ -16,7 +17,9 @@ matplotlib # For linting pre-commit # For testing +packaging pytest-cov +tomli # For debugging icecream ipykernel diff --git a/docs/_static/custom.css b/docs/_static/custom.css index 07834b3bc..f7dd59b74 100644 --- a/docs/_static/custom.css +++ b/docs/_static/custom.css @@ -1,62 +1,78 @@ - /* Main Page Stylings */ -.container-xl { - max-width: 1400px; -} - .intro-card { - margin-bottom: 30px; + background-color: var(--pst-color-background); + margin-bottom: 30px; } .intro-card:hover { - box-shadow: 0.2rem 0.5rem 1rem var(--pst-color-link) !important; + box-shadow: 0.2rem 0.5rem 1rem var(--pst-color-link) !important; } .intro-card .card-header { - background-color: inherit; + background-color: inherit; } .intro-card .card-header .card-text { - font-weight: bold; + font-weight: bold; } .intro-card .card-body { - margin-top: 0; + margin-top: 0; } .intro-card .card-body .card-text:first-child { - margin-bottom: 0; + margin-bottom: 0; } .shadow { - box-shadow: 0.2rem 0.5rem 1rem var(--pst-color-text-muted) !important; + box-shadow: 0.2rem 0.5rem 1rem var(--pst-color-text-muted) !important; } .table { - font-size: smaller; - width: inherit; + font-size: smaller; + width: inherit; } -.table td, .table th { - padding: 0 .75rem; +.table td, +.table th { + padding: 0 0.75rem; } .table.inline { - display: inline-table; - margin-right: 30px; + display: inline-table; + margin-right: 30px; } p.rubric { - border-bottom: none; + border-bottom: none; } -/* Styling for Jupyter Notebook ReST Exports */ +button.navbar-btn.rounded-circle { + padding: 0.25rem; +} + +button.navbar-btn.search-button { + color: var(--pst-color-text-muted); + padding: 0; +} -.dataframe tbody th, .dataframe tbody td { - padding: 10px; +button.navbar-btn:hover { + color: var(--pst-color-primary); } -.bd-sidebar-primary, .bd-sidebar-secondary { - position: sticky; +button.theme-switch-button { + font-size: calc(var(--pst-font-size-icon) - 0.1rem); + border: none; +} + +button span.theme-switch:hover { + color: var(--pst-color-primary); +} + +/* Styling for Jupyter Notebook ReST Exports */ + +.dataframe tbody th, +.dataframe tbody td { + padding: 10px; } diff --git a/docs/_static/img/GraphBLAS-API-example.png b/docs/_static/img/GraphBLAS-API-example.png index c6dd48182..1edc91988 100644 Binary files a/docs/_static/img/GraphBLAS-API-example.png and b/docs/_static/img/GraphBLAS-API-example.png differ diff --git a/docs/_static/img/GraphBLAS-mapping.png b/docs/_static/img/GraphBLAS-mapping.png index 7ef73c88d..c5d1a1d4e 100644 Binary files a/docs/_static/img/GraphBLAS-mapping.png and b/docs/_static/img/GraphBLAS-mapping.png differ diff --git a/docs/_static/img/Matrix-A-strictly-upper.png b/docs/_static/img/Matrix-A-strictly-upper.png index 9b127aa84..0fedf2617 100644 Binary files a/docs/_static/img/Matrix-A-strictly-upper.png and b/docs/_static/img/Matrix-A-strictly-upper.png differ diff --git a/docs/_static/img/Matrix-A-upper.png b/docs/_static/img/Matrix-A-upper.png index 1b930a9a3..e3703710a 100644 Binary files a/docs/_static/img/Matrix-A-upper.png and b/docs/_static/img/Matrix-A-upper.png differ diff --git a/docs/_static/img/Recorder-output.png b/docs/_static/img/Recorder-output.png index 355cc1376..525221c55 100644 Binary files a/docs/_static/img/Recorder-output.png and b/docs/_static/img/Recorder-output.png differ diff --git a/docs/_static/img/adj-graph.png b/docs/_static/img/adj-graph.png index da9f36447..13a05fcc2 100644 Binary files a/docs/_static/img/adj-graph.png and b/docs/_static/img/adj-graph.png differ diff --git a/docs/_static/img/directed-graph.svg b/docs/_static/img/directed-graph.svg index a08f346d1..c7a9cadad 100644 --- a/docs/_static/img/directed-graph.svg +++ b/docs/_static/img/directed-graph.svg @@ -1 +1 @@ -AC5.0B2.3G1.9F6.2E3.0D4.61.43.9H2.72.08.61.04.45.11.7 +AC5.0B2.3G1.9F6.2E3.0D4.61.43.9H2.72.08.61.04.45.11.7 diff --git a/docs/_static/img/draw-example.png b/docs/_static/img/draw-example.png index 3c5e6c008..90c5917d9 100644 Binary files a/docs/_static/img/draw-example.png and b/docs/_static/img/draw-example.png differ diff --git a/docs/_static/img/logo-horizontal-dark.svg b/docs/_static/img/logo-horizontal-dark.svg new file mode 100644 index 000000000..be9e5ccca --- /dev/null +++ b/docs/_static/img/logo-horizontal-dark.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/img/logo-horizontal-light.svg b/docs/_static/img/logo-horizontal-light.svg new file mode 100644 index 000000000..5894eed9a --- /dev/null +++ b/docs/_static/img/logo-horizontal-light.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/img/logo-horizontal-medium-big.svg b/docs/_static/img/logo-horizontal-medium-big.svg new file mode 100644 index 000000000..649c2aef3 --- /dev/null +++ b/docs/_static/img/logo-horizontal-medium-big.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/img/logo-horizontal-medium.svg b/docs/_static/img/logo-horizontal-medium.svg new file mode 100644 index 000000000..038781a3f --- /dev/null +++ b/docs/_static/img/logo-horizontal-medium.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/img/logo-name-dark.svg b/docs/_static/img/logo-name-dark.svg index 039eb7e25..35d4d2970 100644 --- a/docs/_static/img/logo-name-dark.svg +++ b/docs/_static/img/logo-name-dark.svg @@ -1,37 +1 @@ - - - - - - - graphblas - python- - + diff --git a/docs/_static/img/logo-name-light.svg b/docs/_static/img/logo-name-light.svg index 6e16adfbe..3331ae561 100644 --- a/docs/_static/img/logo-name-light.svg +++ b/docs/_static/img/logo-name-light.svg @@ -1,37 +1 @@ - - - - - - - graphblas - python- - + diff --git a/docs/_static/img/logo-name-medium-big.svg b/docs/_static/img/logo-name-medium-big.svg new file mode 100644 index 000000000..7bb245898 --- /dev/null +++ b/docs/_static/img/logo-name-medium-big.svg @@ -0,0 +1 @@ + diff --git a/docs/_static/img/logo-name-medium.svg b/docs/_static/img/logo-name-medium.svg new file mode 100644 index 000000000..3128fda35 --- /dev/null +++ b/docs/_static/img/logo-name-medium.svg @@ -0,0 +1 @@ + diff --git a/docs/_static/img/logo-vertical-dark.svg b/docs/_static/img/logo-vertical-dark.svg new file mode 100644 index 000000000..25dcefc17 --- /dev/null +++ b/docs/_static/img/logo-vertical-dark.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/img/logo-vertical-light.svg b/docs/_static/img/logo-vertical-light.svg new file mode 100644 index 000000000..1cb22644d --- /dev/null +++ b/docs/_static/img/logo-vertical-light.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/img/logo-vertical-medium.svg b/docs/_static/img/logo-vertical-medium.svg new file mode 100644 index 000000000..db2fcaefe --- /dev/null +++ b/docs/_static/img/logo-vertical-medium.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/img/min-plus-semiring.png b/docs/_static/img/min-plus-semiring.png index b6e9075c9..250f15b90 100644 Binary files a/docs/_static/img/min-plus-semiring.png and b/docs/_static/img/min-plus-semiring.png differ diff --git a/docs/_static/img/plus-times-semiring.png b/docs/_static/img/plus-times-semiring.png index 5cea4301f..bb2f527e8 100644 Binary files a/docs/_static/img/plus-times-semiring.png and b/docs/_static/img/plus-times-semiring.png differ diff --git a/docs/_static/img/python-graphblas-logo.svg b/docs/_static/img/python-graphblas-logo.svg new file mode 100644 index 000000000..2422973ff --- /dev/null +++ b/docs/_static/img/python-graphblas-logo.svg @@ -0,0 +1 @@ + diff --git a/docs/_static/img/repr-matrix.png b/docs/_static/img/repr-matrix.png index 39c9a42d5..34f766c9e 100644 Binary files a/docs/_static/img/repr-matrix.png and b/docs/_static/img/repr-matrix.png differ diff --git a/docs/_static/img/repr-scalar.png b/docs/_static/img/repr-scalar.png index faab22d17..8f9ba16f3 100644 Binary files a/docs/_static/img/repr-scalar.png and b/docs/_static/img/repr-scalar.png differ diff --git a/docs/_static/img/repr-vector.png b/docs/_static/img/repr-vector.png index fbcfc97bb..c5d6a0883 100644 Binary files a/docs/_static/img/repr-vector.png and b/docs/_static/img/repr-vector.png differ diff --git a/docs/_static/img/social-network.svg b/docs/_static/img/social-network.svg index a62230fa2..2e0335c54 100644 --- a/docs/_static/img/social-network.svg +++ b/docs/_static/img/social-network.svg @@ -1 +1 @@ -AnnaPriyaBlakeDanXavierYsabelle +AnnaPriyaBlakeDanXavierYsabelle diff --git a/docs/_static/img/sssp-result.png b/docs/_static/img/sssp-result.png index 18a5e1345..6b9bd1604 100644 Binary files a/docs/_static/img/sssp-result.png and b/docs/_static/img/sssp-result.png differ diff --git a/docs/_static/img/super-simple.svg b/docs/_static/img/super-simple.svg index c79530f87..73ae2ee9b 100644 --- a/docs/_static/img/super-simple.svg +++ b/docs/_static/img/super-simple.svg @@ -1 +1 @@ -025.012.030.51.54.25 +025.012.030.51.54.25 diff --git a/docs/_static/img/task-graph.svg b/docs/_static/img/task-graph.svg index e06017e9f..f48284d93 100644 --- a/docs/_static/img/task-graph.svg +++ b/docs/_static/img/task-graph.svg @@ -1 +1 @@ -StartLoad File 1Load File 2Load File 3MergeCleanNormalizeWeekly SummaryDaily SummarySerializeReport 1Report 2DashboardReport 3 +StartLoad File 1Load File 2Load File 3MergeCleanNormalizeWeekly SummaryDaily SummarySerializeReport 1Report 2DashboardReport 3 diff --git a/docs/_static/img/undirected-graph.svg b/docs/_static/img/undirected-graph.svg index 96ac206f9..e29eb261d 100644 --- a/docs/_static/img/undirected-graph.svg +++ b/docs/_static/img/undirected-graph.svg @@ -1 +1 @@ -015.622.334.61.96.23.041.454.461.02.8 +015.622.334.61.96.23.041.454.461.02.8 diff --git a/docs/_static/matrix.css b/docs/_static/matrix.css new file mode 100644 index 000000000..1937178e5 --- /dev/null +++ b/docs/_static/matrix.css @@ -0,0 +1,104 @@ +/* Based on the stylesheet used by matrepr (https://github.com/alugowski/matrepr) and modified for sphinx */ + +table.matrix { + border-collapse: collapse; + border: 0px; +} + +/* Disable a horizintal line from the default stylesheet */ +.table.matrix > :not(caption) > * > * { + border-bottom-width: 0px; +} + +/* row indices */ +table.matrix > tbody tr th { + font-size: smaller; + font-weight: bolder; + vertical-align: middle; + text-align: right; +} +/* row indices are often made bold in the source data; here make them match the boldness of the th column label style*/ +table.matrix strong { + font-weight: bold; +} + +/* column indices */ +table.matrix > thead tr th { + font-size: smaller; + font-weight: bolder; + vertical-align: middle; + text-align: center; +} + +/* cells */ +table.matrix > tbody tr td { + vertical-align: middle; + text-align: center; + position: relative; +} + +/* left border */ +table.matrix > tbody tr td:first-of-type { + border-left: solid 2px var(--pst-color-text-base); +} +/* right border */ +table.matrix > tbody tr td:last-of-type { + border-right: solid 2px var(--pst-color-text-base); +} + +/* prevents empty cells from collapsing, especially empty rows */ +table.matrix > tbody tr td:empty::before { + /* basicaly fills empty cells with   */ + content: "\00a0\00a0\00a0"; + visibility: hidden; +} +table.matrix > tbody tr td:empty::after { + content: "\00a0\00a0\00a0"; + visibility: hidden; +} + +/* matrix bracket ticks */ +table.matrix > tbody > tr:first-child > td:first-of-type::before { + content: ""; + width: 4px; + position: absolute; + top: 0; + bottom: 0; + visibility: visible; + left: 0; + right: auto; + border-top: solid 2px var(--pst-color-text-base); +} +table.matrix > tbody > tr:last-child > td:first-of-type::before { + content: ""; + width: 4px; + position: absolute; + top: 0; + bottom: 0; + visibility: visible; + left: 0; + right: auto; + border-bottom: solid 2px var(--pst-color-text-base); +} +table.matrix > tbody > tr:first-child > td:last-of-type::after { + content: ""; + width: 4px; + position: absolute; + top: 0; + bottom: 0; + visibility: visible; + left: auto; + right: 0; + border-top: solid 2px var(--pst-color-text-base); +} +table.matrix > tbody > tr:last-child > td:last-of-type::after { + content: ""; + width: 4px; + position: absolute; + top: 0; + bottom: 0; + visibility: visible; + left: auto; + right: 0; + border-bottom: solid 2px var(--pst-color-text-base); +} diff --git a/docs/api_reference/collections.rst b/docs/api_reference/collections.rst new file mode 100644 index 000000000..83cabfd21 --- /dev/null +++ b/docs/api_reference/collections.rst @@ -0,0 +1,23 @@ +Collections +----------- + +Matrix +~~~~~~ + +.. autoclass:: graphblas.Matrix + :members: + :special-members: __getitem__, __setitem__, __delitem__, __contains__, __iter__ + +Vector +~~~~~~ + +.. autoclass:: graphblas.Vector + :members: + :special-members: __getitem__, __setitem__, __delitem__, __contains__, __iter__ + +Scalar +~~~~~~ + +.. autoclass:: graphblas.Scalar + :members: + :special-members: __eq__, __bool__ diff --git a/docs/api_reference/exceptions.rst b/docs/api_reference/exceptions.rst new file mode 100644 index 000000000..7968f854c --- /dev/null +++ b/docs/api_reference/exceptions.rst @@ -0,0 +1,7 @@ +Exceptions +---------- + +.. automodule:: graphblas.exceptions + :members: InvalidObject, InvalidIndex, DomainMismatch, DimensionMismatch, + OutputNotEmpty, OutOfMemory, IndexOutOfBound, Panic, EmptyObject, + NotImplementedException, UdfParseError diff --git a/docs/api_reference/index.rst b/docs/api_reference/index.rst index 2f829e29a..84e7d65eb 100644 --- a/docs/api_reference/index.rst +++ b/docs/api_reference/index.rst @@ -4,131 +4,10 @@ API Reference ============= -Collections ------------ +.. toctree:: + :maxdepth: 2 -Matrix -~~~~~~ - -.. autoclass:: graphblas.Matrix - :members: - :special-members: __getitem__, __setitem__, __delitem__, __contains__, __iter__ - -Vector -~~~~~~ - -.. autoclass:: graphblas.Vector - :members: - :special-members: __getitem__, __setitem__, __delitem__, __contains__, __iter__ - -Scalar -~~~~~~ - -.. autoclass:: graphblas.Scalar - :members: - :special-members: __eq__, __bool__ - -Operators ---------- - -UnaryOp -~~~~~~~ - -.. autoclass:: graphblas.core.operator.UnaryOp() - :members: - -BinaryOp -~~~~~~~~ - -.. autoclass:: graphblas.core.operator.BinaryOp() - :members: - -Monoid -~~~~~~ - -.. autoclass:: graphblas.core.operator.Monoid() - :members: - -Semiring -~~~~~~~~ - -.. autoclass:: graphblas.core.operator.Semiring() - :members: - -IndexUnaryOp -~~~~~~~~~~~~ - -.. autoclass:: graphblas.core.operator.IndexUnaryOp() - :members: - -SelectOp -~~~~~~~~ - -.. autoclass:: graphblas.core.operator.SelectOp() - :members: - - -Input/Output ------------- - -NetworkX -~~~~~~~~ - -These methods require `networkx `_ to be installed. - -.. autofunction:: graphblas.io.from_networkx - -.. autofunction:: graphblas.io.to_networkx - -Numpy -~~~~~ - -These methods require `scipy `_ to be installed, as some -of the scipy.sparse machinery is used during the conversion process. - -.. autofunction:: graphblas.io.from_numpy - -.. autofunction:: graphblas.io.to_numpy - -Scipy Sparse -~~~~~~~~~~~~ - -These methods require `scipy `_ to be installed. - -.. autofunction:: graphblas.io.from_scipy_sparse - -.. autofunction:: graphblas.io.to_scipy_sparse - -PyData Sparse -~~~~~~~~~~~~~ - -These methods require `sparse `_ to be installed. - -.. autofunction:: graphblas.io.from_pydata_sparse - -.. autofunction:: graphblas.io.to_pydata_sparse - -Matrix Market -~~~~~~~~~~~~~ - -Matrix Market is a `plain-text format `_ for storing graphs. - -These methods require `scipy `_ to be installed. - -.. autofunction:: graphblas.io.mmread - -.. autofunction:: graphblas.io.mmwrite - -Visualization -~~~~~~~~~~~~~ - -.. autofunction:: graphblas.io.draw - - -Exceptions ----------- - -.. automodule:: graphblas.exceptions - :members: InvalidObject, InvalidIndex, DomainMismatch, DimensionMismatch, - OutputNotEmpty, OutOfMemory, IndexOutOfBound, Panic, EmptyObject, - NotImplementedException, UdfParseError + collections + operators + io + exceptions diff --git a/docs/api_reference/io.rst b/docs/api_reference/io.rst new file mode 100644 index 000000000..1cfc98516 --- /dev/null +++ b/docs/api_reference/io.rst @@ -0,0 +1,75 @@ +Input/Output +------------ + +NetworkX +~~~~~~~~ + +These methods require `networkx `_ to be installed. + +.. autofunction:: graphblas.io.from_networkx + +.. autofunction:: graphblas.io.to_networkx + +NumPy +~~~~~ + +These methods convert to and from dense arrays. For more, see :ref:`IO in the user guide `. + +.. automethod:: graphblas.core.matrix.Matrix.from_dense + +.. automethod:: graphblas.core.matrix.Matrix.to_dense + +.. automethod:: graphblas.core.vector.Vector.from_dense + +.. automethod:: graphblas.core.vector.Vector.to_dense + +Scipy Sparse +~~~~~~~~~~~~ + +These methods require `scipy `_ to be installed. + +.. autofunction:: graphblas.io.from_scipy_sparse + +.. autofunction:: graphblas.io.to_scipy_sparse + +PyData Sparse +~~~~~~~~~~~~~ + +These methods require `sparse `_ to be installed. + +.. autofunction:: graphblas.io.from_pydata_sparse + +.. autofunction:: graphblas.io.to_pydata_sparse + +Matrix Market +~~~~~~~~~~~~~ + +Matrix Market is a `plain-text format `_ for storing graphs. + +These methods require `scipy `_ to be installed. + +.. autofunction:: graphblas.io.mmread + +.. autofunction:: graphblas.io.mmwrite + +Awkward Array +~~~~~~~~~~~~~ + +`Awkward Array `_ is a library for nested, +variable-sized data, including arbitrary-length lists, records, mixed types, +and missing data, using NumPy-like idioms. Note that the intended use of the +``awkward-array``-related ``io`` functions is to convert ``graphblas`` objects to awkward, +perform necessary computations/transformations and, if required, convert the +awkward array back to ``graphblas`` format. To facilitate this conversion process, +``graphblas.io.to_awkward`` adds top-level attribute ``format``, describing the +format of the ``graphblas`` object (this attributed is used by the +``graphblas.io.from_awkward`` function to reconstruct the ``graphblas`` object). + +.. autofunction:: graphblas.io.to_awkward + +.. autofunction:: graphblas.io.from_awkward + +Visualization +~~~~~~~~~~~~~ + +.. autofunction:: graphblas.viz.draw diff --git a/docs/api_reference/operators.rst b/docs/api_reference/operators.rst new file mode 100644 index 000000000..8836bb638 --- /dev/null +++ b/docs/api_reference/operators.rst @@ -0,0 +1,38 @@ +Operators +--------- + +UnaryOp +~~~~~~~ + +.. autoclass:: graphblas.core.operator.UnaryOp() + :members: + +BinaryOp +~~~~~~~~ + +.. autoclass:: graphblas.core.operator.BinaryOp() + :members: + +Monoid +~~~~~~ + +.. autoclass:: graphblas.core.operator.Monoid() + :members: + +Semiring +~~~~~~~~ + +.. autoclass:: graphblas.core.operator.Semiring() + :members: + +IndexUnaryOp +~~~~~~~~~~~~ + +.. autoclass:: graphblas.core.operator.IndexUnaryOp() + :members: + +SelectOp +~~~~~~~~ + +.. autoclass:: graphblas.core.operator.SelectOp() + :members: diff --git a/docs/conf.py b/docs/conf.py index ddd360326..283f6d047 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,13 +19,13 @@ # -- Project information ----------------------------------------------------- project = "python-graphblas" -copyright = "2022, Anaconda, Inc" +copyright = "2020-2023, Anaconda, Inc. and contributors" author = "Anaconda, Inc" # The full version, including alpha/beta/rc tags # release = "1.3.2" # See: https://github.com/pypa/setuptools_scm/#usage-from-sphinx -from importlib.metadata import version # noqa: E402 isort: skip +from importlib.metadata import version # noqa: E402 isort:skip release = version("python-graphblas") del version @@ -36,7 +36,7 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = ["sphinx.ext.autodoc", "numpydoc", "sphinx_panels", "nbsphinx"] -html_css_files = ["custom.css"] +html_css_files = ["custom.css", "matrix.css"] html_js_files = ["custom.js"] # Add any paths that contain templates here, relative to this directory. @@ -55,17 +55,20 @@ # html_theme = "pydata_sphinx_theme" +html_favicon = "_static/img/python-graphblas-logo.svg" + # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] html_theme_options = { "logo": { - "image_light": "img/logo-name-light.svg", - "image_dark": "img/logo-name-dark.svg", + "image_light": "_static/img/logo-horizontal-light.svg", + "image_dark": "_static/img/logo-horizontal-dark.svg", }, "github_url": "https://github.com/python-graphblas/python-graphblas", } +html_show_sourcelink = False autodoc_member_order = "groupwise" diff --git a/docs/contributor_guide/index.rst b/docs/contributor_guide/index.rst index e8078f933..3b94f2f35 100644 --- a/docs/contributor_guide/index.rst +++ b/docs/contributor_guide/index.rst @@ -58,7 +58,7 @@ Here are instructions for two popular environment managers: :: # Create a conda environment named ``graphblas-dev`` using environment.yml in the repository root - conda create -f environment.yml + conda env create -f environment.yml # Activate it conda activate graphblas-dev # Install python-graphblas from source diff --git a/docs/env.yml b/docs/env.yml index 7bc373c06..78a50afbe 100644 --- a/docs/env.yml +++ b/docs/env.yml @@ -1,23 +1,23 @@ name: python-graphblas-docs channels: - - conda-forge - - nodefaults + - conda-forge + - nodefaults dependencies: - - python=3.10 - - pip - # python-graphblas dependencies - - donfig - - numba - - python-suitesparse-graphblas>=7.4.0.0 - - pyyaml - # extra dependencies - - matplotlib - - networkx - - pandas - - scipy>=1.7.0 - # docs dependencies - - commonmark # For RTD - - nbsphinx - - numpydoc - - pydata-sphinx-theme - - sphinx-panels + - python=3.10 + - pip + # python-graphblas dependencies + - donfig + - numba + - python-suitesparse-graphblas>=7.4.0.0 + - pyyaml + # extra dependencies + - matplotlib + - networkx + - pandas + - scipy>=1.7.0 + # docs dependencies + - commonmark # For RTD + - nbsphinx + - numpydoc + - pydata-sphinx-theme=0.13.1 + - sphinx-panels=0.6.0 diff --git a/docs/getting_started/faq.rst b/docs/getting_started/faq.rst index ab905050c..2609e7929 100644 --- a/docs/getting_started/faq.rst +++ b/docs/getting_started/faq.rst @@ -101,17 +101,28 @@ Bugs are not considered deprecations and may be fixed immediately. What is the version support policy? +++++++++++++++++++++++++++++++++++ -Each major Python version will be supported for at least 36 to 42 months. +Each major Python version will be supported for at least 36. Major dependencies such as NumPy should be supported for at least 24 months. -This is motivated by these guidelines: +We aim to follow SPEC 0: -- https://numpy.org/neps/nep-0029-deprecation_policy.html - https://scientific-python.org/specs/spec-0000/ ``python-graphblas`` itself follows a "single trunk" versioning strategy. For example, if a CVE is discovered, we won't retroactively apply the fix to previous releases. Instead, the fix will only be available starting with the next release. +The `GraphBLAS C API specification `_ is expected to change slowly, but it does change. +We aim to support the latest version of the GraphBLAS spec and of implementations. +We will announce plans to drop support of *old* versions of the spec or major versions of implementations +*before* we do so. We will make the announcements in the +`release notes `_ and in our Discord channel. +If the proposed changes will negatively affect you, please +`let us know `_ +so we may work together towards a solution. + +To see which versions of SuiteSparse:GraphBLAS we support, look at the version specification +of ``suitesparse`` under ``[projects.optional-dependencies]`` in ``pyproject.toml``. + What is the relationship between python-graphblas and pygraphblas? ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ diff --git a/docs/getting_started/index.rst b/docs/getting_started/index.rst index 661550803..3726131d2 100644 --- a/docs/getting_started/index.rst +++ b/docs/getting_started/index.rst @@ -15,7 +15,7 @@ Using pip: :: - pip install python-graphblas + pip install python-graphblas[default] Whether installing with conda or pip, the underlying package that is imported in Python is named ``graphblas``. The convention is to import as: @@ -34,6 +34,7 @@ to work. - `matplotlib `__ -- required for basic plotting of graphs - `scipy `__ -- used in ``io`` module to read/write ``scipy.sparse`` format - `networkx `__ -- used in ``io`` module to interface with networkx graphs + - `fast-matrix-market `__ -- for faster read/write of Matrix Market files with ``gb.io.mmread`` and ``gb.io.mmwrite`` GraphBLAS Fundamentals ---------------------- diff --git a/docs/getting_started/primer.rst b/docs/getting_started/primer.rst index 710dca702..b5bec26ee 100644 --- a/docs/getting_started/primer.rst +++ b/docs/getting_started/primer.rst @@ -89,26 +89,13 @@ makes for faster graph algorithms. # networkx-style storage of an undirected graph G = { - 0: {1: {'weight': 5.6}, - 2: {'weight': 2.3}, - 3: {'weight': 4.6}}, - 1: {0: {'weight': 5.6}, - 2: {'weight': 1.9}, - 3: {'weight': 6.2}}, - 2: {0: {'weight': 2.3}, - 1: {'weight': 1.9}, - 3: {'weight': 3.0}}, - 3: {0: {'weight': 4.6}, - 1: {'weight': 6.2}, - 2: {'weight': 3.0}, - 4: {'weight': 1.4}}, - 4: {3: {'weight': 1.4}, - 5: {'weight': 4.4}, - 6: {'weight': 1.0}}, - 5: {4: {'weight': 4.4}, - 6: {'weight': 2.8}}, - 6: {4: {'weight': 1.0}, - 5: {'weight': 2.8}} + 0: {1: {"weight": 5.6}, 2: {"weight": 2.3}, 3: {"weight": 4.6}}, + 1: {0: {"weight": 5.6}, 2: {"weight": 1.9}, 3: {"weight": 6.2}}, + 2: {0: {"weight": 2.3}, 1: {"weight": 1.9}, 3: {"weight": 3.0}}, + 3: {0: {"weight": 4.6}, 1: {"weight": 6.2}, 2: {"weight": 3.0}, 4: {"weight": 1.4}}, + 4: {3: {"weight": 1.4}, 5: {"weight": 4.4}, 6: {"weight": 1.0}}, + 5: {4: {"weight": 4.4}, 6: {"weight": 2.8}}, + 6: {4: {"weight": 1.0}, 5: {"weight": 2.8}}, } An alternative way to store a graph is as an adjacency matrix. Each node becomes both a row @@ -240,7 +227,9 @@ node 0. [0, 0, 1, 1, 2], [1, 2, 2, 3, 3], [2.0, 5.0, 1.5, 4.25, 0.5], - nrows=4, ncols=4) + nrows=4, + ncols=4 + ) v = Vector.from_coo([start_node], [0.0], size=4) # Compute SSSP @@ -274,7 +263,7 @@ and showing that linear algebra can be used to compute graph algorithms with the of semirings. This is a somewhat new field of research, so many academic papers and talks are being given every year. -`Graphblas.org `_ remains the best source for keeping up-to-date with the latest +`Graphblas.org `_ remains the best source for keeping up-to-date with the latest developments in this area. Many people will benefit from faster graph algorithms written in GraphBLAS, but for those that want diff --git a/docs/make.bat b/docs/make.bat index 2119f5109..153be5e2f 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -21,7 +21,7 @@ if errorlevel 9009 ( echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ + echo.https://www.sphinx-doc.org/ exit /b 1 ) diff --git a/docs/user_guide/collections.rst b/docs/user_guide/collections.rst index 2ce759bf4..de7469c6d 100644 --- a/docs/user_guide/collections.rst +++ b/docs/user_guide/collections.rst @@ -145,7 +145,7 @@ The shape and dtype remain unchanged, but the collection will be fully sparse (i to_coo ~~~~~~ -To go from a collection back to the index and values, ``.to_coo()`` can be called. Numpy arrays +To go from a collection back to the index and values, ``.to_coo()`` can be called. NumPy arrays will be returned in a tuple. .. code-block:: python diff --git a/docs/user_guide/fundamentals.rst b/docs/user_guide/fundamentals.rst index f47296c47..6e4a5d195 100644 --- a/docs/user_guide/fundamentals.rst +++ b/docs/user_guide/fundamentals.rst @@ -37,7 +37,7 @@ The descriptor is a set of bitwise flags. - Replace mode indicates that elements outside the mask area should be cleared in the final output. When not in replace mode, elements outside the mask are left untouched. -For more details, look at the official API spec at `graphblas.org `_. +For more details, look at the official API spec at `graphblas.org `_. C-to-Python Mapping ------------------- diff --git a/docs/user_guide/init.rst b/docs/user_guide/init.rst index 62f81b50f..ffb6a3463 100644 --- a/docs/user_guide/init.rst +++ b/docs/user_guide/init.rst @@ -8,8 +8,9 @@ GraphBLAS must be initialized before it can be used. This is done with the .. code-block:: python import graphblas as gb + # Context initialization must happen before any other imports - gb.init('suitesparse', blocking=False) + gb.init("suitesparse", blocking=False) # Now we can import other items from graphblas from graphblas import binary, semiring diff --git a/docs/user_guide/io.rst b/docs/user_guide/io.rst index 9431ff413..f27b40bd3 100644 --- a/docs/user_guide/io.rst +++ b/docs/user_guide/io.rst @@ -4,6 +4,8 @@ Input/Output There are several ways to get data into and out of python-graphblas. +.. _from-to-values: + From/To Values -------------- @@ -29,6 +31,7 @@ array will match the collection dtype. v = gb.Vector.from_coo([1, 3, 6], [2, 3, 4], float, size=10) .. csv-table:: + :class: matrix :header: 0,1,2,3,4,5,6,7,8,9,10 ,2.0,,3.0,,,4.0,,, @@ -129,3 +132,19 @@ Note that A is unchanged in the above example. The SuiteSparse export has a ``give_ownership`` option. This performs a zero-copy move operation and invalidates the original python-graphblas object. When extreme speed is needed or memory is too limited to make a copy, this option may be needed. + +Matrix Market files +------------------- + +The `Matrix Market file format `_ is a common +file format for storing sparse arrays in human-readable ASCII. +Matrix Market files--also called MM files--often use ".mtx" file extension. +For example, many datasets in MM format can be found in `the SuiteSparse Matrix Collection `_. + +Use ``gb.io.mmread()`` to read a Matrix Market file to a python-graphblas Matrix, +and ``gb.io.mmwrite()`` to write a Matrix to a Matrix Market file. +These names match the equivalent functions in `scipy.sparse `_. + +``scipy`` is required to be installed to read Matrix Market files. +If ``fast_matrix_market`` is installed, it will be used by default for +`much better performance `_. diff --git a/docs/user_guide/operations.rst b/docs/user_guide/operations.rst index 41e4fc2c6..18d0352d7 100644 --- a/docs/user_guide/operations.rst +++ b/docs/user_guide/operations.rst @@ -8,7 +8,7 @@ Matrix Multiply The GraphBLAS spec contains three methods for matrix multiplication, depending on whether the inputs are Matrix or Vector. - - **mxm** -- Matrix-Matrix multplications + - **mxm** -- Matrix-Matrix multiplication - **mxv** -- Matrix-Vector multiplication - **vxm** -- Vector-Matrix multiplication @@ -26,18 +26,28 @@ a Vector is treated as an nx1 column matrix. .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 1, 1, 2], [1, 2, 2, 3, 3], - [2., 5., 1.5, 4.25, 0.5], nrows=4, ncols=4) - B = gb.Matrix.from_coo([0, 0, 1, 1, 2, 2, 3, 3], [1, 2, 0, 1, 1, 2, 0, 1], - [3., 2., 9., 6., 3., 1., 0., 5.]) + A = gb.Matrix.from_coo( + [0, 0, 1, 1, 2], + [1, 2, 2, 3, 3], + [2., 5., 1.5, 4.25, 0.5], + nrows=4, + ncols=4 + ) + B = gb.Matrix.from_coo( + [0, 0, 1, 1, 2, 2, 3, 3], + [1, 2, 0, 1, 1, 2, 0, 1], + [3., 2., 9., 6., 3., 1., 0., 5.] + ) + C = gb.Matrix(float, A.nrows, B.ncols) # These are equivalent - C << A.mxm(B, op='min_plus') # method style + C << A.mxm(B, op="min_plus") # method style C << gb.semiring.min_plus(A @ B) # functional style .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1,2,3 + :stub-columns: 1 **0**,,2.0,5.0, **1**,,,1.5,4.25 @@ -45,8 +55,9 @@ a Vector is treated as an nx1 column matrix. **3**,,,, .. csv-table:: B - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,3.0,2.0 **1**,9.0,6.0, @@ -54,8 +65,9 @@ a Vector is treated as an nx1 column matrix. **3**,0.0,5.0, .. csv-table:: C << min_plus(A @ B) - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,11.0,8.0,6.0 **1**,4.25,4.5,2.5 @@ -66,17 +78,24 @@ a Vector is treated as an nx1 column matrix. .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 1, 1, 2], [1, 2, 2, 3, 3], - [2., 5., 1.5, 4.25, 0.5], nrows=4, ncols=4) + A = gb.Matrix.from_coo( + [0, 0, 1, 1, 2], + [1, 2, 2, 3, 3], + [2., 5., 1.5, 4.25, 0.5], + nrows=4, + ncols=4 + ) v = gb.Vector.from_coo([0, 1, 3], [10., 20., 40.]) + w = gb.Vector(float, A.nrows) # These are equivalent - w << A.mxv(v, op='plus_times') # method style + w << A.mxv(v, op="plus_times") # method style w << gb.semiring.plus_times(A @ v) # functional style .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1,2,3 + :stub-columns: 1 **0**,,2.0,5.0, **1**,,,1.5,4.25 @@ -84,13 +103,13 @@ a Vector is treated as an nx1 column matrix. **3**,,,, .. csv-table:: v - :class: inline + :class: inline matrix :header: 0,1,2,3 10.0,20.0,,40.0 .. csv-table:: w << plus_times(A @ v) - :class: inline + :class: inline matrix :header: 0,1,2,3 40.0,170.0,20.0, @@ -100,22 +119,27 @@ a Vector is treated as an nx1 column matrix. .. code-block:: python v = gb.Vector.from_coo([0, 1, 3], [10., 20., 40.]) - B = gb.Matrix.from_coo([0, 0, 1, 1, 2, 2, 3, 3], [1, 2, 0, 1, 1, 2, 0, 1], - [3., 2., 9., 6., 3., 1., 0., 5.]) + B = gb.Matrix.from_coo( + [0, 0, 1, 1, 2, 2, 3, 3], + [1, 2, 0, 1, 1, 2, 0, 1], + [3., 2., 9., 6., 3., 1., 0., 5.] + ) + u = gb.Vector(float, B.ncols) # These are equivalent - u << v.vxm(B, op='plus_plus') # method style + u << v.vxm(B, op="plus_plus") # method style u << gb.semiring.plus_plus(v @ B) # functional style .. csv-table:: v - :class: inline + :class: inline matrix :header: 0,1,2,3 10.0,20.0,,40.0 .. csv-table:: B - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,3.0,2.0 **1**,9.0,6.0, @@ -123,7 +147,7 @@ a Vector is treated as an nx1 column matrix. **3**,0.0,5.0, .. csv-table:: u << plus_plus(v @ B) - :class: inline + :class: inline matrix :header: 0,1,2 69.0,84.0,12.0 @@ -145,34 +169,44 @@ Example usage: .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 1, 1, 2], [1, 2, 0, 2, 1], - [2.0, 5.0, 1.5, 4.0, 0.5]) - B = gb.Matrix.from_coo([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 1, 2], - [3., -2., 0., 6., 3., 1.]) + A = gb.Matrix.from_coo( + [0, 0, 1, 1, 2], + [1, 2, 0, 2, 1], + [2., 5., 1.5, 4., 0.5] + ) + B = gb.Matrix.from_coo( + [0, 0, 1, 1, 2, 2], + [1, 2, 0, 1, 1, 2], + [3., -2., 0., 6., 3., 1.] + ) + C = gb.Matrix(float, A.nrows, A.ncols) # These are equivalent - C << A.ewise_mult(B, op='min') # method style + C << A.ewise_mult(B, op="min") # method style C << gb.binary.min(A & B) # functional style .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,2.0,5.0 **1**,1.5,,4.0 **2**,,0.5, .. csv-table:: B - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,3.0,-2.0 **1**,0.0,6.0, **2**,,3.0,1.0 .. csv-table:: C << min(A & B) - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,2.0,-2.0 **1**,0.0,, @@ -221,34 +255,45 @@ should be used with the functional syntax, ``left_default`` and ``right_default` .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 2], - [9.0, 2.0, 5.0, 1.5, 4.0], nrows=3) - B = gb.Matrix.from_coo([0, 0, 0, 2, 2, 2], [0, 1, 2, 0, 1, 2], - [4., 0., -2., 6., 3., 1.]) + A = gb.Matrix.from_coo( + [0, 0, 0, 1, 1], + [0, 1, 2, 0, 2], + [9., 2., 5., 1.5, 4.], + nrows=3 + ) + B = gb.Matrix.from_coo( + [0, 0, 0, 2, 2, 2], + [0, 1, 2, 0, 1, 2], + [4., 0., -2., 6., 3., 1.] + ) + C = gb.Matrix(float, A.nrows, A.ncols) # These are equivalent - C << A.ewise_add(B, op='minus') # method style + C << A.ewise_add(B, op="minus") # method style C << gb.binary.minus(A | B) # functional style .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,9.0,2.0,5.0 **1**,1.5,,4.0 **2**,,, .. csv-table:: B - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,4.0,0.0,-2.0 **1**,,, **2**,6.0,3.0,1.0 .. csv-table:: C << A.ewise_add(B, 'minus') - :class: inline + :class: inline matrix :header: ,0,1,2, + :stub-columns: 1 **0**,5.0,2.0,7.0 **1**,1.5,,4.0 @@ -258,34 +303,45 @@ should be used with the functional syntax, ``left_default`` and ``right_default` .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 2], - [9.0, 2.0, 5.0, 1.5, 4.0], nrows=3) - B = gb.Matrix.from_coo([0, 0, 0, 2, 2, 2], [0, 1, 2, 0, 1, 2], - [4., 0., -2., 6., 3., 1.]) + A = gb.Matrix.from_coo( + [0, 0, 0, 1, 1], + [0, 1, 2, 0, 2], + [9., 2., 5., 1.5, 4.], + nrows=3 + ) + B = gb.Matrix.from_coo( + [0, 0, 0, 2, 2, 2], + [0, 1, 2, 0, 1, 2], + [4., 0., -2., 6., 3., 1.] + ) + C = gb.Matrix(float, A.nrows, A.ncols) # These are equivalent - C << A.ewise_union(B, op='minus', left_default=0, right_default=0) # method style + C << A.ewise_union(B, op="minus", left_default=0, right_default=0) # method style C << gb.binary.minus(A | B, left_default=0, right_default=0) # functional style .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,9.0,2.0,5.0 **1**,1.5,,4.0 **2**,,, .. csv-table:: B - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,4.0,0.0,-2.0 **1**,,, **2**,6.0,3.0,1.0 .. csv-table:: C << A.ewise_union(B, 'minus', 0, 0) - :class: inline + :class: inline matrix :header: ,0,1,2, + :stub-columns: 1 **0**,5.0,2.0,7.0 **1**,1.5,,4.0 @@ -315,17 +371,18 @@ Vector Slice Example: .. code-block:: python v = gb.Vector.from_coo([0, 1, 3, 4, 6], [10., 2., 40., -5., 24.]) + w = gb.Vector(float, 4) w << v[:4] .. csv-table:: v - :class: inline + :class: inline matrix :header: 0,1,2,3,4,5,6 10.0,2.0,,40.0,-5.0,,24.0 .. csv-table:: w << v[:4] - :class: inline + :class: inline matrix :header: 0,1,2,3 10.0,2.0,,40.0 @@ -334,22 +391,28 @@ Matrix List Example: .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 0, 2], - [2.0, 5.0, 1.5, 4.0, 0.5, -7.0]) + A = gb.Matrix.from_coo( + [0, 0, 1, 1, 2, 2], + [1, 2, 0, 1, 0, 2], + [2., 5., 1.5, 4., 0.5, -7.] + ) + C = gb.Matrix(float, 2, A.ncols) C << A[[0, 2], :] .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,2.0,5.0 **1**,1.5,4.0, **2**,0.5,,-7.0 .. csv-table:: C << A[[0, 2], :] - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,2.0,5.0 **1**,0.5,,-7.0 @@ -374,31 +437,39 @@ Matrix-Matrix Assignment Example: .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 0, 2], - [2.0, 5.0, 1.5, 4.0, 0.5, -7.0]) - B = gb.Matrix.from_coo([0, 0, 1, 1], [0, 1, 0, 1], - [-99., -98., -97., -96.]) - + A = gb.Matrix.from_coo( + [0, 0, 1, 1, 2, 2], + [1, 2, 0, 1, 0, 2], + [2., 5., 1.5, 4., 0.5, -7.] + ) + B = gb.Matrix.from_coo( + [0, 0, 1, 1], + [0, 1, 0, 1], + [-99., -98., -97., -96.] + ) A[::2, ::2] << B .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,2.0,5.0 **1**,1.5,4.0, **2**,0.5,,-7.0 .. csv-table:: B - :class: inline + :class: inline matrix :header: ,0,1 + :stub-columns: 1 **0**,-99.0,-98.0 **1**,-97.0,-96.0 .. csv-table:: A[::2, ::2] << B - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,-99.0,2.0,-98.0 **1**,1.5,4.0, @@ -408,29 +479,34 @@ Matrix-Vector Assignment Example: .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 0, 2], - [2.0, 5.0, 1.5, 4.0, 0.5, -7.0]) + A = gb.Matrix.from_coo( + [0, 0, 1, 1, 2, 2], + [1, 2, 0, 1, 0, 2], + [2., 5., 1.5, 4., 0.5, -7.] + ) v = gb.Vector.from_coo([2], [-99.]) A[1, :] << v .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,2.0,5.0 **1**,1.5,4.0, **2**,0.5,,-7.0 .. csv-table:: v - :class: inline + :class: inline matrix :header: 0,1,2 ,,-99.0 .. csv-table:: A[1, :] << v - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,2.0,5.0 **1**,,,-99.0 @@ -445,13 +521,13 @@ Vector-Scalar Assignment Example: v[:4] << 99 .. csv-table:: v - :class: inline + :class: inline matrix :header: 0,1,2,3,4,5,6 10,2,,40,-5,,24 .. csv-table:: v[:4] << 99 - :class: inline + :class: inline matrix :header: 0,1,2,3,4,5,6 99,99,99,99,-5,,24 @@ -473,19 +549,20 @@ function with the collection as the argument. .. code-block:: python v = gb.Vector.from_coo([0, 1, 3], [10., 20., 40.]) + w = gb.Vector(float, v.size) # These are equivalent w << v.apply(gb.unary.minv) w << gb.unary.minv(v) .. csv-table:: v - :class: inline + :class: inline matrix :header: 0,1,2,3 10.0,20.0,,40.0 .. csv-table:: w << minv(v) - :class: inline + :class: inline matrix :header: 0,1,2,3 0.1,0.05,,0.025 @@ -495,19 +572,20 @@ function with the collection as the argument. .. code-block:: python v = gb.Vector.from_coo([0, 1, 3], [10., 20., 40.]) + w = gb.Vector(int, v.size) # These are equivalent w << v.apply(gb.indexunary.index) w << gb.indexunary.index(v) .. csv-table:: v - :class: inline + :class: inline matrix :header: 0,1,2,3 10.0,20.0,,40.0 .. csv-table:: w << index(v) - :class: inline + :class: inline matrix :header: 0,1,2,3 0,1,,3 @@ -517,20 +595,21 @@ function with the collection as the argument. .. code-block:: python v = gb.Vector.from_coo([0, 1, 3], [10., 20., 40.]) + w = gb.Vector(float, v.size) # These are all equivalent - w << v.apply('minus', right=15) + w << v.apply("minus", right=15) w << gb.binary.minus(v, right=15) w << v - 15 .. csv-table:: v - :class: inline + :class: inline matrix :header: 0,1,2,3 10.0,20.0,,40.0 .. csv-table:: w << v.apply('minus', right=15) - :class: inline + :class: inline matrix :header: 0,1,2,3, -5.0,5.0,,25.0 @@ -546,24 +625,30 @@ Upper Triangle Example: .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 1, 1, 2, 2], [1, 2, 0, 2, 1, 2], - [2.0, 5.0, 1.5, 4.0, 0.5, -7.0]) + A = gb.Matrix.from_coo( + [0, 0, 1, 1, 2, 2], + [1, 2, 0, 2, 1, 2], + [2., 5., 1.5, 4., 0.5, -7.] + ) + C = gb.Matrix(float, A.nrows, A.ncols) # These are equivalent - C << A.select('triu') + C << A.select("triu") C << gb.select.triu(A) .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,2.0,5.0 **1**,1.5,,4.0 **2**,,0.5,-7.0 .. csv-table:: C << select.triu(A) - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,2.0,5.0 **1**,,,4.0 @@ -574,19 +659,20 @@ Select by Value Example: .. code-block:: python v = gb.Vector.from_coo([0, 1, 3, 4, 6], [10., 2., 40., -5., 24.]) + w = gb.Vector(float, v.size) # These are equivalent - w << v.select('>=', 5) + w << v.select(">=", 5) w << gb.select.value(v >= 5) .. csv-table:: v - :class: inline + :class: inline matrix :header: 0,1,2,3,4,5,6 10.0,2.0,,40.0,-5.0,,24.0 .. csv-table:: w << select.value(v >= 5) - :class: inline + :class: inline matrix :header: 0,1,2,3,4,5,6 10.0,,,40.0,,,24.0 @@ -605,21 +691,26 @@ A monoid or aggregator is used to perform the reduction. .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 1, 1, 2, 2], [1, 3, 0, 1, 0, 1], - [2.0, 5.0, 1.5, 4.0, 0.5, -7.0]) + A = gb.Matrix.from_coo( + [0, 0, 1, 1, 2, 2], + [1, 3, 0, 1, 0, 1], + [2., 5., 1.5, 4., 0.5, -7.] + ) + w = gb.Vector(float, A.ncols) - w << A.reduce_columnwise('times') + w << A.reduce_columnwise("times") .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1,2,3 + :stub-columns: 1 **0**,,2.0,,5.0 **1**,1.5,4.0,, **2**,0.5,-7.0,, .. csv-table:: w << A.reduce_columnwise('times') - :class: inline + :class: inline matrix :header: ,0,1,2,3 ,0.75,-56.0,,5.0 @@ -628,21 +719,26 @@ A monoid or aggregator is used to perform the reduction. .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 1, 1, 2, 2], [1, 3, 0, 1, 0, 1], - [2.0, 5.0, 1.5, 4.0, 0.5, -7.0]) + A = gb.Matrix.from_coo( + [0, 0, 1, 1, 2, 2], + [1, 3, 0, 1, 0, 1], + [2., 5., 1.5, 4., 0.5, -7.] + ) + s = gb.Scalar(float) - s << A.reduce_scalar('max') + s << A.reduce_scalar("max") .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1,2,3 + :stub-columns: 1 **0**,,2.0,,5.0 **1**,1.5,4.0,, **2**,0.5,-7.0,, .. csv-table:: s << A.reduce_scalar('max') - :class: inline + :class: inline matrix :header: ,,,, 5.0 @@ -652,19 +748,20 @@ A monoid or aggregator is used to perform the reduction. .. code-block:: python v = gb.Vector.from_coo([0, 1, 3, 4, 6], [10., 2., 40., -5., 24.]) + s = gb.Scalar(int) # These are equivalent - s << v.reduce('argmin') + s << v.reduce("argmin") s << gb.agg.argmin(v) .. csv-table:: v - :class: inline + :class: inline matrix :header: 0,1,2,3,4,5,6 10.0,2.0,,40.0,-5.0,,24.0 .. csv-table:: s << argmin(v) - :class: inline + :class: inline matrix :header: ,,, 4 @@ -679,22 +776,28 @@ To force the transpose to be computed by itself, use it by itself as the right-h .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 1, 1, 2, 2], [1, 3, 0, 1, 0, 2], - [2.0, 5.0, 1.5, 4.0, 0.5, -7.0]) + A = gb.Matrix.from_coo( + [0, 0, 1, 1, 2, 2], + [1, 3, 0, 1, 0, 2], + [2., 5., 1.5, 4., 0.5, -7.] + ) + C = gb.Matrix(float, A.ncols, A.nrows) C << A.T .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1,2,3 + :stub-columns: 1 **0**,,2.0,,5.0 **1**,1.5,4.0,, **2**,0.5,,-7.0, .. csv-table:: C << A.T - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,1.5,0.5 **1**,2.0,4.0, @@ -711,30 +814,41 @@ The Kronecker product uses a binary operator. .. code-block:: python - A = gb.Matrix.from_coo([0, 0, 1], [0, 1, 0], [1., -2., 3.]) - B = gb.Matrix.from_coo([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 0, 2], - [2.0, 5.0, 1.5, 4.0, 0.5, -7.0]) + A = gb.Matrix.from_coo( + [0, 0, 1], + [0, 1, 0], + [1., -2., 3.] + ) + B = gb.Matrix.from_coo( + [0, 0, 1, 1, 2, 2], + [1, 2, 0, 1, 0, 2], + [2., 5., 1.5, 4., 0.5, -7.] + ) + C = gb.Matrix(float, A.nrows * B.nrows, A.ncols * B.ncols) - C << A.kronecker(B, 'times') + C << A.kronecker(B, "times") .. csv-table:: A - :class: inline + :class: inline matrix :header: ,0,1 + :stub-columns: 1 **0**,1.0,-2.0 **1**,3.0, .. csv-table:: B - :class: inline + :class: inline matrix :header: ,0,1,2 + :stub-columns: 1 **0**,,2.0,5.0 **1**,1.5,4.0, **2**,0.5,,-7.0 .. csv-table:: C << A.kronecker(B, 'times') - :class: inline + :class: inline matrix :header: ,0,1,2,3,4,5 + :stub-columns: 1 **0**,,2.0,5.0,,-4.0,-10.0 **1**,1.5,4.0,,-3.0,-8.0, diff --git a/docs/user_guide/operators.rst b/docs/user_guide/operators.rst index 84fe9312c..8bb5e9fa8 100644 --- a/docs/user_guide/operators.rst +++ b/docs/user_guide/operators.rst @@ -89,9 +89,12 @@ registered from numpy are located in ``graphblas.binary.numpy``. Monoids ------- -Monoids extend the concept of a binary operator to require a single domain for all inputs and -the output. Monoids are also associative, so the order of the inputs does not matter. And finally, -monoids have a default identity such that ``A op identity == A``. +Monoids extend the concept of a binary operator to require a single domain for all inputs and the output. +Monoids are also associative so the order of operations does not matter +(for example, ``(a + b) + c == a + (b + c)``). +GraphBLAS primarily uses *commutative monoids* (for example, ``a + b == b + a``), +and all standard monoids in python-graphblas commute. +And finally, monoids have a default identity such that ``A op identity == A``. Monoids are commonly for reductions, collapsing all elements down to a single value. @@ -273,7 +276,7 @@ Example usage: minval = v.reduce(gb.monoid.min).value # This will force the FP32 version of min to be used, possibly type casting the elements - minvalFP32 = v.reduce(gb.monoid.min['FP32']).value + minvalFP32 = v.reduce(gb.monoid.min["FP32"]).value The gb.op Namespace @@ -311,12 +314,14 @@ each symbol. Each is detailed below. The following objects will be used to demonstrate the behavior. .. csv-table:: Vector v + :class: matrix :header: 0,1,2,3,4,5 1.0,,2.0,3.5,,9.0 .. csv-table:: Vector w + :class: matrix :header: 0,1,2,3,4,5 7.0,5.2,,3.0,,2.5 @@ -340,6 +345,7 @@ Addition performs an element-wise union between collections, adding overlapping v + w .. csv-table:: + :class: matrix :header: 0,1,2,3,4,5 8.0,5.2,2.0,6.5,,11.5 @@ -355,6 +361,7 @@ and negating any standalone elements from the right-hand object. v - w .. csv-table:: + :class: matrix :header: 0,1,2,3,4,5 -6.0,-5.2,2.0,0.5,,6.5 @@ -370,6 +377,7 @@ overlapping elements. v * w .. csv-table:: + :class: matrix :header: 0,1,2,3,4,5 7.0,,,10.5,,22.5 @@ -389,6 +397,7 @@ elements and always results in a floating-point dtype. v / w .. csv-table:: + :class: matrix :header: 0,1,2,3,4,5 0.142857,,,1.166667,,3.6 @@ -404,6 +413,7 @@ Dividing by zero with floor division will raise a ``ZeroDivisionError``. v // w .. csv-table:: + :class: matrix :header: 0,1,2,3,4,5 0.0,,,1.0,,3.0 @@ -419,6 +429,7 @@ of dividing overlapping elements. v % w .. csv-table:: + :class: matrix :header: 0,1,2,3,4,5 1.0,,,0.5,,1.5 @@ -431,9 +442,10 @@ the power of y for overlapping elements. .. code-block:: python - v ** w + v**w .. csv-table:: + :class: matrix :header: 0,1,2,3,4,5 1.0,,,42.875,,243.0 @@ -452,6 +464,7 @@ rather than ``all(A == B)`` v > w .. csv-table:: + :class: matrix :header: 0,1,2,3,4,5 False,,,True,,True @@ -461,6 +474,7 @@ rather than ``all(A == B)`` v == w .. csv-table:: + :class: matrix :header: 0,1,2,3,4,5 False,,,False,,False diff --git a/docs/user_guide/recorder.rst b/docs/user_guide/recorder.rst index ee6d2bbb9..3355d93ce 100644 --- a/docs/user_guide/recorder.rst +++ b/docs/user_guide/recorder.rst @@ -25,7 +25,9 @@ Instead, only the calls from the last iteration will be returned. [0, 0, 1, 1, 2], [1, 2, 2, 3, 3], [2.0, 5.0, 1.5, 4.25, 0.5], - nrows=4, ncols=4) + nrows=4, + ncols=4 + ) v = Vector.from_coo([start_node], [0.0], size=4) # Compute SSSP diff --git a/docs/user_guide/udf.rst b/docs/user_guide/udf.rst index 6c72535fc..e7b984b44 100644 --- a/docs/user_guide/udf.rst +++ b/docs/user_guide/udf.rst @@ -21,12 +21,13 @@ Example user-defined UnaryOp: return x + 1 return x - unary.register_new('force_odd', force_odd_func) + unary.register_new("force_odd", force_odd_func) v = Vector.from_coo([0, 1, 3, 4, 5], [1, 2, 3, 8, 14]) w = v.apply(unary.force_odd).new() .. csv-table:: w + :class: matrix :header: 0,1,2,3,4,5 1,3,,3,9,15 @@ -48,6 +49,7 @@ Example lambda usage: v.apply(lambda x: x % 5 - 2).new() .. csv-table:: + :class: matrix :header: 0,1,2,3,4,5 -1,0,,1,1,2 diff --git a/environment.yml b/environment.yml index f327a6980..2bae0b76e 100644 --- a/environment.yml +++ b/environment.yml @@ -11,95 +11,100 @@ # It is okay to comment out sections below that you don't need such as viz or building docs. name: graphblas-dev channels: - - conda-forge - - nodefaults # Only install packages from conda-forge for faster solving + - conda-forge + - nodefaults # Only install packages from conda-forge for faster solving dependencies: - - python - - donfig - - numba - - python-suitesparse-graphblas - - pyyaml - # For repr - - pandas - # For I/O - - awkward - # - fast_matrix_market # Coming soon... - - networkx - - scipy - - sparse - # For viz - - datashader - - hvplot - - matplotlib - # For linting - - pre-commit - # For testing - - packaging - - pytest-cov - # For debugging - - icecream - - ipykernel - - ipython - # For type annotations - - mypy - # For building docs - - nbsphinx - - numpydoc - - pydata-sphinx-theme - - sphinx-panels - # EXTRA (optional; uncomment as desired) - # - autoflake - # - black - # - black-jupyter - # - build - # - codespell - # - commonmark - # - cython - # - cytoolz - # - distributed - # - flake8 - # - flake8-bugbear - # - flake8-comprehensions - # - flake8-print - # - flake8-quotes - # - flake8-simplify - # - gcc - # - gh - # - graph-tool - # - xorg-libxcursor # for graph-tool - # - grayskull - # - h5py - # - hiveplot - # - igraph - # - ipycytoscape - # - isort - # - jupyter - # - jupyterlab - # - line_profiler - # - lxml - # - make - # - memory_profiler - # - nbqa - # - netcdf4 - # - networkit - # - nxviz - # - pycodestyle - # - pydot - # - pygraphviz - # - pylint - # - pytest-runner - # - pytest-xdist - # - python-graphviz - # - python-igraph - # - python-louvain - # - pyupgrade - # - ruff - # - scalene - # - setuptools-git-versioning - # - snakeviz - # - sphinx-lint - # - sympy - # - twine - # - vim - # - yesqa - # - zarr + - python + - donfig + - numba + - python-suitesparse-graphblas + - pyyaml + # For repr + - pandas + # For I/O + - awkward + - networkx + - scipy + - sparse + # For viz + - datashader + - hvplot + - matplotlib + # For linting + - pre-commit + # For testing + - packaging + - pytest-cov + - tomli + # For debugging + - icecream + - ipykernel + - ipython + # For type annotations + - mypy + # For building docs + - nbsphinx + - numpydoc + - pydata-sphinx-theme + - sphinx-panels + # For building logo + - drawsvg + - cairosvg + # EXTRA (optional; uncomment as desired) + # - autoflake + # - black + # - black-jupyter + # - codespell + # - commonmark + # - cython + # - cytoolz + # - distributed + # - flake8 + # - flake8-bugbear + # - flake8-comprehensions + # - flake8-print + # - flake8-quotes + # - flake8-simplify + # - gcc + # - gh + # - git + # - graph-tool + # - xorg-libxcursor # for graph-tool + # - grayskull + # - h5py + # - hiveplot + # - igraph + # - ipycytoscape + # - isort + # - jupyter + # - jupyterlab + # - line_profiler + # - lxml + # - make + # - memory_profiler + # - nbqa + # - netcdf4 + # - networkit + # - nxviz + # - pycodestyle + # - pydot + # - pygraphviz + # - pylint + # - pytest-runner + # - pytest-xdist + # - python-graphviz + # - python-igraph + # - python-louvain + # - pyupgrade + # - rich + # - ruff + # - scalene + # - scikit-network + # - setuptools-git-versioning + # - snakeviz + # - sphinx-lint + # - sympy + # - tuna + # - twine + # - vim + # - zarr diff --git a/graphblas/__init__.py b/graphblas/__init__.py index 87311599c..63110eeeb 100644 --- a/graphblas/__init__.py +++ b/graphblas/__init__.py @@ -39,6 +39,7 @@ def get_config(): backend = None _init_params = None _SPECIAL_ATTRS = { + "MAX_SIZE", # The maximum size of Vector and Matrix dimensions (GrB_INDEX_MAX + 1) "Matrix", "Recorder", "Scalar", @@ -137,7 +138,21 @@ def _init(backend_arg, blocking, automatic=False): backend = backend_arg if backend in {"suitesparse", "suitesparse-vanilla"}: - from suitesparse_graphblas import ffi, initialize, is_initialized, lib + try: + from suitesparse_graphblas import ffi, initialize, is_initialized, lib + except ImportError: # pragma: no cover (import) + raise ImportError( + f"suitesparse_graphblas is required for {backend!r} backend. " + "It may be installed with pip or conda:\n\n" + " $ pip install suitesparse-graphblas\n" + " $ conda install -c conda-forge python-suitesparse-graphblas\n\n" + "SuiteSparse:GraphBLAS is the primary C implementation and backend of " + "python-graphblas and is what we recommend to most users. If you are " + "installing python-graphblas with pip, we recommend installing with one " + "of the following to automatically include suitespare-graphblas:\n\n" + " $ pip install python-graphblas[suitesparse]\n" + " $ pip install python-graphblas[default]" + ) from None if is_initialized(): mode = ffi.new("int32_t*") @@ -191,6 +206,10 @@ def _load(name): if name in {"Matrix", "Vector", "Scalar", "Recorder"}: module = _import_module(f".core.{name.lower()}", __name__) globals()[name] = getattr(module, name) + elif name == "MAX_SIZE": + from .core import lib + + globals()[name] = lib.GrB_INDEX_MAX + 1 else: # Everything else is a module globals()[name] = _import_module(f".{name}", __name__) diff --git a/graphblas/agg/__init__.py b/graphblas/agg/__init__.py index f2dddb851..da7c13591 100644 --- a/graphblas/agg/__init__.py +++ b/graphblas/agg/__init__.py @@ -1,4 +1,4 @@ -"""`graphblas.agg` is an experimental module for exploring Aggregators. +"""``graphblas.agg`` is an experimental module for exploring Aggregators. Aggregators may be used in reduce methods: - Matrix.reduce_rowwise @@ -59,9 +59,9 @@ - ss.argmax .. deprecated:: 2023.1.0 - Aggregators `first`, `last`, `first_index`, `last_index`, `argmin`, and `argmax` are - deprecated in the `agg` namespace such as `agg.first`. Use them from `agg.ss` namespace - instead such as `agg.ss.first`. Will be removed in version 2023.9.0 or later. + Aggregators ``first``, ``last``, ``first_index``, ``last_index``, ``argmin``, and ``argmax`` + are deprecated in the ``agg`` namespace such as ``agg.first``. Use them from ``agg.ss`` + namespace instead such as ``agg.ss.first``. Will be removed in version 2023.9.0 or later. # Possible aggregators: # - absolute_deviation, sum(abs(x - mean(x))), sum_absminus(x, mean(x)) @@ -73,7 +73,8 @@ # - bxnor monoid: even bits # - bnor monoid: odd bits """ -# All items are dynamically added by classes in core/agg.py + +# All items are dynamically added by classes in core/operator/agg.py # This module acts as a container of Aggregator instances _deprecated = {} @@ -111,6 +112,6 @@ def __getattr__(key): raise AttributeError(f"module {__name__!r} has no attribute {key!r}") -from ..core import agg # noqa: E402 isort:skip +from ..core import operator # noqa: E402 isort:skip -del agg +del operator diff --git a/graphblas/agg/ss.py b/graphblas/agg/ss.py index c3f06c0a7..e45cbcda0 100644 --- a/graphblas/agg/ss.py +++ b/graphblas/agg/ss.py @@ -1,3 +1,3 @@ -from ..core import agg +from ..core import operator -del agg +del operator diff --git a/graphblas/binary/__init__.py b/graphblas/binary/__init__.py index e59c0405e..1b8985f73 100644 --- a/graphblas/binary/__init__.py +++ b/graphblas/binary/__init__.py @@ -1,5 +1,7 @@ # All items are dynamically added by classes in operator.py # This module acts as a container of BinaryOp instances +from ..core import _supports_udfs + _delayed = {} _delayed_commutes_to = { "absfirst": "abssecond", @@ -9,6 +11,15 @@ "rpow": "pow", } _deprecated = {} +_udfs = { + "absfirst", + "abssecond", + "binom", + "floordiv", + "isclose", + "rfloordiv", + "rpow", +} def __dir__(): @@ -50,6 +61,11 @@ def __getattr__(key): ss = import_module(".ss", __name__) globals()["ss"] = ss return ss + if not _supports_udfs and key in _udfs: + raise AttributeError( + f"module {__name__!r} unable to compile UDF for {key!r}; " + "install numba for UDF support" + ) raise AttributeError(f"module {__name__!r} has no attribute {key!r}") diff --git a/graphblas/binary/numpy.py b/graphblas/binary/numpy.py index 21ed568ea..bb22d0b07 100644 --- a/graphblas/binary/numpy.py +++ b/graphblas/binary/numpy.py @@ -5,11 +5,13 @@ https://numba.readthedocs.io/en/stable/reference/numpysupported.html#math-operations """ + import numpy as _np from .. import _STANDARD_OPERATOR_NAMES from .. import binary as _binary from .. import config as _config +from ..core import _supports_udfs _delayed = {} _binary_names = { @@ -130,7 +132,13 @@ def __dir__(): - return globals().keys() | _delayed.keys() | _binary_names + if not _supports_udfs and not _config["mapnumpy"]: + # float_power is special: it's constructed from builtin operators + return globals().keys() | {"float_power"} # FLAKY COVERAGE + attrs = _delayed.keys() | _binary_names + if not _supports_udfs: + attrs &= _numpy_to_graphblas.keys() + return attrs | globals().keys() def __getattr__(name): @@ -141,19 +149,20 @@ def __getattr__(name): return rv if name not in _binary_names: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - if _config.get("mapnumpy") and name in _numpy_to_graphblas: + if _config.get("mapnumpy") and name in _numpy_to_graphblas or name == "float_power": if name == "float_power": - from ..core import operator + from ..core.operator import binary + from ..dtypes import FP64 - new_op = operator.BinaryOp(f"numpy.{name}") + new_op = binary.BinaryOp(f"numpy.{name}") builtin_op = _binary.pow for dtype in builtin_op.types: if dtype.name in {"FP32", "FC32", "FC64"}: orig_dtype = dtype else: - orig_dtype = operator.FP64 + orig_dtype = FP64 orig_op = builtin_op[orig_dtype] - cur_op = operator.TypedBuiltinBinaryOp( + cur_op = binary.TypedBuiltinBinaryOp( new_op, new_op.name, dtype, @@ -165,15 +174,18 @@ def __getattr__(name): globals()[name] = new_op else: globals()[name] = getattr(_binary, _numpy_to_graphblas[name]) + elif not _supports_udfs: + raise AttributeError( + f"module {__name__!r} unable to compile UDF for {name!r}; " + "install numba for UDF support" + ) else: - from ..core import operator - numpy_func = getattr(_np, name) def func(x, y): # pragma: no cover (numba) return numpy_func(x, y) - operator.BinaryOp.register_new(f"numpy.{name}", func) + _binary.register_new(f"numpy.{name}", func) rv = globals()[name] if name in _commutative: rv._commutes_to = rv diff --git a/graphblas/binary/ss.py b/graphblas/binary/ss.py index e45cbcda0..0c294e322 100644 --- a/graphblas/binary/ss.py +++ b/graphblas/binary/ss.py @@ -1,3 +1,6 @@ from ..core import operator +from ..core.ss.binary import register_new # noqa: F401 + +_delayed = {} del operator diff --git a/graphblas/core/__init__.py b/graphblas/core/__init__.py index c079a6e2f..7fd2dc526 100644 --- a/graphblas/core/__init__.py +++ b/graphblas/core/__init__.py @@ -1,3 +1,12 @@ +try: + import numba +except ImportError: + _has_numba = _supports_udfs = False +else: + _has_numba = _supports_udfs = True + del numba + + def __getattr__(name): if name in {"ffi", "lib", "NULL"}: from .. import _autoinit diff --git a/graphblas/core/automethods.py b/graphblas/core/automethods.py index 98dc61137..600a6e139 100644 --- a/graphblas/core/automethods.py +++ b/graphblas/core/automethods.py @@ -1,12 +1,13 @@ """Define functions to use as property methods on expressions. -These will automatically compute the value and avoid the need for `.new()`. +These will automatically compute the value and avoid the need for ``.new()``. To automatically create the functions, run: $ python -m graphblas.core.automethods """ + from .. import config @@ -213,6 +214,10 @@ def outer(self): return self._get_value("outer") +def power(self): + return self._get_value("power") + + def reduce(self): return self._get_value("reduce") @@ -277,10 +282,6 @@ def to_edgelist(self): return self._get_value("to_edgelist") -def to_values(self): - return self._get_value("to_values") - - def value(self): return self._get_value("value") @@ -394,7 +395,6 @@ def _main(): "ss", "to_coo", "to_dense", - "to_values", } vector = { "_as_matrix", @@ -410,6 +410,7 @@ def _main(): "kronecker", "mxm", "mxv", + "power", "reduce_columnwise", "reduce_rowwise", "reduce_scalar", diff --git a/graphblas/core/base.py b/graphblas/core/base.py index a4e48b612..24a49ba1a 100644 --- a/graphblas/core/base.py +++ b/graphblas/core/base.py @@ -263,23 +263,31 @@ def __call__( ) def __or__(self, other): - from .infix import _ewise_infix_expr + from .infix import _ewise_infix_expr, _ewise_mult_expr_types + if isinstance(other, _ewise_mult_expr_types): + raise TypeError("XXX") return _ewise_infix_expr(self, other, method="ewise_add", within="__or__") def __ror__(self, other): - from .infix import _ewise_infix_expr + from .infix import _ewise_infix_expr, _ewise_mult_expr_types + if isinstance(other, _ewise_mult_expr_types): + raise TypeError("XXX") return _ewise_infix_expr(other, self, method="ewise_add", within="__ror__") def __and__(self, other): - from .infix import _ewise_infix_expr + from .infix import _ewise_add_expr_types, _ewise_infix_expr + if isinstance(other, _ewise_add_expr_types): + raise TypeError("XXX") return _ewise_infix_expr(self, other, method="ewise_mult", within="__and__") def __rand__(self, other): - from .infix import _ewise_infix_expr + from .infix import _ewise_add_expr_types, _ewise_infix_expr + if isinstance(other, _ewise_add_expr_types): + raise TypeError("XXX") return _ewise_infix_expr(other, self, method="ewise_mult", within="__rand__") def __matmul__(self, other): @@ -348,7 +356,7 @@ def _update(self, expr, mask=None, accum=None, replace=False, input_mask=None, * return if opts: # Ignore opts for now - descriptor_lookup(**opts) + desc = descriptor_lookup(**opts) self.value = expr return @@ -371,7 +379,7 @@ def _update(self, expr, mask=None, accum=None, replace=False, input_mask=None, * else: if opts: # Ignore opts for now - descriptor_lookup(**opts) + desc = descriptor_lookup(**opts) self.value = expr return else: @@ -505,7 +513,7 @@ def _name_html(self): _expect_op = _expect_op # Don't let non-scalars be coerced to numpy arrays - def __array__(self, dtype=None): + def __array__(self, dtype=None, *, copy=None): raise TypeError( f"{type(self).__name__} can't be directly converted to a numpy array; " f"perhaps use `{self.name}.to_coo()` method instead." @@ -571,7 +579,7 @@ def _new(self, dtype, mask, name, is_cscalar=None, **opts): ): if opts: # Ignore opts for now - descriptor_lookup(**opts) + desc = descriptor_lookup(**opts) # noqa: F841 (keep desc in scope for context) if self._is_scalar and self._value._is_cscalar != is_cscalar: return self._value.dup(is_cscalar=is_cscalar, name=name) rv = self._value diff --git a/graphblas/core/descriptor.py b/graphblas/core/descriptor.py index 1e195e3fe..11f634afd 100644 --- a/graphblas/core/descriptor.py +++ b/graphblas/core/descriptor.py @@ -26,6 +26,7 @@ def __init__( self.mask_structure = mask_structure self.transpose_first = transpose_first self.transpose_second = transpose_second + self._context = None # Used by SuiteSparse:GraphBLAS 8 @property def _carg(self): diff --git a/graphblas/dtypes.py b/graphblas/core/dtypes.py similarity index 59% rename from graphblas/dtypes.py rename to graphblas/core/dtypes.py index 22d98b8f1..2d4178b14 100644 --- a/graphblas/dtypes.py +++ b/graphblas/core/dtypes.py @@ -1,17 +1,17 @@ -import warnings as _warnings +import warnings +from ast import literal_eval -import numba as _numba -import numpy as _np -from numpy import find_common_type as _find_common_type -from numpy import promote_types as _promote_types +import numpy as np +from numpy import promote_types, result_type -from . import backend -from .core import NULL as _NULL -from .core import ffi as _ffi -from .core import lib as _lib +from .. import backend, dtypes +from ..core import NULL, _has_numba, ffi, lib + +if _has_numba: + import numba # Default assumption unless FC32/FC64 are found in lib -_supports_complex = hasattr(_lib, "GrB_FC64") or hasattr(_lib, "GxB_FC64") +_supports_complex = hasattr(lib, "GrB_FC64") or hasattr(lib, "GxB_FC64") class DataType: @@ -23,7 +23,7 @@ def __init__(self, name, gb_obj, gb_name, c_type, numba_type, np_type): self.gb_name = gb_name self.c_type = c_type self.numba_type = numba_type - self.np_type = _np.dtype(np_type) + self.np_type = np.dtype(np_type) if np_type is not None else None def __repr__(self): return self.name @@ -59,7 +59,7 @@ def _carg(self): @property def _is_anonymous(self): - return globals().get(self.name) is not self + return getattr(dtypes, self.name, None) is not self @property def _is_udt(self): @@ -77,27 +77,29 @@ def _deserialize(name, dtype, is_anonymous): def register_new(name, dtype): if not name.isidentifier(): raise ValueError(f"`name` argument must be a valid Python identifier; got: {name!r}") - if name in _registry or name in globals(): + if name in _registry or hasattr(dtypes, name): raise ValueError(f"{name!r} name for dtype is unavailable") rv = register_anonymous(dtype, name) _registry[name] = rv - globals()[name] = rv + setattr(dtypes, name, rv) return rv def register_anonymous(dtype, name=None): try: - dtype = _np.dtype(dtype) + dtype = np.dtype(dtype) except TypeError: if isinstance(dtype, dict): # Allow dtypes such as `{'x': int, 'y': float}` for convenience - dtype = _np.dtype([(key, lookup_dtype(val).np_type) for key, val in dtype.items()]) + dtype = np.dtype( + [(key, lookup_dtype(val).np_type) for key, val in dtype.items()], align=True + ) elif isinstance(dtype, str) and "[" in dtype and dtype.endswith("]"): # Allow dtypes such as `"INT64[3, 4]"` for convenience base_dtype, shape = dtype.split("[", 1) base_dtype = lookup_dtype(base_dtype) - shape = _np.lib.format.safe_eval(f"[{shape}") - dtype = _np.dtype((base_dtype.np_type, shape)) + shape = literal_eval(f"[{shape}") + dtype = np.dtype((base_dtype.np_type, shape)) else: raise if dtype in _registry: @@ -111,112 +113,204 @@ def register_anonymous(dtype, name=None): if dtype.hasobject: raise ValueError("dtype must not allow Python objects") - from .exceptions import check_status_carg + from ..exceptions import check_status_carg - gb_obj = _ffi.new("GrB_Type*") - if backend == "suitesparse": + gb_obj = ffi.new("GrB_Type*") + + if hasattr(lib, "GrB_Type_set_String"): + # We name this so that we can serialize and deserialize UDTs + # We don't yet have C definitions + np_repr = _dtype_to_string(dtype) + status = lib.GrB_Type_new(gb_obj, dtype.itemsize) + check_status_carg(status, "Type", gb_obj[0]) + val_obj = ffi.new("char[]", np_repr.encode()) + status = lib.GrB_Type_set_String(gb_obj[0], val_obj, lib.GrB_NAME) + elif backend == "suitesparse": + # For SuiteSparse < 9 # We name this so that we can serialize and deserialize UDTs # We don't yet have C definitions np_repr = _dtype_to_string(dtype).encode() - if len(np_repr) > _lib.GxB_MAX_NAME_LEN: + if len(np_repr) > lib.GxB_MAX_NAME_LEN: msg = ( f"UDT repr is too large to serialize ({len(repr(dtype).encode())} > " - f"{_lib.GxB_MAX_NAME_LEN})." + f"{lib.GxB_MAX_NAME_LEN})." ) if name is not None: - np_repr = name.encode()[: _lib.GxB_MAX_NAME_LEN] + np_repr = name.encode()[: lib.GxB_MAX_NAME_LEN] else: - np_repr = np_repr[: _lib.GxB_MAX_NAME_LEN] - _warnings.warn( + np_repr = np_repr[: lib.GxB_MAX_NAME_LEN] + warnings.warn( f"{msg}. It will use the following name, " f"and the dtype may need to be specified when deserializing: {np_repr}", stacklevel=2, ) - status = _lib.GxB_Type_new(gb_obj, dtype.itemsize, np_repr, _NULL) + status = lib.GxB_Type_new(gb_obj, dtype.itemsize, np_repr, NULL) else: - status = _lib.GrB_Type_new(gb_obj, dtype.itemsize) + status = lib.GrB_Type_new(gb_obj, dtype.itemsize) check_status_carg(status, "Type", gb_obj[0]) # For now, let's use "opaque" unsigned bytes for the c type. if name is None: name = _default_name(dtype) - numba_type = _numba.typeof(dtype).dtype + numba_type = numba.typeof(dtype).dtype if _has_numba else None rv = DataType(name, gb_obj, None, f"uint8_t[{dtype.itemsize}]", numba_type, dtype) _registry[gb_obj] = rv _registry[dtype] = rv - _registry[numba_type] = rv - _registry[numba_type.name] = rv + if _has_numba: + _registry[numba_type] = rv + _registry[numba_type.name] = rv return rv -BOOL = DataType("BOOL", _lib.GrB_BOOL, "GrB_BOOL", "_Bool", _numba.types.bool_, _np.bool_) -INT8 = DataType("INT8", _lib.GrB_INT8, "GrB_INT8", "int8_t", _numba.types.int8, _np.int8) -UINT8 = DataType("UINT8", _lib.GrB_UINT8, "GrB_UINT8", "uint8_t", _numba.types.uint8, _np.uint8) -INT16 = DataType("INT16", _lib.GrB_INT16, "GrB_INT16", "int16_t", _numba.types.int16, _np.int16) +BOOL = DataType( + "BOOL", + lib.GrB_BOOL, + "GrB_BOOL", + "_Bool", + numba.types.bool_ if _has_numba else None, + np.bool_, +) +INT8 = DataType( + "INT8", lib.GrB_INT8, "GrB_INT8", "int8_t", numba.types.int8 if _has_numba else None, np.int8 +) +UINT8 = DataType( + "UINT8", + lib.GrB_UINT8, + "GrB_UINT8", + "uint8_t", + numba.types.uint8 if _has_numba else None, + np.uint8, +) +INT16 = DataType( + "INT16", + lib.GrB_INT16, + "GrB_INT16", + "int16_t", + numba.types.int16 if _has_numba else None, + np.int16, +) UINT16 = DataType( - "UINT16", _lib.GrB_UINT16, "GrB_UINT16", "uint16_t", _numba.types.uint16, _np.uint16 + "UINT16", + lib.GrB_UINT16, + "GrB_UINT16", + "uint16_t", + numba.types.uint16 if _has_numba else None, + np.uint16, +) +INT32 = DataType( + "INT32", + lib.GrB_INT32, + "GrB_INT32", + "int32_t", + numba.types.int32 if _has_numba else None, + np.int32, ) -INT32 = DataType("INT32", _lib.GrB_INT32, "GrB_INT32", "int32_t", _numba.types.int32, _np.int32) UINT32 = DataType( - "UINT32", _lib.GrB_UINT32, "GrB_UINT32", "uint32_t", _numba.types.uint32, _np.uint32 + "UINT32", + lib.GrB_UINT32, + "GrB_UINT32", + "uint32_t", + numba.types.uint32 if _has_numba else None, + np.uint32, +) +INT64 = DataType( + "INT64", + lib.GrB_INT64, + "GrB_INT64", + "int64_t", + numba.types.int64 if _has_numba else None, + np.int64, ) -INT64 = DataType("INT64", _lib.GrB_INT64, "GrB_INT64", "int64_t", _numba.types.int64, _np.int64) # _Index (like UINT64) is for internal use only and shouldn't be exposed to the user _INDEX = DataType( - "UINT64", _lib.GrB_UINT64, "GrB_Index", "GrB_Index", _numba.types.uint64, _np.uint64 + "UINT64", + lib.GrB_UINT64, + "GrB_Index", + "GrB_Index", + numba.types.uint64 if _has_numba else None, + np.uint64, ) UINT64 = DataType( - "UINT64", _lib.GrB_UINT64, "GrB_UINT64", "uint64_t", _numba.types.uint64, _np.uint64 + "UINT64", + lib.GrB_UINT64, + "GrB_UINT64", + "uint64_t", + numba.types.uint64 if _has_numba else None, + np.uint64, +) +FP32 = DataType( + "FP32", + lib.GrB_FP32, + "GrB_FP32", + "float", + numba.types.float32 if _has_numba else None, + np.float32, +) +FP64 = DataType( + "FP64", + lib.GrB_FP64, + "GrB_FP64", + "double", + numba.types.float64 if _has_numba else None, + np.float64, ) -FP32 = DataType("FP32", _lib.GrB_FP32, "GrB_FP32", "float", _numba.types.float32, _np.float32) -FP64 = DataType("FP64", _lib.GrB_FP64, "GrB_FP64", "double", _numba.types.float64, _np.float64) -if _supports_complex and hasattr(_lib, "GxB_FC32"): +if _supports_complex and hasattr(lib, "GxB_FC32"): FC32 = DataType( - "FC32", _lib.GxB_FC32, "GxB_FC32", "float _Complex", _numba.types.complex64, _np.complex64 + "FC32", + lib.GxB_FC32, + "GxB_FC32", + "float _Complex", + numba.types.complex64 if _has_numba else None, + np.complex64, ) -if _supports_complex and hasattr(_lib, "GrB_FC32"): # pragma: no cover (unused) +if _supports_complex and hasattr(lib, "GrB_FC32"): # pragma: no cover (unused) FC32 = DataType( - "FC32", _lib.GrB_FC32, "GrB_FC32", "float _Complex", _numba.types.complex64, _np.complex64 + "FC32", + lib.GrB_FC32, + "GrB_FC32", + "float _Complex", + numba.types.complex64 if _has_numba else None, + np.complex64, ) -if _supports_complex and hasattr(_lib, "GxB_FC64"): +if _supports_complex and hasattr(lib, "GxB_FC64"): FC64 = DataType( "FC64", - _lib.GxB_FC64, + lib.GxB_FC64, "GxB_FC64", "double _Complex", - _numba.types.complex128, - _np.complex128, + numba.types.complex128 if _has_numba else None, + np.complex128, ) -if _supports_complex and hasattr(_lib, "GrB_FC64"): # pragma: no cover (unused) +if _supports_complex and hasattr(lib, "GrB_FC64"): # pragma: no cover (unused) FC64 = DataType( "FC64", - _lib.GrB_FC64, + lib.GrB_FC64, "GrB_FC64", "double _Complex", - _numba.types.complex128, - _np.complex128, + numba.types.complex128 if _has_numba else None, + np.complex128, ) # Used for testing user-defined functions _sample_values = { - INT8: _np.int8(1), - UINT8: _np.uint8(1), - INT16: _np.int16(1), - UINT16: _np.uint16(1), - INT32: _np.int32(1), - UINT32: _np.uint32(1), - INT64: _np.int64(1), - UINT64: _np.uint64(1), - FP32: _np.float32(0.5), - FP64: _np.float64(0.5), - BOOL: _np.bool_(True), + INT8: np.int8(1), + UINT8: np.uint8(1), + INT16: np.int16(1), + UINT16: np.uint16(1), + INT32: np.int32(1), + UINT32: np.uint32(1), + INT64: np.int64(1), + UINT64: np.uint64(1), + FP32: np.float32(0.5), + FP64: np.float64(0.5), + BOOL: np.bool_(True), } if _supports_complex: _sample_values.update( { - FC32: _np.complex64(complex(0, 0.5)), - FC64: _np.complex128(complex(0, 0.5)), + FC32: np.complex64(complex(0, 0.5)), + FC64: np.complex128(complex(0, 0.5)), } ) @@ -246,8 +340,9 @@ def register_anonymous(dtype, name=None): _registry[dtype.gb_name.lower()] = dtype _registry[dtype.c_type] = dtype _registry[dtype.c_type.upper()] = dtype - _registry[dtype.numba_type] = dtype - _registry[dtype.numba_type.name] = dtype + if _has_numba: + _registry[dtype.numba_type] = dtype + _registry[dtype.numba_type.name] = dtype val = _sample_values[dtype] _registry[val.dtype] = dtype _registry[val.dtype.name] = dtype @@ -291,8 +386,7 @@ def lookup_dtype(key, value=None): def unify(type1, type2, *, is_left_scalar=False, is_right_scalar=False): - """ - Returns a type that can hold both type1 and type2. + """Returns a type that can hold both type1 and type2. For example: unify(INT32, INT64) -> INT64 @@ -303,19 +397,11 @@ def unify(type1, type2, *, is_left_scalar=False, is_right_scalar=False): if type1 is type2: return type1 if is_left_scalar: - scalar_types = [type1.np_type] - array_types = [] - elif not is_right_scalar: - # Using `promote_types` is faster than `find_common_type` - return lookup_dtype(_promote_types(type1.np_type, type2.np_type)) - else: - scalar_types = [] - array_types = [type1.np_type] - if is_right_scalar: - scalar_types.append(type2.np_type) - else: - array_types.append(type2.np_type) - return lookup_dtype(_find_common_type(array_types, scalar_types)) + if not is_right_scalar: + return lookup_dtype(result_type(np.array(0, type1.np_type), type2.np_type)) + elif is_right_scalar: + return lookup_dtype(result_type(type1.np_type, np.array(0, type2.np_type))) + return lookup_dtype(promote_types(type1.np_type, type2.np_type)) def _default_name(dtype): @@ -345,7 +431,7 @@ def _dtype_to_string(dtype): >>> dtype == new_dtype True """ - if isinstance(dtype, _np.dtype) and dtype not in _registry: + if isinstance(dtype, np.dtype) and dtype not in _registry: np_type = dtype else: dtype = lookup_dtype(dtype) @@ -354,11 +440,11 @@ def _dtype_to_string(dtype): np_type = dtype.np_type s = str(np_type) try: - if _np.dtype(_np.lib.format.safe_eval(s)) == np_type: # pragma: no branch (safety) + if np.dtype(literal_eval(s)) == np_type: # pragma: no branch (safety) return s except Exception: pass - if _np.dtype(np_type.str) != np_type: # pragma: no cover (safety) + if np.dtype(np_type.str) != np_type: # pragma: no cover (safety) raise ValueError(f"Unable to reliably convert dtype to string and back: {dtype}") return repr(np_type.str) @@ -373,5 +459,5 @@ def _string_to_dtype(s): return lookup_dtype(s) except Exception: pass - np_type = _np.dtype(_np.lib.format.safe_eval(s)) + np_type = np.dtype(literal_eval(s)) return lookup_dtype(np_type) diff --git a/graphblas/core/expr.py b/graphblas/core/expr.py index 9046795db..efec2db5f 100644 --- a/graphblas/core/expr.py +++ b/graphblas/core/expr.py @@ -147,22 +147,21 @@ def py_indices(self): return self.indices[0]._py_index() def parse_indices(self, indices, shape): - """ - Returns + """Returns ------- [(rows, rowsize), (cols, colsize)] for Matrix [(idx, idx_size)] for Vector Within each tuple, if the index is of type int, the size will be None + """ if len(shape) == 1: if type(indices) is tuple: raise TypeError(f"Index for {type(self.obj).__name__} cannot be a tuple") # Convert to tuple for consistent processing indices = (indices,) - else: # len(shape) == 2 - if type(indices) is not tuple or len(indices) != 2: - raise TypeError(f"Index for {type(self.obj).__name__} must be a 2-tuple") + elif type(indices) is not tuple or len(indices) != 2: + raise TypeError(f"Index for {type(self.obj).__name__} must be a 2-tuple") out = [] for i, idx in enumerate(indices): @@ -313,8 +312,8 @@ def update(self, expr, **opts): Updater(self.parent, opts=opts)._setitem(self.resolved_indexes, expr, is_submask=False) def new(self, dtype=None, *, mask=None, input_mask=None, name=None, **opts): - """ - Force extraction of the indexes into a new object + """Force extraction of the indexes into a new object. + dtype and mask are the only controllable parameters. """ if input_mask is not None: @@ -422,7 +421,7 @@ def _setitem(self, resolved_indexes, obj, *, is_submask): # Fast path using assignElement if self.opts: # Ignore opts for now - descriptor_lookup(**self.opts) + desc = descriptor_lookup(**self.opts) # noqa: F841 (keep desc in scope for context) self.parent._assign_element(resolved_indexes, obj) else: mask = self.kwargs.get("mask") @@ -479,33 +478,34 @@ def __bool__(self): class InfixExprBase: - __slots__ = "left", "right", "_value", "__weakref__" + __slots__ = "left", "right", "_expr", "__weakref__" _is_scalar = False def __init__(self, left, right): self.left = left self.right = right - self._value = None + self._expr = None def new(self, dtype=None, *, mask=None, name=None, **opts): if ( mask is None - and self._value is not None - and (dtype is None or self._value.dtype == dtype) + and self._expr is not None + and self._expr._value is not None + and (dtype is None or self._expr._value.dtype == dtype) ): - rv = self._value + rv = self._expr._value if name is not None: rv.name = name - self._value = None + self._expr._value = None return rv expr = self._to_expr() return expr.new(dtype, mask=mask, name=name, **opts) def _to_expr(self): - if self._value is None: + if self._expr is None: # Rely on the default operator for `x @ y` - self._value = getattr(self.left, self.method_name)(self.right) - return self._value + self._expr = getattr(self.left, self.method_name)(self.right) + return self._expr def _get_value(self, attr=None, default=None): expr = self._to_expr() @@ -537,10 +537,18 @@ def __repr__(self): @property def dtype(self): - if self._value is not None: - return self._value.dtype return self._to_expr().dtype + @property + def _value(self): + if self._expr is None: + return None + return self._expr._value + + @_value.setter + def _value(self, val): + self._to_expr()._value = val + # Mistakes utils._output_types[AmbiguousAssignOrExtract] = AmbiguousAssignOrExtract diff --git a/graphblas/core/formatting.py b/graphblas/core/formatting.py index 305df05ae..0b6252101 100644 --- a/graphblas/core/formatting.py +++ b/graphblas/core/formatting.py @@ -1,3 +1,4 @@ +# This file imports pandas, so it should only be imported when formatting import numpy as np from .. import backend, config, monoid, unary @@ -629,7 +630,7 @@ def create_header(type_name, keys, vals, *, lower_border=False, name="", quote=T name = f'"{name}"' key_text = [] val_text = [] - for key, val in zip(keys, vals): + for key, val in zip(keys, vals, strict=True): width = max(len(key), len(val)) + 2 key_text.append(key.rjust(width)) val_text.append(val.rjust(width)) @@ -879,6 +880,7 @@ def format_index_expression_html(expr): computed = get_expr_result(expr, html=True) if "__EXPR__" in computed: return computed.replace("__EXPR__", topline) + # BRANCH NOT COVERED keys = [] values = [] diff --git a/graphblas/core/infix.py b/graphblas/core/infix.py index 1fc7caa95..24c109639 100644 --- a/graphblas/core/infix.py +++ b/graphblas/core/infix.py @@ -1,8 +1,9 @@ from .. import backend, binary from ..dtypes import BOOL +from ..exceptions import DimensionMismatch from ..monoid import land, lor from ..semiring import any_pair -from . import automethods, utils +from . import automethods, recorder, utils from .base import _expect_op, _expect_type from .expr import InfixExprBase from .mask import Mask @@ -16,11 +17,11 @@ def _ewise_add_to_expr(self): - if self._value is not None: - return self._value + if self._expr is not None: + return self._expr if self.left.dtype == BOOL and self.right.dtype == BOOL: - self._value = self.left.ewise_add(self.right, lor) - return self._value + self._expr = self.left.ewise_add(self.right, lor) + return self._expr raise TypeError( "Bad dtypes for `x | y`! Automatic computation of `x | y` infix expressions is only valid " f"for BOOL dtypes. The argument dtypes are {self.left.dtype} and {self.right.dtype}.\n\n" @@ -30,11 +31,11 @@ def _ewise_add_to_expr(self): def _ewise_mult_to_expr(self): - if self._value is not None: - return self._value + if self._expr is not None: + return self._expr if self.left.dtype == BOOL and self.right.dtype == BOOL: - self._value = self.left.ewise_mult(self.right, land) - return self._value + self._expr = self.left.ewise_mult(self.right, land) + return self._expr raise TypeError( "Bad dtypes for `x & y`! Automatic computation of `x & y` infix expressions is only valid " f"for BOOL dtypes. The argument dtypes are {self.left.dtype} and {self.right.dtype}.\n\n" @@ -125,6 +126,19 @@ class ScalarEwiseAddExpr(ScalarInfixExpr): _to_expr = _ewise_add_to_expr + # Allow e.g. `plus(x | y | z)` + __or__ = Scalar.__or__ + __ror__ = Scalar.__ror__ + _ewise_add = Scalar._ewise_add + _ewise_union = Scalar._ewise_union + + # Don't allow e.g. `plus(x | y & z)` + def __and__(self, other): + raise TypeError("XXX") + + def __rand__(self, other): + raise TypeError("XXX") + class ScalarEwiseMultExpr(ScalarInfixExpr): __slots__ = () @@ -134,6 +148,18 @@ class ScalarEwiseMultExpr(ScalarInfixExpr): _to_expr = _ewise_mult_to_expr + # Allow e.g. `plus(x & y & z)` + __and__ = Scalar.__and__ + __rand__ = Scalar.__rand__ + _ewise_mult = Scalar._ewise_mult + + # Don't allow e.g. `plus(x | y & z)` + def __or__(self, other): + raise TypeError("XXX") + + def __ror__(self, other): + raise TypeError("XXX") + class ScalarMatMulExpr(ScalarInfixExpr): __slots__ = () @@ -210,7 +236,6 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): to_coo = wrapdoc(Vector.to_coo)(property(automethods.to_coo)) to_dense = wrapdoc(Vector.to_dense)(property(automethods.to_dense)) to_dict = wrapdoc(Vector.to_dict)(property(automethods.to_dict)) - to_values = wrapdoc(Vector.to_values)(property(automethods.to_values)) vxm = wrapdoc(Vector.vxm)(property(automethods.vxm)) wait = wrapdoc(Vector.wait)(property(automethods.wait)) # These raise exceptions @@ -238,6 +263,15 @@ class VectorEwiseAddExpr(VectorInfixExpr): _to_expr = _ewise_add_to_expr + # Allow e.g. `plus(x | y | z)` + __or__ = Vector.__or__ + __ror__ = Vector.__ror__ + _ewise_add = Vector._ewise_add + _ewise_union = Vector._ewise_union + # Don't allow e.g. `plus(x | y & z)` + __and__ = ScalarEwiseAddExpr.__and__ # raises + __rand__ = ScalarEwiseAddExpr.__rand__ # raises + class VectorEwiseMultExpr(VectorInfixExpr): __slots__ = () @@ -247,6 +281,14 @@ class VectorEwiseMultExpr(VectorInfixExpr): _to_expr = _ewise_mult_to_expr + # Allow e.g. `plus(x & y & z)` + __and__ = Vector.__and__ + __rand__ = Vector.__rand__ + _ewise_mult = Vector._ewise_mult + # Don't allow e.g. `plus(x | y & z)` + __or__ = ScalarEwiseMultExpr.__or__ # raises + __ror__ = ScalarEwiseMultExpr.__ror__ # raises + class VectorMatMulExpr(VectorInfixExpr): __slots__ = "method_name" @@ -258,6 +300,11 @@ def __init__(self, left, right, *, method_name, size): self.method_name = method_name self._size = size + __matmul__ = Vector.__matmul__ + __rmatmul__ = Vector.__rmatmul__ + _inner = Vector._inner + _vxm = Vector._vxm + utils._output_types[VectorEwiseAddExpr] = Vector utils._output_types[VectorEwiseMultExpr] = Vector @@ -269,6 +316,7 @@ class MatrixInfixExpr(InfixExprBase): ndim = 2 output_type = MatrixExpression _is_transposed = False + __networkx_backend__ = "graphblas" __networkx_plugin__ = "graphblas" def __init__(self, left, right): @@ -330,6 +378,7 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): mxv = wrapdoc(Matrix.mxv)(property(automethods.mxv)) name = wrapdoc(Matrix.name)(property(automethods.name)).setter(automethods._set_name) nvals = wrapdoc(Matrix.nvals)(property(automethods.nvals)) + power = wrapdoc(Matrix.power)(property(automethods.power)) reduce_columnwise = wrapdoc(Matrix.reduce_columnwise)(property(automethods.reduce_columnwise)) reduce_rowwise = wrapdoc(Matrix.reduce_rowwise)(property(automethods.reduce_rowwise)) reduce_scalar = wrapdoc(Matrix.reduce_scalar)(property(automethods.reduce_scalar)) @@ -347,7 +396,6 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): to_dense = wrapdoc(Matrix.to_dense)(property(automethods.to_dense)) to_dicts = wrapdoc(Matrix.to_dicts)(property(automethods.to_dicts)) to_edgelist = wrapdoc(Matrix.to_edgelist)(property(automethods.to_edgelist)) - to_values = wrapdoc(Matrix.to_values)(property(automethods.to_values)) wait = wrapdoc(Matrix.wait)(property(automethods.wait)) # These raise exceptions __array__ = Matrix.__array__ @@ -374,6 +422,15 @@ class MatrixEwiseAddExpr(MatrixInfixExpr): _to_expr = _ewise_add_to_expr + # Allow e.g. `plus(x | y | z)` + __or__ = Matrix.__or__ + __ror__ = Matrix.__ror__ + _ewise_add = Matrix._ewise_add + _ewise_union = Matrix._ewise_union + # Don't allow e.g. `plus(x | y & z)` + __and__ = VectorEwiseAddExpr.__and__ # raises + __rand__ = VectorEwiseAddExpr.__rand__ # raises + class MatrixEwiseMultExpr(MatrixInfixExpr): __slots__ = () @@ -383,6 +440,14 @@ class MatrixEwiseMultExpr(MatrixInfixExpr): _to_expr = _ewise_mult_to_expr + # Allow e.g. `plus(x & y & z)` + __and__ = Matrix.__and__ + __rand__ = Matrix.__rand__ + _ewise_mult = Matrix._ewise_mult + # Don't allow e.g. `plus(x | y & z)` + __or__ = VectorEwiseMultExpr.__or__ # raises + __ror__ = VectorEwiseMultExpr.__ror__ # raises + class MatrixMatMulExpr(MatrixInfixExpr): __slots__ = () @@ -395,49 +460,73 @@ def __init__(self, left, right, *, nrows, ncols): self._nrows = nrows self._ncols = ncols + __matmul__ = Matrix.__matmul__ + __rmatmul__ = Matrix.__rmatmul__ + _mxm = Matrix._mxm + _mxv = Matrix._mxv + utils._output_types[MatrixEwiseAddExpr] = Matrix utils._output_types[MatrixEwiseMultExpr] = Matrix utils._output_types[MatrixMatMulExpr] = Matrix +def _dummy(obj, obj_type): + with recorder.skip_record: + return output_type(obj)(BOOL, *obj.shape, name="") + + +def _mismatched(left, right, method, op): + # Create dummy expression to raise on incompatible dimensions + getattr(_dummy(left) if isinstance(left, InfixExprBase) else left, method)( + _dummy(right) if isinstance(right, InfixExprBase) else right, op + ) + raise DimensionMismatch # pragma: no cover + + def _ewise_infix_expr(left, right, *, method, within): left_type = output_type(left) right_type = output_type(right) types = {Vector, Matrix, TransposedMatrix} if left_type in types and right_type in types: - # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(left, method)(right, binary.any) - if expr.output_type is Vector: - if method == "ewise_mult": - return VectorEwiseMultExpr(left, right) - return VectorEwiseAddExpr(left, right) + if left_type is Vector: + if right_type is Vector: + if left._size != right._size: + _mismatched(left, right, method, binary.first) + if method == "ewise_mult": + return VectorEwiseMultExpr(left, right) + return VectorEwiseAddExpr(left, right) + if left._size != right._nrows: + _mismatched(left, right, method, binary.first) + elif right_type is Vector: + if left._ncols != right._size: + _mismatched(left, right, method, binary.first) + elif left.shape != right.shape: + _mismatched(left, right, method, binary.first) if method == "ewise_mult": return MatrixEwiseMultExpr(left, right) return MatrixEwiseAddExpr(left, right) + if within == "__or__" and isinstance(right, Mask): return right.__ror__(left) if within == "__and__" and isinstance(right, Mask): return right.__rand__(left) if left_type in types: left._expect_type(right, tuple(types), within=within, argname="right") - elif right_type in types: + if right_type in types: right._expect_type(left, tuple(types), within=within, argname="left") - elif left_type is Scalar: - # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(left, method)(right, binary.any) + if left_type is Scalar: if method == "ewise_mult": return ScalarEwiseMultExpr(left, right) return ScalarEwiseAddExpr(left, right) - elif right_type is Scalar: - # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(right, method)(left, binary.any) + if right_type is Scalar: if method == "ewise_mult": return ScalarEwiseMultExpr(right, left) return ScalarEwiseAddExpr(right, left) - else: # pragma: no cover (sanity) - raise TypeError(f"Bad types for ewise infix: {type(left).__name__}, {type(right).__name__}") + raise TypeError( # pragma: no cover (sanity) + f"Bad types for ewise infix: {type(left).__name__}, {type(right).__name__}" + ) def _matmul_infix_expr(left, right, *, within): @@ -446,55 +535,55 @@ def _matmul_infix_expr(left, right, *, within): if left_type is Vector: if right_type is Matrix or right_type is TransposedMatrix: - method = "vxm" - elif right_type is Vector: - method = "inner" - else: - right = left._expect_type( - right, - (Matrix, TransposedMatrix), - within=within, - argname="right", - ) - elif left_type is Matrix or left_type is TransposedMatrix: + if left._size != right._nrows: + _mismatched(left, right, "vxm", any_pair[BOOL]) + return VectorMatMulExpr(left, right, method_name="vxm", size=right._ncols) if right_type is Vector: - method = "mxv" - elif right_type is Matrix or right_type is TransposedMatrix: - method = "mxm" - else: - right = left._expect_type( - right, - (Vector, Matrix, TransposedMatrix), - within=within, - argname="right", - ) - elif right_type is Vector: - left = right._expect_type( + if left._size != right._size: + _mismatched(left, right, "inner", any_pair[BOOL]) + return ScalarMatMulExpr(left, right) + left._expect_type( + right, + (Matrix, TransposedMatrix, Vector), + within=within, + argname="right", + ) + if left_type is Matrix or left_type is TransposedMatrix: + if right_type is Vector: + if left._ncols != right._size: + _mismatched(left, right, "mxv", any_pair[BOOL]) + return VectorMatMulExpr(left, right, method_name="mxv", size=left._nrows) + if right_type is Matrix or right_type is TransposedMatrix: + if left._ncols != right._nrows: + _mismatched(left, right, "mxm", any_pair[BOOL]) + return MatrixMatMulExpr(left, right, nrows=left._nrows, ncols=right._ncols) + left._expect_type( + right, + (Vector, Matrix, TransposedMatrix), + within=within, + argname="right", + ) + if right_type is Vector: + right._expect_type( left, (Matrix, TransposedMatrix), within=within, argname="left", ) - elif right_type is Matrix or right_type is TransposedMatrix: - left = right._expect_type( + if right_type is Matrix or right_type is TransposedMatrix: + right._expect_type( left, (Vector, Matrix, TransposedMatrix), within=within, argname="left", ) - else: # pragma: no cover (sanity) - raise TypeError( - f"Bad types for matmul infix: {type(left).__name__}, {type(right).__name__}" - ) + raise TypeError( # pragma: no cover (sanity) + f"Bad types for matmul infix: {type(left).__name__}, {type(right).__name__}" + ) - # Create dummy expression to check compatibility of dimensions, etc. - expr = getattr(left, method)(right, any_pair[bool]) - if expr.output_type is Vector: - return VectorMatMulExpr(left, right, method_name=method, size=expr._size) - if expr.output_type is Matrix: - return MatrixMatMulExpr(left, right, nrows=expr._nrows, ncols=expr._ncols) - return ScalarMatMulExpr(left, right) +_ewise_add_expr_types = (MatrixEwiseAddExpr, VectorEwiseAddExpr, ScalarEwiseAddExpr) +_ewise_mult_expr_types = (MatrixEwiseMultExpr, VectorEwiseMultExpr, ScalarEwiseMultExpr) # Import infixmethods, which has side effects from . import infixmethods # noqa: E402, F401 isort:skip diff --git a/graphblas/core/mask.py b/graphblas/core/mask.py index 9ad209095..3bda2188a 100644 --- a/graphblas/core/mask.py +++ b/graphblas/core/mask.py @@ -35,7 +35,7 @@ def new(self, dtype=None, *, complement=False, mask=None, name=None, **opts): """Return a new object with True values determined by the mask(s). By default, the result is True wherever the mask(s) would have been applied, - and empty otherwise. If `complement` is True, then these are switched: + and empty otherwise. If ``complement`` is True, then these are switched: the result is empty where the mask(s) would have been applied, and True otherwise. In other words, these are equivalent if complement is False (and mask keyword is None): @@ -48,14 +48,14 @@ def new(self, dtype=None, *, complement=False, mask=None, name=None, **opts): >>> C(self) << expr >>> C(~result.S) << expr # equivalent when complement is True - This can also efficiently merge two masks by using the `mask=` argument. + This can also efficiently merge two masks by using the ``mask=`` argument. This is equivalent to the following (but uses more efficient recipes): >>> val = Matrix(...) >>> val(self) << True >>> val(mask, replace=True) << val - If `complement=` argument is True, then the *complement* will be returned. + If ``complement=`` argument is True, then the *complement* will be returned. This is equivalent to the following (but uses more efficient recipes): >>> val = Matrix(...) @@ -83,7 +83,7 @@ def new(self, dtype=None, *, complement=False, mask=None, name=None, **opts): def __and__(self, other, **opts): """Return the intersection of two masks as a new mask. - `new_mask = mask1 & mask2` is equivalent to the following: + ``new_mask = mask1 & mask2`` is equivalent to the following: >>> val = Matrix(bool, nrows, ncols) >>> val(mask1) << True @@ -109,7 +109,7 @@ def __and__(self, other, **opts): def __or__(self, other, **opts): """Return the union of two masks as a new mask. - `new_mask = mask1 | mask2` is equivalent to the following: + ``new_mask = mask1 | mask2`` is equivalent to the following: >>> val = Matrix(bool, nrows, ncols) >>> val(mask1) << True diff --git a/graphblas/core/matrix.py b/graphblas/core/matrix.py index 8b9b4b678..bf20cc953 100644 --- a/graphblas/core/matrix.py +++ b/graphblas/core/matrix.py @@ -1,5 +1,4 @@ import itertools -import warnings from collections.abc import Sequence import numpy as np @@ -7,12 +6,19 @@ from .. import backend, binary, monoid, select, semiring from ..dtypes import _INDEX, FP64, INT64, lookup_dtype, unify from ..exceptions import DimensionMismatch, InvalidValue, NoValue, check_status -from . import automethods, ffi, lib, utils +from . import _supports_udfs, automethods, ffi, lib, utils from .base import BaseExpression, BaseType, _check_mask, call from .descriptor import lookup as descriptor_lookup -from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater +from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, InfixExprBase, Updater from .mask import Mask, StructuralMask, ValueMask -from .operator import UNKNOWN_OPCLASS, find_opclass, get_semiring, get_typed_op, op_from_string +from .operator import ( + UNKNOWN_OPCLASS, + _get_typed_op_from_exprs, + find_opclass, + get_semiring, + get_typed_op, + op_from_string, +) from .scalar import ( _COMPLETE, _MATERIALIZE, @@ -28,12 +34,13 @@ class_property, get_order, ints_to_numpy_buffer, + maybe_integral, normalize_values, output_type, values_to_numpy_buffer, wrapdoc, ) -from .vector import Vector, VectorExpression, VectorIndexExpr, _select_mask +from .vector import Vector, VectorExpression, VectorIndexExpr, _isclose_recipe, _select_mask if backend == "suitesparse": from .ss.matrix import ss @@ -66,13 +73,13 @@ def _m_mult_v(updater, left, right, op): updater << left.mxm(right.diag(name="M_temp"), get_semiring(monoid.any, op)) -def _m_union_m(updater, left, right, left_default, right_default, op, dtype): +def _m_union_m(updater, left, right, left_default, right_default, op): mask = updater.kwargs.get("mask") opts = updater.opts - new_left = left.dup(dtype, clear=True) + new_left = left.dup(op.type, clear=True) new_left(mask=mask, **opts) << binary.second(right, left_default) new_left(mask=mask, **opts) << binary.first(left | new_left) - new_right = right.dup(dtype, clear=True) + new_right = right.dup(op.type2, clear=True) new_right(mask=mask, **opts) << binary.second(left, right_default) new_right(mask=mask, **opts) << binary.first(right | new_right) updater << op(new_left & new_right) @@ -91,6 +98,72 @@ def _reposition(updater, indices, chunk): updater[indices] = chunk +def _power(updater, A, n, op): + opts = updater.opts + if n == 0: + v = Vector.from_scalar(op.binaryop.monoid.identity, A._nrows, A.dtype, name="v_diag") + updater << v.diag(name="M_diag") + return + if n == 1: + updater << A + return + # Use repeated squaring: compute A^2, A^4, A^8, etc., and combine terms as needed. + # See `numpy.linalg.matrix_power` for a simpler implementation to understand how this works. + # We reuse `result` and `square` outputs, and use `square_expr` so masks can be applied. + result = square = square_expr = None + n, bit = divmod(n, 2) + while True: + if bit != 0: + # Need to multiply `square_expr` or `A` into the result + if square_expr is not None: + # Need to evaluate `square_expr`; either into final result, or into `square` + if n == 0 and result is None: + # Handle `updater << A @ A` without an intermediate value + updater << square_expr + return + if square is None: + # Create `square = A @ A` + square = square_expr.new(name="Squares", **opts) + else: + # Compute `square << square @ square` + square(**opts) << square_expr + square_expr = None + if result is None: + # First time needing the intermediate result! + if square is None: + # Use `A` if possible to avoid unnecessary copying + # We will detect and handle `result is A` below + result = A + else: + # Copy square as intermediate result + result = square.dup(name="Power", **opts) + elif n == 0: + # All done! No more terms to compute + updater << op(result @ square) + return + elif result is A: + # Now we need to create a new matrix for the intermediate result + result = op(result @ square).new(name="Power", **opts) + else: + # Main branch: multiply `square` into `result` + result(**opts) << op(result @ square) + n, bit = divmod(n, 2) + if square_expr is not None: + # We need to perform another squaring, so evaluate current `square_expr` first + if square is None: + # Create `square` + square = square_expr.new(name="Squares", **opts) + else: + # Compute `square` + square << square_expr + if square is None: + # First iteration! Create expression for first square + square_expr = op(A @ A) + else: + # Expression for repeated squaring + square_expr = op(square @ square) + + class Matrix(BaseType): """Create a new GraphBLAS Sparse Matrix. @@ -104,12 +177,14 @@ class Matrix(BaseType): Number of columns. name : str, optional Name to give the Matrix. This will be displayed in the ``__repr__``. + """ __slots__ = "_nrows", "_ncols", "_parent", "ss" ndim = 2 _is_transposed = False _name_counter = itertools.count() + __networkx_backend__ = "graphblas" __networkx_plugin__ = "graphblas" def __new__(cls, dtype=FP64, nrows=0, ncols=0, *, name=None): @@ -155,8 +230,6 @@ def _as_vector(self, *, name=None): This is SuiteSparse-specific and may change in the future. This does not copy the matrix. """ - from .vector import Vector - if self._ncols != 1: raise ValueError( f"Matrix must have a single column (not {self._ncols}) to be cast to a Vector" @@ -225,6 +298,7 @@ def __delitem__(self, keys, **opts): Examples -------- >>> del M[1, 5] + """ del Updater(self, opts=opts)[keys] @@ -239,6 +313,7 @@ def __getitem__(self, keys): .. code-block:: python subM = M[[1, 3, 5], :].new() + """ resolved_indexes = IndexerResolver(self, keys) shape = resolved_indexes.shape @@ -260,6 +335,7 @@ def __setitem__(self, keys, expr, **opts): .. code-block:: python M[0, 0:3] = 17 + """ Updater(self, opts=opts)[keys] = expr @@ -271,6 +347,7 @@ def __contains__(self, index): .. code-block:: python (10, 15) in M + """ extractor = self[index] if not extractor._is_scalar: @@ -284,7 +361,7 @@ def __contains__(self, index): def __iter__(self): """Iterate over (row, col) indices which are present in the matrix.""" rows, columns, _ = self.to_coo(values=False) - return zip(rows.flat, columns.flat) + return zip(rows.flat, columns.flat, strict=True) def __sizeof__(self): if backend == "suitesparse": @@ -310,6 +387,7 @@ def isequal(self, other, *, check_dtype=False, **opts): See Also -------- :meth:`isclose` : For equality check of floating point dtypes + """ other = self._expect_type( other, (Matrix, TransposedMatrix), within="isequal", argname="other" @@ -355,7 +433,8 @@ def isclose(self, other, *, rel_tol=1e-7, abs_tol=0.0, check_dtype=False, **opts Returns ------- bool - Whether all values of the Matrix are close to the values in `other`. + Whether all values of the Matrix are close to the values in ``other``. + """ other = self._expect_type( other, (Matrix, TransposedMatrix), within="isclose", argname="other" @@ -368,6 +447,8 @@ def isclose(self, other, *, rel_tol=1e-7, abs_tol=0.0, check_dtype=False, **opts return False if self._nvals != other._nvals: return False + if not _supports_udfs: + return _isclose_recipe(self, other, rel_tol, abs_tol, **opts) matches = self.ewise_mult(other, binary.isclose(rel_tol, abs_tol)).new( bool, name="M_isclose", **opts @@ -441,42 +522,6 @@ def resize(self, nrows, ncols): self._nrows = nrows.value self._ncols = ncols.value - def to_values(self, dtype=None, *, rows=True, columns=True, values=True, sort=True): - """Extract the indices and values as a 3-tuple of numpy arrays - corresponding to the COO format of the Matrix. - - .. deprecated:: 2022.11.0 - `Matrix.to_values` will be removed in a future release. - Use `Matrix.to_coo` instead. Will be removed in version 2023.9.0 or later - - Parameters - ---------- - dtype : - Requested dtype for the output values array. - rows : bool, default=True - Whether to return rows; will return `None` for rows if `False` - columns :bool, default=True - Whether to return columns; will return `None` for columns if `False` - values : bool, default=True - Whether to return values; will return `None` for values if `False` - sort : bool, default=True - Whether to require sorted indices. - If internally stored rowwise, the sorting will be first by rows, then by column. - If internally stored columnwise, the sorting will be first by column, then by row. - - Returns - ------- - np.ndarray[dtype=uint64] : Rows - np.ndarray[dtype=uint64] : Columns - np.ndarray : Values - """ - warnings.warn( - "`Matrix.to_values(...)` is deprecated; please use `Matrix.to_coo(...)` instead.", - DeprecationWarning, - stacklevel=2, - ) - return self.to_coo(dtype, rows=rows, columns=columns, values=values, sort=sort) - def to_coo(self, dtype=None, *, rows=True, columns=True, values=True, sort=True): """Extract the indices and values as a 3-tuple of numpy arrays corresponding to the COO format of the Matrix. @@ -486,11 +531,11 @@ def to_coo(self, dtype=None, *, rows=True, columns=True, values=True, sort=True) dtype : Requested dtype for the output values array. rows : bool, default=True - Whether to return rows; will return `None` for rows if `False` + Whether to return rows; will return ``None`` for rows if ``False`` columns :bool, default=True - Whether to return columns; will return `None` for columns if `False` + Whether to return columns; will return ``None`` for columns if ``False`` values : bool, default=True - Whether to return values; will return `None` for values if `False` + Whether to return values; will return ``None`` for values if ``False`` sort : bool, default=True Whether to require sorted indices. If internally stored rowwise, the sorting will be first by rows, then by column. @@ -507,6 +552,7 @@ def to_coo(self, dtype=None, *, rows=True, columns=True, values=True, sort=True) np.ndarray[dtype=uint64] : Rows np.ndarray[dtype=uint64] : Columns np.ndarray : Values + """ if sort and backend == "suitesparse": self.wait() # sort in SS @@ -557,7 +603,7 @@ def to_edgelist(self, dtype=None, *, values=True, sort=True): dtype : Requested dtype for the output values array. values : bool, default=True - Whether to return values; will return `None` for values if `False` + Whether to return values; will return ``None`` for values if ``False`` sort : bool, default=True Whether to require sorted indices. If internally stored rowwise, the sorting will be first by rows, then by column. @@ -573,6 +619,7 @@ def to_edgelist(self, dtype=None, *, values=True, sort=True): ------- np.ndarray[dtype=uint64] : Edgelist np.ndarray : Values + """ rows, columns, values = self.to_coo(dtype, values=values, sort=sort) return (np.column_stack([rows, columns]), values) @@ -583,7 +630,7 @@ def build(self, rows, columns, values, *, dup_op=None, clear=False, nrows=None, The typical use case is to create a new Matrix and insert values at the same time using :meth:`from_coo`. - All the arguments are used identically in :meth:`from_coo`, except for `clear`, which + All the arguments are used identically in :meth:`from_coo`, except for ``clear``, which indicates whether to clear the Matrix prior to adding the new values. """ # TODO: accept `dtype` keyword to match the dtype of `values`? @@ -611,14 +658,15 @@ def build(self, rows, columns, values, *, dup_op=None, clear=False, nrows=None, if not dup_op_given: if not self.dtype._is_udt: dup_op = binary.plus - else: + elif backend != "suitesparse": dup_op = binary.any - # SS:SuiteSparse-specific: we could use NULL for dup_op - dup_op = get_typed_op(dup_op, self.dtype, kind="binary") - if dup_op.opclass == "Monoid": - dup_op = dup_op.binaryop - else: - self._expect_op(dup_op, "BinaryOp", within="build", argname="dup_op") + # SS:SuiteSparse-specific: we use NULL for dup_op + if dup_op is not None: + dup_op = get_typed_op(dup_op, self.dtype, kind="binary") + if dup_op.opclass == "Monoid": + dup_op = dup_op.binaryop + else: + self._expect_op(dup_op, "BinaryOp", within="build", argname="dup_op") rows = _CArray(rows) columns = _CArray(columns) @@ -652,6 +700,7 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): Returns ------- Matrix + """ if dtype is not None or mask is not None or clear: if dtype is None: @@ -662,7 +711,7 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): else: if opts: # Ignore opts for now - descriptor_lookup(**opts) + desc = descriptor_lookup(**opts) # noqa: F841 (keep desc in scope for context) new_mat = ffi_new("GrB_Matrix*") rv = Matrix._from_obj(new_mat, self.dtype, self._nrows, self._ncols, name=name) call("GrB_Matrix_dup", [_Pointer(rv), self]) @@ -683,6 +732,7 @@ def diag(self, k=0, dtype=None, *, name=None, **opts): Returns ------- :class:`~graphblas.Vector` + """ if backend == "suitesparse": from ..ss._core import diag @@ -726,6 +776,7 @@ def wait(self, how="materialize"): Use wait to force completion of the Matrix. Has no effect in `blocking mode <../user_guide/init.html#graphblas-modes>`__. + """ how = how.lower() if how == "materialize": @@ -752,6 +803,7 @@ def get(self, row, col, default=None): Returns ------- Python scalar + """ expr = self[row, col] if expr._is_scalar: @@ -762,61 +814,6 @@ def get(self, row, col, default=None): "Indices should get a single element, which will be extracted as a Python scalar." ) - @classmethod - def from_values( - cls, - rows, - columns, - values, - dtype=None, - *, - nrows=None, - ncols=None, - dup_op=None, - name=None, - ): - """Create a new Matrix from row and column indices and values. - - .. deprecated:: 2022.11.0 - `Matrix.from_values` will be removed in a future release. - Use `Matrix.from_coo` instead. Will be removed in version 2023.9.0 or later - - Parameters - ---------- - rows : list or np.ndarray - Row indices. - columns : list or np.ndarray - Column indices. - values : list or np.ndarray or scalar - List of values. If a scalar is provided, all values will be set to this single value. - dtype : - Data type of the Matrix. If not provided, the values will be inspected - to choose an appropriate dtype. - nrows : int, optional - Number of rows in the Matrix. If not provided, ``nrows`` is computed - from the maximum row index found in ``rows``. - ncols : int, optional - Number of columns in the Matrix. If not provided, ``ncols`` is computed - from the maximum column index found in ``columns``. - dup_op : :class:`~graphblas.core.operator.BinaryOp`, optional - Function used to combine values if duplicate indices are found. - Leaving ``dup_op=None`` will raise an error if duplicates are found. - name : str, optional - Name to give the Matrix. - - Returns - ------- - Matrix - """ - warnings.warn( - "`Matrix.from_values(...)` is deprecated; please use `Matrix.from_coo(...)` instead.", - DeprecationWarning, - stacklevel=2, - ) - return cls.from_coo( - rows, columns, values, dtype, nrows=nrows, ncols=ncols, dup_op=dup_op, name=name - ) - @classmethod def from_coo( cls, @@ -864,6 +861,7 @@ def from_coo( Returns ------- Matrix + """ rows = ints_to_numpy_buffer(rows, np.uint64, name="row indices") columns = ints_to_numpy_buffer(columns, np.uint64, name="column indices") @@ -943,6 +941,7 @@ def from_edgelist( Returns ------- Matrix + """ edgelist_values = None if isinstance(edgelist, np.ndarray): @@ -963,7 +962,7 @@ def from_edgelist( rows = edgelist[:, 0] cols = edgelist[:, 1] else: - unzipped = list(zip(*edgelist)) + unzipped = list(zip(*edgelist, strict=True)) if len(unzipped) == 2: rows, cols = unzipped elif len(unzipped) == 3: @@ -1083,7 +1082,7 @@ def from_csr( Parameters ---------- indptr : list or np.ndarray - Pointers for each row into col_indices and values; `indptr.size == nrows + 1`. + Pointers for each row into col_indices and values; ``indptr.size == nrows + 1``. col_indices : list or np.ndarray Column indices. values : list or np.ndarray or scalar, default 1.0 @@ -1112,6 +1111,7 @@ def from_csr( to_csr Matrix.ss.import_csr io.from_scipy_sparse + """ return cls._from_csx(_CSR_FORMAT, indptr, col_indices, values, dtype, ncols, nrows, name) @@ -1130,7 +1130,7 @@ def from_csc( Parameters ---------- indptr : list or np.ndarray - Pointers for each column into row_indices and values; `indptr.size == ncols + 1`. + Pointers for each column into row_indices and values; ``indptr.size == ncols + 1``. col_indices : list or np.ndarray Column indices. values : list or np.ndarray or scalar, default 1.0 @@ -1159,6 +1159,7 @@ def from_csc( to_csc Matrix.ss.import_csc io.from_scipy_sparse + """ return cls._from_csx(_CSC_FORMAT, indptr, row_indices, values, dtype, nrows, ncols, name) @@ -1219,6 +1220,7 @@ def from_dcsr( to_dcsr Matrix.ss.import_hypercsr io.from_scipy_sparse + """ if backend == "suitesparse": return cls.ss.import_hypercsr( @@ -1303,6 +1305,7 @@ def from_dcsc( to_dcsc Matrix.ss.import_hypercsc io.from_scipy_sparse + """ if backend == "suitesparse": return cls.ss.import_hypercsc( @@ -1364,6 +1367,7 @@ def from_scalar(cls, value, nrows, ncols, dtype=None, *, name=None, **opts): Returns ------- Matrix + """ if type(value) is not Scalar: try: @@ -1417,6 +1421,7 @@ def from_dense(cls, values, missing_value=None, *, dtype=None, name=None, **opts Returns ------- Matrix + """ values, dtype = values_to_numpy_buffer(values, dtype, subarray_after=2) if values.ndim == 0: @@ -1476,6 +1481,7 @@ def to_dense(self, fill_value=None, dtype=None, **opts): Returns ------- np.ndarray + """ max_nvals = self._nrows * self._ncols if fill_value is None or self._nvals == max_nvals: @@ -1551,6 +1557,7 @@ def from_dicts( Returns ------- Matrix + """ order = get_order(order) if isinstance(nested_dicts, Sequence): @@ -1584,7 +1591,7 @@ def from_dicts( # If we know the dtype, then using `np.fromiter` is much faster dtype = lookup_dtype(dtype) if dtype.np_type.subdtype is not None and np.__version__[:5] in {"1.21.", "1.22."}: - values, dtype = values_to_numpy_buffer(list(iter_values), dtype) + values, dtype = values_to_numpy_buffer(list(iter_values), dtype) # FLAKY COVERAGE else: values = np.fromiter(iter_values, dtype.np_type) return getattr(cls, methodname)( @@ -1660,6 +1667,7 @@ def to_csr(self, dtype=None, *, sort=True): from_csr Matrix.ss.export io.to_scipy_sparse + """ if backend == "suitesparse": info = self.ss.export("csr", sort=sort) @@ -1691,6 +1699,7 @@ def to_csc(self, dtype=None, *, sort=True): from_csc Matrix.ss.export io.to_scipy_sparse + """ if backend == "suitesparse": info = self.ss.export("csc", sort=sort) @@ -1725,6 +1734,7 @@ def to_dcsr(self, dtype=None, *, sort=True): from_dcsc Matrix.ss.export io.to_scipy_sparse + """ if backend == "suitesparse": info = self.ss.export("hypercsr", sort=sort) @@ -1767,6 +1777,7 @@ def to_dcsc(self, dtype=None, *, sort=True): from_dcsc Matrix.ss.export io.to_scipy_sparse + """ if backend == "suitesparse": info = self.ss.export("hypercsc", sort=sort) @@ -1804,6 +1815,7 @@ def to_dicts(self, order="rowwise"): Returns ------- dict + """ order = get_order(order) if order == "rowwise": @@ -1815,10 +1827,11 @@ def to_dicts(self, order="rowwise"): cols = cols.tolist() values = values.tolist() return { - row: dict(zip(cols[start:stop], values[start:stop])) + row: dict(zip(cols[start:stop], values[start:stop], strict=True)) for row, (start, stop) in zip( compressed_rows.tolist(), np.lib.stride_tricks.sliding_window_view(indptr, 2).tolist(), + strict=True, ) } # Alternative @@ -1873,18 +1886,41 @@ def ewise_add(self, other, op=monoid.plus): # Functional syntax C << monoid.max(A | B) + """ + return self._ewise_add(other, op) + + def _ewise_add(self, other, op=monoid.plus, is_infix=False): method_name = "ewise_add" - other = self._expect_type( - other, - (Matrix, TransposedMatrix, Vector), - within=method_name, - argname="other", - op=op, - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="binary") - # Per the spec, op may be a semiring, but this is weird, so don't. - self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if is_infix: + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector, MatrixEwiseAddExpr, VectorEwiseAddExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if isinstance(self, MatrixEwiseAddExpr): + self = op(self).new() + if isinstance(other, InfixExprBase): + other = op(other).new() + else: + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector), + within=method_name, + argname="other", + op=op, + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if other.ndim == 1: # Broadcast rowwise from the right if self._ncols != other._size: @@ -1941,14 +1977,41 @@ def ewise_mult(self, other, op=binary.times): # Functional syntax C << binary.gt(A & B) + """ + return self._ewise_mult(other, op) + + def _ewise_mult(self, other, op=binary.times, is_infix=False): method_name = "ewise_mult" - other = self._expect_type( - other, (Matrix, TransposedMatrix, Vector), within=method_name, argname="other", op=op - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="binary") - # Per the spec, op may be a semiring, but this is weird, so don't. - self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if is_infix: + from .infix import MatrixEwiseMultExpr, VectorEwiseMultExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector, MatrixEwiseMultExpr, VectorEwiseMultExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if isinstance(self, MatrixEwiseMultExpr): + self = op(self).new() + if isinstance(other, InfixExprBase): + other = op(other).new() + else: + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector), + within=method_name, + argname="other", + op=op, + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if other.ndim == 1: # Broadcast rowwise from the right if self._ncols != other._size: @@ -2009,12 +2072,35 @@ def ewise_union(self, other, op, left_default, right_default): # Functional syntax C << binary.div(A | B, left_default=1, right_default=1) + """ + return self._ewise_union(other, op, left_default, right_default) + + def _ewise_union(self, other, op, left_default, right_default, is_infix=False): method_name = "ewise_union" - other = self._expect_type( - other, (Matrix, TransposedMatrix, Vector), within=method_name, argname="other", op=op - ) - dtype = self.dtype if self.dtype._is_udt else None + if is_infix: + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector, MatrixEwiseAddExpr, VectorEwiseAddExpr), + within=method_name, + argname="other", + op=op, + ) + temp_op = _get_typed_op_from_exprs(op, self, other, kind="binary") + else: + other = self._expect_type( + other, + (Matrix, TransposedMatrix, Vector), + within=method_name, + argname="other", + op=op, + ) + temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + + left_dtype = temp_op.type + dtype = left_dtype if left_dtype._is_udt else None if type(left_default) is not Scalar: try: left = Scalar.from_value( @@ -2031,6 +2117,8 @@ def ewise_union(self, other, op, left_default, right_default): ) else: left = _as_scalar(left_default, dtype, is_cscalar=False) # pragma: is_grbscalar + right_dtype = temp_op.type2 + dtype = right_dtype if right_dtype._is_udt else None if type(right_default) is not Scalar: try: right = Scalar.from_value( @@ -2047,12 +2135,29 @@ def ewise_union(self, other, op, left_default, right_default): ) else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - scalar_dtype = unify(left.dtype, right.dtype) - nonscalar_dtype = unify(self.dtype, other.dtype) - op = get_typed_op(op, scalar_dtype, nonscalar_dtype, is_left_scalar=True, kind="binary") + + if is_infix: + op1 = _get_typed_op_from_exprs(op, self, right, kind="binary") + op2 = _get_typed_op_from_exprs(op, left, other, kind="binary") + else: + op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") + op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") + if op1 is not op2: + left_dtype = unify(op1.type, op2.type, is_right_scalar=True) + right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) + op = get_typed_op(op, left_dtype, right_dtype, kind="binary") + else: + op = op1 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop + + if is_infix: + if isinstance(self, MatrixEwiseAddExpr): + self = op(self, left_default=left, right_default=right).new() + if isinstance(other, InfixExprBase): + other = op(other, left_default=left, right_default=right).new() + expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if other.ndim == 1: # Broadcast rowwise from the right @@ -2082,11 +2187,10 @@ def ewise_union(self, other, op, left_default, right_default): expr_repr=expr_repr, ) else: - dtype = unify(scalar_dtype, nonscalar_dtype, is_left_scalar=True) expr = MatrixExpression( method_name, None, - [self, left, other, right, _m_union_m, (self, other, left, right, op, dtype)], + [self, left, other, right, _m_union_m, (self, other, left, right, op)], expr_repr=expr_repr, nrows=self._nrows, ncols=self._ncols, @@ -2122,11 +2226,29 @@ def mxv(self, other, op=semiring.plus_times): # Functional syntax C << semiring.min_plus(A @ v) + """ + return self._mxv(other, op) + + def _mxv(self, other, op=semiring.plus_times, is_infix=False): method_name = "mxv" - other = self._expect_type(other, Vector, within=method_name, argname="other", op=op) - op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") - self._expect_op(op, "Semiring", within=method_name, argname="op") + if is_infix: + from .infix import MatrixMatMulExpr, VectorMatMulExpr + + other = self._expect_type( + other, (Vector, VectorMatMulExpr), within=method_name, argname="other", op=op + ) + op = _get_typed_op_from_exprs(op, self, other, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if isinstance(self, MatrixMatMulExpr): + self = op(self).new() + if isinstance(other, VectorMatMulExpr): + other = op(other).new() + else: + other = self._expect_type(other, Vector, within=method_name, argname="other", op=op) + op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + expr = VectorExpression( method_name, "GrB_mxv", @@ -2165,13 +2287,35 @@ def mxm(self, other, op=semiring.plus_times): # Functional syntax C << semiring.min_plus(A @ B) + """ + return self._mxm(other, op) + + def _mxm(self, other, op=semiring.plus_times, is_infix=False): method_name = "mxm" - other = self._expect_type( - other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") - self._expect_op(op, "Semiring", within=method_name, argname="op") + if is_infix: + from .infix import MatrixMatMulExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, MatrixMatMulExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if isinstance(self, MatrixMatMulExpr): + self = op(self).new() + if isinstance(other, MatrixMatMulExpr): + other = op(other).new() + else: + other = self._expect_type( + other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + expr = MatrixExpression( method_name, "GrB_mxm", @@ -2208,6 +2352,7 @@ def kronecker(self, other, op=binary.times): .. code-block:: python C << A.kronecker(B, op=binary.times) + """ method_name = "kronecker" other = self._expect_type( @@ -2264,6 +2409,7 @@ def apply(self, op, right=None, *, left=None): # Functional syntax C << op.abs(A) + """ method_name = "apply" extra_message = ( @@ -2412,6 +2558,7 @@ def select(self, op, thunk=None): # Functional syntax C << select.value(A >= 1) + """ method_name = "select" if isinstance(op, str): @@ -2466,6 +2613,7 @@ def select(self, op, thunk=None): self._expect_op(op, ("SelectOp", "IndexUnaryOp"), within=method_name, argname="op") if thunk._is_cscalar: if thunk.dtype._is_udt: + # NOT COVERED dtype_name = "UDT" thunk = _Pointer(thunk) else: @@ -2505,6 +2653,7 @@ def reduce_rowwise(self, op=monoid.plus): .. code-block:: python w << A.reduce_rowwise(monoid.plus) + """ method_name = "reduce_rowwise" op = get_typed_op(op, self.dtype, kind="binary|aggregator") @@ -2542,6 +2691,7 @@ def reduce_columnwise(self, op=monoid.plus): .. code-block:: python w << A.reduce_columnwise(monoid.plus) + """ method_name = "reduce_columnwise" op = get_typed_op(op, self.dtype, kind="binary|aggregator") @@ -2560,8 +2710,7 @@ def reduce_columnwise(self, op=monoid.plus): ) def reduce_scalar(self, op=monoid.plus, *, allow_empty=True): - """ - Reduce all values in the Matrix into a single value using ``op``. + """Reduce all values in the Matrix into a single value using ``op``. See the `Reduce <../user_guide/operations.html#reduce>`__ section in the User Guide for more details. @@ -2583,6 +2732,7 @@ def reduce_scalar(self, op=monoid.plus, *, allow_empty=True): .. code-block:: python total << A.reduce_scalar(monoid.plus) + """ method_name = "reduce_scalar" op = get_typed_op(op, self.dtype, kind="binary|aggregator") @@ -2643,6 +2793,7 @@ def reposition(self, row_offset, column_offset, *, nrows=None, ncols=None): .. code-block:: python C = A.reposition(1, 2).new() + """ if nrows is None: nrows = self._nrows @@ -2686,6 +2837,185 @@ def reposition(self, row_offset, column_offset, *, nrows=None, ncols=None): dtype=self.dtype, ) + def power(self, n, op=semiring.plus_times): + """Raise a square Matrix to the (positive integer) power ``n``. + + Matrix power is computed by repeated matrix squaring and matrix multiplication. + For a graph as an adjacency matrix, matrix power with default ``plus_times`` + semiring computes the number of walks connecting each pair of nodes. + The result can grow very quickly for large matrices and with larger ``n``. + + Parameters + ---------- + n : int + The exponent must be a nonnegative integer. If n=0, the result will be a diagonal + matrix with values equal to the identity of the semiring's binary operator. + For example, ``plus_times`` will have diagonal values of 1, which is the + identity of ``times``. The binary operator must be associated with a monoid + when n=0 so the identity can be determined; otherwise, ValueError is raised. + op : :class:`~graphblas.core.operator.Semiring` + Semiring used in the computation + + Returns + ------- + MatrixExpression + + Examples + -------- + .. code-block:: python + + C << A.power(4, op=semiring.plus_times) + + # Is equivalent to: + tmp = (A @ A).new() + tmp << tmp @ tmp + C << tmp @ tmp + + # And is more efficient than the naive implementation: + C = A.dup() + for i in range(1, 4): + C << A @ C + + """ + method_name = "power" + if self._nrows != self._ncols: + raise DimensionMismatch(f"power only works for square Matrix; shape is {self.shape}") + if (N := maybe_integral(n)) is None: + raise TypeError(f"n must be a nonnegative integer; got bad type: {type(n)}") + if N < 0: + raise ValueError(f"n must be a nonnegative integer; got: {N}") + op = get_typed_op(op, self.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if N == 0 and op.binaryop.monoid is None: + raise ValueError( + f"Binary operator of {op} semiring does not have a monoid with an identity. " + "When n=0, the result is a diagonal matrix with values equal to the " + "identity of the binaryop, so the binaryop must be associated with a monoid." + ) + return MatrixExpression( + "power", + None, + [self, _power, (self, N, op)], # [*expr_args, func, args] + expr_repr=f"{{0.name}}.power({N}, op={op})", + nrows=self._nrows, + ncols=self._ncols, + dtype=self.dtype, + ) + + def setdiag(self, values, k=0, *, mask=None, accum=None, **opts): + """Set k'th diagonal with a Scalar, Vector, or array. + + This is not a built-in GraphBLAS operation. It is implemented as a recipe. + + Parameters + ---------- + values : Vector or list or np.ndarray or scalar + New values to assign to the diagonal. The length of Vector and array + values must match the size of the diagonal being assigned to. + k : int, default=0 + Which diagonal or off-diagonal to set. For example, set the elements + ``A[i, i+k] = values[i]``. The default, k=0, is the main diagonal. + mask : Mask, optional + Vector or Matrix Mask to control which diagonal elements to set. + If it is Matrix Mask, then only the diagonal is used as the mask. + accum : Monoid or BinaryOp, optional + Operator to use to combine existing diagonal values and new values. + + """ + if (K := maybe_integral(k)) is None: + raise TypeError(f"k must be an integer; got bad type: {type(k)}") + k = K + if k < 0: + if (size := min(self._nrows + k, self._ncols)) <= 0 and k <= -self._nrows: + raise IndexError( + f"k={k} is too small; the k'th diagonal is out of range. " + f"Valid k for Matrix with shape {self._nrows}x{self._ncols}: " + f"{-self._nrows} {'<' if self._nrows else '<='} k " + f"{'<' if self._ncols else '<='} {self._ncols}" + ) + elif (size := min(self._ncols - k, self._nrows)) <= 0 and k > 0 and k >= self._ncols: + raise IndexError( + f"k={k} is too large; the k'th diagonal is out of range. " + f"Valid k for Matrix with shape {self._nrows}x{self._ncols}: " + f"{-self._nrows} {'<' if self._nrows else '<='} k " + f"{'<' if self._ncols else '<='} {self._ncols}" + ) + + # Convert `values` to Vector if necessary (i.e., it's scalar or array) + is_scalar = clear_diag = False + if output_type(values) is Vector: + v = values + clear_diag = accum is None and v._nvals != v._size + elif type(values) is Scalar: + is_scalar = True + else: + dtype = self.dtype if self.dtype._is_udt else None + try: + # Try to make it a Scalar + values = Scalar.from_value(values, dtype, is_cscalar=None, name="") + is_scalar = True + except (TypeError, ValueError): + try: + # Else try to make it a numpy array + values, dtype = values_to_numpy_buffer(values, dtype) + except Exception: + self._expect_type( + values, + (Scalar, Vector, np.ndarray), + within="setdiag", + argname="values", + extra_message="Literal scalars also accepted.", + ) + else: + v = Vector.from_dense(values, dtype=dtype, **opts) + + if is_scalar: + v = Vector.from_scalar(values, size, **opts) + elif v._size != size: + raise DimensionMismatch( + f"Dimensions not compatible for assigning length {v._size} Vector " + f"to {k}'th diagonal of Matrix with shape {self._nrows}x{self._ncols}." + f"The Vector should be size {size}." + ) + + if mask is not None: + mask = _check_mask(mask) + if mask.parent.ndim == 2: + if mask.parent.shape != self.shape: + raise DimensionMismatch( + "Matrix mask in setdiag is the wrong shape; " + f"expected shape {self._nrows}x{self._ncols}, " + f"got {mask.parent._nrows}x{mask.parent._ncols}" + ) + if mask.complement: + mval = type(mask)(mask.parent.diag(k)).new(**opts) + mask = mval.S + M = mval.diag() + else: + M = select.diag(mask.parent, k).new(**opts) + elif mask.parent._size != size: + raise DimensionMismatch( + "Vector mask in setdiag is the wrong length; " + f"expected size {size}, got size {mask.parent._size}." + ) + else: + if mask.complement: + mask = mask.new(**opts).S + M = mask.parent.diag() + if M.shape != self.shape: + M.resize(self._nrows, self._ncols) + mask = type(mask)(M) + + if clear_diag: + self(mask=mask, **opts) << select.offdiag(self, k) + + Diag = v.diag(k) + if Diag.shape != self.shape: + Diag.resize(self._nrows, self._ncols) + if mask is None: + mask = Diag.S + self(accum=accum, mask=mask, **opts) << Diag + ################################## # Extract and Assign index methods ################################## @@ -2703,7 +3033,7 @@ def _extract_element( result = Scalar(dtype, is_cscalar=is_cscalar, name=name) if opts: # Ignore opts for now - descriptor_lookup(**opts) + desc = descriptor_lookup(**opts) # noqa: F841 (keep desc in scope for context) if is_cscalar: dtype_name = "UDT" if dtype._is_udt else dtype.name if ( @@ -3154,14 +3484,11 @@ def _prep_for_assign(self, resolved_indexes, value, mask, is_submask, replace, o mask = _vanilla_subassign_mask( self, mask, rowidx, colidx, replace, opts ) + elif backend == "suitesparse": + cfunc_name = "GxB_Matrix_subassign_Scalar" else: - if backend == "suitesparse": - cfunc_name = "GxB_Matrix_subassign_Scalar" - else: - cfunc_name = "GrB_Matrix_assign_Scalar" - mask = _vanilla_subassign_mask( - self, mask, rowidx, colidx, replace, opts - ) + cfunc_name = "GrB_Matrix_assign_Scalar" + mask = _vanilla_subassign_mask(self, mask, rowidx, colidx, replace, opts) expr_repr = ( "[[{2._expr_name} rows], [{4._expr_name} cols]]" f"({mask.name})" @@ -3257,6 +3584,7 @@ class MatrixExpression(BaseExpression): ndim = 2 output_type = Matrix _is_transposed = False + __networkx_backend__ = "graphblas" __networkx_plugin__ = "graphblas" def __init__( @@ -3357,6 +3685,7 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): mxv = wrapdoc(Matrix.mxv)(property(automethods.mxv)) name = wrapdoc(Matrix.name)(property(automethods.name)).setter(automethods._set_name) nvals = wrapdoc(Matrix.nvals)(property(automethods.nvals)) + power = wrapdoc(Matrix.power)(property(automethods.power)) reduce_columnwise = wrapdoc(Matrix.reduce_columnwise)(property(automethods.reduce_columnwise)) reduce_rowwise = wrapdoc(Matrix.reduce_rowwise)(property(automethods.reduce_rowwise)) reduce_scalar = wrapdoc(Matrix.reduce_scalar)(property(automethods.reduce_scalar)) @@ -3374,7 +3703,6 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): to_dense = wrapdoc(Matrix.to_dense)(property(automethods.to_dense)) to_dicts = wrapdoc(Matrix.to_dicts)(property(automethods.to_dicts)) to_edgelist = wrapdoc(Matrix.to_edgelist)(property(automethods.to_edgelist)) - to_values = wrapdoc(Matrix.to_values)(property(automethods.to_values)) wait = wrapdoc(Matrix.wait)(property(automethods.wait)) # These raise exceptions __array__ = Matrix.__array__ @@ -3398,6 +3726,7 @@ class MatrixIndexExpr(AmbiguousAssignOrExtract): ndim = 2 output_type = Matrix _is_transposed = False + __networkx_backend__ = "graphblas" __networkx_plugin__ = "graphblas" def __init__(self, parent, resolved_indexes, nrows, ncols): @@ -3457,6 +3786,7 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): mxv = wrapdoc(Matrix.mxv)(property(automethods.mxv)) name = wrapdoc(Matrix.name)(property(automethods.name)).setter(automethods._set_name) nvals = wrapdoc(Matrix.nvals)(property(automethods.nvals)) + power = wrapdoc(Matrix.power)(property(automethods.power)) reduce_columnwise = wrapdoc(Matrix.reduce_columnwise)(property(automethods.reduce_columnwise)) reduce_rowwise = wrapdoc(Matrix.reduce_rowwise)(property(automethods.reduce_rowwise)) reduce_scalar = wrapdoc(Matrix.reduce_scalar)(property(automethods.reduce_scalar)) @@ -3474,7 +3804,6 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): to_dense = wrapdoc(Matrix.to_dense)(property(automethods.to_dense)) to_dicts = wrapdoc(Matrix.to_dicts)(property(automethods.to_dicts)) to_edgelist = wrapdoc(Matrix.to_edgelist)(property(automethods.to_edgelist)) - to_values = wrapdoc(Matrix.to_values)(property(automethods.to_values)) wait = wrapdoc(Matrix.wait)(property(automethods.wait)) # These raise exceptions __array__ = Matrix.__array__ @@ -3498,6 +3827,7 @@ class TransposedMatrix: ndim = 2 _is_scalar = False _is_transposed = True + __networkx_backend__ = "graphblas" __networkx_plugin__ = "graphblas" def __init__(self, matrix): @@ -3549,13 +3879,6 @@ def to_coo(self, dtype=None, *, rows=True, columns=True, values=True, sort=True) ) return cols, rows, vals - @wrapdoc(Matrix.to_values) - def to_values(self, dtype=None, *, rows=True, columns=True, values=True, sort=True): - rows, cols, vals = self._matrix.to_values( - dtype, rows=rows, columns=columns, values=values, sort=sort - ) - return cols, rows, vals - @wrapdoc(Matrix.diag) def diag(self, k=0, dtype=None, *, name=None, **opts): return self._matrix.diag(-k, dtype, name=name, **opts) @@ -3618,6 +3941,13 @@ def to_dicts(self, order="rowwise"): reduce_columnwise = Matrix.reduce_columnwise reduce_scalar = Matrix.reduce_scalar reposition = Matrix.reposition + power = Matrix.power + + _ewise_add = Matrix._ewise_add + _ewise_mult = Matrix._ewise_mult + _ewise_union = Matrix._ewise_union + _mxv = Matrix._mxv + _mxm = Matrix._mxm # Operator sugar __or__ = Matrix.__or__ diff --git a/graphblas/core/operator.py b/graphblas/core/operator.py deleted file mode 100644 index bfd03d9df..000000000 --- a/graphblas/core/operator.py +++ /dev/null @@ -1,3598 +0,0 @@ -import inspect -import itertools -import re -from collections.abc import Mapping -from functools import lru_cache, reduce -from operator import getitem, mul -from types import BuiltinFunctionType, FunctionType, ModuleType - -import numba -import numpy as np - -from .. import ( - _STANDARD_OPERATOR_NAMES, - backend, - binary, - config, - indexunary, - monoid, - op, - select, - semiring, - unary, -) -from ..dtypes import ( - BOOL, - FP32, - FP64, - INT8, - INT16, - INT32, - INT64, - UINT8, - UINT16, - UINT32, - UINT64, - _sample_values, - _supports_complex, - lookup_dtype, - unify, -) -from ..exceptions import UdfParseError, check_status_carg -from . import ffi, lib -from .expr import InfixExprBase -from .utils import libget, output_type - -if _supports_complex: - from ..dtypes import FC32, FC64 - -ffi_new = ffi.new -UNKNOWN_OPCLASS = "UnknownOpClass" - -# These now live as e.g. `gb.unary.ss.positioni` -# Deprecations such as `gb.unary.positioni` will be removed in 2023.9.0 or later. -_SS_OPERATORS = { - # unary - "erf", # scipy.special.erf - "erfc", # scipy.special.erfc - "frexpe", # np.frexp[1] - "frexpx", # np.frexp[0] - "lgamma", # scipy.special.loggamma - "tgamma", # scipy.special.gamma - # Positional - # unary - "positioni", - "positioni1", - "positionj", - "positionj1", - # binary - "firsti", - "firsti1", - "firstj", - "firstj1", - "secondi", - "secondi1", - "secondj", - "secondj1", - # semiring - "any_firsti", - "any_firsti1", - "any_firstj", - "any_firstj1", - "any_secondi", - "any_secondi1", - "any_secondj", - "any_secondj1", - "max_firsti", - "max_firsti1", - "max_firstj", - "max_firstj1", - "max_secondi", - "max_secondi1", - "max_secondj", - "max_secondj1", - "min_firsti", - "min_firsti1", - "min_firstj", - "min_firstj1", - "min_secondi", - "min_secondi1", - "min_secondj", - "min_secondj1", - "plus_firsti", - "plus_firsti1", - "plus_firstj", - "plus_firstj1", - "plus_secondi", - "plus_secondi1", - "plus_secondj", - "plus_secondj1", - "times_firsti", - "times_firsti1", - "times_firstj", - "times_firstj1", - "times_secondi", - "times_secondi1", - "times_secondj", - "times_secondj1", -} - - -def _hasop(module, name): - return ( - name in module.__dict__ - or name in module._delayed - or name in getattr(module, "_deprecated", ()) - ) - - -class OpPath: - def __init__(self, parent, name): - self._parent = parent - self._name = name - self._delayed = {} - self._delayed_commutes_to = {} - - def __getattr__(self, key): - if key in self._delayed: - func, kwargs = self._delayed.pop(key) - return func(**kwargs) - self.__getattribute__(key) # raises - - -def _call_op(op, left, right=None, thunk=None, **kwargs): - if right is None and thunk is None: - if isinstance(left, InfixExprBase): - # op(A & B), op(A | B), op(A @ B) - return getattr(left.left, left.method_name)(left.right, op, **kwargs) - if find_opclass(op)[1] == "Semiring": - raise TypeError( - f"Bad type when calling {op!r}. Got type: {type(left)}.\n" - f"Expected an infix expression, such as: {op!r}(A @ B)" - ) - raise TypeError( - f"Bad type when calling {op!r}. Got type: {type(left)}.\n" - "Expected an infix expression or an apply with a Vector or Matrix and a scalar:\n" - f" - {op!r}(A & B)\n" - f" - {op!r}(A, 1)\n" - f" - {op!r}(1, A)" - ) - - # op(A, 1) -> apply (or select if thunk provided) - from .matrix import Matrix, TransposedMatrix - from .vector import Vector - - if (left_type := output_type(left)) in {Vector, Matrix, TransposedMatrix}: - if thunk is not None: - return left.select(op, thunk=thunk, **kwargs) - return left.apply(op, right=right, **kwargs) - if (right_type := output_type(right)) in {Vector, Matrix, TransposedMatrix}: - return right.apply(op, left=left, **kwargs) - - from .scalar import Scalar, _as_scalar - - if left_type is Scalar: - if thunk is not None: - return left.select(op, thunk=thunk, **kwargs) - return left.apply(op, right=right, **kwargs) - if right_type is Scalar: - return right.apply(op, left=left, **kwargs) - try: - left_scalar = _as_scalar(left, is_cscalar=False) - except Exception: - pass - else: - if thunk is not None: - return left_scalar.select(op, thunk=thunk, **kwargs) - return left_scalar.apply(op, right=right, **kwargs) - raise TypeError( - f"Bad types when calling {op!r}. Got types: {type(left)}, {type(right)}.\n" - "Expected an infix expression or an apply with a Vector or Matrix and a scalar:\n" - f" - {op!r}(A & B)\n" - f" - {op!r}(A, 1)\n" - f" - {op!r}(1, A)" - ) - - -_udt_mask_cache = {} - - -def _udt_mask(dtype): - """Create mask to determine which bytes of UDTs to use for equality check.""" - if dtype in _udt_mask_cache: - return _udt_mask_cache[dtype] - if dtype.subdtype is not None: - mask = _udt_mask(dtype.subdtype[0]) - N = reduce(mul, dtype.subdtype[1]) - rv = np.concatenate([mask] * N) - elif dtype.names is not None: - prev_offset = mask = None - masks = [] - for name in dtype.names: - dtype2, offset = dtype.fields[name] - if mask is not None: - masks.append(np.pad(mask, (0, offset - prev_offset - mask.size))) - mask = _udt_mask(dtype2) - prev_offset = offset - masks.append(np.pad(mask, (0, dtype.itemsize - prev_offset - mask.size))) - rv = np.concatenate(masks) - else: - rv = np.ones(dtype.itemsize, dtype=bool) - # assert rv.size == dtype.itemsize - _udt_mask_cache[dtype] = rv - return rv - - -class TypedOpBase: - __slots__ = ( - "parent", - "name", - "type", - "return_type", - "gb_obj", - "gb_name", - "_type2", - "__weakref__", - ) - - def __init__(self, parent, name, type_, return_type, gb_obj, gb_name, dtype2=None): - self.parent = parent - self.name = name - self.type = type_ - self.return_type = return_type - self.gb_obj = gb_obj - self.gb_name = gb_name - self._type2 = dtype2 - - def __repr__(self): - classname = self.opclass.lower() - if classname.endswith("op"): - classname = classname[:-2] - dtype2 = "" if self._type2 is None else f", {self._type2.name}" - return f"{classname}.{self.name}[{self.type.name}{dtype2}]" - - @property - def _carg(self): - return self.gb_obj - - @property - def is_positional(self): - return self.parent.is_positional - - def __reduce__(self): - if self._type2 is None or self.type == self._type2: - return (getitem, (self.parent, self.type)) - return (getitem, (self.parent, (self.type, self._type2))) - - -class TypedBuiltinUnaryOp(TypedOpBase): - __slots__ = () - opclass = "UnaryOp" - - def __call__(self, val): - from .matrix import Matrix, TransposedMatrix - from .vector import Vector - - if (typ := output_type(val)) in {Vector, Matrix, TransposedMatrix}: - return val.apply(self) - from .scalar import Scalar, _as_scalar - - if typ is Scalar: - return val.apply(self) - try: - scalar = _as_scalar(val, is_cscalar=False) - except Exception: - pass - else: - return scalar.apply(self) - raise TypeError( - f"Bad type when calling {self!r}.\n" - " - Expected type: Scalar, Vector, Matrix, TransposedMatrix.\n" - f" - Got: {type(val)}.\n" - "Calling a UnaryOp is syntactic sugar for calling apply. " - f"For example, `A.apply({self!r})` is the same as `{self!r}(A)`." - ) - - -class TypedBuiltinIndexUnaryOp(TypedOpBase): - __slots__ = () - opclass = "IndexUnaryOp" - - def __call__(self, val, thunk=None): - if thunk is None: - thunk = False # most basic form of 0 when unifying dtypes - return _call_op(self, val, right=thunk) - - -class TypedBuiltinSelectOp(TypedOpBase): - __slots__ = () - opclass = "SelectOp" - - def __call__(self, val, thunk=None): - if thunk is None: - thunk = False # most basic form of 0 when unifying dtypes - return _call_op(self, val, thunk=thunk) - - -class TypedBuiltinBinaryOp(TypedOpBase): - __slots__ = () - opclass = "BinaryOp" - - def __call__(self, left, right=None, *, left_default=None, right_default=None): - if left_default is not None or right_default is not None: - if ( - left_default is None - or right_default is None - or right is not None - or not isinstance(left, InfixExprBase) - or left.method_name != "ewise_add" - ): - raise TypeError( - "Specifying `left_default` or `right_default` keyword arguments implies " - "performing `ewise_union` operation with infix notation.\n" - "There is only one valid way to do this:\n\n" - f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y " - "are Vectors or Matrices, and left_default and right_default are scalars." - ) - return left.left.ewise_union(left.right, self, left_default, right_default) - return _call_op(self, left, right) - - @property - def monoid(self): - rv = getattr(monoid, self.name, None) - if rv is not None and self.type in rv._typed_ops: - return rv[self.type] - - @property - def commutes_to(self): - commutes_to = self.parent.commutes_to - if commutes_to is not None and (self.type in commutes_to._typed_ops or self.type._is_udt): - return commutes_to[self.type] - - @property - def _semiring_commutes_to(self): - commutes_to = self.parent._semiring_commutes_to - if commutes_to is not None and (self.type in commutes_to._typed_ops or self.type._is_udt): - return commutes_to[self.type] - - @property - def is_commutative(self): - return self.commutes_to is self - - @property - def type2(self): - return self.type if self._type2 is None else self._type2 - - -class TypedBuiltinMonoid(TypedOpBase): - __slots__ = "_identity" - opclass = "Monoid" - is_commutative = True - - def __init__(self, parent, name, type_, return_type, gb_obj, gb_name): - super().__init__(parent, name, type_, return_type, gb_obj, gb_name) - self._identity = None - - def __call__(self, left, right=None, *, left_default=None, right_default=None): - if left_default is not None or right_default is not None: - if ( - left_default is None - or right_default is None - or right is not None - or not isinstance(left, InfixExprBase) - or left.method_name != "ewise_add" - ): - raise TypeError( - "Specifying `left_default` or `right_default` keyword arguments implies " - "performing `ewise_union` operation with infix notation.\n" - "There is only one valid way to do this:\n\n" - f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y " - "are Vectors or Matrices, and left_default and right_default are scalars." - ) - return left.left.ewise_union(left.right, self, left_default, right_default) - return _call_op(self, left, right) - - @property - def identity(self): - if self._identity is None: - from .recorder import skip_record - from .vector import Vector - - with skip_record: - self._identity = ( - Vector(self.type, size=1, name="").reduce(self, allow_empty=False).new().value - ) - return self._identity - - @property - def binaryop(self): - return getattr(binary, self.name)[self.type] - - @property - def commutes_to(self): - return self - - @property - def type2(self): - return self.type - - @property - def is_idempotent(self): - """True if ``monoid(x, x) == x`` for any x.""" - return self.parent.is_idempotent - - -class TypedBuiltinSemiring(TypedOpBase): - __slots__ = () - opclass = "Semiring" - - def __call__(self, left, right=None): - if right is not None: - raise TypeError( - f"Bad types when calling {self!r}. Got types: {type(left)}, {type(right)}.\n" - f"Expected an infix expression, such as: {self!r}(A @ B)" - ) - return _call_op(self, left) - - @property - def binaryop(self): - name = self.name.split("_", 1)[1] - if name in _SS_OPERATORS: - binop = binary._deprecated[name] - else: - binop = getattr(binary, name) - return binop[self.type] - - @property - def monoid(self): - monoid_name, binary_name = self.name.split("_", 1) - if binary_name in _SS_OPERATORS: - binop = binary._deprecated[binary_name] - else: - binop = getattr(binary, binary_name) - binop = binop[self.type] - val = getattr(monoid, monoid_name) - return val[binop.return_type] - - @property - def commutes_to(self): - binop = self.binaryop - commutes_to = binop._semiring_commutes_to or binop.commutes_to - if commutes_to is None: - return - if commutes_to is binop: - return self - return get_semiring(self.monoid, commutes_to) - - @property - def is_commutative(self): - return self.binaryop.is_commutative - - type2 = TypedBuiltinBinaryOp.type2 - - -class TypedUserUnaryOp(TypedOpBase): - __slots__ = () - opclass = "UnaryOp" - - def __init__(self, parent, name, type_, return_type, gb_obj): - super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}") - - @property - def orig_func(self): - return self.parent.orig_func - - @property - def _numba_func(self): - return self.parent._numba_func - - __call__ = TypedBuiltinUnaryOp.__call__ - - -class TypedUserIndexUnaryOp(TypedOpBase): - __slots__ = () - opclass = "IndexUnaryOp" - - def __init__(self, parent, name, type_, return_type, gb_obj, dtype2=None): - super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) - - @property - def orig_func(self): - return self.parent.orig_func - - @property - def _numba_func(self): - return self.parent._numba_func - - __call__ = TypedBuiltinIndexUnaryOp.__call__ - - -class TypedUserSelectOp(TypedOpBase): - __slots__ = () - opclass = "SelectOp" - - def __init__(self, parent, name, type_, return_type, gb_obj): - super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}") - - @property - def orig_func(self): - return self.parent.orig_func - - @property - def _numba_func(self): - return self.parent._numba_func - - __call__ = TypedBuiltinSelectOp.__call__ - - -class TypedUserBinaryOp(TypedOpBase): - __slots__ = "_monoid" - opclass = "BinaryOp" - - def __init__(self, parent, name, type_, return_type, gb_obj, dtype2=None): - super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) - self._monoid = None - - @property - def monoid(self): - if self._monoid is None: - monoid = self.parent.monoid - if monoid is not None and self.type in monoid: - self._monoid = monoid[self.type] - return self._monoid - - commutes_to = TypedBuiltinBinaryOp.commutes_to - _semiring_commutes_to = TypedBuiltinBinaryOp._semiring_commutes_to - is_commutative = TypedBuiltinBinaryOp.is_commutative - orig_func = TypedUserUnaryOp.orig_func - _numba_func = TypedUserUnaryOp._numba_func - type2 = TypedBuiltinBinaryOp.type2 - __call__ = TypedBuiltinBinaryOp.__call__ - - -class TypedUserMonoid(TypedOpBase): - __slots__ = "binaryop", "identity" - opclass = "Monoid" - is_commutative = True - - def __init__(self, parent, name, type_, return_type, gb_obj, binaryop, identity): - super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}") - self.binaryop = binaryop - self.identity = identity - binaryop._monoid = self - - commutes_to = TypedBuiltinMonoid.commutes_to - type2 = TypedBuiltinMonoid.type2 - is_idempotent = TypedBuiltinMonoid.is_idempotent - __call__ = TypedBuiltinMonoid.__call__ - - -class TypedUserSemiring(TypedOpBase): - __slots__ = "monoid", "binaryop" - opclass = "Semiring" - - def __init__(self, parent, name, type_, return_type, gb_obj, monoid, binaryop, dtype2=None): - super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) - self.monoid = monoid - self.binaryop = binaryop - - commutes_to = TypedBuiltinSemiring.commutes_to - is_commutative = TypedBuiltinSemiring.is_commutative - type2 = TypedBuiltinBinaryOp.type2 - __call__ = TypedBuiltinSemiring.__call__ - - -def _deserialize_parameterized(parameterized_op, args, kwargs): - return parameterized_op(*args, **kwargs) - - -class ParameterizedUdf: - __slots__ = "name", "__call__", "_anonymous", "__weakref__" - is_positional = False - _custom_dtype = None - - def __init__(self, name, anonymous): - self.name = name - self._anonymous = anonymous - # lru_cache per instance - method = self._call.__get__(self, type(self)) - self.__call__ = lru_cache(maxsize=1024)(method) - - def _call(self, *args, **kwargs): - raise NotImplementedError - - -class ParameterizedUnaryOp(ParameterizedUdf): - __slots__ = "func", "__signature__", "_is_udt" - - def __init__(self, name, func, *, anonymous=False, is_udt=False): - self.func = func - self.__signature__ = inspect.signature(func) - self._is_udt = is_udt - if name is None: - name = getattr(func, "__name__", name) - super().__init__(name, anonymous) - - def _call(self, *args, **kwargs): - unary = self.func(*args, **kwargs) - unary._parameterized_info = (self, args, kwargs) - return UnaryOp.register_anonymous(unary, self.name, is_udt=self._is_udt) - - def __reduce__(self): - name = f"unary.{self.name}" - if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: # pragma: no cover - return name - return (self._deserialize, (self.name, self.func, self._anonymous)) - - @staticmethod - def _deserialize(name, func, anonymous): - if anonymous: - return UnaryOp.register_anonymous(func, name, parameterized=True) - if (rv := UnaryOp._find(name)) is not None: - return rv - return UnaryOp.register_new(name, func, parameterized=True) - - -class ParameterizedIndexUnaryOp(ParameterizedUdf): - __slots__ = "func", "__signature__", "_is_udt" - - def __init__(self, name, func, *, anonymous=False, is_udt=False): - self.func = func - self.__signature__ = inspect.signature(func) - self._is_udt = is_udt - if name is None: - name = getattr(func, "__name__", name) - super().__init__(name, anonymous) - - def _call(self, *args, **kwargs): - indexunary = self.func(*args, **kwargs) - indexunary._parameterized_info = (self, args, kwargs) - return IndexUnaryOp.register_anonymous(indexunary, self.name, is_udt=self._is_udt) - - def __reduce__(self): - name = f"indexunary.{self.name}" - if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.func, self._anonymous)) - - @staticmethod - def _deserialize(name, func, anonymous): - if anonymous: - return IndexUnaryOp.register_anonymous(func, name, parameterized=True) - if (rv := IndexUnaryOp._find(name)) is not None: - return rv - return IndexUnaryOp.register_new(name, func, parameterized=True) - - -class ParameterizedSelectOp(ParameterizedUdf): - __slots__ = "func", "__signature__", "_is_udt" - - def __init__(self, name, func, *, anonymous=False, is_udt=False): - self.func = func - self.__signature__ = inspect.signature(func) - self._is_udt = is_udt - if name is None: - name = getattr(func, "__name__", name) - super().__init__(name, anonymous) - - def _call(self, *args, **kwargs): - sel = self.func(*args, **kwargs) - sel._parameterized_info = (self, args, kwargs) - return SelectOp.register_anonymous(sel, self.name, is_udt=self._is_udt) - - def __reduce__(self): - name = f"select.{self.name}" - if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.func, self._anonymous)) - - @staticmethod - def _deserialize(name, func, anonymous): - if anonymous: - return SelectOp.register_anonymous(func, name, parameterized=True) - if (rv := SelectOp._find(name)) is not None: - return rv - return SelectOp.register_new(name, func, parameterized=True) - - -class ParameterizedBinaryOp(ParameterizedUdf): - __slots__ = "func", "__signature__", "_monoid", "_cached_call", "_commutes_to", "_is_udt" - - def __init__(self, name, func, *, anonymous=False, is_udt=False): - self.func = func - self.__signature__ = inspect.signature(func) - self._monoid = None - self._is_udt = is_udt - if name is None: - name = getattr(func, "__name__", name) - super().__init__(name, anonymous) - method = self._call_to_cache.__get__(self, type(self)) - self._cached_call = lru_cache(maxsize=1024)(method) - self.__call__ = self._call - self._commutes_to = None - - def _call_to_cache(self, *args, **kwargs): - binary = self.func(*args, **kwargs) - binary._parameterized_info = (self, args, kwargs) - return BinaryOp.register_anonymous(binary, self.name, is_udt=self._is_udt) - - def _call(self, *args, **kwargs): - binop = self._cached_call(*args, **kwargs) - if self._monoid is not None and binop._monoid is None: - # This is all a bit funky. We try our best to associate a binaryop - # to a monoid. So, if we made a ParameterizedMonoid using this object, - # then try to create a monoid with the given arguments. - binop._monoid = binop # temporary! - try: - # If this call is successful, then it will set `binop._monoid` - self._monoid(*args, **kwargs) # pylint: disable=not-callable - except Exception: - binop._monoid = None - # assert binop._monoid is not binop - if self.is_commutative: - binop._commutes_to = binop - # Don't bother yet with creating `binop.commutes_to` (but we could!) - return binop - - @property - def monoid(self): - return self._monoid - - @property - def commutes_to(self): - if type(self._commutes_to) is str: - self._commutes_to = BinaryOp._find(self._commutes_to) - return self._commutes_to - - is_commutative = TypedBuiltinBinaryOp.is_commutative - - def __reduce__(self): - name = f"binary.{self.name}" - if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.func, self._anonymous)) - - @staticmethod - def _deserialize(name, func, anonymous): - if anonymous: - return BinaryOp.register_anonymous(func, name, parameterized=True) - if (rv := BinaryOp._find(name)) is not None: - return rv - return BinaryOp.register_new(name, func, parameterized=True) - - -class ParameterizedMonoid(ParameterizedUdf): - __slots__ = "binaryop", "identity", "_is_idempotent", "__signature__" - is_commutative = True - - def __init__(self, name, binaryop, identity, *, is_idempotent=False, anonymous=False): - if type(binaryop) is not ParameterizedBinaryOp: - raise TypeError("binaryop must be parameterized") - self.binaryop = binaryop - self.__signature__ = binaryop.__signature__ - if callable(identity): - # assume it must be parameterized as well, so signature must match - sig = inspect.signature(identity) - if sig != self.__signature__: - raise ValueError( - "Signatures of binaryop and identity passed to " - f"{type(self).__name__} must be the same. Got:\n" - f" binaryop{self.__signature__}\n" - " !=\n" - f" identity{sig}" - ) - self.identity = identity - self._is_idempotent = is_idempotent - if name is None: - name = binaryop.name - super().__init__(name, anonymous) - binaryop._monoid = self - # clear binaryop cache so it can be associated with this monoid - binaryop._cached_call.cache_clear() - - def _call(self, *args, **kwargs): - binary = self.binaryop(*args, **kwargs) - identity = self.identity - if callable(identity): - identity = identity(*args, **kwargs) - return Monoid.register_anonymous( - binary, identity, self.name, is_idempotent=self._is_idempotent - ) - - commutes_to = TypedBuiltinMonoid.commutes_to - - @property - def is_idempotent(self): - """True if ``monoid(x, x) == x`` for any x.""" - return self._is_idempotent - - def __reduce__(self): - name = f"monoid.{self.name}" - if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: # pragma: no cover - return name - return (self._deserialize, (self.name, self.binaryop, self.identity, self._anonymous)) - - @staticmethod - def _deserialize(name, binaryop, identity, anonymous): - if anonymous: - return Monoid.register_anonymous(binaryop, identity, name) - if (rv := Monoid._find(name)) is not None: - return rv - return Monoid.register_new(name, binaryop, identity) - - -class ParameterizedSemiring(ParameterizedUdf): - __slots__ = "monoid", "binaryop", "__signature__" - - def __init__(self, name, monoid, binaryop, *, anonymous=False): - if type(monoid) not in {ParameterizedMonoid, Monoid}: - raise TypeError("monoid must be of type Monoid or ParameterizedMonoid") - if type(binaryop) is ParameterizedBinaryOp: - self.__signature__ = binaryop.__signature__ - if type(monoid) is ParameterizedMonoid and monoid.__signature__ != self.__signature__: - raise ValueError( - "Signatures of monoid and binaryop passed to " - f"{type(self).__name__} must be the same. Got:\n" - f" monoid{monoid.__signature__}\n" - " !=\n" - f" binaryop{self.__signature__}\n\n" - "Perhaps call monoid or binaryop with parameters before creating the semiring." - ) - elif type(binaryop) is BinaryOp: - if type(monoid) is Monoid: - raise TypeError("At least one of monoid or binaryop must be parameterized") - self.__signature__ = monoid.__signature__ - else: - raise TypeError("binaryop must be of type BinaryOp or ParameterizedBinaryOp") - self.monoid = monoid - self.binaryop = binaryop - if name is None: - name = f"{monoid.name}_{binaryop.name}" - super().__init__(name, anonymous) - - def _call(self, *args, **kwargs): - monoid = self.monoid - if type(monoid) is ParameterizedMonoid: - monoid = monoid(*args, **kwargs) - binary = self.binaryop - if type(binary) is ParameterizedBinaryOp: - binary = binary(*args, **kwargs) - return Semiring.register_anonymous(monoid, binary, self.name) - - commutes_to = TypedBuiltinSemiring.commutes_to - is_commutative = TypedBuiltinSemiring.is_commutative - - def __reduce__(self): - name = f"semiring.{self.name}" - if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: # pragma: no cover - return name - return (self._deserialize, (self.name, self.monoid, self.binaryop, self._anonymous)) - - @staticmethod - def _deserialize(name, monoid, binaryop, anonymous): - if anonymous: - return Semiring.register_anonymous(monoid, binaryop, name) - if (rv := Semiring._find(name)) is not None: - return rv - return Semiring.register_new(name, monoid, binaryop) - - -_VARNAMES = tuple(x for x in dir(lib) if x[0] != "_") - - -class OpBase: - __slots__ = ( - "name", - "_typed_ops", - "types", - "coercions", - "_anonymous", - "_udt_types", - "_udt_ops", - "__weakref__", - ) - _parse_config = None - _initialized = False - _module = None - _positional = None - - def __init__(self, name, *, anonymous=False): - self.name = name - self._typed_ops = {} - self.types = {} - self.coercions = {} - self._anonymous = anonymous - self._udt_types = None - self._udt_ops = None - - def __repr__(self): - return f"{self._modname}.{self.name}" - - def __getitem__(self, type_): - if type(type_) is tuple: - dtype1, dtype2 = type_ - dtype1 = lookup_dtype(dtype1) - dtype2 = lookup_dtype(dtype2) - return get_typed_op(self, dtype1, dtype2) - if not self._is_udt: - type_ = lookup_dtype(type_) - if type_ not in self._typed_ops: - if self._udt_types is None: - if self.is_positional: - return self._typed_ops[UINT64] - raise KeyError(f"{self.name} does not work with {type_}") - else: - return self._typed_ops[type_] - # This is a UDT or is able to operate on UDTs such as `first` any `any` - dtype = lookup_dtype(type_) - return self._compile_udt(dtype, dtype) - - def _add(self, op): - self._typed_ops[op.type] = op - self.types[op.type] = op.return_type - - def __delitem__(self, type_): - type_ = lookup_dtype(type_) - del self._typed_ops[type_] - del self.types[type_] - - def __contains__(self, type_): - try: - self[type_] - except (TypeError, KeyError, numba.NumbaError): - return False - return True - - @classmethod - def _remove_nesting(cls, funcname, *, module=None, modname=None, strict=True): - if module is None: - module = cls._module - if modname is None: - modname = cls._modname - if "." not in funcname: - if strict and _hasop(module, funcname): - raise AttributeError(f"{modname}.{funcname} is already defined") - else: - path, funcname = funcname.rsplit(".", 1) - for folder in path.split("."): - if not _hasop(module, folder): - setattr(module, folder, OpPath(module, folder)) - module = getattr(module, folder) - modname = f"{modname}.{folder}" - if not isinstance(module, (OpPath, ModuleType)): - raise AttributeError( - f"{modname} is already defined. Cannot use as a nested path." - ) - if strict and _hasop(module, funcname): - raise AttributeError(f"{path}.{funcname} is already defined") - return module, funcname - - @classmethod - def _find(cls, funcname): - rv = cls._module - for attr in funcname.split("."): - if attr in getattr(rv, "_deprecated", ()): - rv = rv._deprecated[attr] - else: - rv = getattr(rv, attr, None) - if rv is None: - break - return rv - - @classmethod - def _initialize(cls, include_in_ops=True): - """ - include_in_ops determines whether the operators are included in the - `gb.ops` namespace in addition to the defined module. - """ - if cls._initialized: # pragma: no cover (safety) - return - # Read in the parse configs - trim_from_front = cls._parse_config.get("trim_from_front", 0) - delete_exact = cls._parse_config.get("delete_exact", None) - num_underscores = cls._parse_config["num_underscores"] - - for re_str, return_prefix in [ - ("re_exprs", None), - ("re_exprs_return_bool", "BOOL"), - ("re_exprs_return_float", "FP"), - ("re_exprs_return_complex", "FC"), - ]: - if re_str not in cls._parse_config: - continue - if "complex" in re_str and not _supports_complex: - continue - for r in reversed(cls._parse_config[re_str]): - for varname in _VARNAMES: - m = r.match(varname) - if m: - # Parse function into name and datatype - gb_name = m.string - splitname = gb_name[trim_from_front:].split("_") - if delete_exact and delete_exact in splitname: - splitname.remove(delete_exact) - if len(splitname) == num_underscores + 1: - *splitname, type_ = splitname - else: - type_ = None - name = "_".join(splitname).lower() - # Create object for name unless it already exists - if not _hasop(cls._module, name): - if backend == "suitesparse" and name in _SS_OPERATORS: - fullname = f"ss.{name}" - else: - fullname = name - if cls._positional is None: - obj = cls(fullname) - else: - obj = cls(fullname, is_positional=name in cls._positional) - if name in _SS_OPERATORS: - if backend == "suitesparse": - setattr(cls._module.ss, name, obj) - cls._module._deprecated[name] = obj - if include_in_ops and not _hasop(op, name): # pragma: no branch - op._deprecated[name] = obj - if backend == "suitesparse": - setattr(op.ss, name, obj) - else: - setattr(cls._module, name, obj) - if include_in_ops and not _hasop(op, name): - setattr(op, name, obj) - _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{fullname}") - elif name in _SS_OPERATORS: - obj = cls._module._deprecated[name] - else: - obj = getattr(cls._module, name) - gb_obj = getattr(lib, varname) - # Determine return type - if return_prefix == "BOOL": - return_type = BOOL - if type_ is None: - type_ = BOOL - else: - if type_ is None: # pragma: no cover - raise TypeError(f"Unable to determine return type for {varname}") - if return_prefix is None: - return_type = type_ - else: - # Grab the number of bits from type_ - num_bits = type_[-2:] - if num_bits not in {"32", "64"}: # pragma: no cover (safety) - raise TypeError(f"Unexpected number of bits: {num_bits}") - return_type = f"{return_prefix}{num_bits}" - builtin_op = cls._typed_class( - obj, - name, - lookup_dtype(type_), - lookup_dtype(return_type), - gb_obj, - gb_name, - ) - obj._add(builtin_op) - - @classmethod - def _deserialize(cls, name, *args): - if (rv := cls._find(name)) is not None: - return rv # Should we verify this is what the user expects? - return cls.register_new(name, *args) - - -def _identity(x): - return x # pragma: no cover (numba) - - -def _one(x): - return 1 # pragma: no cover (numba) - - -class UnaryOp(OpBase): - """Takes one input and returns one output, possibly of a different data type. - - Built-in and registered UnaryOps are located in the ``graphblas.unary`` namespace - as well as in the ``graphblas.ops`` combined namespace. - """ - - __slots__ = "orig_func", "is_positional", "_is_udt", "_numba_func" - _custom_dtype = None - _module = unary - _modname = "unary" - _typed_class = TypedBuiltinUnaryOp - _parse_config = { - "trim_from_front": 4, - "num_underscores": 1, - "re_exprs": [ - re.compile( - "^GrB_(IDENTITY|AINV|MINV|ABS|BNOT)" - "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64|FC32|FC64)$" - ), - re.compile( - "^GxB_(LNOT|ONE|POSITIONI1|POSITIONI|POSITIONJ1|POSITIONJ)" - "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile( - "^GxB_(SQRT|LOG|EXP|LOG2|SIN|COS|TAN|ACOS|ASIN|ATAN|SINH|COSH|TANH|ACOSH" - "|ASINH|ATANH|SIGNUM|CEIL|FLOOR|ROUND|TRUNC|EXP2|EXPM1|LOG10|LOG1P)" - "_(FP32|FP64|FC32|FC64)$" - ), - re.compile("^GxB_(LGAMMA|TGAMMA|ERF|ERFC|FREXPX|FREXPE|CBRT)_(FP32|FP64)$"), - re.compile("^GxB_(IDENTITY|AINV|MINV|ONE|CONJ)_(FC32|FC64)$"), - ], - "re_exprs_return_bool": [ - re.compile("^GrB_LNOT$"), - re.compile("^GxB_(ISINF|ISNAN|ISFINITE)_(FP32|FP64|FC32|FC64)$"), - ], - "re_exprs_return_float": [re.compile("^GxB_(CREAL|CIMAG|CARG|ABS)_(FC32|FC64)$")], - } - _positional = {"positioni", "positioni1", "positionj", "positionj1"} - - @classmethod - def _build(cls, name, func, *, anonymous=False, is_udt=False): - if type(func) is not FunctionType: - raise TypeError(f"UDF argument must be a function, not {type(func)}") - if name is None: - name = getattr(func, "__name__", "") - success = False - unary_udf = numba.njit(func) - new_type_obj = cls(name, func, anonymous=anonymous, is_udt=is_udt, numba_func=unary_udf) - return_types = {} - nt = numba.types - if not is_udt: - for type_ in _sample_values: - sig = (type_.numba_type,) - try: - unary_udf.compile(sig) - except numba.TypingError: - continue - ret_type = lookup_dtype(unary_udf.overloads[sig].signature.return_type) - if ret_type != type_ and ( - ("INT" in ret_type.name and "INT" in type_.name) - or ("FP" in ret_type.name and "FP" in type_.name) - or ("FC" in ret_type.name and "FC" in type_.name) - or (type_ == UINT64 and ret_type == FP64 and return_types.get(INT64) == INT64) - ): - # Downcast `ret_type` to `type_`. - # This is what users want most of the time, but we can't make a perfect rule. - # There should be a way for users to be explicit. - ret_type = type_ - elif type_ == BOOL and ret_type == INT64 and return_types.get(INT8) == INT8: - ret_type = INT8 - - # Numba is unable to handle BOOL correctly right now, but we have a workaround - # See: https://github.com/numba/numba/issues/5395 - # We're relying on coercion behaving correctly here - input_type = INT8 if type_ == BOOL else type_ - return_type = INT8 if ret_type == BOOL else ret_type - - # Build wrapper because GraphBLAS wants pointers and void return - wrapper_sig = nt.void( - nt.CPointer(return_type.numba_type), - nt.CPointer(input_type.numba_type), - ) - - if type_ == BOOL: - if ret_type == BOOL: - - def unary_wrapper(z, x): - z[0] = bool(unary_udf(bool(x[0]))) # pragma: no cover (numba) - - else: - - def unary_wrapper(z, x): - z[0] = unary_udf(bool(x[0])) # pragma: no cover (numba) - - elif ret_type == BOOL: - - def unary_wrapper(z, x): - z[0] = bool(unary_udf(x[0])) # pragma: no cover (numba) - - else: - - def unary_wrapper(z, x): - z[0] = unary_udf(x[0]) # pragma: no cover (numba) - - unary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(unary_wrapper) - new_unary = ffi_new("GrB_UnaryOp*") - check_status_carg( - lib.GrB_UnaryOp_new( - new_unary, unary_wrapper.cffi, ret_type.gb_obj, type_.gb_obj - ), - "UnaryOp", - new_unary, - ) - op = TypedUserUnaryOp(new_type_obj, name, type_, ret_type, new_unary[0]) - new_type_obj._add(op) - success = True - return_types[type_] = ret_type - if success or is_udt: - return new_type_obj - raise UdfParseError("Unable to parse function using Numba") - - def _compile_udt(self, dtype, dtype2): - if dtype in self._udt_types: - return self._udt_ops[dtype] - - numba_func = self._numba_func - sig = (dtype.numba_type,) - numba_func.compile(sig) # Should we catch and give additional error message? - ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type) - - unary_wrapper, wrapper_sig = _get_udt_wrapper(numba_func, ret_type, dtype) - unary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(unary_wrapper) - new_unary = ffi_new("GrB_UnaryOp*") - check_status_carg( - lib.GrB_UnaryOp_new(new_unary, unary_wrapper.cffi, ret_type._carg, dtype._carg), - "UnaryOp", - new_unary, - ) - op = TypedUserUnaryOp(self, self.name, dtype, ret_type, new_unary[0]) - self._udt_types[dtype] = ret_type - self._udt_ops[dtype] = op - return op - - @classmethod - def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): - """Register a UnaryOp without registering it in the ``graphblas.unary`` namespace. - - Because it is not registered in the namespace, the name is optional. - """ - if parameterized: - return ParameterizedUnaryOp(name, func, anonymous=True, is_udt=is_udt) - return cls._build(name, func, anonymous=True, is_udt=is_udt) - - @classmethod - def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): - """Register a UnaryOp. The name will be used to identify the UnaryOp in the - ``graphblas.unary`` namespace. - - >>> gb.core.operator.UnaryOp.register_new("plus_one", lambda x: x + 1) - >>> dir(gb.unary) - [..., 'plus_one', ...] - """ - module, funcname = cls._remove_nesting(name) - if lazy: - module._delayed[funcname] = ( - cls.register_new, - {"name": name, "func": func, "parameterized": parameterized}, - ) - elif parameterized: - unary_op = ParameterizedUnaryOp(name, func, is_udt=is_udt) - setattr(module, funcname, unary_op) - else: - unary_op = cls._build(name, func, is_udt=is_udt) - setattr(module, funcname, unary_op) - # Also save it to `graphblas.op` if not yet defined - opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) - if not _hasop(opmodule, funcname): - if lazy: - opmodule._delayed[funcname] = module - else: - setattr(opmodule, funcname, unary_op) - if not cls._initialized: # pragma: no cover - _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") - if not lazy: - return unary_op - - @classmethod - def _initialize(cls): - if cls._initialized: - return - super()._initialize() - # Update type information with sane coercion - position_dtypes = [ - BOOL, - FP32, - FP64, - INT8, - INT16, - UINT8, - UINT16, - UINT32, - UINT64, - ] - if _supports_complex: - position_dtypes.extend([FC32, FC64]) - for names, *types in [ - # fmt: off - ( - ( - "erf", "erfc", "lgamma", "tgamma", "acos", "acosh", "asin", "asinh", - "atan", "atanh", "ceil", "cos", "cosh", "exp", "exp2", "expm1", "floor", - "log", "log10", "log1p", "log2", "round", "signum", "sin", "sinh", "sqrt", - "tan", "tanh", "trunc", "cbrt", - ), - ((BOOL, INT8, INT16, UINT8, UINT16), FP32), - ((INT32, INT64, UINT32, UINT64), FP64), - ), - ( - ("positioni", "positioni1", "positionj", "positionj1"), - ( - position_dtypes, - INT64, - ), - ), - # fmt: on - ]: - for name in names: - if name in _SS_OPERATORS: - op = unary._deprecated[name] - else: - op = getattr(unary, name) - for input_types, target_type in types: - typed_op = op._typed_ops[target_type] - output_type = op.types[target_type] - for dtype in input_types: - if dtype not in op.types: # pragma: no branch (safety) - op.types[dtype] = output_type - op._typed_ops[dtype] = typed_op - op.coercions[dtype] = target_type - # Allow some functions to work on UDTs - for unop, func in [ - (unary.identity, _identity), - (unary.one, _one), - ]: - unop.orig_func = func - unop._numba_func = numba.njit(func) - unop._udt_types = {} - unop._udt_ops = {} - cls._initialized = True - - def __init__( - self, - name, - func=None, - *, - anonymous=False, - is_positional=False, - is_udt=False, - numba_func=None, - ): - super().__init__(name, anonymous=anonymous) - self.orig_func = func - self._numba_func = numba_func - self.is_positional = is_positional - self._is_udt = is_udt - if is_udt: - self._udt_types = {} # {dtype: DataType} - self._udt_ops = {} # {dtype: TypedUserUnaryOp} - - def __reduce__(self): - if self._anonymous: - if hasattr(self.orig_func, "_parameterized_info"): - return (_deserialize_parameterized, self.orig_func._parameterized_info) - return (self.register_anonymous, (self.orig_func, self.name)) - if (name := f"unary.{self.name}") in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.orig_func)) - - __call__ = TypedBuiltinUnaryOp.__call__ - - -class IndexUnaryOp(OpBase): - """Takes one input and a thunk and returns one output, possibly of a different data type. - Along with the input value, the index(es) of the element are given to the function. - - This is an advanced form of a unary operation that allows, for example, converting - elements of a Vector to their index position to build a ramp structure. Another use - case is returning a boolean value indicating whether the element is part of the upper - triangular structure of a Matrix. - - Built-in and registered IndexUnaryOps are located in the ``graphblas.indexunary`` namespace. - """ - - __slots__ = "orig_func", "is_positional", "_is_udt", "_numba_func" - _module = indexunary - _modname = "indexunary" - _custom_dtype = None - _typed_class = TypedBuiltinIndexUnaryOp - _typed_user_class = TypedUserIndexUnaryOp - _parse_config = { - "trim_from_front": 4, - "num_underscores": 1, - "re_exprs": [ - re.compile("^GrB_(ROWINDEX|COLINDEX|DIAGINDEX)_(INT32|INT64)$"), - ], - "re_exprs_return_bool": [ - re.compile("^GrB_(TRIL|TRIU|DIAG|OFFDIAG|COLLE|COLGT|ROWLE|ROWGT)$"), - re.compile( - "^GrB_(VALUEEQ|VALUENE|VALUEGT|VALUEGE|VALUELT|VALUELE)" - "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile("^GxB_(VALUEEQ|VALUENE)_(FC32|FC64)$"), - ], - } - _positional = {"tril", "triu", "diag", "offdiag", "colle", "colgt", "rowle", "rowgt", - "rowindex", "colindex"} # fmt: skip - - @classmethod - def _build(cls, name, func, *, is_udt=False, anonymous=False): - if not isinstance(func, FunctionType): - raise TypeError(f"UDF argument must be a function, not {type(func)}") - if name is None: - name = getattr(func, "__name__", "") - success = False - indexunary_udf = numba.njit(func) - new_type_obj = cls( - name, func, anonymous=anonymous, is_udt=is_udt, numba_func=indexunary_udf - ) - return_types = {} - nt = numba.types - if not is_udt: - for type_ in _sample_values: - sig = (type_.numba_type, UINT64.numba_type, UINT64.numba_type, type_.numba_type) - try: - indexunary_udf.compile(sig) - except numba.TypingError: - continue - ret_type = lookup_dtype(indexunary_udf.overloads[sig].signature.return_type) - if ret_type != type_ and ( - ("INT" in ret_type.name and "INT" in type_.name) - or ("FP" in ret_type.name and "FP" in type_.name) - or ("FC" in ret_type.name and "FC" in type_.name) - or (type_ == UINT64 and ret_type == FP64 and return_types.get(INT64) == INT64) - ): - # Downcast `ret_type` to `type_`. - # This is what users want most of the time, but we can't make a perfect rule. - # There should be a way for users to be explicit. - ret_type = type_ - elif type_ == BOOL and ret_type == INT64 and return_types.get(INT8) == INT8: - ret_type = INT8 - - # Numba is unable to handle BOOL correctly right now, but we have a workaround - # See: https://github.com/numba/numba/issues/5395 - # We're relying on coercion behaving correctly here - input_type = INT8 if type_ == BOOL else type_ - return_type = INT8 if ret_type == BOOL else ret_type - - # Build wrapper because GraphBLAS wants pointers and void return - wrapper_sig = nt.void( - nt.CPointer(return_type.numba_type), - nt.CPointer(input_type.numba_type), - UINT64.numba_type, - UINT64.numba_type, - nt.CPointer(input_type.numba_type), - ) - - if type_ == BOOL: - if ret_type == BOOL: - - def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) - z[0] = bool(indexunary_udf(bool(x[0]), row, col, bool(y[0]))) - - else: - - def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) - z[0] = indexunary_udf(bool(x[0]), row, col, bool(y[0])) - - elif ret_type == BOOL: - - def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) - z[0] = bool(indexunary_udf(x[0], row, col, y[0])) - - else: - - def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) - z[0] = indexunary_udf(x[0], row, col, y[0]) - - indexunary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(indexunary_wrapper) - new_indexunary = ffi_new("GrB_IndexUnaryOp*") - check_status_carg( - lib.GrB_IndexUnaryOp_new( - new_indexunary, - indexunary_wrapper.cffi, - ret_type.gb_obj, - type_.gb_obj, - type_.gb_obj, - ), - "IndexUnaryOp", - new_indexunary, - ) - op = cls._typed_user_class(new_type_obj, name, type_, ret_type, new_indexunary[0]) - new_type_obj._add(op) - success = True - return_types[type_] = ret_type - if success or is_udt: - return new_type_obj - raise UdfParseError("Unable to parse function using Numba") - - def _compile_udt(self, dtype, dtype2): - if dtype2 is None: # pragma: no cover - dtype2 = dtype - dtypes = (dtype, dtype2) - if dtypes in self._udt_types: - return self._udt_ops[dtypes] - - numba_func = self._numba_func - sig = (dtype.numba_type, UINT64.numba_type, UINT64.numba_type, dtype2.numba_type) - numba_func.compile(sig) # Should we catch and give additional error message? - ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type) - indexunary_wrapper, wrapper_sig = _get_udt_wrapper( - numba_func, ret_type, dtype, dtype2, include_indexes=True - ) - - indexunary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(indexunary_wrapper) - new_indexunary = ffi_new("GrB_IndexUnaryOp*") - check_status_carg( - lib.GrB_IndexUnaryOp_new( - new_indexunary, indexunary_wrapper.cffi, ret_type._carg, dtype._carg, dtype2._carg - ), - "IndexUnaryOp", - new_indexunary, - ) - op = TypedUserIndexUnaryOp( - self, - self.name, - dtype, - ret_type, - new_indexunary[0], - dtype2=dtype2, - ) - self._udt_types[dtypes] = ret_type - self._udt_ops[dtypes] = op - return op - - @classmethod - def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): - """Register an IndexUnaryOp without registering it in the - ``graphblas.indexunary`` namespace. - - Because it is not registered in the namespace, the name is optional. - """ - if parameterized: - return ParameterizedIndexUnaryOp(name, func, anonymous=True, is_udt=is_udt) - return cls._build(name, func, anonymous=True, is_udt=is_udt) - - @classmethod - def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): - """Register an IndexUnaryOp. The name will be used to identify the IndexUnaryOp in the - ``graphblas.indexunary`` namespace. - - If the return type is Boolean, the function will also be registered as a SelectOp - with the same name. - - >>> gb.indexunary.register_new("row_mod", lambda x, i, j, thunk: i % max(thunk, 2)) - >>> dir(gb.indexunary) - [..., 'row_mod', ...] - """ - module, funcname = cls._remove_nesting(name) - if lazy: - module._delayed[funcname] = ( - cls.register_new, - {"name": name, "func": func, "parameterized": parameterized}, - ) - elif parameterized: - indexunary_op = ParameterizedIndexUnaryOp(name, func, is_udt=is_udt) - setattr(module, funcname, indexunary_op) - else: - indexunary_op = cls._build(name, func, is_udt=is_udt) - setattr(module, funcname, indexunary_op) - # If return type is BOOL, register additionally as a SelectOp - if all(x == BOOL for x in indexunary_op.types.values()): - setattr(select, funcname, SelectOp._from_indexunary(indexunary_op)) - - if not cls._initialized: - _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") - if not lazy: - return indexunary_op - - @classmethod - def _initialize(cls): - if cls._initialized: - return - super()._initialize(include_in_ops=False) - # Update type information to include UINT64 for positional ops - for name in ["tril", "triu", "diag", "offdiag", "colle", "colgt", "rowle", "rowgt"]: - op = getattr(indexunary, name) - typed_op = op._typed_ops[BOOL] - output_type = op.types[BOOL] - if UINT64 not in op.types: # pragma: no branch (safety) - op.types[UINT64] = output_type - op._typed_ops[UINT64] = typed_op - op.coercions[UINT64] = BOOL - for name in ["rowindex", "colindex"]: - op = getattr(indexunary, name) - typed_op = op._typed_ops[INT64] - output_type = op.types[INT64] - if UINT64 not in op.types: # pragma: no branch (safety) - op.types[UINT64] = output_type - op._typed_ops[UINT64] = typed_op - op.coercions[UINT64] = INT64 - # Add index->row alias to make it more intuitive which to use for vectors - indexunary.indexle = indexunary.rowle - indexunary.indexgt = indexunary.rowgt - indexunary.index = indexunary.rowindex - # fmt: off - # Add SelectOp when it makes sense - for name in ["tril", "triu", "diag", "offdiag", - "colle", "colgt", "rowle", "rowgt", "indexle", "indexgt", - "valueeq", "valuene", "valuegt", "valuege", "valuelt", "valuele"]: - iop = getattr(indexunary, name) - setattr(select, name, SelectOp._from_indexunary(iop)) - # fmt: on - cls._initialized = True - - def __init__( - self, - name, - func=None, - *, - anonymous=False, - is_positional=False, - is_udt=False, - numba_func=None, - ): - super().__init__(name, anonymous=anonymous) - self.orig_func = func - self._numba_func = numba_func - self.is_positional = is_positional - self._is_udt = is_udt - if is_udt: - self._udt_types = {} # {dtype: DataType} - self._udt_ops = {} # {dtype: TypedUserIndexUnaryOp} - - def __reduce__(self): - if self._anonymous: - if hasattr(self.orig_func, "_parameterized_info"): - return (_deserialize_parameterized, self.orig_func._parameterized_info) - return (self.register_anonymous, (self.orig_func, self.name)) - if (name := f"indexunary.{self.name}") in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.orig_func)) - - __call__ = TypedBuiltinIndexUnaryOp.__call__ - - -class SelectOp(OpBase): - """Identical to an :class:`IndexUnaryOp `, - but must have a Boolean return type. - - A SelectOp is used exclusively to select a subset of values from a collection where - the function returns True. - - Built-in and registered SelectOps are located in the ``graphblas.select`` namespace. - """ - - __slots__ = "orig_func", "is_positional", "_is_udt", "_numba_func" - _module = select - _modname = "select" - _custom_dtype = None - _typed_class = TypedBuiltinSelectOp - _typed_user_class = TypedUserSelectOp - - @classmethod - def _from_indexunary(cls, iop): - obj = cls( - iop.name, - iop.orig_func, - anonymous=iop._anonymous, - is_positional=iop.is_positional, - is_udt=iop._is_udt, - numba_func=iop._numba_func, - ) - if not all(x == BOOL for x in iop.types.values()): - raise ValueError("SelectOp must have BOOL return type") - for type_, t in iop._typed_ops.items(): - if iop.orig_func is not None: - op = cls._typed_user_class( - obj, - iop.name, - t.type, - t.return_type, - t.gb_obj, - ) - else: - op = cls._typed_class( - obj, - iop.name, - t.type, - t.return_type, - t.gb_obj, - t.gb_name, - ) - # type is not always equal to t.type, so can't use op._add - # but otherwise perform the same logic - obj._typed_ops[type_] = op - obj.types[type_] = op.return_type - return obj - - @classmethod - def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): - """Register a SelectOp without registering it in the ``graphblas.select`` namespace. - - Because it is not registered in the namespace, the name is optional. - """ - if parameterized: - return ParameterizedSelectOp(name, func, anonymous=True, is_udt=is_udt) - iop = IndexUnaryOp._build(name, func, anonymous=True, is_udt=is_udt) - return SelectOp._from_indexunary(iop) - - @classmethod - def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): - """Register a SelectOp. The name will be used to identify the SelectOp in the - ``graphblas.select`` namespace. - - The function will also be registered as a IndexUnaryOp with the same name. - - >>> gb.select.register_new("upper_left_triangle", lambda x, i, j, thunk: i + j <= thunk) - >>> dir(gb.select) - [..., 'upper_left_triangle', ...] - """ - iop = IndexUnaryOp.register_new( - name, func, parameterized=parameterized, is_udt=is_udt, lazy=lazy - ) - if not all(x == BOOL for x in iop.types.values()): - raise ValueError("SelectOp must have BOOL return type") - if lazy: - return getattr(select, iop.name) - - @classmethod - def _initialize(cls): - if cls._initialized: # pragma: no cover (safety) - return - # IndexUnaryOp adds it boolean-returning objects to SelectOp - IndexUnaryOp._initialize() - cls._initialized = True - - def __init__( - self, - name, - func=None, - *, - anonymous=False, - is_positional=False, - is_udt=False, - numba_func=None, - ): - super().__init__(name, anonymous=anonymous) - self.orig_func = func - self._numba_func = numba_func - self.is_positional = is_positional - self._is_udt = is_udt - if is_udt: - self._udt_types = {} # {dtype: DataType} - self._udt_ops = {} # {dtype: TypedUserIndexUnaryOp} - - def __reduce__(self): - if self._anonymous: - if hasattr(self.orig_func, "_parameterized_info"): - return (_deserialize_parameterized, self.orig_func._parameterized_info) - return (self.register_anonymous, (self.orig_func, self.name)) - if (name := f"select.{self.name}") in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.orig_func)) - - __call__ = TypedBuiltinSelectOp.__call__ - - -def _floordiv(x, y): - return x // y # pragma: no cover (numba) - - -def _rfloordiv(x, y): - return y // x # pragma: no cover (numba) - - -def _absfirst(x, y): - return np.abs(x) # pragma: no cover (numba) - - -def _abssecond(x, y): - return np.abs(y) # pragma: no cover (numba) - - -def _rpow(x, y): - return y**x # pragma: no cover (numba) - - -def _isclose(rel_tol=1e-7, abs_tol=0.0): - def inner(x, y): # pragma: no cover (numba) - return x == y or abs(x - y) <= max(rel_tol * max(abs(x), abs(y)), abs_tol) - - return inner - - -_MAX_INT64 = np.iinfo(np.int64).max - - -def _binom(N, k): # pragma: no cover (numba) - # Returns 0 if overflow or out-of-bounds - if k > N or k < 0: - return 0 - val = np.int64(1) - for i in range(min(k, N - k)): - if val > _MAX_INT64 // (N - i): # Overflow - return 0 - val *= N - i - val //= i + 1 - return val - - -# Kinda complicated, but works for now -def _register_binom(): - # "Fake" UDT so we only compile once for INT64 - op = BinaryOp.register_new("binom", _binom, is_udt=True) - typed_op = op[INT64, INT64] - # Make this look like a normal operator - for dtype in [UINT8, UINT16, UINT32, UINT64, INT8, INT16, INT32, INT64]: - op.types[dtype] = INT64 - op._typed_ops[dtype] = typed_op - if dtype != INT64: - op.coercions[dtype] = typed_op - # And make it not look like it operates on UDTs - typed_op._type2 = None - op._is_udt = False - op._udt_types = None - op._udt_ops = None - return op - - -def _first(x, y): - return x # pragma: no cover (numba) - - -def _second(x, y): - return y # pragma: no cover (numba) - - -def _pair(x, y): - return 1 # pragma: no cover (numba) - - -def _first_dtype(op, dtype, dtype2): - if dtype._is_udt or dtype2._is_udt: - return op._compile_udt(dtype, dtype2) - - -def _second_dtype(op, dtype, dtype2): - if dtype._is_udt or dtype2._is_udt: - return op._compile_udt(dtype, dtype2) - - -def _pair_dtype(op, dtype, dtype2): - return op[INT64] - - -def _get_udt_wrapper(numba_func, return_type, dtype, dtype2=None, *, include_indexes=False): - ztype = INT8 if return_type == BOOL else return_type - xtype = INT8 if dtype == BOOL else dtype - nt = numba.types - wrapper_args = [nt.CPointer(ztype.numba_type), nt.CPointer(xtype.numba_type)] - if include_indexes: - wrapper_args.extend([UINT64.numba_type, UINT64.numba_type]) - if dtype2 is not None: - ytype = INT8 if dtype2 == BOOL else dtype2 - wrapper_args.append(nt.CPointer(ytype.numba_type)) - wrapper_sig = nt.void(*wrapper_args) - - zarray = xarray = yarray = BL = BR = yarg = yname = rcidx = "" - if return_type._is_udt: - if return_type.np_type.subdtype is None: - zarray = " z = numba.carray(z_ptr, 1)\n" - zname = "z[0]" - else: - zname = "z_ptr[0]" - BR = "[0]" - else: - zname = "z_ptr[0]" - if return_type == BOOL: - BL = "bool(" - BR = ")" - - if dtype._is_udt: - if dtype.np_type.subdtype is None: - xarray = " x = numba.carray(x_ptr, 1)\n" - xname = "x[0]" - else: - xname = "x_ptr" - elif dtype == BOOL: - xname = "bool(x_ptr[0])" - else: - xname = "x_ptr[0]" - - if dtype2 is not None: - yarg = ", y_ptr" - if dtype2._is_udt: - if dtype2.np_type.subdtype is None: - yarray = " y = numba.carray(y_ptr, 1)\n" - yname = ", y[0]" - else: - yname = ", y_ptr" - elif dtype2 == BOOL: - yname = ", bool(y_ptr[0])" - else: - yname = ", y_ptr[0]" - - if include_indexes: - rcidx = ", row, col" - - d = {"numba": numba, "numba_func": numba_func} - text = ( - f"def wrapper(z_ptr, x_ptr{rcidx}{yarg}):\n" - f"{zarray}{xarray}{yarray}" - f" {zname} = {BL}numba_func({xname}{rcidx}{yname}){BR}\n" - ) - exec(text, d) # pylint: disable=exec-used - return d["wrapper"], wrapper_sig - - -class BinaryOp(OpBase): - """Takes two inputs and returns one output, possibly of a different data type. - - Built-in and registered BinaryOps are located in the ``graphblas.binary`` namespace - as well as in the ``graphblas.ops`` combined namespace. - """ - - __slots__ = ( - "_monoid", - "_commutes_to", - "_semiring_commutes_to", - "orig_func", - "is_positional", - "_is_udt", - "_numba_func", - "_custom_dtype", - ) - _module = binary - _modname = "binary" - _typed_class = TypedBuiltinBinaryOp - _parse_config = { - "trim_from_front": 4, - "num_underscores": 1, - "re_exprs": [ - re.compile( - "^GrB_(FIRST|SECOND|PLUS|MINUS|TIMES|DIV|MIN|MAX)" - "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64|FC32|FC64)$" - ), - re.compile( - "GrB_(BOR|BAND|BXOR|BXNOR)_(INT8|INT16|INT32|INT64|UINT8|UINT16|UINT32|UINT64)$" - ), - re.compile( - "^GxB_(POW|RMINUS|RDIV|PAIR|ANY|ISEQ|ISNE|ISGT|ISLT|ISGE|ISLE|LOR|LAND|LXOR)" - "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64|FC32|FC64)$" - ), - re.compile("^GxB_(FIRST|SECOND|PLUS|MINUS|TIMES|DIV)_(FC32|FC64)$"), - re.compile("^GxB_(ATAN2|HYPOT|FMOD|REMAINDER|LDEXP|COPYSIGN)_(FP32|FP64)$"), - re.compile( - "GxB_(BGET|BSET|BCLR|BSHIFT|FIRSTI1|FIRSTI|FIRSTJ1|FIRSTJ" - "|SECONDI1|SECONDI|SECONDJ1|SECONDJ)" - "_(INT8|INT16|INT32|INT64|UINT8|UINT16|UINT32|UINT64)$" - ), - # These are coerced to 0 or 1, but don't return BOOL - re.compile( - "^GxB_(LOR|LAND|LXOR|LXNOR)_" - "(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - ], - "re_exprs_return_bool": [ - re.compile("^GrB_(LOR|LAND|LXOR|LXNOR)$"), - re.compile( - "^GrB_(EQ|NE|GT|LT|GE|LE)_" - "(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile("^GxB_(EQ|NE)_(FC32|FC64)$"), - ], - "re_exprs_return_complex": [re.compile("^GxB_(CMPLX)_(FP32|FP64)$")], - } - _commutes = { - # builtins - "cdiv": "rdiv", - "first": "second", - "ge": "le", - "gt": "lt", - "isge": "isle", - "isgt": "islt", - "minus": "rminus", - "pow": "rpow", - # special - "firsti": "secondi", - "firsti1": "secondi1", - "firstj": "secondj", - "firstj1": "secondj1", - # custom - # "absfirst": "abssecond", # handled in graphblas.binary - # "floordiv": "rfloordiv", - "truediv": "rtruediv", - } - _commutes_to_in_semiring = { - "firsti": "secondj", - "firsti1": "secondj1", - "firstj": "secondi", - "firstj1": "secondi1", - } - _commutative = { - # monoids - "any", - "band", - "bor", - "bxnor", - "bxor", - "eq", - "land", - "lor", - "lxnor", - "lxor", - "max", - "min", - "plus", - "times", - # other - "hypot", - "isclose", - "iseq", - "isne", - "ne", - "pair", - } - # Don't commute: atan2, bclr, bget, bset, bshift, cmplx, copysign, fmod, ldexp, remainder - _positional = { - "firsti", - "firsti1", - "firstj", - "firstj1", - "secondi", - "secondi1", - "secondj", - "secondj1", - } - - @classmethod - def _build(cls, name, func, *, is_udt=False, anonymous=False): - if not isinstance(func, FunctionType): - raise TypeError(f"UDF argument must be a function, not {type(func)}") - if name is None: - name = getattr(func, "__name__", "") - success = False - binary_udf = numba.njit(func) - new_type_obj = cls(name, func, anonymous=anonymous, is_udt=is_udt, numba_func=binary_udf) - return_types = {} - nt = numba.types - if not is_udt: - for type_ in _sample_values: - sig = (type_.numba_type, type_.numba_type) - try: - binary_udf.compile(sig) - except numba.TypingError: - continue - ret_type = lookup_dtype(binary_udf.overloads[sig].signature.return_type) - if ret_type != type_ and ( - ("INT" in ret_type.name and "INT" in type_.name) - or ("FP" in ret_type.name and "FP" in type_.name) - or ("FC" in ret_type.name and "FC" in type_.name) - or (type_ == UINT64 and ret_type == FP64 and return_types.get(INT64) == INT64) - ): - # Downcast `ret_type` to `type_`. - # This is what users want most of the time, but we can't make a perfect rule. - # There should be a way for users to be explicit. - ret_type = type_ - elif type_ == BOOL and ret_type == INT64 and return_types.get(INT8) == INT8: - ret_type = INT8 - - # Numba is unable to handle BOOL correctly right now, but we have a workaround - # See: https://github.com/numba/numba/issues/5395 - # We're relying on coercion behaving correctly here - input_type = INT8 if type_ == BOOL else type_ - return_type = INT8 if ret_type == BOOL else ret_type - - # Build wrapper because GraphBLAS wants pointers and void return - wrapper_sig = nt.void( - nt.CPointer(return_type.numba_type), - nt.CPointer(input_type.numba_type), - nt.CPointer(input_type.numba_type), - ) - - if type_ == BOOL: - if ret_type == BOOL: - - def binary_wrapper(z, x, y): # pragma: no cover (numba) - z[0] = bool(binary_udf(bool(x[0]), bool(y[0]))) - - else: - - def binary_wrapper(z, x, y): # pragma: no cover (numba) - z[0] = binary_udf(bool(x[0]), bool(y[0])) - - elif ret_type == BOOL: - - def binary_wrapper(z, x, y): # pragma: no cover (numba) - z[0] = bool(binary_udf(x[0], y[0])) - - else: - - def binary_wrapper(z, x, y): # pragma: no cover (numba) - z[0] = binary_udf(x[0], y[0]) - - binary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(binary_wrapper) - new_binary = ffi_new("GrB_BinaryOp*") - check_status_carg( - lib.GrB_BinaryOp_new( - new_binary, - binary_wrapper.cffi, - ret_type.gb_obj, - type_.gb_obj, - type_.gb_obj, - ), - "BinaryOp", - new_binary, - ) - op = TypedUserBinaryOp(new_type_obj, name, type_, ret_type, new_binary[0]) - new_type_obj._add(op) - success = True - return_types[type_] = ret_type - if success or is_udt: - return new_type_obj - raise UdfParseError("Unable to parse function using Numba") - - def _compile_udt(self, dtype, dtype2): - if dtype2 is None: - dtype2 = dtype - dtypes = (dtype, dtype2) - if dtypes in self._udt_types: - return self._udt_ops[dtypes] - - nt = numba.types - if self.name == "eq" and not self._anonymous: - # assert dtype.np_type == dtype2.np_type - itemsize = dtype.np_type.itemsize - mask = _udt_mask(dtype.np_type) - ret_type = BOOL - wrapper_sig = nt.void( - nt.CPointer(INT8.numba_type), - nt.CPointer(UINT8.numba_type), - nt.CPointer(UINT8.numba_type), - ) - # PERF: we can probably make this faster - if mask.all(): - - def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) - x = numba.carray(x_ptr, itemsize) - y = numba.carray(y_ptr, itemsize) - # for i in range(itemsize): - # if x[i] != y[i]: - # z_ptr[0] = False - # break - # else: - # z_ptr[0] = True - z_ptr[0] = (x == y).all() - - else: - - def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) - x = numba.carray(x_ptr, itemsize) - y = numba.carray(y_ptr, itemsize) - # for i in range(itemsize): - # if mask[i] and x[i] != y[i]: - # z_ptr[0] = False - # break - # else: - # z_ptr[0] = True - z_ptr[0] = (x[mask] == y[mask]).all() - - elif self.name == "ne" and not self._anonymous: - # assert dtype.np_type == dtype2.np_type - itemsize = dtype.np_type.itemsize - mask = _udt_mask(dtype.np_type) - ret_type = BOOL - wrapper_sig = nt.void( - nt.CPointer(INT8.numba_type), - nt.CPointer(UINT8.numba_type), - nt.CPointer(UINT8.numba_type), - ) - if mask.all(): - - def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) - x = numba.carray(x_ptr, itemsize) - y = numba.carray(y_ptr, itemsize) - # for i in range(itemsize): - # if x[i] != y[i]: - # z_ptr[0] = True - # break - # else: - # z_ptr[0] = False - z_ptr[0] = (x != y).any() - - else: - - def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) - x = numba.carray(x_ptr, itemsize) - y = numba.carray(y_ptr, itemsize) - # for i in range(itemsize): - # if mask[i] and x[i] != y[i]: - # z_ptr[0] = True - # break - # else: - # z_ptr[0] = False - z_ptr[0] = (x[mask] != y[mask]).any() - - else: - numba_func = self._numba_func - sig = (dtype.numba_type, dtype2.numba_type) - numba_func.compile(sig) # Should we catch and give additional error message? - ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type) - binary_wrapper, wrapper_sig = _get_udt_wrapper(numba_func, ret_type, dtype, dtype2) - - binary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(binary_wrapper) - new_binary = ffi_new("GrB_BinaryOp*") - check_status_carg( - lib.GrB_BinaryOp_new( - new_binary, binary_wrapper.cffi, ret_type._carg, dtype._carg, dtype2._carg - ), - "BinaryOp", - new_binary, - ) - op = TypedUserBinaryOp( - self, - self.name, - dtype, - ret_type, - new_binary[0], - dtype2=dtype2, - ) - self._udt_types[dtypes] = ret_type - self._udt_ops[dtypes] = op - return op - - @classmethod - def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): - """Register a BinaryOp without registering it in the ``graphblas.binary`` namespace. - - Because it is not registered in the namespace, the name is optional. - """ - if parameterized: - return ParameterizedBinaryOp(name, func, anonymous=True, is_udt=is_udt) - return cls._build(name, func, anonymous=True, is_udt=is_udt) - - @classmethod - def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): - """Register a BinaryOp. The name will be used to identify the BinaryOp in the - ``graphblas.binary`` namespace. - - >>> def max_zero(x, y): - r = 0 - if x > r: - r = x - if y > r: - r = y - return r - >>> gb.core.operator.BinaryOp.register_new("max_zero", max_zero) - >>> dir(gb.binary) - [..., 'max_zero', ...] - """ - module, funcname = cls._remove_nesting(name) - if lazy: - module._delayed[funcname] = ( - cls.register_new, - {"name": name, "func": func, "parameterized": parameterized}, - ) - elif parameterized: - binary_op = ParameterizedBinaryOp(name, func, is_udt=is_udt) - setattr(module, funcname, binary_op) - else: - binary_op = cls._build(name, func, is_udt=is_udt) - setattr(module, funcname, binary_op) - # Also save it to `graphblas.op` if not yet defined - opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) - if not _hasop(opmodule, funcname): - if lazy: - opmodule._delayed[funcname] = module - else: - setattr(opmodule, funcname, binary_op) - if not cls._initialized: - _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") - if not lazy: - return binary_op - - @classmethod - def _initialize(cls): - if cls._initialized: # pragma: no cover (safety) - return - super()._initialize() - # Rename div to cdiv - cdiv = binary.cdiv = op.cdiv = BinaryOp("cdiv") - for dtype, ret_type in binary.div.types.items(): - orig_op = binary.div[dtype] - cur_op = TypedBuiltinBinaryOp( - cdiv, "cdiv", dtype, ret_type, orig_op.gb_obj, orig_op.gb_name - ) - cdiv._add(cur_op) - del binary.div - del op.div - # Add truediv which always points to floating point cdiv - # We are effectively hacking cdiv to always return floating point values - # If the inputs are FP32, we use DIV_FP32; use DIV_FP64 for all other input dtypes - truediv = binary.truediv = op.truediv = BinaryOp("truediv") - rtruediv = binary.rtruediv = op.rtruediv = BinaryOp("rtruediv") - for new_op, builtin_op in [(truediv, binary.cdiv), (rtruediv, binary.rdiv)]: - for dtype in builtin_op.types: - if dtype.name in {"FP32", "FC32", "FC64"}: - orig_dtype = dtype - else: - orig_dtype = FP64 - orig_op = builtin_op[orig_dtype] - cur_op = TypedBuiltinBinaryOp( - new_op, - new_op.name, - dtype, - builtin_op.types[orig_dtype], - orig_op.gb_obj, - orig_op.gb_name, - ) - new_op._add(cur_op) - # Add floordiv - # cdiv truncates towards 0, while floordiv truncates towards -inf - BinaryOp.register_new("floordiv", _floordiv, lazy=True) # cast to integer - BinaryOp.register_new("rfloordiv", _rfloordiv, lazy=True) # cast to integer - - # For aggregators - BinaryOp.register_new("absfirst", _absfirst, lazy=True) - BinaryOp.register_new("abssecond", _abssecond, lazy=True) - BinaryOp.register_new("rpow", _rpow, lazy=True) - - # For algorithms - binary._delayed["binom"] = (_register_binom, {}) # Lazy with custom creation - op._delayed["binom"] = binary - - BinaryOp.register_new("isclose", _isclose, parameterized=True) - - # Update type information with sane coercion - position_dtypes = [ - BOOL, - FP32, - FP64, - INT8, - INT16, - UINT8, - UINT16, - UINT32, - UINT64, - ] - if _supports_complex: - position_dtypes.extend([FC32, FC64]) - name_types = [ - # fmt: off - ( - ("atan2", "copysign", "fmod", "hypot", "ldexp", "remainder"), - ((BOOL, INT8, INT16, UINT8, UINT16), FP32), - ((INT32, INT64, UINT32, UINT64), FP64), - ), - ( - ( - "firsti", "firsti1", "firstj", "firstj1", "secondi", "secondi1", - "secondj", "secondj1"), - ( - position_dtypes, - INT64, - ), - ), - ( - ["lxnor"], - ( - ( - FP32, FP64, INT8, INT16, INT32, INT64, - UINT8, UINT16, UINT32, UINT64, - ), - BOOL, - ), - ), - # fmt: on - ] - if _supports_complex: - name_types.append( - ( - ["cmplx"], - ((BOOL, INT8, INT16, UINT8, UINT16), FP32), - ((INT32, INT64, UINT32, UINT64), FP64), - ) - ) - for names, *types in name_types: - for name in names: - if name in _SS_OPERATORS: - cur_op = binary._deprecated[name] - else: - cur_op = getattr(binary, name) - for input_types, target_type in types: - typed_op = cur_op._typed_ops[target_type] - output_type = cur_op.types[target_type] - for dtype in input_types: - if dtype not in cur_op.types: # pragma: no branch (safety) - cur_op.types[dtype] = output_type - cur_op._typed_ops[dtype] = typed_op - cur_op.coercions[dtype] = target_type - # Not valid input dtypes - del binary.ldexp[FP32] - del binary.ldexp[FP64] - # Fill in commutes info - for left_name, right_name in cls._commutes.items(): - if left_name in _SS_OPERATORS: - left = binary._deprecated[left_name] - else: - left = getattr(binary, left_name) - if backend == "suitesparse" and right_name in _SS_OPERATORS: - left._commutes_to = f"ss.{right_name}" - else: - left._commutes_to = right_name - if right_name not in binary._delayed: - if right_name in _SS_OPERATORS: - right = binary._deprecated[right_name] - else: - right = getattr(binary, right_name) - if backend == "suitesparse" and left_name in _SS_OPERATORS: - right._commutes_to = f"ss.{left_name}" - else: - right._commutes_to = left_name - for name in cls._commutative: - cur_op = getattr(binary, name) - cur_op._commutes_to = name - for left_name, right_name in cls._commutes_to_in_semiring.items(): - if left_name in _SS_OPERATORS: - left = binary._deprecated[left_name] - else: # pragma: no cover (safety) - left = getattr(binary, left_name) - if right_name in _SS_OPERATORS: - right = binary._deprecated[right_name] - else: # pragma: no cover (safety) - right = getattr(binary, right_name) - left._semiring_commutes_to = right - right._semiring_commutes_to = left - # Allow some functions to work on UDTs - for binop, func in [ - (binary.first, _first), - (binary.second, _second), - (binary.pair, _pair), - (binary.any, _first), - ]: - binop.orig_func = func - binop._numba_func = numba.njit(func) - binop._udt_types = {} - binop._udt_ops = {} - binary.any._numba_func = binary.first._numba_func - binary.eq._udt_types = {} - binary.eq._udt_ops = {} - binary.ne._udt_types = {} - binary.ne._udt_ops = {} - # Set custom dtype handling - binary.first._custom_dtype = _first_dtype - binary.second._custom_dtype = _second_dtype - binary.pair._custom_dtype = _pair_dtype - cls._initialized = True - - def __init__( - self, - name, - func=None, - *, - anonymous=False, - is_positional=False, - is_udt=False, - numba_func=None, - ): - super().__init__(name, anonymous=anonymous) - self._monoid = None - self._commutes_to = None - self._semiring_commutes_to = None - self.orig_func = func - self._numba_func = numba_func - self._is_udt = is_udt - self.is_positional = is_positional - self._custom_dtype = None - if is_udt: - self._udt_types = {} # {(dtype, dtype): DataType} - self._udt_ops = {} # {(dtype, dtype): TypedUserBinaryOp} - - def __reduce__(self): - if self._anonymous: - if hasattr(self.orig_func, "_parameterized_info"): - return (_deserialize_parameterized, self.orig_func._parameterized_info) - return (self.register_anonymous, (self.orig_func, self.name)) - if (name := f"binary.{self.name}") in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self.orig_func)) - - __call__ = TypedBuiltinBinaryOp.__call__ - is_commutative = TypedBuiltinBinaryOp.is_commutative - commutes_to = ParameterizedBinaryOp.commutes_to - - @property - def monoid(self): - if self._monoid is None and not self._anonymous: - self._monoid = Monoid._find(self.name) - return self._monoid - - -class Monoid(OpBase): - """Takes two inputs and returns one output, all of the same data type. - - Built-in and registered Monoids are located in the ``graphblas.monoid`` namespace - as well as in the ``graphblas.ops`` combined namespace. - """ - - __slots__ = "_binaryop", "_identity", "_is_idempotent" - is_commutative = True - is_positional = False - _custom_dtype = None - _module = monoid - _modname = "monoid" - _typed_class = TypedBuiltinMonoid - _parse_config = { - "trim_from_front": 4, - "delete_exact": "MONOID", - "num_underscores": 1, - "re_exprs": [ - re.compile( - "^GrB_(MIN|MAX|PLUS|TIMES|LOR|LAND|LXOR|LXNOR)_MONOID" - "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile( - "^GxB_(ANY)_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)_MONOID$" - ), - re.compile("^GxB_(PLUS|TIMES|ANY)_(FC32|FC64)_MONOID$"), - re.compile("^GxB_(EQ|ANY)_BOOL_MONOID$"), - re.compile("^GxB_(BOR|BAND|BXOR|BXNOR)_(UINT8|UINT16|UINT32|UINT64)_MONOID$"), - ], - } - - @classmethod - def _build(cls, name, binaryop, identity, *, is_idempotent=False, anonymous=False): - if type(binaryop) is not BinaryOp: - raise TypeError(f"binaryop must be a BinaryOp, not {type(binaryop)}") - if name is None: - name = binaryop.name - new_type_obj = cls( - name, binaryop, identity, is_idempotent=is_idempotent, anonymous=anonymous - ) - if not binaryop._is_udt: - if not isinstance(identity, Mapping): - identities = dict.fromkeys(binaryop.types, identity) - explicit_identities = False - else: - identities = {lookup_dtype(key): val for key, val in identity.items()} - explicit_identities = True - for type_, ident in identities.items(): - ret_type = binaryop[type_].return_type - # If there is a domain mismatch, then DomainMismatch will be raised - # below if identities were explicitly given. - if type_ != ret_type and not explicit_identities: - continue - new_monoid = ffi_new("GrB_Monoid*") - func = libget(f"GrB_Monoid_new_{type_.name}") - zcast = ffi.cast(type_.c_type, ident) - check_status_carg( - func(new_monoid, binaryop[type_].gb_obj, zcast), "Monoid", new_monoid[0] - ) - op = TypedUserMonoid( - new_type_obj, - name, - type_, - ret_type, - new_monoid[0], - binaryop[type_], - ident, - ) - new_type_obj._add(op) - return new_type_obj - - def _compile_udt(self, dtype, dtype2): - if dtype2 is None: - dtype2 = dtype - elif dtype != dtype2: - raise TypeError( - "Monoid inputs must be the same dtype (got {dtype} and {dtype2}); " - "unable to coerce when using UDTs." - ) - if dtype in self._udt_types: - return self._udt_ops[dtype] - binaryop = self.binaryop._compile_udt(dtype, dtype2) - from .scalar import Scalar - - ret_type = binaryop.return_type - identity = Scalar.from_value(self._identity, dtype=ret_type, is_cscalar=True) - new_monoid = ffi_new("GrB_Monoid*") - status = lib.GrB_Monoid_new_UDT(new_monoid, binaryop.gb_obj, identity.gb_obj) - check_status_carg(status, "Monoid", new_monoid[0]) - op = TypedUserMonoid( - new_monoid, - self.name, - dtype, - ret_type, - new_monoid[0], - binaryop, - identity, - ) - self._udt_types[dtype] = dtype - self._udt_ops[dtype] = op - return op - - @classmethod - def register_anonymous(cls, binaryop, identity, name=None, *, is_idempotent=False): - """Register a Monoid without registering it in the ``graphblas.monoid`` namespace. - - Because it is not registered in the namespace, the name is optional. - - Parameters - ---------- - binaryop : BinaryOp - Builtin or registered binary operator - identity : - Identity value of the monoid - name : str, optional - Name associated with the monoid - is_idempotent : bool, default False - Does ``op(x, x) == x`` for any x? - - Returns - ------- - Function handle - """ - if type(binaryop) is ParameterizedBinaryOp: - return ParameterizedMonoid( - name, binaryop, identity, is_idempotent=is_idempotent, anonymous=True - ) - return cls._build(name, binaryop, identity, is_idempotent=is_idempotent, anonymous=True) - - @classmethod - def register_new(cls, name, binaryop, identity, *, is_idempotent=False, lazy=False): - """Register a Monoid. The name will be used to identify the Monoid in the - ``graphblas.monoid`` namespace. - - >>> gb.core.operator.Monoid.register_new("max_zero", gb.binary.max_zero, 0) - >>> dir(gb.monoid) - [..., 'max_zero', ...] - """ - module, funcname = cls._remove_nesting(name) - if lazy: - module._delayed[funcname] = ( - cls.register_new, - {"name": name, "binaryop": binaryop, "identity": identity}, - ) - elif type(binaryop) is ParameterizedBinaryOp: - monoid = ParameterizedMonoid(name, binaryop, identity, is_idempotent=is_idempotent) - setattr(module, funcname, monoid) - else: - monoid = cls._build(name, binaryop, identity, is_idempotent=is_idempotent) - setattr(module, funcname, monoid) - # Also save it to `graphblas.op` if not yet defined - opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) - if not _hasop(opmodule, funcname): - if lazy: - opmodule._delayed[funcname] = module - else: - setattr(opmodule, funcname, monoid) - if not cls._initialized: # pragma: no cover - _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") - if not lazy: - return monoid - - def __init__(self, name, binaryop=None, identity=None, *, is_idempotent=False, anonymous=False): - super().__init__(name, anonymous=anonymous) - self._binaryop = binaryop - self._identity = identity - self._is_idempotent = is_idempotent - if binaryop is not None: - binaryop._monoid = self - if binaryop._is_udt: - self._udt_types = {} # {dtype: DataType} - self._udt_ops = {} # {dtype: TypedUserMonoid} - - def __reduce__(self): - if self._anonymous: - return (self.register_anonymous, (self._binaryop, self._identity, self.name)) - if (name := f"monoid.{self.name}") in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self._binaryop, self._identity)) - - @property - def binaryop(self): - """The :class:`BinaryOp` associated with the Monoid.""" - if self._binaryop is not None: - return self._binaryop - # Must be builtin - return getattr(binary, self.name) - - @property - def identities(self): - """The per-dtype identity values for the Monoid.""" - return {dtype: val.identity for dtype, val in self._typed_ops.items()} - - @property - def is_idempotent(self): - """True if ``monoid(x, x) == x`` for any x.""" - return self._is_idempotent - - @property - def _is_udt(self): - return self._binaryop is not None and self._binaryop._is_udt - - @classmethod - def _initialize(cls): - if cls._initialized: # pragma: no cover (safety) - return - super()._initialize() - lor = monoid.lor._typed_ops[BOOL] - land = monoid.land._typed_ops[BOOL] - for cur_op, typed_op in [ - (monoid.max, lor), - (monoid.min, land), - # (monoid.plus, lor), # two choices: lor, or plus[int] - (monoid.times, land), - ]: - if BOOL not in cur_op.types: # pragma: no branch (safety) - cur_op.types[BOOL] = BOOL - cur_op.coercions[BOOL] = BOOL - cur_op._typed_ops[BOOL] = typed_op - - for cur_op in [monoid.lor, monoid.land, monoid.lxnor, monoid.lxor]: - bool_op = cur_op._typed_ops[BOOL] - for dtype in [ - FP32, - FP64, - INT8, - INT16, - INT32, - INT64, - UINT8, - UINT16, - UINT32, - UINT64, - ]: - if dtype in cur_op.types: # pragma: no cover (safety) - continue - cur_op.types[dtype] = BOOL - cur_op.coercions[dtype] = BOOL - cur_op._typed_ops[dtype] = bool_op - - # Builtin monoids that are idempotent; i.e., `op(x, x) == x` for any x - for name in ["any", "band", "bor", "land", "lor", "max", "min"]: - getattr(monoid, name)._is_idempotent = True - for name in [ - "bitwise_and", - "bitwise_or", - "fmax", - "fmin", - "gcd", - "logical_and", - "logical_or", - "maximum", - "minimum", - ]: - getattr(monoid.numpy, name)._is_idempotent = True - - # Allow some functions to work on UDTs - any_ = monoid.any - any_._identity = 0 - any_._udt_types = {} - any_._udt_ops = {} - cls._initialized = True - - commutes_to = TypedBuiltinMonoid.commutes_to - __call__ = TypedBuiltinMonoid.__call__ - - -class Semiring(OpBase): - """Combination of a :class:`Monoid` and a :class:`BinaryOp`. - - Semirings are most commonly used for performing matrix multiplication, - with the BinaryOp taking the place of the standard multiplication operator - and the Monoid taking the place of the standard addition operator. - - Built-in and registered Semirings are located in the ``graphblas.semiring`` namespace - as well as in the ``graphblas.ops`` combined namespace. - """ - - __slots__ = "_monoid", "_binaryop" - _module = semiring - _modname = "semiring" - _typed_class = TypedBuiltinSemiring - _parse_config = { - "trim_from_front": 4, - "delete_exact": "SEMIRING", - "num_underscores": 2, - "re_exprs": [ - re.compile( - "^GrB_(PLUS|MIN|MAX)_(PLUS|TIMES|FIRST|SECOND|MIN|MAX)_SEMIRING" - "_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile( - "^GxB_(MIN|MAX|PLUS|TIMES|ANY)" - "_(FIRST|SECOND|PAIR|MIN|MAX|PLUS|MINUS|RMINUS|TIMES" - "|DIV|RDIV|ISEQ|ISNE|ISGT|ISLT|ISGE|ISLE|LOR|LAND|LXOR" - "|FIRSTI1|FIRSTI|FIRSTJ1|FIRSTJ|SECONDI1|SECONDI|SECONDJ1|SECONDJ)" - "_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile( - "^GxB_(PLUS|TIMES|ANY)_(FIRST|SECOND|PAIR|PLUS|MINUS|TIMES|DIV|RDIV|RMINUS)" - "_(FC32|FC64)$" - ), - re.compile( - "^GxB_(BOR|BAND|BXOR|BXNOR)_(BOR|BAND|BXOR|BXNOR)_(UINT8|UINT16|UINT32|UINT64)$" - ), - ], - "re_exprs_return_bool": [ - re.compile("^GrB_(LOR|LAND|LXOR|LXNOR)_(LOR|LAND)_SEMIRING_BOOL$"), - re.compile( - "^GxB_(LOR|LAND|LXOR|EQ|ANY)_(EQ|NE|GT|LT|GE|LE)" - "_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" - ), - re.compile( - "^GxB_(LOR|LAND|LXOR|EQ|ANY)_(FIRST|SECOND|PAIR|LOR|LAND|LXOR|EQ|GT|LT|GE|LE)_BOOL$" - ), - ], - } - - @classmethod - def _build(cls, name, monoid, binaryop, *, anonymous=False): - if type(monoid) is not Monoid: - raise TypeError(f"monoid must be a Monoid, not {type(monoid)}") - if type(binaryop) is not BinaryOp: - raise TypeError(f"binaryop must be a BinaryOp, not {type(binaryop)}") - if name is None: - name = f"{monoid.name}_{binaryop.name}".replace(".", "_") - new_type_obj = cls(name, monoid, binaryop, anonymous=anonymous) - if binaryop._is_udt: - return new_type_obj - for binary_in, binary_func in binaryop._typed_ops.items(): - binary_out = binary_func.return_type - # Unfortunately, we can't have user-defined monoids over bools yet - # because numba can't compile correctly. - if ( - binary_out not in monoid.types - # Are all coercions bad, or just to bool? - or monoid.coercions.get(binary_out, binary_out) != binary_out - ): - continue - new_semiring = ffi_new("GrB_Semiring*") - check_status_carg( - lib.GrB_Semiring_new(new_semiring, monoid[binary_out].gb_obj, binary_func.gb_obj), - "Semiring", - new_semiring, - ) - ret_type = monoid[binary_out].return_type - op = TypedUserSemiring( - new_type_obj, - name, - binary_in, - ret_type, - new_semiring[0], - monoid[binary_out], - binary_func, - ) - new_type_obj._add(op) - return new_type_obj - - def _compile_udt(self, dtype, dtype2): - if dtype2 is None: - dtype2 = dtype - dtypes = (dtype, dtype2) - if dtypes in self._udt_types: - return self._udt_ops[dtypes] - binaryop = self.binaryop._compile_udt(dtype, dtype2) - monoid = self.monoid[binaryop.return_type] - ret_type = monoid.return_type - new_semiring = ffi_new("GrB_Semiring*") - status = lib.GrB_Semiring_new(new_semiring, monoid.gb_obj, binaryop.gb_obj) - check_status_carg(status, "Semiring", new_semiring) - op = TypedUserSemiring( - new_semiring, - self.name, - dtype, - ret_type, - new_semiring[0], - monoid, - binaryop, - dtype2=dtype2, - ) - self._udt_types[dtypes] = dtype - self._udt_ops[dtypes] = op - return op - - @classmethod - def register_anonymous(cls, monoid, binaryop, name=None): - """Register a Semiring without registering it in the ``graphblas.semiring`` namespace. - - Because it is not registered in the namespace, the name is optional. - - Parameters - ---------- - monoid : Monoid - Builtin or registered monoid - binaryop : BinaryOp - Builtin or registered binary operator - name : str, optional - Name associated with the semiring - - Returns - ------- - Function handle - """ - if type(monoid) is ParameterizedMonoid or type(binaryop) is ParameterizedBinaryOp: - return ParameterizedSemiring(name, monoid, binaryop, anonymous=True) - return cls._build(name, monoid, binaryop, anonymous=True) - - @classmethod - def register_new(cls, name, monoid, binaryop, *, lazy=False): - """Register a Semiring. The name will be used to identify the Semiring in the - ``graphblas.semiring`` namespace. - - >>> gb.core.operator.Semiring.register_new("max_max", gb.monoid.max, gb.binary.max) - >>> dir(gb.semiring) - [..., 'max_max', ...] - """ - module, funcname = cls._remove_nesting(name) - if lazy: - module._delayed[funcname] = ( - cls.register_new, - {"name": name, "monoid": monoid, "binaryop": binaryop}, - ) - elif type(monoid) is ParameterizedMonoid or type(binaryop) is ParameterizedBinaryOp: - semiring = ParameterizedSemiring(name, monoid, binaryop) - setattr(module, funcname, semiring) - else: - semiring = cls._build(name, monoid, binaryop) - setattr(module, funcname, semiring) - # Also save it to `graphblas.op` if not yet defined - opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) - if not _hasop(opmodule, funcname): - if lazy: - opmodule._delayed[funcname] = module - else: - setattr(opmodule, funcname, semiring) - if not cls._initialized: - _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") - if not lazy: - return semiring - - @classmethod - def _initialize(cls): - if cls._initialized: # pragma: no cover (safety) - return - super()._initialize() - # Rename div to cdiv (truncate towards 0) - div_semirings = { - attr: val - for attr, val in vars(semiring).items() - if type(val) is Semiring and attr.endswith("_div") - } - for orig_name, orig in div_semirings.items(): - name = f"{orig_name[:-3]}cdiv" - cdiv_semiring = Semiring(name) - setattr(semiring, name, cdiv_semiring) - setattr(op, name, cdiv_semiring) - delattr(semiring, orig_name) - delattr(op, orig_name) - for dtype, ret_type in orig.types.items(): - orig_semiring = orig[dtype] - new_semiring = TypedBuiltinSemiring( - cdiv_semiring, - name, - dtype, - ret_type, - orig_semiring.gb_obj, - orig_semiring.gb_name, - ) - cdiv_semiring._add(new_semiring) - # Also add truediv (always floating point) and floordiv (truncate towards -inf) - for orig_name, orig in div_semirings.items(): - cls.register_new(f"{orig_name[:-3]}truediv", orig.monoid, binary.truediv, lazy=True) - cls.register_new(f"{orig_name[:-3]}rtruediv", orig.monoid, "rtruediv", lazy=True) - cls.register_new(f"{orig_name[:-3]}floordiv", orig.monoid, "floordiv", lazy=True) - cls.register_new(f"{orig_name[:-3]}rfloordiv", orig.monoid, "rfloordiv", lazy=True) - # For aggregators - cls.register_new("plus_pow", monoid.plus, binary.pow) - cls.register_new("plus_rpow", monoid.plus, "rpow", lazy=True) - cls.register_new("plus_absfirst", monoid.plus, "absfirst", lazy=True) - cls.register_new("max_absfirst", monoid.max, "absfirst", lazy=True) - cls.register_new("plus_abssecond", monoid.plus, "abssecond", lazy=True) - cls.register_new("max_abssecond", monoid.max, "abssecond", lazy=True) - - # Update type information with sane coercion - for lname in ["any", "eq", "land", "lor", "lxnor", "lxor"]: - target_name = f"{lname}_ne" - source_name = f"{lname}_lxor" - if not _hasop(semiring, target_name): - continue - target_op = getattr(semiring, target_name) - if BOOL not in target_op.types: # pragma: no branch (safety) - source_op = getattr(semiring, source_name) - typed_op = source_op._typed_ops[BOOL] - target_op.types[BOOL] = BOOL - target_op._typed_ops[BOOL] = typed_op - target_op.coercions[dtype] = BOOL - - position_dtypes = [ - BOOL, - FP32, - FP64, - INT8, - INT16, - UINT8, - UINT16, - UINT32, - UINT64, - ] - notbool_dtypes = [ - FP32, - FP64, - INT8, - INT16, - INT32, - INT64, - UINT8, - UINT16, - UINT32, - UINT64, - ] - if _supports_complex: - position_dtypes.extend([FC32, FC64]) - notbool_dtypes.extend([FC32, FC64]) - for lnames, rnames, *types in [ - # fmt: off - ( - ("any", "max", "min", "plus", "times"), - ( - "firsti", "firsti1", "firstj", "firstj1", - "secondi", "secondi1", "secondj", "secondj1", - ), - ( - position_dtypes, - INT64, - ), - ), - ( - ("eq", "land", "lor", "lxnor", "lxor"), - ("first", "pair", "second"), - # TODO: check if FC coercion works here - ( - notbool_dtypes, - BOOL, - ), - ), - ( - ("band", "bor", "bxnor", "bxor"), - ("band", "bor", "bxnor", "bxor"), - ([INT8], UINT16), - ([INT16], UINT32), - ([INT32], UINT64), - ([INT64], UINT64), - ), - ( - ("any", "eq", "land", "lor", "lxnor", "lxor"), - ("eq", "land", "lor", "lxnor", "lxor", "ne"), - ( - ( - FP32, FP64, INT8, INT16, INT32, INT64, - UINT8, UINT16, UINT32, UINT64, - ), - BOOL, - ), - ), - # fmt: on - ]: - for left, right in itertools.product(lnames, rnames): - name = f"{left}_{right}" - if not _hasop(semiring, name): - continue - if name in _SS_OPERATORS: - cur_op = semiring._deprecated[name] - else: - cur_op = getattr(semiring, name) - for input_types, target_type in types: - typed_op = cur_op._typed_ops[target_type] - output_type = cur_op.types[target_type] - for dtype in input_types: - if dtype not in cur_op.types: - cur_op.types[dtype] = output_type - cur_op._typed_ops[dtype] = typed_op - cur_op.coercions[dtype] = target_type - - # Handle a few boolean cases - for opname, targetname in [ - ("max_first", "lor_first"), - ("max_second", "lor_second"), - ("max_land", "lor_land"), - ("max_lor", "lor_lor"), - ("max_lxor", "lor_lxor"), - ("min_first", "land_first"), - ("min_second", "land_second"), - ("min_land", "land_land"), - ("min_lor", "land_lor"), - ("min_lxor", "land_lxor"), - ]: - cur_op = getattr(semiring, opname) - target = getattr(semiring, targetname) - if BOOL in cur_op.types or BOOL not in target.types: # pragma: no cover (safety) - continue - cur_op.types[BOOL] = target.types[BOOL] - cur_op._typed_ops[BOOL] = target._typed_ops[BOOL] - cur_op.coercions[BOOL] = BOOL - cls._initialized = True - - def __init__(self, name, monoid=None, binaryop=None, *, anonymous=False): - super().__init__(name, anonymous=anonymous) - self._monoid = monoid - self._binaryop = binaryop - try: - if self.binaryop._udt_types is not None: - self._udt_types = {} # {(dtype, dtype): DataType} - self._udt_ops = {} # {(dtype, dtype): TypedUserSemiring} - except AttributeError: - # `*_div` semirings raise here, but don't need `_udt_types` - pass - - def __reduce__(self): - if self._anonymous: - return (self.register_anonymous, (self._monoid, self._binaryop, self.name)) - if (name := f"semiring.{self.name}") in _STANDARD_OPERATOR_NAMES: - return name - return (self._deserialize, (self.name, self._monoid, self._binaryop)) - - @property - def binaryop(self): - """The :class:`BinaryOp` associated with the Semiring.""" - if self._binaryop is not None: - return self._binaryop - # Must be builtin - name = self.name.split("_")[1] - if name in _SS_OPERATORS: - return binary._deprecated[name] - return getattr(binary, name) - - @property - def monoid(self): - """The :class:`Monoid` associated with the Semiring.""" - if self._monoid is not None: - return self._monoid - # Must be builtin - return getattr(monoid, self.name.split("_")[0].split(".")[-1]) - - @property - def is_positional(self): - return self.binaryop.is_positional - - @property - def _is_udt(self): - return self._binaryop is not None and self._binaryop._is_udt - - @property - def _custom_dtype(self): - return self.binaryop._custom_dtype - - commutes_to = TypedBuiltinSemiring.commutes_to - is_commutative = TypedBuiltinSemiring.is_commutative - __call__ = TypedBuiltinSemiring.__call__ - - -def get_typed_op(op, dtype, dtype2=None, *, is_left_scalar=False, is_right_scalar=False, kind=None): - if isinstance(op, OpBase): - # UDTs always get compiled - if op._is_udt: - return op._compile_udt(dtype, dtype2) - # Single dtype is simple lookup - if dtype2 is None: - return op[dtype] - # Handle special cases such as first and second (may have UDTs) - if op._custom_dtype is not None and (rv := op._custom_dtype(op, dtype, dtype2)) is not None: - return rv - # Generic case: try to unify the two dtypes - try: - return op[ - unify(dtype, dtype2, is_left_scalar=is_left_scalar, is_right_scalar=is_right_scalar) - ] - except (TypeError, AttributeError): - # Failure to unify implies a dtype is UDT; some builtin operators can handle UDTs - if op.is_positional: - return op[UINT64] - if op._udt_types is None: - raise - return op._compile_udt(dtype, dtype2) - if isinstance(op, ParameterizedUdf): - op = op() # Use default parameters of parameterized UDFs - return get_typed_op( - op, - dtype, - dtype2, - is_left_scalar=is_left_scalar, - is_right_scalar=is_right_scalar, - kind=kind, - ) - if isinstance(op, TypedOpBase): - return op - - from .agg import Aggregator, TypedAggregator - - if isinstance(op, Aggregator): - return op[dtype] - if isinstance(op, TypedAggregator): - return op - if isinstance(op, str): - if kind == "unary": - op = unary_from_string(op) - elif kind == "select": - op = select_from_string(op) - elif kind == "binary": - op = binary_from_string(op) - elif kind == "monoid": - op = monoid_from_string(op) - elif kind == "semiring": - op = semiring_from_string(op) - elif kind == "binary|aggregator": - try: - op = binary_from_string(op) - except ValueError: - try: - op = aggregator_from_string(op) - except ValueError: - raise ValueError( - f"Unknown binary or aggregator string: {op!r}. Example usage: '+[int]'" - ) from None - - else: - raise ValueError( - f"Unable to get op from string {op!r}. `kind=` argument must be provided as " - '"unary", "binary", "monoid", "semiring", "indexunary", "select", ' - 'or "binary|aggregator".' - ) - return get_typed_op( - op, - dtype, - dtype2, - is_left_scalar=is_left_scalar, - is_right_scalar=is_right_scalar, - kind=kind, - ) - if isinstance(op, FunctionType): - if kind == "unary": - op = UnaryOp.register_anonymous(op, is_udt=True) - return op._compile_udt(dtype, dtype2) - if kind.startswith("binary"): - op = BinaryOp.register_anonymous(op, is_udt=True) - return op._compile_udt(dtype, dtype2) - if isinstance(op, BuiltinFunctionType) and op in _builtin_to_op: - return get_typed_op( - _builtin_to_op[op], - dtype, - dtype2, - is_left_scalar=is_left_scalar, - is_right_scalar=is_right_scalar, - kind=kind, - ) - raise TypeError(f"Unable to get typed operator from object with type {type(op)}") - - -def find_opclass(gb_op): - if isinstance(gb_op, OpBase): - opclass = type(gb_op).__name__ - elif isinstance(gb_op, TypedOpBase): - opclass = gb_op.opclass - elif isinstance(gb_op, ParameterizedUdf): - gb_op = gb_op() # Use default parameters of parameterized UDFs - gb_op, opclass = find_opclass(gb_op) - elif isinstance(gb_op, BuiltinFunctionType) and gb_op in _builtin_to_op: - gb_op, opclass = find_opclass(_builtin_to_op[gb_op]) - else: - opclass = UNKNOWN_OPCLASS - return gb_op, opclass - - -def get_semiring(monoid, binaryop, name=None): - """Get or create a Semiring object from a monoid and binaryop. - - If either are typed, then the returned semiring will also be typed. - - See Also - -------- - semiring.register_anonymous - semiring.register_new - semiring.from_string - """ - monoid, opclass = find_opclass(monoid) - switched = False - if opclass == "BinaryOp" and monoid.monoid is not None: - switched = True - monoid = monoid.monoid - elif opclass != "Monoid": - raise TypeError(f"Expected a Monoid for the monoid argument. Got type: {type(monoid)}") - binaryop, opclass = find_opclass(binaryop) - if opclass == "Monoid": - if switched: - raise TypeError( - "Got a BinaryOp for the monoid argument and a Monoid for the binaryop argument. " - "Are the arguments switched? Hint: you can do `mymonoid.binaryop` to get the " - "binaryop from a monoid." - ) - binaryop = binaryop.binaryop - elif opclass != "BinaryOp": - raise TypeError( - f"Expected a BinaryOp for the binaryop argument. Got type: {type(binaryop)}" - ) - if isinstance(monoid, Monoid): - monoid_type = None - else: - monoid_type = monoid.type - monoid = monoid.parent - if isinstance(binaryop, BinaryOp): - binary_type = None - else: - binary_type = binaryop.type - binaryop = binaryop.parent - if monoid._anonymous or binaryop._anonymous: - rv = Semiring.register_anonymous(monoid, binaryop, name=name) - else: - *monoid_prefix, monoid_name = monoid.name.rsplit(".", 1) - *binary_prefix, binary_name = binaryop.name.rsplit(".", 1) - if ( - monoid_prefix - and binary_prefix - and monoid_prefix == binary_prefix - or config.get("mapnumpy") - and ( - monoid_prefix == ["numpy"] - and not binary_prefix - or binary_prefix == ["numpy"] - and not monoid_prefix - ) - or backend == "suitesparse" - and binary_name in _SS_OPERATORS - ): - canonical_name = ( - ".".join(monoid_prefix or binary_prefix) + f".{monoid_name}_{binary_name}" - ) - else: - canonical_name = f"{monoid.name}_{binaryop.name}".replace(".", "_") - if name is None: - name = canonical_name - - module, funcname = Semiring._remove_nesting(canonical_name, strict=False) - rv = ( - getattr(module, funcname) - if funcname in module.__dict__ or funcname in module._delayed - else getattr(module, "_deprecated", {}).get(funcname) - ) - if rv is None and name != canonical_name: - module, funcname = Semiring._remove_nesting(name, strict=False) - rv = ( - getattr(module, funcname) - if funcname in module.__dict__ or funcname in module._delayed - else getattr(module, "_deprecated", {}).get(funcname) - ) - if rv is None: - rv = Semiring.register_new(canonical_name, monoid, binaryop) - elif rv.monoid is not monoid or rv.binaryop is not binaryop: # pragma: no cover - # It's not the object we expect (can this happen?) - rv = Semiring.register_anonymous(monoid, binaryop, name=name) - if name != canonical_name: - module, funcname = Semiring._remove_nesting(name, strict=False) - if not _hasop(module, funcname): # pragma: no branch (safety) - setattr(module, funcname, rv) - - if binary_type is not None: - return rv[binary_type] - if monoid_type is not None: - return rv[monoid_type] - return rv - - -# Now initialize all the things! -try: - UnaryOp._initialize() - IndexUnaryOp._initialize() - SelectOp._initialize() - BinaryOp._initialize() - Monoid._initialize() - Semiring._initialize() -except Exception: # pragma: no cover (debug) - # Exceptions here can often get ignored by Python - import traceback - - traceback.print_exc() - raise - -unary.register_new = UnaryOp.register_new -unary.register_anonymous = UnaryOp.register_anonymous -indexunary.register_new = IndexUnaryOp.register_new -indexunary.register_anonymous = IndexUnaryOp.register_anonymous -select.register_new = SelectOp.register_new -select.register_anonymous = SelectOp.register_anonymous -binary.register_new = BinaryOp.register_new -binary.register_anonymous = BinaryOp.register_anonymous -monoid.register_new = Monoid.register_new -monoid.register_anonymous = Monoid.register_anonymous -semiring.register_new = Semiring.register_new -semiring.register_anonymous = Semiring.register_anonymous -semiring.get_semiring = get_semiring - -select._binary_to_select.update( - { - binary.eq: select.valueeq, - binary.ne: select.valuene, - binary.le: select.valuele, - binary.lt: select.valuelt, - binary.ge: select.valuege, - binary.gt: select.valuegt, - binary.iseq: select.valueeq, - binary.isne: select.valuene, - binary.isle: select.valuele, - binary.islt: select.valuelt, - binary.isge: select.valuege, - binary.isgt: select.valuegt, - } -) - -_builtin_to_op = { - abs: unary.abs, - max: binary.max, - min: binary.min, - # Maybe someday: all, any, pow, sum -} - -_str_to_unary = { - "-": unary.ainv, - "~": unary.lnot, -} -_str_to_select = { - "<": select.valuelt, - ">": select.valuegt, - "<=": select.valuele, - ">=": select.valuege, - "!=": select.valuene, - "==": select.valueeq, - "col<=": select.colle, - "col>": select.colgt, - "row<=": select.rowle, - "row>": select.rowgt, - "index<=": select.indexle, - "index>": select.indexgt, -} -_str_to_binary = { - "<": binary.lt, - ">": binary.gt, - "<=": binary.le, - ">=": binary.ge, - "!=": binary.ne, - "==": binary.eq, - "+": binary.plus, - "-": binary.minus, - "*": binary.times, - "/": binary.truediv, - "//": "floordiv", - "%": "numpy.mod", - "**": binary.pow, - "&": binary.land, - "|": binary.lor, - "^": binary.lxor, -} -_str_to_monoid = { - "==": monoid.eq, - "+": monoid.plus, - "*": monoid.times, - "&": monoid.land, - "|": monoid.lor, - "^": monoid.lxor, -} - - -def _from_string(string, module, mapping, example): - s = string.lower().strip() - base, *dtype = s.split("[") - if len(dtype) > 1: - name = module.__name__.split(".")[-1] - raise ValueError( - f'Bad {name} string: {string!r}. Contains too many "[". Example usage: {example!r}' - ) - if dtype: - dtype = dtype[0] - if not dtype.endswith("]"): - name = module.__name__.split(".")[-1] - raise ValueError( - f'Bad {name} string: {string!r}. Datatype specification does not end with "]". ' - f"Example usage: {example!r}" - ) - dtype = lookup_dtype(dtype[:-1]) - if "]" in base: - name = module.__name__.split(".")[-1] - raise ValueError( - f'Bad {name} string: {string!r}. "]" not matched by "[". Example usage: {example!r}' - ) - if base in mapping: - op = mapping[base] - if type(op) is str: - op = mapping[base] = module.from_string(op) - elif hasattr(module, base): - op = getattr(module, base) - elif hasattr(module, "numpy") and hasattr(module.numpy, base): - op = getattr(module.numpy, base) - else: - *paths, attr = base.split(".") - op = None - cur = module - for path in paths: - cur = getattr(cur, path, None) - if not isinstance(cur, (OpPath, ModuleType)): - cur = None - break - op = getattr(cur, attr, None) - if op is None: - name = module.__name__.split(".")[-1] - raise ValueError(f"Unknown {name} string: {string!r}. Example usage: {example!r}") - if dtype: - op = op[dtype] - return op - - -def unary_from_string(string): - return _from_string(string, unary, _str_to_unary, "abs[int]") - - -def indexunary_from_string(string): - # "select" is a variant of IndexUnary, so the string abbreviations in - # _str_to_select are appropriate to reuse here - return _from_string(string, indexunary, _str_to_select, "row_index") - - -def select_from_string(string): - return _from_string(string, select, _str_to_select, "tril") - - -def binary_from_string(string): - return _from_string(string, binary, _str_to_binary, "+[int]") - - -def monoid_from_string(string): - return _from_string(string, monoid, _str_to_monoid, "+[int]") - - -def semiring_from_string(string): - split = string.split(".") - if len(split) == 1: - try: - return _from_string(string, semiring, {}, "min.+[int]") - except Exception: - pass - if len(split) != 2: - raise ValueError( - f"Bad semiring string: {string!r}. " - 'The monoid and binaryop should be separated by exactly one period, ".". ' - "Example usage: min.+[int]" - ) - cur_monoid = monoid_from_string(split[0]) - cur_binary = binary_from_string(split[1]) - return get_semiring(cur_monoid, cur_binary) - - -def op_from_string(string): - for func in [ - # Note: order matters here - unary_from_string, - binary_from_string, - monoid_from_string, - semiring_from_string, - indexunary_from_string, - select_from_string, - aggregator_from_string, - ]: - try: - return func(string) - except Exception: - pass - raise ValueError(f"Unknown op string: {string!r}. Example usage: 'abs[int]'") - - -unary.from_string = unary_from_string -indexunary.from_string = indexunary_from_string -select.from_string = select_from_string -binary.from_string = binary_from_string -monoid.from_string = monoid_from_string -semiring.from_string = semiring_from_string -op.from_string = op_from_string - -_str_to_agg = { - "+": "sum", - "*": "prod", - "&": "all", - "|": "any", -} - - -def aggregator_from_string(string): - return _from_string(string, agg, _str_to_agg, "sum[int]") - - -from .. import agg # noqa: E402 isort:skip - -agg.from_string = aggregator_from_string diff --git a/graphblas/core/operator/__init__.py b/graphblas/core/operator/__init__.py new file mode 100644 index 000000000..d59c835b3 --- /dev/null +++ b/graphblas/core/operator/__init__.py @@ -0,0 +1,22 @@ +from .base import UNKNOWN_OPCLASS, OpBase, OpPath, ParameterizedUdf, TypedOpBase, find_opclass +from .binary import BinaryOp, ParameterizedBinaryOp +from .indexunary import IndexUnaryOp, ParameterizedIndexUnaryOp +from .monoid import Monoid, ParameterizedMonoid +from .select import ParameterizedSelectOp, SelectOp +from .semiring import ParameterizedSemiring, Semiring +from .unary import ParameterizedUnaryOp, UnaryOp +from .utils import ( + _get_typed_op_from_exprs, + aggregator_from_string, + binary_from_string, + get_semiring, + get_typed_op, + indexunary_from_string, + monoid_from_string, + op_from_string, + select_from_string, + semiring_from_string, + unary_from_string, +) + +from .agg import Aggregator # isort:skip diff --git a/graphblas/core/agg.py b/graphblas/core/operator/agg.py similarity index 93% rename from graphblas/core/agg.py rename to graphblas/core/operator/agg.py index 3afcbc408..6b463a8a6 100644 --- a/graphblas/core/agg.py +++ b/graphblas/core/operator/agg.py @@ -3,9 +3,10 @@ import numpy as np -from .. import agg, backend, binary, monoid, semiring, unary -from ..dtypes import INT64, lookup_dtype -from .utils import output_type +from ... import agg, backend, binary, monoid, semiring, unary +from ...dtypes import INT64, lookup_dtype +from .. import _supports_udfs +from ..utils import output_type def _get_types(ops, initdtype): @@ -38,6 +39,7 @@ def __init__( semiring=None, switch=False, semiring2=None, + applybegin=None, finalize=None, composite=None, custom=None, @@ -52,6 +54,7 @@ def __init__( self._semiring = semiring self._semiring2 = semiring2 self._switch = switch + self._applybegin = applybegin self._finalize = finalize self._composite = composite self._custom = custom @@ -73,9 +76,9 @@ def __init__( @property def types(self): if self._types is None: - if type(self._semiring) is str: + if isinstance(self._semiring, str): self._semiring = semiring.from_string(self._semiring) - if type(self._types_orig[0]) is str: # pragma: no branch + if isinstance(self._types_orig[0], str): # pragma: no branch self._types_orig[0] = semiring.from_string(self._types_orig[0]) self._types = _get_types( self._types_orig, None if self._initval_orig is None else self._initdtype @@ -106,8 +109,8 @@ def __reduce__(self): def __call__(self, val, *, rowwise=False, columnwise=False): # Should we expose `allow_empty=` keyword when reducing to Scalar? - from .matrix import Matrix, TransposedMatrix - from .vector import Vector + from ..matrix import Matrix, TransposedMatrix + from ..vector import Vector typ = output_type(val) if typ is Vector: @@ -152,8 +155,11 @@ def __repr__(self): def _new(self, updater, expr, *, in_composite=False): agg = self.parent + opts = updater.opts if agg._monoid is not None: x = expr.args[0] + if agg._applybegin is not None: # pragma: no cover (unused) + x = agg._applybegin(x).new(**opts) method = getattr(x, expr.method_name) if expr.output_type.__name__ == "Scalar": expr = method(agg._monoid[self.type], allow_empty=not expr._is_cscalar) @@ -167,7 +173,6 @@ def _new(self, updater, expr, *, in_composite=False): return parent._as_vector() return - opts = updater.opts if agg._composite is not None: # Masks are applied throughout the aggregation, including composite aggregations. # Aggregations done while `in_composite is True` should return the updater parent @@ -203,6 +208,8 @@ def _new(self, updater, expr, *, in_composite=False): if expr.cfunc_name == "GrB_Matrix_reduce_Aggregator": # Matrix -> Vector A = expr.args[0] + if agg._applybegin is not None: + A = agg._applybegin(A).new(**opts) orig_updater = updater if agg._finalize is not None: step1 = expr.construct_output(semiring.return_type) @@ -223,6 +230,8 @@ def _new(self, updater, expr, *, in_composite=False): elif expr.cfunc_name.startswith("GrB_Vector_reduce"): # Vector -> Scalar v = expr.args[0] + if agg._applybegin is not None: + v = agg._applybegin(v).new(**opts) step1 = expr._new_vector(semiring.return_type, size=1) init = expr._new_matrix(agg._initdtype, nrows=v._size, ncols=1) init(**opts)[...] = agg._initval # O(1) dense column vector in SuiteSparse 5 @@ -242,6 +251,8 @@ def _new(self, updater, expr, *, in_composite=False): elif expr.cfunc_name.startswith("GrB_Matrix_reduce"): # Matrix -> Scalar A = expr.args[0] + if agg._applybegin is not None: + A = agg._applybegin(A).new(**opts) # We need to compute in two steps: Matrix -> Vector -> Scalar. # This has not been benchmarked or optimized. # We may be able to intelligently choose the faster path. @@ -339,11 +350,21 @@ def __reduce__(self): # logaddexp2 = Aggregator('logaddexp2', monoid=semiring.numpy.logaddexp2) # hypot as monoid doesn't work if single negative element! # hypot = Aggregator('hypot', monoid=semiring.numpy.hypot) +# hypot = Aggregator('hypot', applybegin=unary.abs, monoid=semiring.numpy.hypot) agg.L0norm = agg.count_nonzero -agg.L1norm = Aggregator("L1norm", semiring="plus_absfirst", semiring2=semiring.plus_first) agg.L2norm = agg.hypot -agg.Linfnorm = Aggregator("Linfnorm", semiring="max_absfirst", semiring2=semiring.max_first) +if _supports_udfs: + agg.L1norm = Aggregator("L1norm", semiring="plus_absfirst", semiring2=semiring.plus_first) + agg.Linfnorm = Aggregator("Linfnorm", semiring="max_absfirst", semiring2=semiring.max_first) +else: + # Are these always better? + agg.L1norm = Aggregator( + "L1norm", applybegin=unary.abs, semiring=semiring.plus_first, semiring2=semiring.plus_first + ) + agg.Linfnorm = Aggregator( + "Linfnorm", applybegin=unary.abs, semiring=semiring.max_first, semiring2=semiring.max_first + ) # Composite @@ -677,4 +698,4 @@ def _first_last_index(agg, updater, expr, opts, *, in_composite, semiring): agg.Aggregator = Aggregator agg.TypedAggregator = TypedAggregator -from .operator import get_typed_op # noqa: E402 isort:skip +from .utils import get_typed_op # noqa: E402 isort:skip diff --git a/graphblas/core/operator/base.py b/graphblas/core/operator/base.py new file mode 100644 index 000000000..97b2c9fbd --- /dev/null +++ b/graphblas/core/operator/base.py @@ -0,0 +1,526 @@ +from functools import lru_cache +from operator import getitem +from types import BuiltinFunctionType, ModuleType + +from ... import _STANDARD_OPERATOR_NAMES, backend, op +from ...dtypes import BOOL, INT8, UINT64, _supports_complex, lookup_dtype +from .. import _has_numba, _supports_udfs, lib +from ..expr import InfixExprBase +from ..utils import output_type + +if _has_numba: + import numba + from numba import NumbaError +else: + NumbaError = TypeError + +UNKNOWN_OPCLASS = "UnknownOpClass" + +# These now live as e.g. `gb.unary.ss.positioni` +# Deprecations such as `gb.unary.positioni` will be removed in 2023.9.0 or later. +_SS_OPERATORS = { + # unary + "erf", # scipy.special.erf + "erfc", # scipy.special.erfc + "frexpe", # np.frexp[1] + "frexpx", # np.frexp[0] + "lgamma", # scipy.special.loggamma + "tgamma", # scipy.special.gamma + # Positional + # unary + "positioni", + "positioni1", + "positionj", + "positionj1", + # binary + "firsti", + "firsti1", + "firstj", + "firstj1", + "secondi", + "secondi1", + "secondj", + "secondj1", + # semiring + "any_firsti", + "any_firsti1", + "any_firstj", + "any_firstj1", + "any_secondi", + "any_secondi1", + "any_secondj", + "any_secondj1", + "max_firsti", + "max_firsti1", + "max_firstj", + "max_firstj1", + "max_secondi", + "max_secondi1", + "max_secondj", + "max_secondj1", + "min_firsti", + "min_firsti1", + "min_firstj", + "min_firstj1", + "min_secondi", + "min_secondi1", + "min_secondj", + "min_secondj1", + "plus_firsti", + "plus_firsti1", + "plus_firstj", + "plus_firstj1", + "plus_secondi", + "plus_secondi1", + "plus_secondj", + "plus_secondj1", + "times_firsti", + "times_firsti1", + "times_firstj", + "times_firstj1", + "times_secondi", + "times_secondi1", + "times_secondj", + "times_secondj1", +} + + +def _hasop(module, name): + return ( + name in module.__dict__ + or name in module._delayed + or name in getattr(module, "_deprecated", ()) + ) + + +class OpPath: + def __init__(self, parent, name): + self._parent = parent + self._name = name + self._delayed = {} + self._delayed_commutes_to = {} + + def __getattr__(self, key): + if key in self._delayed: + func, kwargs = self._delayed.pop(key) + return func(**kwargs) + self.__getattribute__(key) # raises + + +def _call_op(op, left, right=None, thunk=None, **kwargs): + if right is None and thunk is None: + if isinstance(left, InfixExprBase): + # op(A & B), op(A | B), op(A @ B) + return getattr(left.left, f"_{left.method_name}")( + left.right, op, is_infix=True, **kwargs + ) + if find_opclass(op)[1] == "Semiring": + raise TypeError( + f"Bad type when calling {op!r}. Got type: {type(left)}.\n" + f"Expected an infix expression, such as: {op!r}(A @ B)" + ) + raise TypeError( + f"Bad type when calling {op!r}. Got type: {type(left)}.\n" + "Expected an infix expression or an apply with a Vector or Matrix and a scalar:\n" + f" - {op!r}(A & B)\n" + f" - {op!r}(A, 1)\n" + f" - {op!r}(1, A)" + ) + + # op(A, 1) -> apply (or select if thunk provided) + from ..matrix import Matrix, TransposedMatrix + from ..vector import Vector + + if (left_type := output_type(left)) in {Vector, Matrix, TransposedMatrix}: + if thunk is not None: + return left.select(op, thunk=thunk, **kwargs) + return left.apply(op, right=right, **kwargs) + if (right_type := output_type(right)) in {Vector, Matrix, TransposedMatrix}: + return right.apply(op, left=left, **kwargs) + + from ..scalar import Scalar, _as_scalar + + if left_type is Scalar: + if thunk is not None: + return left.select(op, thunk=thunk, **kwargs) + return left.apply(op, right=right, **kwargs) + if right_type is Scalar: + return right.apply(op, left=left, **kwargs) + try: + left_scalar = _as_scalar(left, is_cscalar=False) + except Exception: + pass + else: + if thunk is not None: + return left_scalar.select(op, thunk=thunk, **kwargs) + return left_scalar.apply(op, right=right, **kwargs) + raise TypeError( + f"Bad types when calling {op!r}. Got types: {type(left)}, {type(right)}.\n" + "Expected an infix expression or an apply with a Vector or Matrix and a scalar:\n" + f" - {op!r}(A & B)\n" + f" - {op!r}(A, 1)\n" + f" - {op!r}(1, A)" + ) + + +if _has_numba: + + def _get_udt_wrapper(numba_func, return_type, dtype, dtype2=None, *, include_indexes=False): + ztype = INT8 if return_type == BOOL else return_type + xtype = INT8 if dtype == BOOL else dtype + nt = numba.types + wrapper_args = [nt.CPointer(ztype.numba_type), nt.CPointer(xtype.numba_type)] + if include_indexes: + wrapper_args.extend([UINT64.numba_type, UINT64.numba_type]) + if dtype2 is not None: + ytype = INT8 if dtype2 == BOOL else dtype2 + wrapper_args.append(nt.CPointer(ytype.numba_type)) + wrapper_sig = nt.void(*wrapper_args) + + zarray = xarray = yarray = BL = BR = yarg = yname = rcidx = "" + if return_type._is_udt: + if return_type.np_type.subdtype is None: + zarray = " z = numba.carray(z_ptr, 1)\n" + zname = "z[0]" + else: + zname = "z_ptr[0]" + BR = "[0]" + else: + zname = "z_ptr[0]" + if return_type == BOOL: + BL = "bool(" + BR = ")" + + if dtype._is_udt: + if dtype.np_type.subdtype is None: + xarray = " x = numba.carray(x_ptr, 1)\n" + xname = "x[0]" + else: + xname = "x_ptr" + elif dtype == BOOL: + xname = "bool(x_ptr[0])" + else: + xname = "x_ptr[0]" + + if dtype2 is not None: + yarg = ", y_ptr" + if dtype2._is_udt: + if dtype2.np_type.subdtype is None: + yarray = " y = numba.carray(y_ptr, 1)\n" + yname = ", y[0]" + else: + yname = ", y_ptr" + elif dtype2 == BOOL: + yname = ", bool(y_ptr[0])" + else: + yname = ", y_ptr[0]" + + if include_indexes: + rcidx = ", row, col" + + d = {"numba": numba, "numba_func": numba_func} + text = ( + f"def wrapper(z_ptr, x_ptr{rcidx}{yarg}):\n" + f"{zarray}{xarray}{yarray}" + f" {zname} = {BL}numba_func({xname}{rcidx}{yname}){BR}\n" + ) + exec(text, d) # pylint: disable=exec-used + return d["wrapper"], wrapper_sig + + +class TypedOpBase: + __slots__ = ( + "parent", + "name", + "type", + "return_type", + "gb_obj", + "gb_name", + "_type2", + "__weakref__", + ) + + def __init__(self, parent, name, type_, return_type, gb_obj, gb_name, dtype2=None): + self.parent = parent + self.name = name + self.type = type_ + self.return_type = return_type + self.gb_obj = gb_obj + self.gb_name = gb_name + self._type2 = dtype2 + + def __repr__(self): + classname = self.opclass.lower() + classname = classname.removesuffix("op") + dtype2 = "" if self._type2 is None else f", {self._type2.name}" + return f"{classname}.{self.name}[{self.type.name}{dtype2}]" + + @property + def _carg(self): + return self.gb_obj + + @property + def is_positional(self): + return self.parent.is_positional + + def __reduce__(self): + if self._type2 is None or self.type == self._type2: + return (getitem, (self.parent, self.type)) + return (getitem, (self.parent, (self.type, self._type2))) + + +def _deserialize_parameterized(parameterized_op, args, kwargs): + return parameterized_op(*args, **kwargs) + + +class ParameterizedUdf: + __slots__ = "name", "__call__", "_anonymous", "__weakref__" + is_positional = False + _custom_dtype = None + + def __init__(self, name, anonymous): + self.name = name + self._anonymous = anonymous + # lru_cache per instance + method = self._call.__get__(self, type(self)) + self.__call__ = lru_cache(maxsize=1024)(method) + + def _call(self, *args, **kwargs): + raise NotImplementedError + + +_VARNAMES = tuple(x for x in dir(lib) if x[0] != "_") + + +class OpBase: + __slots__ = ( + "name", + "_typed_ops", + "types", + "coercions", + "_anonymous", + "_udt_types", + "_udt_ops", + "__weakref__", + ) + _parse_config = None + _initialized = False + _module = None + _positional = None + + def __init__(self, name, *, anonymous=False): + self.name = name + self._typed_ops = {} + self.types = {} + self.coercions = {} + self._anonymous = anonymous + self._udt_types = None + self._udt_ops = None + + def __repr__(self): + return f"{self._modname}.{self.name}" + + def __getitem__(self, type_): + if type(type_) is tuple: + from .utils import get_typed_op + + dtype1, dtype2 = type_ + dtype1 = lookup_dtype(dtype1) + dtype2 = lookup_dtype(dtype2) + return get_typed_op(self, dtype1, dtype2) + if not self._is_udt: + type_ = lookup_dtype(type_) + if type_ not in self._typed_ops: + if self._udt_types is None: + if self.is_positional: + return self._typed_ops[UINT64] + raise KeyError(f"{self.name} does not work with {type_}") + else: + return self._typed_ops[type_] + # This is a UDT or is able to operate on UDTs such as `first` any `any` + dtype = lookup_dtype(type_) + return self._compile_udt(dtype, dtype) + + def _add(self, op, *, is_jit=False): + if is_jit: + if hasattr(op, "type2") or hasattr(op, "thunk_type"): + dtypes = (op.type, op._type2) + else: + dtypes = op.type + self.types[dtypes] = op.return_type # This is a different use of .types + self._udt_types[dtypes] = op.return_type + self._udt_ops[dtypes] = op + else: + self._typed_ops[op.type] = op + self.types[op.type] = op.return_type + + def __delitem__(self, type_): + type_ = lookup_dtype(type_) + del self._typed_ops[type_] + del self.types[type_] + + def __contains__(self, type_): + try: + self[type_] + except (TypeError, KeyError, NumbaError): + return False + return True + + @classmethod + def _remove_nesting(cls, funcname, *, module=None, modname=None, strict=True): + if module is None: + module = cls._module + if modname is None: + modname = cls._modname + if "." not in funcname: + if strict and _hasop(module, funcname): + raise AttributeError(f"{modname}.{funcname} is already defined") + else: + path, funcname = funcname.rsplit(".", 1) + for folder in path.split("."): + if not _hasop(module, folder): + setattr(module, folder, OpPath(module, folder)) + module = getattr(module, folder) + modname = f"{modname}.{folder}" + if not isinstance(module, (OpPath, ModuleType)): + raise AttributeError( + f"{modname} is already defined. Cannot use as a nested path." + ) + if strict and _hasop(module, funcname): + raise AttributeError(f"{path}.{funcname} is already defined") + return module, funcname + + @classmethod + def _find(cls, funcname): + rv = cls._module + for attr in funcname.split("."): + if attr in getattr(rv, "_deprecated", ()): + rv = rv._deprecated[attr] + else: + rv = getattr(rv, attr, None) + if rv is None: + break + return rv + + @classmethod + def _initialize(cls, include_in_ops=True): + """Initialize operators for this operator type. + + include_in_ops determines whether the operators are included in the + ``gb.ops`` namespace in addition to the defined module. + """ + if cls._initialized: # pragma: no cover (safety) + return + # Read in the parse configs + trim_from_front = cls._parse_config.get("trim_from_front", 0) + delete_exact = cls._parse_config.get("delete_exact") + num_underscores = cls._parse_config["num_underscores"] + + for re_str, return_prefix in [ + ("re_exprs", None), + ("re_exprs_return_bool", "BOOL"), + ("re_exprs_return_float", "FP"), + ("re_exprs_return_complex", "FC"), + ]: + if re_str not in cls._parse_config: + continue + if "complex" in re_str and not _supports_complex: + continue + for r in reversed(cls._parse_config[re_str]): + for varname in _VARNAMES: + m = r.match(varname) + if m: + # Parse function into name and datatype + gb_name = m.string + splitname = gb_name[trim_from_front:].split("_") + if delete_exact and delete_exact in splitname: + splitname.remove(delete_exact) + if len(splitname) == num_underscores + 1: + *splitname, type_ = splitname + else: + type_ = None + name = "_".join(splitname).lower() + # Create object for name unless it already exists + if not _hasop(cls._module, name): + if backend == "suitesparse" and name in _SS_OPERATORS: + fullname = f"ss.{name}" + else: + fullname = name + if cls._positional is None: + obj = cls(fullname) + else: + obj = cls(fullname, is_positional=name in cls._positional) + if name in _SS_OPERATORS: + if backend == "suitesparse": + setattr(cls._module.ss, name, obj) + cls._module._deprecated[name] = obj + if include_in_ops and not _hasop(op, name): # pragma: no branch + op._deprecated[name] = obj + if backend == "suitesparse": + setattr(op.ss, name, obj) + else: + setattr(cls._module, name, obj) + if include_in_ops and not _hasop(op, name): + setattr(op, name, obj) + _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{fullname}") + elif name in _SS_OPERATORS: + obj = cls._module._deprecated[name] + else: + obj = getattr(cls._module, name) + gb_obj = getattr(lib, varname) + # Determine return type + if return_prefix == "BOOL": + return_type = BOOL + if type_ is None: + type_ = BOOL + else: + if type_ is None: # pragma: no cover (safety) + raise TypeError(f"Unable to determine return type for {varname}") + if return_prefix is None: + return_type = type_ + else: + # Grab the number of bits from type_ + num_bits = type_[-2:] + if num_bits not in {"32", "64"}: # pragma: no cover (safety) + raise TypeError(f"Unexpected number of bits: {num_bits}") + return_type = f"{return_prefix}{num_bits}" + builtin_op = cls._typed_class( + obj, + name, + lookup_dtype(type_), + lookup_dtype(return_type), + gb_obj, + gb_name, + ) + obj._add(builtin_op) + + @classmethod + def _deserialize(cls, name, *args): + if (rv := cls._find(name)) is not None: + return rv # Should we verify this is what the user expects? + return cls.register_new(name, *args) + + @classmethod + def _check_supports_udf(cls, method_name): + if not _supports_udfs: + raise RuntimeError( + f"{cls.__name__}.{method_name}(...) unavailable; install numba for UDF support" + ) + + +_builtin_to_op = {} # Populated in .utils + + +def find_opclass(gb_op): + if isinstance(gb_op, OpBase): + opclass = type(gb_op).__name__ + elif isinstance(gb_op, TypedOpBase): + opclass = gb_op.opclass + elif isinstance(gb_op, ParameterizedUdf): + gb_op = gb_op() # Use default parameters of parameterized UDFs + gb_op, opclass = find_opclass(gb_op) + elif isinstance(gb_op, BuiltinFunctionType) and gb_op in _builtin_to_op: + gb_op, opclass = find_opclass(_builtin_to_op[gb_op]) + else: + opclass = UNKNOWN_OPCLASS + return gb_op, opclass diff --git a/graphblas/core/operator/binary.py b/graphblas/core/operator/binary.py new file mode 100644 index 000000000..3ee089fe4 --- /dev/null +++ b/graphblas/core/operator/binary.py @@ -0,0 +1,985 @@ +import inspect +import re +from functools import lru_cache, reduce +from operator import mul +from types import FunctionType + +import numpy as np + +from ... import _STANDARD_OPERATOR_NAMES, backend, binary, monoid, op +from ...dtypes import ( + BOOL, + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + _supports_complex, + lookup_dtype, +) +from ...exceptions import UdfParseError, check_status_carg +from .. import _has_numba, _supports_udfs, ffi, lib +from ..dtypes import _sample_values +from ..expr import InfixExprBase +from .base import ( + _SS_OPERATORS, + OpBase, + ParameterizedUdf, + TypedOpBase, + _call_op, + _deserialize_parameterized, + _hasop, +) + +if _has_numba: + import numba + + from .base import _get_udt_wrapper +if _supports_complex: + from ...dtypes import FC32, FC64 + +ffi_new = ffi.new + +if _has_numba: + _udt_mask_cache = {} + + def _udt_mask(dtype): + """Create mask to determine which bytes of UDTs to use for equality check.""" + if dtype in _udt_mask_cache: + return _udt_mask_cache[dtype] + if dtype.subdtype is not None: + mask = _udt_mask(dtype.subdtype[0]) + N = reduce(mul, dtype.subdtype[1]) + rv = np.concatenate([mask] * N) + elif dtype.names is not None: + prev_offset = mask = None + masks = [] + for name in dtype.names: + dtype2, offset = dtype.fields[name] + if mask is not None: + masks.append(np.pad(mask, (0, offset - prev_offset - mask.size))) + mask = _udt_mask(dtype2) + prev_offset = offset + masks.append(np.pad(mask, (0, dtype.itemsize - prev_offset - mask.size))) + rv = np.concatenate(masks) + else: + rv = np.ones(dtype.itemsize, dtype=bool) + # assert rv.size == dtype.itemsize + _udt_mask_cache[dtype] = rv + return rv + + +class TypedBuiltinBinaryOp(TypedOpBase): + __slots__ = () + opclass = "BinaryOp" + + def __call__(self, left, right=None, *, left_default=None, right_default=None): + if left_default is not None or right_default is not None: + if ( + left_default is None + or right_default is None + or right is not None + or not isinstance(left, InfixExprBase) + or left.method_name != "ewise_add" + ): + raise TypeError( + "Specifying `left_default` or `right_default` keyword arguments implies " + "performing `ewise_union` operation with infix notation.\n" + "There is only one valid way to do this:\n\n" + f">>> {self}(x | y, left_default=0, right_default=0)\n\nwhere x and y " + "are Vectors or Matrices, and left_default and right_default are scalars." + ) + return left.left._ewise_union( + left.right, self, left_default, right_default, is_infix=True + ) + return _call_op(self, left, right) + + @property + def monoid(self): + rv = getattr(monoid, self.name, None) + if rv is not None and self.type in rv._typed_ops: + return rv[self.type] + + @property + def commutes_to(self): + commutes_to = self.parent.commutes_to + if commutes_to is not None and (self.type in commutes_to._typed_ops or self.type._is_udt): + return commutes_to[self.type] + + @property + def _semiring_commutes_to(self): + commutes_to = self.parent._semiring_commutes_to + if commutes_to is not None and (self.type in commutes_to._typed_ops or self.type._is_udt): + return commutes_to[self.type] + + @property + def is_commutative(self): + return self.commutes_to is self + + @property + def type2(self): + return self.type if self._type2 is None else self._type2 + + +class TypedUserBinaryOp(TypedOpBase): + __slots__ = "_monoid" + opclass = "BinaryOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) + self._monoid = None + + @property + def monoid(self): + if self._monoid is None: + monoid = self.parent.monoid + if monoid is not None and self.type in monoid: + self._monoid = monoid[self.type] + return self._monoid + + @property + def orig_func(self): + return self.parent.orig_func + + @property + def _numba_func(self): + return self.parent._numba_func + + commutes_to = TypedBuiltinBinaryOp.commutes_to + _semiring_commutes_to = TypedBuiltinBinaryOp._semiring_commutes_to + is_commutative = TypedBuiltinBinaryOp.is_commutative + type2 = TypedBuiltinBinaryOp.type2 + __call__ = TypedBuiltinBinaryOp.__call__ + + +class ParameterizedBinaryOp(ParameterizedUdf): + __slots__ = "func", "__signature__", "_monoid", "_cached_call", "_commutes_to", "_is_udt" + + def __init__(self, name, func, *, anonymous=False, is_udt=False): + self.func = func + self.__signature__ = inspect.signature(func) + self._monoid = None + self._is_udt = is_udt + if name is None: + name = getattr(func, "__name__", name) + super().__init__(name, anonymous) + method = self._call_to_cache.__get__(self, type(self)) + self._cached_call = lru_cache(maxsize=1024)(method) + self.__call__ = self._call + self._commutes_to = None + + def _call_to_cache(self, *args, **kwargs): + binary = self.func(*args, **kwargs) + binary._parameterized_info = (self, args, kwargs) + return BinaryOp.register_anonymous(binary, self.name, is_udt=self._is_udt) + + def _call(self, *args, **kwargs): + binop = self._cached_call(*args, **kwargs) + if self._monoid is not None and binop._monoid is None: + # This is all a bit funky. We try our best to associate a binaryop + # to a monoid. So, if we made a ParameterizedMonoid using this object, + # then try to create a monoid with the given arguments. + binop._monoid = binop # temporary! + try: + # If this call is successful, then it will set `binop._monoid` + self._monoid(*args, **kwargs) # pylint: disable=not-callable + except Exception: + binop._monoid = None + # assert binop._monoid is not binop + if self.is_commutative: + binop._commutes_to = binop + # Don't bother yet with creating `binop.commutes_to` (but we could!) + return binop + + @property + def monoid(self): + return self._monoid + + @property + def commutes_to(self): + if isinstance(self._commutes_to, str): + self._commutes_to = BinaryOp._find(self._commutes_to) + return self._commutes_to + + is_commutative = TypedBuiltinBinaryOp.is_commutative + + def __reduce__(self): + name = f"binary.{self.name}" + if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self.func, self._anonymous)) + + @staticmethod + def _deserialize(name, func, anonymous): + if anonymous: + return BinaryOp.register_anonymous(func, name, parameterized=True) + if (rv := BinaryOp._find(name)) is not None: + return rv + return BinaryOp.register_new(name, func, parameterized=True) + + +def _floordiv(x, y): + return x // y # pragma: no cover (numba) + + +def _rfloordiv(x, y): + return y // x # pragma: no cover (numba) + + +def _absfirst(x, y): + return np.abs(x) # pragma: no cover (numba) + + +def _abssecond(x, y): + return np.abs(y) # pragma: no cover (numba) + + +def _rpow(x, y): + return y**x # pragma: no cover (numba) + + +def _isclose(rel_tol=1e-7, abs_tol=0.0): + def inner(x, y): # pragma: no cover (numba) + return x == y or abs(x - y) <= max(rel_tol * max(abs(x), abs(y)), abs_tol) + + return inner + + +_MAX_INT64 = np.iinfo(np.int64).max + + +def _binom(N, k): # pragma: no cover (numba) + # Returns 0 if overflow or out-of-bounds + if k > N or k < 0: + return 0 + val = np.int64(1) + for i in range(min(k, N - k)): + if val > _MAX_INT64 // (N - i): # Overflow + return 0 + val *= N - i + val //= i + 1 + return val + + +# Kinda complicated, but works for now +def _register_binom(): + # "Fake" UDT so we only compile once for INT64 + op = BinaryOp.register_new("binom", _binom, is_udt=True) + typed_op = op[INT64, INT64] + # Make this look like a normal operator + for dtype in [UINT8, UINT16, UINT32, UINT64, INT8, INT16, INT32, INT64]: + op.types[dtype] = INT64 + op._typed_ops[dtype] = typed_op + if dtype != INT64: + op.coercions[dtype] = typed_op + # And make it not look like it operates on UDTs + typed_op._type2 = None + op._is_udt = False + op._udt_types = None + op._udt_ops = None + return op + + +def _first(x, y): + return x # pragma: no cover (numba) + + +def _second(x, y): + return y # pragma: no cover (numba) + + +def _pair(x, y): + return 1 # pragma: no cover (numba) + + +def _first_dtype(op, dtype, dtype2): + if dtype._is_udt or dtype2._is_udt: + return op._compile_udt(dtype, dtype2) + + +def _second_dtype(op, dtype, dtype2): + if dtype._is_udt or dtype2._is_udt: + return op._compile_udt(dtype, dtype2) + + +def _pair_dtype(op, dtype, dtype2): + return op[INT64] + + +class BinaryOp(OpBase): + """Takes two inputs and returns one output, possibly of a different data type. + + Built-in and registered BinaryOps are located in the ``graphblas.binary`` namespace + as well as in the ``graphblas.ops`` combined namespace. + """ + + __slots__ = ( + "_monoid", + "_commutes_to", + "_semiring_commutes_to", + "orig_func", + "is_positional", + "_is_udt", + "_numba_func", + "_custom_dtype", + ) + _module = binary + _modname = "binary" + _typed_class = TypedBuiltinBinaryOp + _parse_config = { + "trim_from_front": 4, + "num_underscores": 1, + "re_exprs": [ + re.compile( + "^GrB_(FIRST|SECOND|PLUS|MINUS|TIMES|DIV|MIN|MAX)" + "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64|FC32|FC64)$" + ), + re.compile( + "GrB_(BOR|BAND|BXOR|BXNOR)_(INT8|INT16|INT32|INT64|UINT8|UINT16|UINT32|UINT64)$" + ), + re.compile( + "^GxB_(POW|RMINUS|RDIV|PAIR|ANY|ISEQ|ISNE|ISGT|ISLT|ISGE|ISLE|LOR|LAND|LXOR)" + "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64|FC32|FC64)$" + ), + re.compile("^GxB_(FIRST|SECOND|PLUS|MINUS|TIMES|DIV)_(FC32|FC64)$"), + re.compile("^GxB_(ATAN2|HYPOT|FMOD|REMAINDER|LDEXP|COPYSIGN)_(FP32|FP64)$"), + re.compile( + "GxB_(BGET|BSET|BCLR|BSHIFT|FIRSTI1|FIRSTI|FIRSTJ1|FIRSTJ" + "|SECONDI1|SECONDI|SECONDJ1|SECONDJ)" + "_(INT8|INT16|INT32|INT64|UINT8|UINT16|UINT32|UINT64)$" + ), + # These are coerced to 0 or 1, but don't return BOOL + re.compile( + "^GxB_(LOR|LAND|LXOR|LXNOR)_" + "(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + ], + "re_exprs_return_bool": [ + re.compile("^GrB_(LOR|LAND|LXOR|LXNOR)$"), + re.compile( + "^GrB_(EQ|NE|GT|LT|GE|LE)_" + "(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile("^GxB_(EQ|NE)_(FC32|FC64)$"), + ], + "re_exprs_return_complex": [re.compile("^GxB_(CMPLX)_(FP32|FP64)$")], + } + _commutes = { + # builtins + "cdiv": "rdiv", + "first": "second", + "ge": "le", + "gt": "lt", + "isge": "isle", + "isgt": "islt", + "minus": "rminus", + "pow": "rpow", + # special + "firsti": "secondi", + "firsti1": "secondi1", + "firstj": "secondj", + "firstj1": "secondj1", + # custom + # "absfirst": "abssecond", # handled in graphblas.binary + # "floordiv": "rfloordiv", + "truediv": "rtruediv", + } + _commutes_to_in_semiring = { + "firsti": "secondj", + "firsti1": "secondj1", + "firstj": "secondi", + "firstj1": "secondi1", + } + _commutative = { + # monoids + "any", + "band", + "bor", + "bxnor", + "bxor", + "eq", + "land", + "lor", + "lxnor", + "lxor", + "max", + "min", + "plus", + "times", + # other + "hypot", + "isclose", + "iseq", + "isne", + "ne", + "pair", + } + # Don't commute: atan2, bclr, bget, bset, bshift, cmplx, copysign, fmod, ldexp, remainder + _positional = { + "firsti", + "firsti1", + "firstj", + "firstj1", + "secondi", + "secondi1", + "secondj", + "secondj1", + } + + @classmethod + def _build(cls, name, func, *, is_udt=False, anonymous=False): + if not isinstance(func, FunctionType): + raise TypeError(f"UDF argument must be a function, not {type(func)}") + if name is None: + name = getattr(func, "__name__", "") + success = False + binary_udf = numba.njit(func) + new_type_obj = cls(name, func, anonymous=anonymous, is_udt=is_udt, numba_func=binary_udf) + return_types = {} + nt = numba.types + if not is_udt: + for type_ in _sample_values: + sig = (type_.numba_type, type_.numba_type) + try: + binary_udf.compile(sig) + except numba.TypingError: + continue + ret_type = lookup_dtype(binary_udf.overloads[sig].signature.return_type) + if ret_type != type_ and ( + ("INT" in ret_type.name and "INT" in type_.name) + or ("FP" in ret_type.name and "FP" in type_.name) + or ("FC" in ret_type.name and "FC" in type_.name) + or (type_ == UINT64 and ret_type == FP64 and return_types.get(INT64) == INT64) + ): + # Downcast `ret_type` to `type_`. + # This is what users want most of the time, but we can't make a perfect rule. + # There should be a way for users to be explicit. + ret_type = type_ + elif type_ == BOOL and ret_type == INT64 and return_types.get(INT8) == INT8: + ret_type = INT8 + + # Numba is unable to handle BOOL correctly right now, but we have a workaround + # See: https://github.com/numba/numba/issues/5395 + # We're relying on coercion behaving correctly here + input_type = INT8 if type_ == BOOL else type_ + return_type = INT8 if ret_type == BOOL else ret_type + + # Build wrapper because GraphBLAS wants pointers and void return + wrapper_sig = nt.void( + nt.CPointer(return_type.numba_type), + nt.CPointer(input_type.numba_type), + nt.CPointer(input_type.numba_type), + ) + + if type_ == BOOL: + if ret_type == BOOL: + + def binary_wrapper(z, x, y): # pragma: no cover (numba) + z[0] = bool(binary_udf(bool(x[0]), bool(y[0]))) + + else: + + def binary_wrapper(z, x, y): # pragma: no cover (numba) + z[0] = binary_udf(bool(x[0]), bool(y[0])) + + elif ret_type == BOOL: + + def binary_wrapper(z, x, y): # pragma: no cover (numba) + z[0] = bool(binary_udf(x[0], y[0])) + + else: + + def binary_wrapper(z, x, y): # pragma: no cover (numba) + z[0] = binary_udf(x[0], y[0]) + + binary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(binary_wrapper) + new_binary = ffi_new("GrB_BinaryOp*") + check_status_carg( + lib.GrB_BinaryOp_new( + new_binary, + binary_wrapper.cffi, + ret_type.gb_obj, + type_.gb_obj, + type_.gb_obj, + ), + "BinaryOp", + new_binary[0], + ) + op = TypedUserBinaryOp(new_type_obj, name, type_, ret_type, new_binary[0]) + new_type_obj._add(op) + success = True + return_types[type_] = ret_type + if success or is_udt: + return new_type_obj + raise UdfParseError("Unable to parse function using Numba") + + def _compile_udt(self, dtype, dtype2): + if dtype2 is None: + dtype2 = dtype + dtypes = (dtype, dtype2) + if dtypes in self._udt_types: + return self._udt_ops[dtypes] + + if self.name == "eq" and not self._anonymous and _has_numba: + nt = numba.types + # assert dtype.np_type == dtype2.np_type + itemsize = dtype.np_type.itemsize + mask = _udt_mask(dtype.np_type) + ret_type = BOOL + wrapper_sig = nt.void( + nt.CPointer(INT8.numba_type), + nt.CPointer(UINT8.numba_type), + nt.CPointer(UINT8.numba_type), + ) + # PERF: we can probably make this faster + if mask.all(): + + def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) + x = numba.carray(x_ptr, itemsize) + y = numba.carray(y_ptr, itemsize) + # for i in range(itemsize): + # if x[i] != y[i]: + # z_ptr[0] = False + # break + # else: + # z_ptr[0] = True + z_ptr[0] = (x == y).all() + + else: + + def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) + x = numba.carray(x_ptr, itemsize) + y = numba.carray(y_ptr, itemsize) + # for i in range(itemsize): + # if mask[i] and x[i] != y[i]: + # z_ptr[0] = False + # break + # else: + # z_ptr[0] = True + z_ptr[0] = (x[mask] == y[mask]).all() + + elif self.name == "ne" and not self._anonymous and _has_numba: + nt = numba.types + # assert dtype.np_type == dtype2.np_type + itemsize = dtype.np_type.itemsize + mask = _udt_mask(dtype.np_type) + ret_type = BOOL + wrapper_sig = nt.void( + nt.CPointer(INT8.numba_type), + nt.CPointer(UINT8.numba_type), + nt.CPointer(UINT8.numba_type), + ) + if mask.all(): + + def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) + x = numba.carray(x_ptr, itemsize) + y = numba.carray(y_ptr, itemsize) + # for i in range(itemsize): + # if x[i] != y[i]: + # z_ptr[0] = True + # break + # else: + # z_ptr[0] = False + z_ptr[0] = (x != y).any() + + else: + + def binary_wrapper(z_ptr, x_ptr, y_ptr): # pragma: no cover (numba) + x = numba.carray(x_ptr, itemsize) + y = numba.carray(y_ptr, itemsize) + # for i in range(itemsize): + # if mask[i] and x[i] != y[i]: + # z_ptr[0] = True + # break + # else: + # z_ptr[0] = False + z_ptr[0] = (x[mask] != y[mask]).any() + + elif self._numba_func is None: + raise KeyError(f"{self.name} does not work with {dtypes} types") + else: + numba_func = self._numba_func + sig = (dtype.numba_type, dtype2.numba_type) + numba_func.compile(sig) # Should we catch and give additional error message? + ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type) + binary_wrapper, wrapper_sig = _get_udt_wrapper(numba_func, ret_type, dtype, dtype2) + + binary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(binary_wrapper) + new_binary = ffi_new("GrB_BinaryOp*") + check_status_carg( + lib.GrB_BinaryOp_new( + new_binary, binary_wrapper.cffi, ret_type._carg, dtype._carg, dtype2._carg + ), + "BinaryOp", + new_binary[0], + ) + op = TypedUserBinaryOp( + self, + self.name, + dtype, + ret_type, + new_binary[0], + dtype2=dtype2, + ) + self._udt_types[dtypes] = ret_type + self._udt_ops[dtypes] = op + return op + + @classmethod + def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): + """Register a BinaryOp without registering it in the ``graphblas.binary`` namespace. + + Because it is not registered in the namespace, the name is optional. + + Parameters + ---------- + func : FunctionType + The function to compile. For all current backends, this must be able + to be compiled with ``numba.njit``. + ``func`` takes two input parameters of any dtype and returns any dtype. + name : str, optional + The name of the operator. This *does not* show up as ``gb.binary.{name}``. + parameterized : bool, default False + When True, create a parameterized user-defined operator, which means + additional parameters can be "baked into" the operator when used. + For example, ``gb.binary.isclose`` is a parameterized function that + optionally accepts ``rel_tol`` and ``abs_tol`` parameters, and it + can be used as: ``A.ewise_mult(B, gb.binary.isclose(rel_tol=1e-5))``. + When creating a parameterized user-defined operator, the ``func`` + parameter must be a callable that *returns* a function that will + then get compiled. + is_udt : bool, default False + Whether the operator is intended to operate on user-defined types. + If True, then the function will not be automatically compiled for + builtin types, and it will be compiled "just in time" when used. + Setting ``is_udt=True`` is also helpful when the left and right + dtypes need to be different. + + Returns + ------- + BinaryOp or ParameterizedBinaryOp + + """ + cls._check_supports_udf("register_anonymous") + if parameterized: + return ParameterizedBinaryOp(name, func, anonymous=True, is_udt=is_udt) + return cls._build(name, func, anonymous=True, is_udt=is_udt) + + @classmethod + def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): + """Register a new BinaryOp and save it to ``graphblas.binary`` namespace. + + Parameters + ---------- + name : str + The name of the operator. This will show up as ``gb.binary.{name}``. + The name may contain periods, ".", which will result in nested objects + such as ``gb.binary.x.y.z`` for name ``"x.y.z"``. + func : FunctionType + The function to compile. For all current backends, this must be able + to be compiled with ``numba.njit``. + ``func`` takes two input parameters of any dtype and returns any dtype. + parameterized : bool, default False + When True, create a parameterized user-defined operator, which means + additional parameters can be "baked into" the operator when used. + For example, ``gb.binary.isclose`` is a parameterized function that + optionally accepts ``rel_tol`` and ``abs_tol`` parameters, and it + can be used as: ``A.ewise_mult(B, gb.binary.isclose(rel_tol=1e-5))``. + When creating a parameterized user-defined operator, the ``func`` + parameter must be a callable that *returns* a function that will + then get compiled. See the ``user_isclose`` example below. + is_udt : bool, default False + Whether the operator is intended to operate on user-defined types. + If True, then the function will not be automatically compiled for + builtin types, and it will be compiled "just in time" when used. + Setting ``is_udt=True`` is also helpful when the left and right + dtypes need to be different. + lazy : bool, default False + If False (the default), then the function will be automatically + compiled for builtin types (unless ``is_udt`` is True). + Compiling functions can be slow, however, so you may want to + delay compilation and only compile when the operator is used, + which is done by setting ``lazy=True``. + + Examples + -------- + >>> def max_zero(x, y): + r = 0 + if x > r: + r = x + if y > r: + r = y + return r + >>> gb.core.operator.BinaryOp.register_new("max_zero", max_zero) + >>> dir(gb.binary) + [..., 'max_zero', ...] + + This is how ``gb.binary.isclose`` is defined: + + >>> def user_isclose(rel_tol=1e-7, abs_tol=0.0): + >>> def inner(x, y): + >>> return x == y or abs(x - y) <= max(rel_tol * max(abs(x), abs(y)), abs_tol) + >>> return inner + >>> gb.binary.register_new("user_isclose", user_isclose, parameterized=True) + + """ + cls._check_supports_udf("register_new") + module, funcname = cls._remove_nesting(name) + if lazy: + module._delayed[funcname] = ( + cls.register_new, + {"name": name, "func": func, "parameterized": parameterized}, + ) + elif parameterized: + binary_op = ParameterizedBinaryOp(name, func, is_udt=is_udt) + setattr(module, funcname, binary_op) + else: + binary_op = cls._build(name, func, is_udt=is_udt) + setattr(module, funcname, binary_op) + # Also save it to `graphblas.op` if not yet defined + opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) + if not _hasop(opmodule, funcname): + if lazy: + opmodule._delayed[funcname] = module + else: + setattr(opmodule, funcname, binary_op) + if not cls._initialized: + _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") + if not lazy: + return binary_op + + @classmethod + def _initialize(cls): + if cls._initialized: # pragma: no cover (safety) + return + super()._initialize() + # Rename div to cdiv + cdiv = binary.cdiv = op.cdiv = BinaryOp("cdiv") + for dtype, ret_type in binary.div.types.items(): + orig_op = binary.div[dtype] + cur_op = TypedBuiltinBinaryOp( + cdiv, "cdiv", dtype, ret_type, orig_op.gb_obj, orig_op.gb_name + ) + cdiv._add(cur_op) + del binary.div + del op.div + # Add truediv which always points to floating point cdiv + # We are effectively hacking cdiv to always return floating point values + # If the inputs are FP32, we use DIV_FP32; use DIV_FP64 for all other input dtypes + truediv = binary.truediv = op.truediv = BinaryOp("truediv") + rtruediv = binary.rtruediv = op.rtruediv = BinaryOp("rtruediv") + for new_op, builtin_op in [(truediv, binary.cdiv), (rtruediv, binary.rdiv)]: + for dtype in builtin_op.types: + if dtype.name in {"FP32", "FC32", "FC64"}: + orig_dtype = dtype + else: + orig_dtype = FP64 + orig_op = builtin_op[orig_dtype] + cur_op = TypedBuiltinBinaryOp( + new_op, + new_op.name, + dtype, + builtin_op.types[orig_dtype], + orig_op.gb_obj, + orig_op.gb_name, + ) + new_op._add(cur_op) + if _supports_udfs: + # Add floordiv + # cdiv truncates towards 0, while floordiv truncates towards -inf + BinaryOp.register_new("floordiv", _floordiv, lazy=True) # cast to integer + BinaryOp.register_new("rfloordiv", _rfloordiv, lazy=True) # cast to integer + + # For aggregators + BinaryOp.register_new("absfirst", _absfirst, lazy=True) + BinaryOp.register_new("abssecond", _abssecond, lazy=True) + BinaryOp.register_new("rpow", _rpow, lazy=True) + + # For algorithms + binary._delayed["binom"] = (_register_binom, {}) # Lazy with custom creation + op._delayed["binom"] = binary + + BinaryOp.register_new("isclose", _isclose, parameterized=True) + + # Update type information with sane coercion + position_dtypes = [ + BOOL, + FP32, + FP64, + INT8, + INT16, + UINT8, + UINT16, + UINT32, + UINT64, + ] + if _supports_complex: + position_dtypes.extend([FC32, FC64]) + name_types = [ + # fmt: off + ( + ("atan2", "copysign", "fmod", "hypot", "ldexp", "remainder"), + ((BOOL, INT8, INT16, UINT8, UINT16), FP32), + ((INT32, INT64, UINT32, UINT64), FP64), + ), + ( + ( + "firsti", "firsti1", "firstj", "firstj1", "secondi", "secondi1", + "secondj", "secondj1"), + ( + position_dtypes, + INT64, + ), + ), + ( + ["lxnor"], + ( + ( + FP32, FP64, INT8, INT16, INT32, INT64, + UINT8, UINT16, UINT32, UINT64, + ), + BOOL, + ), + ), + # fmt: on + ] + if _supports_complex: + name_types.append( + ( + ["cmplx"], + ((BOOL, INT8, INT16, UINT8, UINT16), FP32), + ((INT32, INT64, UINT32, UINT64), FP64), + ) + ) + for names, *types in name_types: + for name in names: + if name in _SS_OPERATORS: + cur_op = binary._deprecated[name] + else: + cur_op = getattr(binary, name) + for input_types, target_type in types: + typed_op = cur_op._typed_ops[target_type] + output_type = cur_op.types[target_type] + for dtype in input_types: + if dtype not in cur_op.types: # pragma: no branch (safety) + cur_op.types[dtype] = output_type + cur_op._typed_ops[dtype] = typed_op + cur_op.coercions[dtype] = target_type + # Not valid input dtypes + del binary.ldexp[FP32] + del binary.ldexp[FP64] + # Fill in commutes info + for left_name, right_name in cls._commutes.items(): + if left_name in _SS_OPERATORS: + left = binary._deprecated[left_name] + else: + left = getattr(binary, left_name) + if backend == "suitesparse" and right_name in _SS_OPERATORS: + left._commutes_to = f"ss.{right_name}" + else: + left._commutes_to = right_name + if right_name not in binary._delayed: + if right_name in _SS_OPERATORS: + right = binary._deprecated[right_name] + elif _supports_udfs: + right = getattr(binary, right_name) + else: + right = getattr(binary, right_name, None) + if right is None: + continue + if backend == "suitesparse" and left_name in _SS_OPERATORS: + right._commutes_to = f"ss.{left_name}" + else: + right._commutes_to = left_name + for name in cls._commutative: + if _supports_udfs: + cur_op = getattr(binary, name) + else: + cur_op = getattr(binary, name, None) + if cur_op is None: + continue + cur_op._commutes_to = name + for left_name, right_name in cls._commutes_to_in_semiring.items(): + if left_name in _SS_OPERATORS: + left = binary._deprecated[left_name] + else: # pragma: no cover (safety) + left = getattr(binary, left_name) + if right_name in _SS_OPERATORS: + right = binary._deprecated[right_name] + else: # pragma: no cover (safety) + right = getattr(binary, right_name) + left._semiring_commutes_to = right + right._semiring_commutes_to = left + # Allow some functions to work on UDTs + for binop, func in [ + (binary.first, _first), + (binary.second, _second), + (binary.pair, _pair), + (binary.any, _first), + ]: + binop.orig_func = func + if _has_numba: + binop._numba_func = numba.njit(func) + else: + binop._numba_func = None + binop._udt_types = {} + binop._udt_ops = {} + binary.any._numba_func = binary.first._numba_func + binary.eq._udt_types = {} + binary.eq._udt_ops = {} + binary.ne._udt_types = {} + binary.ne._udt_ops = {} + # Set custom dtype handling + binary.first._custom_dtype = _first_dtype + binary.second._custom_dtype = _second_dtype + binary.pair._custom_dtype = _pair_dtype + cls._initialized = True + + def __init__( + self, + name, + func=None, + *, + anonymous=False, + is_positional=False, + is_udt=False, + numba_func=None, + ): + super().__init__(name, anonymous=anonymous) + self._monoid = None + self._commutes_to = None + self._semiring_commutes_to = None + self.orig_func = func + self._numba_func = numba_func + self._is_udt = is_udt + self.is_positional = is_positional + self._custom_dtype = None + if is_udt: + self._udt_types = {} # {(dtype, dtype): DataType} + self._udt_ops = {} # {(dtype, dtype): TypedUserBinaryOp} + + def __reduce__(self): + if self._anonymous: + if hasattr(self.orig_func, "_parameterized_info"): + return (_deserialize_parameterized, self.orig_func._parameterized_info) + return (self.register_anonymous, (self.orig_func, self.name)) + if (name := f"binary.{self.name}") in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self.orig_func)) + + __call__ = TypedBuiltinBinaryOp.__call__ + is_commutative = TypedBuiltinBinaryOp.is_commutative + commutes_to = ParameterizedBinaryOp.commutes_to + + @property + def monoid(self): + if self._monoid is None and not self._anonymous: + from .monoid import Monoid + + self._monoid = Monoid._find(self.name) + return self._monoid diff --git a/graphblas/core/operator/indexunary.py b/graphblas/core/operator/indexunary.py new file mode 100644 index 000000000..6fdacbcc1 --- /dev/null +++ b/graphblas/core/operator/indexunary.py @@ -0,0 +1,442 @@ +import inspect +import re +from types import FunctionType + +from ... import _STANDARD_OPERATOR_NAMES, indexunary, select +from ...dtypes import BOOL, FP64, INT8, INT64, UINT64, lookup_dtype +from ...exceptions import UdfParseError, check_status_carg +from .. import _has_numba, ffi, lib +from ..dtypes import _sample_values +from .base import OpBase, ParameterizedUdf, TypedOpBase, _call_op, _deserialize_parameterized + +if _has_numba: + import numba + + from .base import _get_udt_wrapper +ffi_new = ffi.new + + +class TypedBuiltinIndexUnaryOp(TypedOpBase): + __slots__ = () + opclass = "IndexUnaryOp" + + def __call__(self, val, thunk=None): + if thunk is None: + thunk = False # most basic form of 0 when unifying dtypes + return _call_op(self, val, right=thunk) + + @property + def thunk_type(self): + return self.type if self._type2 is None else self._type2 + + +class TypedUserIndexUnaryOp(TypedOpBase): + __slots__ = () + opclass = "IndexUnaryOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) + + @property + def orig_func(self): + return self.parent.orig_func + + @property + def _numba_func(self): + return self.parent._numba_func + + thunk_type = TypedBuiltinIndexUnaryOp.thunk_type + __call__ = TypedBuiltinIndexUnaryOp.__call__ + + +class ParameterizedIndexUnaryOp(ParameterizedUdf): + __slots__ = "func", "__signature__", "_is_udt" + + def __init__(self, name, func, *, anonymous=False, is_udt=False): + self.func = func + self.__signature__ = inspect.signature(func) + self._is_udt = is_udt + if name is None: + name = getattr(func, "__name__", name) + super().__init__(name, anonymous) + + def _call(self, *args, **kwargs): + indexunary = self.func(*args, **kwargs) + indexunary._parameterized_info = (self, args, kwargs) + return IndexUnaryOp.register_anonymous(indexunary, self.name, is_udt=self._is_udt) + + def __reduce__(self): + # NOT COVERED + name = f"indexunary.{self.name}" + if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self.func, self._anonymous)) + + @staticmethod + def _deserialize(name, func, anonymous): + # NOT COVERED + if anonymous: + return IndexUnaryOp.register_anonymous(func, name, parameterized=True) + if (rv := IndexUnaryOp._find(name)) is not None: + return rv + return IndexUnaryOp.register_new(name, func, parameterized=True) + + +class IndexUnaryOp(OpBase): + """Takes one input and a thunk and returns one output, possibly of a different data type. + Along with the input value, the index(es) of the element are given to the function. + + This is an advanced form of a unary operation that allows, for example, converting + elements of a Vector to their index position to build a ramp structure. Another use + case is returning a boolean value indicating whether the element is part of the upper + triangular structure of a Matrix. + + Built-in and registered IndexUnaryOps are located in the ``graphblas.indexunary`` namespace. + """ + + __slots__ = "orig_func", "is_positional", "_is_udt", "_numba_func" + _module = indexunary + _modname = "indexunary" + _custom_dtype = None + _typed_class = TypedBuiltinIndexUnaryOp + _typed_user_class = TypedUserIndexUnaryOp + _parse_config = { + "trim_from_front": 4, + "num_underscores": 1, + "re_exprs": [ + re.compile("^GrB_(ROWINDEX|COLINDEX|DIAGINDEX)_(INT32|INT64)$"), + ], + "re_exprs_return_bool": [ + re.compile("^GrB_(TRIL|TRIU|DIAG|OFFDIAG|COLLE|COLGT|ROWLE|ROWGT)$"), + re.compile( + "^GrB_(VALUEEQ|VALUENE|VALUEGT|VALUEGE|VALUELT|VALUELE)" + "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile("^GxB_(VALUEEQ|VALUENE)_(FC32|FC64)$"), + ], + } + _positional = {"tril", "triu", "diag", "offdiag", "colle", "colgt", "rowle", "rowgt", + "rowindex", "colindex"} # fmt: skip + + @classmethod + def _build(cls, name, func, *, is_udt=False, anonymous=False): + if not isinstance(func, FunctionType): + raise TypeError(f"UDF argument must be a function, not {type(func)}") + if name is None: + name = getattr(func, "__name__", "") + success = False + indexunary_udf = numba.njit(func) + new_type_obj = cls( + name, func, anonymous=anonymous, is_udt=is_udt, numba_func=indexunary_udf + ) + return_types = {} + nt = numba.types + if not is_udt: + for type_ in _sample_values: + sig = (type_.numba_type, UINT64.numba_type, UINT64.numba_type, type_.numba_type) + try: + indexunary_udf.compile(sig) + except numba.TypingError: + continue + ret_type = lookup_dtype(indexunary_udf.overloads[sig].signature.return_type) + if ret_type != type_ and ( + ("INT" in ret_type.name and "INT" in type_.name) + or ("FP" in ret_type.name and "FP" in type_.name) + or ("FC" in ret_type.name and "FC" in type_.name) + or (type_ == UINT64 and ret_type == FP64 and return_types.get(INT64) == INT64) + ): + # Downcast `ret_type` to `type_`. + # This is what users want most of the time, but we can't make a perfect rule. + # There should be a way for users to be explicit. + ret_type = type_ + elif type_ == BOOL and ret_type == INT64 and return_types.get(INT8) == INT8: + ret_type = INT8 + + # Numba is unable to handle BOOL correctly right now, but we have a workaround + # See: https://github.com/numba/numba/issues/5395 + # We're relying on coercion behaving correctly here + input_type = INT8 if type_ == BOOL else type_ + return_type = INT8 if ret_type == BOOL else ret_type + + # Build wrapper because GraphBLAS wants pointers and void return + wrapper_sig = nt.void( + nt.CPointer(return_type.numba_type), + nt.CPointer(input_type.numba_type), + UINT64.numba_type, + UINT64.numba_type, + nt.CPointer(input_type.numba_type), + ) + + if type_ == BOOL: + if ret_type == BOOL: + + def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) + z[0] = bool(indexunary_udf(bool(x[0]), row, col, bool(y[0]))) + + else: + + def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) + z[0] = indexunary_udf(bool(x[0]), row, col, bool(y[0])) + + elif ret_type == BOOL: + + def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) + z[0] = bool(indexunary_udf(x[0], row, col, y[0])) + + else: + + def indexunary_wrapper(z, x, row, col, y): # pragma: no cover (numba) + z[0] = indexunary_udf(x[0], row, col, y[0]) + + indexunary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(indexunary_wrapper) + new_indexunary = ffi_new("GrB_IndexUnaryOp*") + check_status_carg( + lib.GrB_IndexUnaryOp_new( + new_indexunary, + indexunary_wrapper.cffi, + ret_type.gb_obj, + type_.gb_obj, + type_.gb_obj, + ), + "IndexUnaryOp", + new_indexunary[0], + ) + op = cls._typed_user_class(new_type_obj, name, type_, ret_type, new_indexunary[0]) + new_type_obj._add(op) + success = True + return_types[type_] = ret_type + if success or is_udt: + return new_type_obj + raise UdfParseError("Unable to parse function using Numba") + + def _compile_udt(self, dtype, dtype2): + if dtype2 is None: # pragma: no cover + dtype2 = dtype + dtypes = (dtype, dtype2) + if dtypes in self._udt_types: + return self._udt_ops[dtypes] + if self._numba_func is None: + raise KeyError(f"{self.name} does not work with {dtypes} types") + + numba_func = self._numba_func + sig = (dtype.numba_type, UINT64.numba_type, UINT64.numba_type, dtype2.numba_type) + numba_func.compile(sig) # Should we catch and give additional error message? + ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type) + indexunary_wrapper, wrapper_sig = _get_udt_wrapper( + numba_func, ret_type, dtype, dtype2, include_indexes=True + ) + + indexunary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(indexunary_wrapper) + new_indexunary = ffi_new("GrB_IndexUnaryOp*") + check_status_carg( + lib.GrB_IndexUnaryOp_new( + new_indexunary, indexunary_wrapper.cffi, ret_type._carg, dtype._carg, dtype2._carg + ), + "IndexUnaryOp", + new_indexunary[0], + ) + op = TypedUserIndexUnaryOp( + self, + self.name, + dtype, + ret_type, + new_indexunary[0], + dtype2=dtype2, + ) + self._udt_types[dtypes] = ret_type + self._udt_ops[dtypes] = op + return op + + @classmethod + def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): + """Register a IndexUnary without registering it in the ``graphblas.indexunary`` namespace. + + Because it is not registered in the namespace, the name is optional. + + Parameters + ---------- + func : FunctionType + The function to compile. For all current backends, this must be able + to be compiled with ``numba.njit``. + ``func`` takes four input parameters--any dtype, int64, int64, + any dtype and returns any dtype. The first argument (any dtype) is + the value of the input Matrix or Vector, the second argument (int64) + is the row index of the Matrix or the index of the Vector, the third + argument (int64) is the column index of the Matrix or 0 for a Vector, + and the fourth argument (any dtype) is the value of the input Scalar. + name : str, optional + The name of the operator. This *does not* show up as ``gb.indexunary.{name}``. + parameterized : bool, default False + When True, create a parameterized user-defined operator, which means + additional parameters can be "baked into" the operator when used. + For example, ``gb.binary.isclose`` is a parameterized BinaryOp that + optionally accepts ``rel_tol`` and ``abs_tol`` parameters, and it + can be used as: ``A.ewise_mult(B, gb.binary.isclose(rel_tol=1e-5))``. + When creating a parameterized user-defined operator, the ``func`` + parameter must be a callable that *returns* a function that will + then get compiled. + is_udt : bool, default False + Whether the operator is intended to operate on user-defined types. + If True, then the function will not be automatically compiled for + builtin types, and it will be compiled "just in time" when used. + Setting ``is_udt=True`` is also helpful when the left and right + dtypes need to be different. + + Returns + ------- + return IndexUnaryOp or ParameterizedIndexUnaryOp + + """ + cls._check_supports_udf("register_anonymous") + if parameterized: + return ParameterizedIndexUnaryOp(name, func, anonymous=True, is_udt=is_udt) + return cls._build(name, func, anonymous=True, is_udt=is_udt) + + @classmethod + def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): + """Register a new IndexUnaryOp and save it to ``graphblas.indexunary`` namespace. + + If the return type is Boolean, the function will also be registered as a SelectOp + (and saved to ``grablas.select`` namespace) with the same name. + + Parameters + ---------- + name : str + The name of the operator. This will show up as ``gb.indexunary.{name}``. + The name may contain periods, ".", which will result in nested objects + such as ``gb.indexunary.x.y.z`` for name ``"x.y.z"``. + func : FunctionType + The function to compile. For all current backends, this must be able + to be compiled with ``numba.njit``. + ``func`` takes four input parameters--any dtype, int64, int64, + any dtype and returns any dtype. The first argument (any dtype) is + the value of the input Matrix or Vector, the second argument (int64) + is the row index of the Matrix or the index of the Vector, the third + argument (int64) is the column index of the Matrix or 0 for a Vector, + and the fourth argument (any dtype) is the value of the input Scalar. + parameterized : bool, default False + When True, create a parameterized user-defined operator, which means + additional parameters can be "baked into" the operator when used. + For example, ``gb.binary.isclose`` is a parameterized BinaryOp that + optionally accepts ``rel_tol`` and ``abs_tol`` parameters, and it + can be used as: ``A.ewise_mult(B, gb.binary.isclose(rel_tol=1e-5))``. + When creating a parameterized user-defined operator, the ``func`` + parameter must be a callable that *returns* a function that will + then get compiled. + is_udt : bool, default False + Whether the operator is intended to operate on user-defined types. + If True, then the function will not be automatically compiled for + builtin types, and it will be compiled "just in time" when used. + Setting ``is_udt=True`` is also helpful when the left and right + dtypes need to be different. + lazy : bool, default False + If False (the default), then the function will be automatically + compiled for builtin types (unless ``is_udt`` is True). + Compiling functions can be slow, however, so you may want to + delay compilation and only compile when the operator is used, + which is done by setting ``lazy=True``. + + Examples + -------- + >>> gb.indexunary.register_new("row_mod", lambda x, i, j, thunk: i % max(thunk, 2)) + >>> dir(gb.indexunary) + [..., 'row_mod', ...] + + """ + cls._check_supports_udf("register_new") + module, funcname = cls._remove_nesting(name) + if lazy: + module._delayed[funcname] = ( + cls.register_new, + {"name": name, "func": func, "parameterized": parameterized}, + ) + elif parameterized: + indexunary_op = ParameterizedIndexUnaryOp(name, func, is_udt=is_udt) + setattr(module, funcname, indexunary_op) + else: + indexunary_op = cls._build(name, func, is_udt=is_udt) + setattr(module, funcname, indexunary_op) + # If return type is BOOL, register additionally as a SelectOp + if all(x == BOOL for x in indexunary_op.types.values()): + from .select import SelectOp + + select_module, funcname = SelectOp._remove_nesting(name, strict=False) + setattr(select_module, funcname, SelectOp._from_indexunary(indexunary_op)) + if not cls._initialized: # pragma: no cover (safety) + _STANDARD_OPERATOR_NAMES.add(f"{SelectOp._modname}.{name}") + + if not cls._initialized: # pragma: no cover (safety) + _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") + if not lazy: + return indexunary_op + + @classmethod + def _initialize(cls): + if cls._initialized: + return + from .select import SelectOp + + super()._initialize(include_in_ops=False) + # Update type information to include UINT64 for positional ops + for name in ["tril", "triu", "diag", "offdiag", "colle", "colgt", "rowle", "rowgt"]: + op = getattr(indexunary, name) + typed_op = op._typed_ops[BOOL] + output_type = op.types[BOOL] + if UINT64 not in op.types: # pragma: no branch (safety) + op.types[UINT64] = output_type + op._typed_ops[UINT64] = typed_op + op.coercions[UINT64] = BOOL + for name in ["rowindex", "colindex"]: + op = getattr(indexunary, name) + typed_op = op._typed_ops[INT64] + output_type = op.types[INT64] + if UINT64 not in op.types: # pragma: no branch (safety) + op.types[UINT64] = output_type + op._typed_ops[UINT64] = typed_op + op.coercions[UINT64] = INT64 + # Add index->row alias to make it more intuitive which to use for vectors + indexunary.indexle = indexunary.rowle + indexunary.indexgt = indexunary.rowgt + indexunary.index = indexunary.rowindex + # fmt: off + # Add SelectOp when it makes sense + for name in ["tril", "triu", "diag", "offdiag", + "colle", "colgt", "rowle", "rowgt", "indexle", "indexgt", + "valueeq", "valuene", "valuegt", "valuege", "valuelt", "valuele"]: + iop = getattr(indexunary, name) + setattr(select, name, SelectOp._from_indexunary(iop)) + _STANDARD_OPERATOR_NAMES.add(f"{SelectOp._modname}.{name}") + # fmt: on + cls._initialized = True + + def __init__( + self, + name, + func=None, + *, + anonymous=False, + is_positional=False, + is_udt=False, + numba_func=None, + ): + super().__init__(name, anonymous=anonymous) + self.orig_func = func + self._numba_func = numba_func + self.is_positional = is_positional + self._is_udt = is_udt + if is_udt: + self._udt_types = {} # {dtype: DataType} + self._udt_ops = {} # {dtype: TypedUserIndexUnaryOp} + + def __reduce__(self): + if self._anonymous: + if hasattr(self.orig_func, "_parameterized_info"): + # NOT COVERED + return (_deserialize_parameterized, self.orig_func._parameterized_info) + return (self.register_anonymous, (self.orig_func, self.name)) + if (name := f"indexunary.{self.name}") in _STANDARD_OPERATOR_NAMES: + return name + # NOT COVERED + return (self._deserialize, (self.name, self.orig_func)) + + __call__ = TypedBuiltinIndexUnaryOp.__call__ diff --git a/graphblas/core/operator/monoid.py b/graphblas/core/operator/monoid.py new file mode 100644 index 000000000..e3f218a90 --- /dev/null +++ b/graphblas/core/operator/monoid.py @@ -0,0 +1,428 @@ +import inspect +import re +from collections.abc import Mapping + +from ... import _STANDARD_OPERATOR_NAMES, binary, monoid, op +from ...dtypes import ( + BOOL, + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + lookup_dtype, +) +from ...exceptions import check_status_carg +from .. import ffi, lib +from ..utils import libget +from .base import OpBase, ParameterizedUdf, TypedOpBase, _hasop +from .binary import BinaryOp, ParameterizedBinaryOp, TypedBuiltinBinaryOp + +ffi_new = ffi.new + + +class TypedBuiltinMonoid(TypedOpBase): + __slots__ = "_identity" + opclass = "Monoid" + is_commutative = True + + def __init__(self, parent, name, type_, return_type, gb_obj, gb_name): + super().__init__(parent, name, type_, return_type, gb_obj, gb_name) + self._identity = None + + @property + def identity(self): + if self._identity is None: + from ..recorder import skip_record + from ..vector import Vector + + with skip_record: + self._identity = ( + Vector(self.type, size=1, name="").reduce(self, allow_empty=False).new().value + ) + return self._identity + + @property + def binaryop(self): + return getattr(binary, self.name)[self.type] + + @property + def commutes_to(self): + return self + + @property + def type2(self): + return self.type + + @property + def is_idempotent(self): + """True if ``monoid(x, x) == x`` for any x.""" + return self.parent.is_idempotent + + __call__ = TypedBuiltinBinaryOp.__call__ + + +class TypedUserMonoid(TypedOpBase): + __slots__ = "binaryop", "identity" + opclass = "Monoid" + is_commutative = True + + def __init__(self, parent, name, type_, return_type, gb_obj, binaryop, identity): + super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}") + self.binaryop = binaryop + self.identity = identity + binaryop._monoid = self + + commutes_to = TypedBuiltinMonoid.commutes_to + type2 = TypedBuiltinMonoid.type2 + is_idempotent = TypedBuiltinMonoid.is_idempotent + __call__ = TypedBuiltinMonoid.__call__ + + +class ParameterizedMonoid(ParameterizedUdf): + __slots__ = "binaryop", "identity", "_is_idempotent", "__signature__" + is_commutative = True + + def __init__(self, name, binaryop, identity, *, is_idempotent=False, anonymous=False): + if type(binaryop) is not ParameterizedBinaryOp: + raise TypeError("binaryop must be parameterized") + self.binaryop = binaryop + self.__signature__ = binaryop.__signature__ + if callable(identity): + # assume it must be parameterized as well, so signature must match + sig = inspect.signature(identity) + if sig != self.__signature__: + raise ValueError( + "Signatures of binaryop and identity passed to " + f"{type(self).__name__} must be the same. Got:\n" + f" binaryop{self.__signature__}\n" + " !=\n" + f" identity{sig}" + ) + self.identity = identity + self._is_idempotent = is_idempotent + if name is None: + name = binaryop.name + super().__init__(name, anonymous) + binaryop._monoid = self + # clear binaryop cache so it can be associated with this monoid + binaryop._cached_call.cache_clear() + + def _call(self, *args, **kwargs): + binary = self.binaryop(*args, **kwargs) + identity = self.identity + if callable(identity): + identity = identity(*args, **kwargs) + return Monoid.register_anonymous( + binary, identity, self.name, is_idempotent=self._is_idempotent + ) + + commutes_to = TypedBuiltinMonoid.commutes_to + + @property + def is_idempotent(self): + """True if ``monoid(x, x) == x`` for any x.""" + return self._is_idempotent + + def __reduce__(self): + name = f"monoid.{self.name}" + if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: # pragma: no cover + return name + return (self._deserialize, (self.name, self.binaryop, self.identity, self._anonymous)) + + @staticmethod + def _deserialize(name, binaryop, identity, anonymous): + if anonymous: + return Monoid.register_anonymous(binaryop, identity, name) + if (rv := Monoid._find(name)) is not None: + return rv + return Monoid.register_new(name, binaryop, identity) + + +class Monoid(OpBase): + """Takes two inputs and returns one output, all of the same data type. + + Built-in and registered Monoids are located in the ``graphblas.monoid`` namespace + as well as in the ``graphblas.ops`` combined namespace. + """ + + __slots__ = "_binaryop", "_identity", "_is_idempotent" + is_commutative = True + is_positional = False + _custom_dtype = None + _module = monoid + _modname = "monoid" + _typed_class = TypedBuiltinMonoid + _parse_config = { + "trim_from_front": 4, + "delete_exact": "MONOID", + "num_underscores": 1, + "re_exprs": [ + re.compile( + "^GrB_(MIN|MAX|PLUS|TIMES|LOR|LAND|LXOR|LXNOR)_MONOID" + "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile( + "^GxB_(ANY)_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)_MONOID$" + ), + re.compile("^GxB_(PLUS|TIMES|ANY)_(FC32|FC64)_MONOID$"), + re.compile("^GxB_(EQ|ANY)_BOOL_MONOID$"), + re.compile("^GxB_(BOR|BAND|BXOR|BXNOR)_(UINT8|UINT16|UINT32|UINT64)_MONOID$"), + ], + } + + @classmethod + def _build(cls, name, binaryop, identity, *, is_idempotent=False, anonymous=False): + if type(binaryop) is not BinaryOp: + raise TypeError(f"binaryop must be a BinaryOp, not {type(binaryop)}") + if name is None: + name = binaryop.name + new_type_obj = cls( + name, binaryop, identity, is_idempotent=is_idempotent, anonymous=anonymous + ) + if not binaryop._is_udt: + if not isinstance(identity, Mapping): + identities = dict.fromkeys(binaryop.types, identity) + explicit_identities = False + else: + identities = {lookup_dtype(key): val for key, val in identity.items()} + explicit_identities = True + for type_, ident in identities.items(): + ret_type = binaryop[type_].return_type + # If there is a domain mismatch, then DomainMismatch will be raised + # below if identities were explicitly given. + if type_ != ret_type and not explicit_identities: + continue + new_monoid = ffi_new("GrB_Monoid*") + func = libget(f"GrB_Monoid_new_{type_.name}") + zcast = ffi.cast(type_.c_type, ident) + check_status_carg( + func(new_monoid, binaryop[type_].gb_obj, zcast), "Monoid", new_monoid[0] + ) + op = TypedUserMonoid( + new_type_obj, + name, + type_, + ret_type, + new_monoid[0], + binaryop[type_], + ident, + ) + new_type_obj._add(op) + return new_type_obj + + def _compile_udt(self, dtype, dtype2): + if dtype2 is None: + dtype2 = dtype + elif dtype != dtype2: + raise TypeError( + "Monoid inputs must be the same dtype (got {dtype} and {dtype2}); " + "unable to coerce when using UDTs." + ) + if dtype in self._udt_types: + return self._udt_ops[dtype] + binaryop = self.binaryop._compile_udt(dtype, dtype2) + from ..scalar import Scalar + + ret_type = binaryop.return_type + identity = Scalar.from_value(self._identity, dtype=ret_type, is_cscalar=True) + new_monoid = ffi_new("GrB_Monoid*") + status = lib.GrB_Monoid_new_UDT(new_monoid, binaryop.gb_obj, identity.gb_obj) + check_status_carg(status, "Monoid", new_monoid[0]) + op = TypedUserMonoid( + new_monoid, + self.name, + dtype, + ret_type, + new_monoid[0], + binaryop, + identity, + ) + self._udt_types[dtype] = dtype + self._udt_ops[dtype] = op + return op + + @classmethod + def register_anonymous(cls, binaryop, identity, name=None, *, is_idempotent=False): + """Register a Monoid without registering it in the ``graphblas.monoid`` namespace. + + A monoid is a binary operator whose inputs and output are the same dtype. + Because it is not registered in the namespace, the name is optional. + + Parameters + ---------- + binaryop: BinaryOp or ParameterizedBinaryOp + The binary operator of the monoid, which should be able to use the same + dtype for both inputs and the output. + identity: scalar or Mapping + The identity of the monoid such that ``op(x, identity) == x`` for any x. + ``identity`` may also be a mapping from dtype to scalar. + name : str, optional + The name of the operator. This *does not* show up as ``gb.monoid.{name}``. + is_idempotent : bool, default False + Does ``op(x, x) == x`` for any x? + + Returns + ------- + Monoid or ParameterizedMonoid + + """ + if type(binaryop) is ParameterizedBinaryOp: + return ParameterizedMonoid( + name, binaryop, identity, is_idempotent=is_idempotent, anonymous=True + ) + return cls._build(name, binaryop, identity, is_idempotent=is_idempotent, anonymous=True) + + @classmethod + def register_new(cls, name, binaryop, identity, *, is_idempotent=False, lazy=False): + """Register a new Monoid and save it to ``graphblas.monoid`` namespace. + + A monoid is a binary operator whose inputs and output are the same dtype. + + Parameters + ---------- + name : str + The name of the operator. This will show up as ``gb.monoid.{name}``. + The name may contain periods, ".", which will result in nested objects + such as ``gb.monoid.x.y.z`` for name ``"x.y.z"``. + binaryop: BinaryOp or ParameterizedBinaryOp + The binary operator of the monoid, which should be able to use the same + dtype for both inputs and the output. + identity: scalar or Mapping + The identity of the monoid such that ``op(x, identity) == x`` for any x. + ``identity`` may also be a mapping from dtype to scalar. + is_idempotent : bool, default False + Does ``op(x, x) == x`` for any x? + lazy : bool, default False + If False (the default), then the function will be automatically + compiled for builtin types (unless ``is_udt`` was True for the binaryop). + Compiling functions can be slow, however, so you may want to + delay compilation and only compile when the operator is used, + which is done by setting ``lazy=True``. + + Examples + -------- + >>> gb.core.operator.Monoid.register_new("max_zero", gb.binary.max_zero, 0) + >>> dir(gb.monoid) + [..., 'max_zero', ...] + + """ + module, funcname = cls._remove_nesting(name) + if lazy: + module._delayed[funcname] = ( + cls.register_new, + {"name": name, "binaryop": binaryop, "identity": identity}, + ) + elif type(binaryop) is ParameterizedBinaryOp: + monoid = ParameterizedMonoid(name, binaryop, identity, is_idempotent=is_idempotent) + setattr(module, funcname, monoid) + else: + monoid = cls._build(name, binaryop, identity, is_idempotent=is_idempotent) + setattr(module, funcname, monoid) + # Also save it to `graphblas.op` if not yet defined + opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) + if not _hasop(opmodule, funcname): + if lazy: + opmodule._delayed[funcname] = module + else: + setattr(opmodule, funcname, monoid) + if not cls._initialized: # pragma: no cover + _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") + if not lazy: + return monoid + + def __init__(self, name, binaryop=None, identity=None, *, is_idempotent=False, anonymous=False): + super().__init__(name, anonymous=anonymous) + self._binaryop = binaryop + self._identity = identity + self._is_idempotent = is_idempotent + if binaryop is not None: + binaryop._monoid = self + if binaryop._is_udt: + self._udt_types = {} # {dtype: DataType} + self._udt_ops = {} # {dtype: TypedUserMonoid} + + def __reduce__(self): + if self._anonymous: + return (self.register_anonymous, (self._binaryop, self._identity, self.name)) + if (name := f"monoid.{self.name}") in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self._binaryop, self._identity)) + + @property + def binaryop(self): + """The :class:`BinaryOp` associated with the Monoid.""" + if self._binaryop is not None: + return self._binaryop + # Must be builtin + return getattr(binary, self.name) + + @property + def identities(self): + """The per-dtype identity values for the Monoid.""" + return {dtype: val.identity for dtype, val in self._typed_ops.items()} + + @property + def is_idempotent(self): + """True if ``monoid(x, x) == x`` for any x.""" + return self._is_idempotent + + @property + def _is_udt(self): + return self._binaryop is not None and self._binaryop._is_udt + + @classmethod + def _initialize(cls): + if cls._initialized: # pragma: no cover (safety) + return + super()._initialize() + lor = monoid.lor._typed_ops[BOOL] + land = monoid.land._typed_ops[BOOL] + for cur_op, typed_op in [ + (monoid.max, lor), + (monoid.min, land), + # (monoid.plus, lor), # two choices: lor, or plus[int] + (monoid.times, land), + ]: + if BOOL not in cur_op.types: # pragma: no branch (safety) + cur_op.types[BOOL] = BOOL + cur_op.coercions[BOOL] = BOOL + cur_op._typed_ops[BOOL] = typed_op + + for cur_op in [monoid.lor, monoid.land, monoid.lxnor, monoid.lxor]: + bool_op = cur_op._typed_ops[BOOL] + for dtype in [ + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + ]: + if dtype in cur_op.types: # pragma: no cover (safety) + continue + cur_op.types[dtype] = BOOL + cur_op.coercions[dtype] = BOOL + cur_op._typed_ops[dtype] = bool_op + + # Builtin monoids that are idempotent; i.e., `op(x, x) == x` for any x + for name in ["any", "band", "bor", "land", "lor", "max", "min"]: + getattr(monoid, name)._is_idempotent = True + # Allow some functions to work on UDTs + any_ = monoid.any + any_._identity = 0 + any_._udt_types = {} + any_._udt_ops = {} + cls._initialized = True + + commutes_to = TypedBuiltinMonoid.commutes_to + __call__ = TypedBuiltinMonoid.__call__ diff --git a/graphblas/core/operator/select.py b/graphblas/core/operator/select.py new file mode 100644 index 000000000..6de4fa89a --- /dev/null +++ b/graphblas/core/operator/select.py @@ -0,0 +1,340 @@ +import inspect + +from ... import _STANDARD_OPERATOR_NAMES, select +from ...dtypes import BOOL, UINT64 +from ...exceptions import check_status_carg +from .. import _has_numba, ffi, lib +from .base import OpBase, ParameterizedUdf, TypedOpBase, _call_op, _deserialize_parameterized +from .indexunary import IndexUnaryOp, TypedBuiltinIndexUnaryOp + +if _has_numba: + import numba + + from .base import _get_udt_wrapper +ffi_new = ffi.new + + +class TypedBuiltinSelectOp(TypedOpBase): + __slots__ = () + opclass = "SelectOp" + + def __call__(self, val, thunk=None): + if thunk is None: + thunk = False # most basic form of 0 when unifying dtypes + return _call_op(self, val, thunk=thunk) + + thunk_type = TypedBuiltinIndexUnaryOp.thunk_type + + +class TypedUserSelectOp(TypedOpBase): + __slots__ = () + opclass = "SelectOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) + + @property + def orig_func(self): + return self.parent.orig_func + + @property + def _numba_func(self): + return self.parent._numba_func + + thunk_type = TypedBuiltinSelectOp.thunk_type + __call__ = TypedBuiltinSelectOp.__call__ + + +class ParameterizedSelectOp(ParameterizedUdf): + __slots__ = "func", "__signature__", "_is_udt" + + def __init__(self, name, func, *, anonymous=False, is_udt=False): + self.func = func + self.__signature__ = inspect.signature(func) + self._is_udt = is_udt + if name is None: + name = getattr(func, "__name__", name) + super().__init__(name, anonymous) + + def _call(self, *args, **kwargs): + sel = self.func(*args, **kwargs) + sel._parameterized_info = (self, args, kwargs) + return SelectOp.register_anonymous(sel, self.name, is_udt=self._is_udt) + + def __reduce__(self): + # NOT COVERED + name = f"select.{self.name}" + if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self.func, self._anonymous)) + + @staticmethod + def _deserialize(name, func, anonymous): + # NOT COVERED + if anonymous: + return SelectOp.register_anonymous(func, name, parameterized=True) + if (rv := SelectOp._find(name)) is not None: + return rv + return SelectOp.register_new(name, func, parameterized=True) + + +class SelectOp(OpBase): + """Identical to an :class:`IndexUnaryOp `, + but must have a Boolean return type. + + A SelectOp is used exclusively to select a subset of values from a collection where + the function returns True. + + Built-in and registered SelectOps are located in the ``graphblas.select`` namespace. + """ + + __slots__ = "orig_func", "is_positional", "_is_udt", "_numba_func" + _module = select + _modname = "select" + _custom_dtype = None + _typed_class = TypedBuiltinSelectOp + _typed_user_class = TypedUserSelectOp + + @classmethod + def _from_indexunary(cls, iop): + obj = cls( + iop.name, + iop.orig_func, + anonymous=iop._anonymous, + is_positional=iop.is_positional, + is_udt=iop._is_udt, + numba_func=iop._numba_func, + ) + if not all(x == BOOL for x in iop.types.values()): + raise ValueError("SelectOp must have BOOL return type") + for type_, t in iop._typed_ops.items(): + if iop.orig_func is not None: + op = cls._typed_user_class( + obj, + iop.name, + t.type, + t.return_type, + t.gb_obj, + ) + else: + op = cls._typed_class( + obj, + iop.name, + t.type, + t.return_type, + t.gb_obj, + t.gb_name, + ) + # type is not always equal to t.type, so can't use op._add + # but otherwise perform the same logic + obj._typed_ops[type_] = op + obj.types[type_] = op.return_type + return obj + + def _compile_udt(self, dtype, dtype2): + if dtype2 is None: # pragma: no cover + dtype2 = dtype + dtypes = (dtype, dtype2) + if dtypes in self._udt_types: + return self._udt_ops[dtypes] + if self._numba_func is None: + raise KeyError(f"{self.name} does not work with {dtypes} types") + + # It would be nice if we could reuse compiling done for IndexUnaryOp + numba_func = self._numba_func + sig = (dtype.numba_type, UINT64.numba_type, UINT64.numba_type, dtype2.numba_type) + numba_func.compile(sig) # Should we catch and give additional error message? + select_wrapper, wrapper_sig = _get_udt_wrapper( + numba_func, BOOL, dtype, dtype2, include_indexes=True + ) + + select_wrapper = numba.cfunc(wrapper_sig, nopython=True)(select_wrapper) + new_select = ffi_new("GrB_IndexUnaryOp*") + check_status_carg( + lib.GrB_IndexUnaryOp_new( + new_select, select_wrapper.cffi, BOOL._carg, dtype._carg, dtype2._carg + ), + "IndexUnaryOp", + new_select[0], + ) + op = TypedUserSelectOp( + self, + self.name, + dtype, + BOOL, + new_select[0], + dtype2=dtype2, + ) + self._udt_types[dtypes] = BOOL + self._udt_ops[dtypes] = op + return op + + @classmethod + def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): + """Register a SelectOp without registering it in the ``graphblas.select`` namespace. + + Because it is not registered in the namespace, the name is optional. + The return type must be Boolean. + + Parameters + ---------- + func : FunctionType + The function to compile. For all current backends, this must be able + to be compiled with ``numba.njit``. + ``func`` takes four input parameters--any dtype, int64, int64, + any dtype and returns boolean. The first argument (any dtype) is + the value of the input Matrix or Vector, the second argument (int64) + is the row index of the Matrix or the index of the Vector, the third + argument (int64) is the column index of the Matrix or 0 for a Vector, + and the fourth argument (any dtype) is the value of the input Scalar. + name : str, optional + The name of the operator. This *does not* show up as ``gb.select.{name}``. + parameterized : bool, default False + When True, create a parameterized user-defined operator, which means + additional parameters can be "baked into" the operator when used. + For example, ``gb.binary.isclose`` is a parameterized BinaryOp that + optionally accepts ``rel_tol`` and ``abs_tol`` parameters, and it + can be used as: ``A.ewise_mult(B, gb.binary.isclose(rel_tol=1e-5))``. + When creating a parameterized user-defined operator, the ``func`` + parameter must be a callable that *returns* a function that will + then get compiled. + is_udt : bool, default False + Whether the operator is intended to operate on user-defined types. + If True, then the function will not be automatically compiled for + builtin types, and it will be compiled "just in time" when used. + Setting ``is_udt=True`` is also helpful when the left and right + dtypes need to be different. + + Returns + ------- + SelectOp or ParameterizedSelectOp + + """ + cls._check_supports_udf("register_anonymous") + if parameterized: + return ParameterizedSelectOp(name, func, anonymous=True, is_udt=is_udt) + iop = IndexUnaryOp._build(name, func, anonymous=True, is_udt=is_udt) + return SelectOp._from_indexunary(iop) + + @classmethod + def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): + """Register a new SelectOp and save it to ``graphblas.select`` namespace. + + The function will also be registered as a IndexUnaryOp with the same name. + The return type must be Boolean. + + Parameters + ---------- + name : str + The name of the operator. This will show up as ``gb.select.{name}``. + The name may contain periods, ".", which will result in nested objects + such as ``gb.select.x.y.z`` for name ``"x.y.z"``. + func : FunctionType + The function to compile. For all current backends, this must be able + to be compiled with ``numba.njit``. + ``func`` takes four input parameters--any dtype, int64, int64, + any dtype and returns boolean. The first argument (any dtype) is + the value of the input Matrix or Vector, the second argument (int64) + is the row index of the Matrix or the index of the Vector, the third + argument (int64) is the column index of the Matrix or 0 for a Vector, + and the fourth argument (any dtype) is the value of the input Scalar. + parameterized : bool, default False + When True, create a parameterized user-defined operator, which means + additional parameters can be "baked into" the operator when used. + For example, ``gb.binary.isclose`` is a parameterized BinaryOp that + optionally accepts ``rel_tol`` and ``abs_tol`` parameters, and it + can be used as: ``A.ewise_mult(B, gb.binary.isclose(rel_tol=1e-5))``. + When creating a parameterized user-defined operator, the ``func`` + parameter must be a callable that *returns* a function that will + then get compiled. + is_udt : bool, default False + Whether the operator is intended to operate on user-defined types. + If True, then the function will not be automatically compiled for + builtin types, and it will be compiled "just in time" when used. + Setting ``is_udt=True`` is also helpful when the left and right + dtypes need to be different. + lazy : bool, default False + If False (the default), then the function will be automatically + compiled for builtin types (unless ``is_udt`` is True). + Compiling functions can be slow, however, so you may want to + delay compilation and only compile when the operator is used, + which is done by setting ``lazy=True``. + + Examples + -------- + >>> gb.select.register_new("upper_left_triangle", lambda x, i, j, thunk: i + j <= thunk) + >>> dir(gb.select) + [..., 'upper_left_triangle', ...] + + """ + cls._check_supports_udf("register_new") + iop = IndexUnaryOp.register_new( + name, func, parameterized=parameterized, is_udt=is_udt, lazy=lazy + ) + module, funcname = cls._remove_nesting(name, strict=False) + if lazy: + module._delayed[funcname] = ( + cls._get_delayed, + {"name": name}, + ) + elif parameterized: + op = ParameterizedSelectOp(funcname, func, is_udt=is_udt) + setattr(module, funcname, op) + return op + elif not all(x == BOOL for x in iop.types.values()): + # Undo registration of indexunaryop + imodule, funcname = IndexUnaryOp._remove_nesting(name, strict=False) + delattr(imodule, funcname) + raise ValueError("SelectOp must have BOOL return type") + else: + return getattr(module, funcname) + + @classmethod + def _get_delayed(cls, name): + imodule, funcname = IndexUnaryOp._remove_nesting(name, strict=False) + iop = getattr(imodule, name) + if not all(x == BOOL for x in iop.types.values()): + raise ValueError("SelectOp must have BOOL return type") + module, funcname = cls._remove_nesting(name, strict=False) + return getattr(module, funcname) + + @classmethod + def _initialize(cls): + if cls._initialized: # pragma: no cover (safety) + return + # IndexUnaryOp adds it boolean-returning objects to SelectOp + IndexUnaryOp._initialize() + cls._initialized = True + + def __init__( + self, + name, + func=None, + *, + anonymous=False, + is_positional=False, + is_udt=False, + numba_func=None, + ): + super().__init__(name, anonymous=anonymous) + self.orig_func = func + self._numba_func = numba_func + self.is_positional = is_positional + self._is_udt = is_udt + if is_udt: + # NOT COVERED + self._udt_types = {} # {dtype: DataType} + self._udt_ops = {} # {dtype: TypedUserIndexUnaryOp} + + def __reduce__(self): + if self._anonymous: + if hasattr(self.orig_func, "_parameterized_info"): + # NOT COVERED + return (_deserialize_parameterized, self.orig_func._parameterized_info) + return (self.register_anonymous, (self.orig_func, self.name)) + if (name := f"select.{self.name}") in _STANDARD_OPERATOR_NAMES: + return name + # NOT COVERED + return (self._deserialize, (self.name, self.orig_func)) + + __call__ = TypedBuiltinSelectOp.__call__ diff --git a/graphblas/core/operator/semiring.py b/graphblas/core/operator/semiring.py new file mode 100644 index 000000000..a8d18f1bf --- /dev/null +++ b/graphblas/core/operator/semiring.py @@ -0,0 +1,567 @@ +import itertools +import re + +from ... import _STANDARD_OPERATOR_NAMES, binary, monoid, op, semiring +from ...dtypes import ( + BOOL, + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + _supports_complex, +) +from ...exceptions import check_status_carg +from .. import _supports_udfs, ffi, lib +from .base import _SS_OPERATORS, OpBase, ParameterizedUdf, TypedOpBase, _call_op, _hasop +from .binary import BinaryOp, ParameterizedBinaryOp +from .monoid import Monoid, ParameterizedMonoid + +if _supports_complex: + from ...dtypes import FC32, FC64 + +ffi_new = ffi.new + + +class TypedBuiltinSemiring(TypedOpBase): + __slots__ = () + opclass = "Semiring" + + def __call__(self, left, right=None): + if right is not None: + raise TypeError( + f"Bad types when calling {self!r}. Got types: {type(left)}, {type(right)}.\n" + f"Expected an infix expression, such as: {self!r}(A @ B)" + ) + return _call_op(self, left) + + @property + def binaryop(self): + name = self.name.split("_", 1)[1] + if name in _SS_OPERATORS: + binop = binary._deprecated[name] + else: + binop = getattr(binary, name) + return binop[self.type] + + @property + def monoid(self): + monoid_name, binary_name = self.name.split("_", 1) + if binary_name in _SS_OPERATORS: + binop = binary._deprecated[binary_name] + else: + binop = getattr(binary, binary_name) + binop = binop[self.type] + val = getattr(monoid, monoid_name) + return val[binop.return_type] + + @property + def commutes_to(self): + binop = self.binaryop + commutes_to = binop._semiring_commutes_to or binop.commutes_to + if commutes_to is None: + return + if commutes_to is binop: + return self + from .utils import get_semiring + + return get_semiring(self.monoid, commutes_to) + + @property + def is_commutative(self): + return self.binaryop.is_commutative + + @property + def type2(self): + return self.type if self._type2 is None else self._type2 + + +class TypedUserSemiring(TypedOpBase): + __slots__ = "monoid", "binaryop" + opclass = "Semiring" + + def __init__(self, parent, name, type_, return_type, gb_obj, monoid, binaryop, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}", dtype2=dtype2) + self.monoid = monoid + self.binaryop = binaryop + + commutes_to = TypedBuiltinSemiring.commutes_to + is_commutative = TypedBuiltinSemiring.is_commutative + type2 = TypedBuiltinSemiring.type2 + __call__ = TypedBuiltinSemiring.__call__ + + +class ParameterizedSemiring(ParameterizedUdf): + __slots__ = "monoid", "binaryop", "__signature__" + + def __init__(self, name, monoid, binaryop, *, anonymous=False): + if type(monoid) not in {ParameterizedMonoid, Monoid}: + raise TypeError("monoid must be of type Monoid or ParameterizedMonoid") + if type(binaryop) is ParameterizedBinaryOp: + self.__signature__ = binaryop.__signature__ + if type(monoid) is ParameterizedMonoid and monoid.__signature__ != self.__signature__: + raise ValueError( + "Signatures of monoid and binaryop passed to " + f"{type(self).__name__} must be the same. Got:\n" + f" monoid{monoid.__signature__}\n" + " !=\n" + f" binaryop{self.__signature__}\n\n" + "Perhaps call monoid or binaryop with parameters before creating the semiring." + ) + elif type(binaryop) is BinaryOp: + if type(monoid) is Monoid: + raise TypeError("At least one of monoid or binaryop must be parameterized") + self.__signature__ = monoid.__signature__ + else: + raise TypeError("binaryop must be of type BinaryOp or ParameterizedBinaryOp") + self.monoid = monoid + self.binaryop = binaryop + if name is None: + name = f"{monoid.name}_{binaryop.name}" + super().__init__(name, anonymous) + + def _call(self, *args, **kwargs): + monoid = self.monoid + if type(monoid) is ParameterizedMonoid: + monoid = monoid(*args, **kwargs) + binary = self.binaryop + if type(binary) is ParameterizedBinaryOp: + binary = binary(*args, **kwargs) + return Semiring.register_anonymous(monoid, binary, self.name) + + commutes_to = TypedBuiltinSemiring.commutes_to + is_commutative = TypedBuiltinSemiring.is_commutative + + def __reduce__(self): + name = f"semiring.{self.name}" + if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: # pragma: no cover + return name + return (self._deserialize, (self.name, self.monoid, self.binaryop, self._anonymous)) + + @staticmethod + def _deserialize(name, monoid, binaryop, anonymous): + if anonymous: + return Semiring.register_anonymous(monoid, binaryop, name) + if (rv := Semiring._find(name)) is not None: + return rv + return Semiring.register_new(name, monoid, binaryop) + + +class Semiring(OpBase): + """Combination of a :class:`Monoid` and a :class:`BinaryOp`. + + Semirings are most commonly used for performing matrix multiplication, + with the BinaryOp taking the place of the standard multiplication operator + and the Monoid taking the place of the standard addition operator. + + Built-in and registered Semirings are located in the ``graphblas.semiring`` namespace + as well as in the ``graphblas.ops`` combined namespace. + """ + + __slots__ = "_monoid", "_binaryop" + _module = semiring + _modname = "semiring" + _typed_class = TypedBuiltinSemiring + _parse_config = { + "trim_from_front": 4, + "delete_exact": "SEMIRING", + "num_underscores": 2, + "re_exprs": [ + re.compile( + "^GrB_(PLUS|MIN|MAX)_(PLUS|TIMES|FIRST|SECOND|MIN|MAX)_SEMIRING" + "_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile( + "^GxB_(MIN|MAX|PLUS|TIMES|ANY)" + "_(FIRST|SECOND|PAIR|MIN|MAX|PLUS|MINUS|RMINUS|TIMES" + "|DIV|RDIV|ISEQ|ISNE|ISGT|ISLT|ISGE|ISLE|LOR|LAND|LXOR" + "|FIRSTI1|FIRSTI|FIRSTJ1|FIRSTJ|SECONDI1|SECONDI|SECONDJ1|SECONDJ)" + "_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile( + "^GxB_(PLUS|TIMES|ANY)_(FIRST|SECOND|PAIR|PLUS|MINUS|TIMES|DIV|RDIV|RMINUS)" + "_(FC32|FC64)$" + ), + re.compile( + "^GxB_(BOR|BAND|BXOR|BXNOR)_(BOR|BAND|BXOR|BXNOR)_(UINT8|UINT16|UINT32|UINT64)$" + ), + ], + "re_exprs_return_bool": [ + re.compile("^GrB_(LOR|LAND|LXOR|LXNOR)_(LOR|LAND)_SEMIRING_BOOL$"), + re.compile( + "^GxB_(LOR|LAND|LXOR|EQ|ANY)_(EQ|NE|GT|LT|GE|LE)" + "_(INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile( + "^GxB_(LOR|LAND|LXOR|EQ|ANY)_(FIRST|SECOND|PAIR|LOR|LAND|LXOR|EQ|GT|LT|GE|LE)_BOOL$" + ), + ], + } + + @classmethod + def _build(cls, name, monoid, binaryop, *, anonymous=False): + if type(monoid) is not Monoid: + raise TypeError(f"monoid must be a Monoid, not {type(monoid)}") + if type(binaryop) is not BinaryOp: + raise TypeError(f"binaryop must be a BinaryOp, not {type(binaryop)}") + if name is None: + name = f"{monoid.name}_{binaryop.name}".replace(".", "_") + new_type_obj = cls(name, monoid, binaryop, anonymous=anonymous) + if binaryop._is_udt: + return new_type_obj + for binary_in, binary_func in binaryop._typed_ops.items(): + binary_out = binary_func.return_type + # Unfortunately, we can't have user-defined monoids over bools yet + # because numba can't compile correctly. + if ( + binary_out not in monoid.types + # Are all coercions bad, or just to bool? + or monoid.coercions.get(binary_out, binary_out) != binary_out + ): + continue + new_semiring = ffi_new("GrB_Semiring*") + check_status_carg( + lib.GrB_Semiring_new(new_semiring, monoid[binary_out].gb_obj, binary_func.gb_obj), + "Semiring", + new_semiring[0], + ) + ret_type = monoid[binary_out].return_type + op = TypedUserSemiring( + new_type_obj, + name, + binary_in, + ret_type, + new_semiring[0], + monoid[binary_out], + binary_func, + ) + new_type_obj._add(op) + return new_type_obj + + def _compile_udt(self, dtype, dtype2): + if dtype2 is None: + dtype2 = dtype + dtypes = (dtype, dtype2) + if dtypes in self._udt_types: + return self._udt_ops[dtypes] + binaryop = self.binaryop._compile_udt(dtype, dtype2) + monoid = self.monoid[binaryop.return_type] + ret_type = monoid.return_type + new_semiring = ffi_new("GrB_Semiring*") + status = lib.GrB_Semiring_new(new_semiring, monoid.gb_obj, binaryop.gb_obj) + check_status_carg(status, "Semiring", new_semiring[0]) + op = TypedUserSemiring( + new_semiring, + self.name, + dtype, + ret_type, + new_semiring[0], + monoid, + binaryop, + dtype2=dtype2, + ) + self._udt_types[dtypes] = dtype + self._udt_ops[dtypes] = op + return op + + @classmethod + def register_anonymous(cls, monoid, binaryop, name=None): + """Register a Semiring without registering it in the ``graphblas.semiring`` namespace. + + Because it is not registered in the namespace, the name is optional. + + Parameters + ---------- + monoid : Monoid or ParameterizedMonoid + The monoid of the semiring (like "plus" in the default "plus_times" semiring). + binaryop : BinaryOp or ParameterizedBinaryOp + The binaryop of the semiring (like "times" in the default "plus_times" semiring). + name : str, optional + The name of the operator. This *does not* show up as ``gb.semiring.{name}``. + + Returns + ------- + Semiring or ParameterizedSemiring + + """ + if type(monoid) is ParameterizedMonoid or type(binaryop) is ParameterizedBinaryOp: + return ParameterizedSemiring(name, monoid, binaryop, anonymous=True) + return cls._build(name, monoid, binaryop, anonymous=True) + + @classmethod + def register_new(cls, name, monoid, binaryop, *, lazy=False): + """Register a new Semiring and save it to ``graphblas.semiring`` namespace. + + Parameters + ---------- + name : str + The name of the operator. This will show up as ``gb.semiring.{name}``. + The name may contain periods, ".", which will result in nested objects + such as ``gb.semiring.x.y.z`` for name ``"x.y.z"``. + monoid : Monoid or ParameterizedMonoid + The monoid of the semiring (like "plus" in the default "plus_times" semiring). + binaryop : BinaryOp or ParameterizedBinaryOp + The binaryop of the semiring (like "times" in the default "plus_times" semiring). + lazy : bool, default False + If False (the default), then the function will be automatically + compiled for builtin types (unless ``is_udt`` is True). + Compiling functions can be slow, however, so you may want to + delay compilation and only compile when the operator is used, + which is done by setting ``lazy=True``. + + Examples + -------- + >>> gb.core.operator.Semiring.register_new("max_max", gb.monoid.max, gb.binary.max) + >>> dir(gb.semiring) + [..., 'max_max', ...] + + """ + module, funcname = cls._remove_nesting(name) + if lazy: + module._delayed[funcname] = ( + cls.register_new, + {"name": name, "monoid": monoid, "binaryop": binaryop}, + ) + elif type(monoid) is ParameterizedMonoid or type(binaryop) is ParameterizedBinaryOp: + semiring = ParameterizedSemiring(name, monoid, binaryop) + setattr(module, funcname, semiring) + else: + semiring = cls._build(name, monoid, binaryop) + setattr(module, funcname, semiring) + # Also save it to `graphblas.op` if not yet defined + opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) + if not _hasop(opmodule, funcname): + if lazy: + opmodule._delayed[funcname] = module + else: + setattr(opmodule, funcname, semiring) + if not cls._initialized: + _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") + if not lazy: + return semiring + + @classmethod + def _initialize(cls): + if cls._initialized: # pragma: no cover (safety) + return + super()._initialize() + # Rename div to cdiv (truncate towards 0) + div_semirings = { + attr: val + for attr, val in vars(semiring).items() + if type(val) is Semiring and attr.endswith("_div") + } + for orig_name, orig in div_semirings.items(): + name = f"{orig_name[:-3]}cdiv" + cdiv_semiring = Semiring(name) + setattr(semiring, name, cdiv_semiring) + setattr(op, name, cdiv_semiring) + delattr(semiring, orig_name) + delattr(op, orig_name) + for dtype, ret_type in orig.types.items(): + orig_semiring = orig[dtype] + new_semiring = TypedBuiltinSemiring( + cdiv_semiring, + name, + dtype, + ret_type, + orig_semiring.gb_obj, + orig_semiring.gb_name, + ) + cdiv_semiring._add(new_semiring) + # Also add truediv (always floating point) and floordiv (truncate towards -inf) + for orig_name, orig in div_semirings.items(): + cls.register_new(f"{orig_name[:-3]}truediv", orig.monoid, binary.truediv, lazy=True) + cls.register_new(f"{orig_name[:-3]}rtruediv", orig.monoid, "rtruediv", lazy=True) + if _supports_udfs: + cls.register_new(f"{orig_name[:-3]}floordiv", orig.monoid, "floordiv", lazy=True) + cls.register_new(f"{orig_name[:-3]}rfloordiv", orig.monoid, "rfloordiv", lazy=True) + # For aggregators + cls.register_new("plus_pow", monoid.plus, binary.pow) + if _supports_udfs: + cls.register_new("plus_rpow", monoid.plus, "rpow", lazy=True) + cls.register_new("plus_absfirst", monoid.plus, "absfirst", lazy=True) + cls.register_new("max_absfirst", monoid.max, "absfirst", lazy=True) + cls.register_new("plus_abssecond", monoid.plus, "abssecond", lazy=True) + cls.register_new("max_abssecond", monoid.max, "abssecond", lazy=True) + + # Update type information with sane coercion + for lname in ["any", "eq", "land", "lor", "lxnor", "lxor"]: + target_name = f"{lname}_ne" + source_name = f"{lname}_lxor" + if not _hasop(semiring, target_name): + continue + target_op = getattr(semiring, target_name) + if BOOL not in target_op.types: # pragma: no branch (safety) + source_op = getattr(semiring, source_name) + typed_op = source_op._typed_ops[BOOL] + target_op.types[BOOL] = BOOL + target_op._typed_ops[BOOL] = typed_op + target_op.coercions[dtype] = BOOL + + position_dtypes = [ + BOOL, + FP32, + FP64, + INT8, + INT16, + UINT8, + UINT16, + UINT32, + UINT64, + ] + notbool_dtypes = [ + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + ] + if _supports_complex: + position_dtypes.extend([FC32, FC64]) + notbool_dtypes.extend([FC32, FC64]) + for lnames, rnames, *types in [ + # fmt: off + ( + ("any", "max", "min", "plus", "times"), + ( + "firsti", "firsti1", "firstj", "firstj1", + "secondi", "secondi1", "secondj", "secondj1", + ), + ( + position_dtypes, + INT64, + ), + ), + ( + ("eq", "land", "lor", "lxnor", "lxor"), + ("first", "pair", "second"), + # TODO: check if FC coercion works here + ( + notbool_dtypes, + BOOL, + ), + ), + ( + ("band", "bor", "bxnor", "bxor"), + ("band", "bor", "bxnor", "bxor"), + ([INT8], UINT16), + ([INT16], UINT32), + ([INT32], UINT64), + ([INT64], UINT64), + ), + ( + ("any", "eq", "land", "lor", "lxnor", "lxor"), + ("eq", "land", "lor", "lxnor", "lxor", "ne"), + ( + ( + FP32, FP64, INT8, INT16, INT32, INT64, + UINT8, UINT16, UINT32, UINT64, + ), + BOOL, + ), + ), + # fmt: on + ]: + for left, right in itertools.product(lnames, rnames): + name = f"{left}_{right}" + if not _hasop(semiring, name): + continue + if name in _SS_OPERATORS: + cur_op = semiring._deprecated[name] + else: + cur_op = getattr(semiring, name) + for input_types, target_type in types: + typed_op = cur_op._typed_ops[target_type] + output_type = cur_op.types[target_type] + for dtype in input_types: + if dtype not in cur_op.types: + cur_op.types[dtype] = output_type + cur_op._typed_ops[dtype] = typed_op + cur_op.coercions[dtype] = target_type + + # Handle a few boolean cases + for opname, targetname in [ + ("max_first", "lor_first"), + ("max_second", "lor_second"), + ("max_land", "lor_land"), + ("max_lor", "lor_lor"), + ("max_lxor", "lor_lxor"), + ("min_first", "land_first"), + ("min_second", "land_second"), + ("min_land", "land_land"), + ("min_lor", "land_lor"), + ("min_lxor", "land_lxor"), + ]: + cur_op = getattr(semiring, opname) + target = getattr(semiring, targetname) + if BOOL in cur_op.types or BOOL not in target.types: # pragma: no cover (safety) + continue + cur_op.types[BOOL] = target.types[BOOL] + cur_op._typed_ops[BOOL] = target._typed_ops[BOOL] + cur_op.coercions[BOOL] = BOOL + cls._initialized = True + + def __init__(self, name, monoid=None, binaryop=None, *, anonymous=False): + super().__init__(name, anonymous=anonymous) + self._monoid = monoid + self._binaryop = binaryop + try: + if self.binaryop._udt_types is not None: + self._udt_types = {} # {(dtype, dtype): DataType} + self._udt_ops = {} # {(dtype, dtype): TypedUserSemiring} + except AttributeError: + # `*_div` semirings raise here, but don't need `_udt_types` + pass + + def __reduce__(self): + if self._anonymous: + return (self.register_anonymous, (self._monoid, self._binaryop, self.name)) + if (name := f"semiring.{self.name}") in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self._monoid, self._binaryop)) + + @property + def binaryop(self): + """The :class:`BinaryOp` associated with the Semiring.""" + if self._binaryop is not None: + return self._binaryop + # Must be builtin + name = self.name.split("_")[1] + if name in _SS_OPERATORS: + return binary._deprecated[name] + return getattr(binary, name) + + @property + def monoid(self): + """The :class:`Monoid` associated with the Semiring.""" + if self._monoid is not None: + return self._monoid + # Must be builtin + return getattr(monoid, self.name.split("_")[0].split(".")[-1]) + + @property + def is_positional(self): + return self.binaryop.is_positional + + @property + def _is_udt(self): + return self._binaryop is not None and self._binaryop._is_udt + + @property + def _custom_dtype(self): + return self.binaryop._custom_dtype + + commutes_to = TypedBuiltinSemiring.commutes_to + is_commutative = TypedBuiltinSemiring.is_commutative + __call__ = TypedBuiltinSemiring.__call__ diff --git a/graphblas/core/operator/unary.py b/graphblas/core/operator/unary.py new file mode 100644 index 000000000..26e0ca61c --- /dev/null +++ b/graphblas/core/operator/unary.py @@ -0,0 +1,475 @@ +import inspect +import re +from types import FunctionType + +from ... import _STANDARD_OPERATOR_NAMES, op, unary +from ...dtypes import ( + BOOL, + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + _supports_complex, + lookup_dtype, +) +from ...exceptions import UdfParseError, check_status_carg +from .. import _has_numba, ffi, lib +from ..dtypes import _sample_values +from ..utils import output_type +from .base import ( + _SS_OPERATORS, + OpBase, + ParameterizedUdf, + TypedOpBase, + _deserialize_parameterized, + _hasop, +) + +if _supports_complex: + from ...dtypes import FC32, FC64 +if _has_numba: + import numba + + from .base import _get_udt_wrapper + +ffi_new = ffi.new + + +class TypedBuiltinUnaryOp(TypedOpBase): + __slots__ = () + opclass = "UnaryOp" + + def __call__(self, val): + from ..matrix import Matrix, TransposedMatrix + from ..vector import Vector + + if (typ := output_type(val)) in {Vector, Matrix, TransposedMatrix}: + return val.apply(self) + from ..scalar import Scalar, _as_scalar + + if typ is Scalar: + return val.apply(self) + try: + scalar = _as_scalar(val, is_cscalar=False) + except Exception: + pass + else: + return scalar.apply(self) + raise TypeError( + f"Bad type when calling {self!r}.\n" + " - Expected type: Scalar, Vector, Matrix, TransposedMatrix.\n" + f" - Got: {type(val)}.\n" + "Calling a UnaryOp is syntactic sugar for calling apply. " + f"For example, `A.apply({self!r})` is the same as `{self!r}(A)`." + ) + + +class TypedUserUnaryOp(TypedOpBase): + __slots__ = () + opclass = "UnaryOp" + + def __init__(self, parent, name, type_, return_type, gb_obj): + super().__init__(parent, name, type_, return_type, gb_obj, f"{name}_{type_}") + + @property + def orig_func(self): + return self.parent.orig_func + + @property + def _numba_func(self): + return self.parent._numba_func + + __call__ = TypedBuiltinUnaryOp.__call__ + + +class ParameterizedUnaryOp(ParameterizedUdf): + __slots__ = "func", "__signature__", "_is_udt" + + def __init__(self, name, func, *, anonymous=False, is_udt=False): + self.func = func + self.__signature__ = inspect.signature(func) + self._is_udt = is_udt + if name is None: + name = getattr(func, "__name__", name) + super().__init__(name, anonymous) + + def _call(self, *args, **kwargs): + unary = self.func(*args, **kwargs) + unary._parameterized_info = (self, args, kwargs) + return UnaryOp.register_anonymous(unary, self.name, is_udt=self._is_udt) + + def __reduce__(self): + name = f"unary.{self.name}" + if not self._anonymous and name in _STANDARD_OPERATOR_NAMES: # pragma: no cover + return name + return (self._deserialize, (self.name, self.func, self._anonymous)) + + @staticmethod + def _deserialize(name, func, anonymous): + if anonymous: + return UnaryOp.register_anonymous(func, name, parameterized=True) + if (rv := UnaryOp._find(name)) is not None: + return rv + return UnaryOp.register_new(name, func, parameterized=True) + + +def _identity(x): + return x # pragma: no cover (numba) + + +def _one(x): + return 1 # pragma: no cover (numba) + + +class UnaryOp(OpBase): + """Takes one input and returns one output, possibly of a different data type. + + Built-in and registered UnaryOps are located in the ``graphblas.unary`` namespace + as well as in the ``graphblas.ops`` combined namespace. + """ + + __slots__ = "orig_func", "is_positional", "_is_udt", "_numba_func" + _custom_dtype = None + _module = unary + _modname = "unary" + _typed_class = TypedBuiltinUnaryOp + _parse_config = { + "trim_from_front": 4, + "num_underscores": 1, + "re_exprs": [ + re.compile( + "^GrB_(IDENTITY|AINV|MINV|ABS|BNOT)" + "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64|FC32|FC64)$" + ), + re.compile( + "^GxB_(LNOT|ONE|POSITIONI1|POSITIONI|POSITIONJ1|POSITIONJ)" + "_(BOOL|INT8|UINT8|INT16|UINT16|INT32|UINT32|INT64|UINT64|FP32|FP64)$" + ), + re.compile( + "^GxB_(SQRT|LOG|EXP|LOG2|SIN|COS|TAN|ACOS|ASIN|ATAN|SINH|COSH|TANH|ACOSH" + "|ASINH|ATANH|SIGNUM|CEIL|FLOOR|ROUND|TRUNC|EXP2|EXPM1|LOG10|LOG1P)" + "_(FP32|FP64|FC32|FC64)$" + ), + re.compile("^GxB_(LGAMMA|TGAMMA|ERF|ERFC|FREXPX|FREXPE|CBRT)_(FP32|FP64)$"), + re.compile("^GxB_(IDENTITY|AINV|MINV|ONE|CONJ)_(FC32|FC64)$"), + ], + "re_exprs_return_bool": [ + re.compile("^GrB_LNOT$"), + re.compile("^GxB_(ISINF|ISNAN|ISFINITE)_(FP32|FP64|FC32|FC64)$"), + ], + "re_exprs_return_float": [re.compile("^GxB_(CREAL|CIMAG|CARG|ABS)_(FC32|FC64)$")], + } + _positional = {"positioni", "positioni1", "positionj", "positionj1"} + + @classmethod + def _build(cls, name, func, *, anonymous=False, is_udt=False): + if type(func) is not FunctionType: + raise TypeError(f"UDF argument must be a function, not {type(func)}") + if name is None: + name = getattr(func, "__name__", "") + success = False + unary_udf = numba.njit(func) + new_type_obj = cls(name, func, anonymous=anonymous, is_udt=is_udt, numba_func=unary_udf) + return_types = {} + nt = numba.types + if not is_udt: + for type_ in _sample_values: + sig = (type_.numba_type,) + try: + unary_udf.compile(sig) + except numba.TypingError: + continue + ret_type = lookup_dtype(unary_udf.overloads[sig].signature.return_type) + if ret_type != type_ and ( + ("INT" in ret_type.name and "INT" in type_.name) + or ("FP" in ret_type.name and "FP" in type_.name) + or ("FC" in ret_type.name and "FC" in type_.name) + or (type_ == UINT64 and ret_type == FP64 and return_types.get(INT64) == INT64) + ): + # Downcast `ret_type` to `type_`. + # This is what users want most of the time, but we can't make a perfect rule. + # There should be a way for users to be explicit. + ret_type = type_ + elif type_ == BOOL and ret_type == INT64 and return_types.get(INT8) == INT8: + ret_type = INT8 + + # Numba is unable to handle BOOL correctly right now, but we have a workaround + # See: https://github.com/numba/numba/issues/5395 + # We're relying on coercion behaving correctly here + input_type = INT8 if type_ == BOOL else type_ + return_type = INT8 if ret_type == BOOL else ret_type + + # Build wrapper because GraphBLAS wants pointers and void return + wrapper_sig = nt.void( + nt.CPointer(return_type.numba_type), + nt.CPointer(input_type.numba_type), + ) + + if type_ == BOOL: + if ret_type == BOOL: + + def unary_wrapper(z, x): + z[0] = bool(unary_udf(bool(x[0]))) # pragma: no cover (numba) + + else: + + def unary_wrapper(z, x): + z[0] = unary_udf(bool(x[0])) # pragma: no cover (numba) + + elif ret_type == BOOL: + + def unary_wrapper(z, x): + z[0] = bool(unary_udf(x[0])) # pragma: no cover (numba) + + else: + + def unary_wrapper(z, x): + z[0] = unary_udf(x[0]) # pragma: no cover (numba) + + unary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(unary_wrapper) + new_unary = ffi_new("GrB_UnaryOp*") + check_status_carg( + lib.GrB_UnaryOp_new( + new_unary, unary_wrapper.cffi, ret_type.gb_obj, type_.gb_obj + ), + "UnaryOp", + new_unary[0], + ) + op = TypedUserUnaryOp(new_type_obj, name, type_, ret_type, new_unary[0]) + new_type_obj._add(op) + success = True + return_types[type_] = ret_type + if success or is_udt: + return new_type_obj + raise UdfParseError("Unable to parse function using Numba") + + def _compile_udt(self, dtype, dtype2): + if dtype in self._udt_types: + return self._udt_ops[dtype] + if self._numba_func is None: + raise KeyError(f"{self.name} does not work with {dtype}") + + numba_func = self._numba_func + sig = (dtype.numba_type,) + numba_func.compile(sig) # Should we catch and give additional error message? + ret_type = lookup_dtype(numba_func.overloads[sig].signature.return_type) + + unary_wrapper, wrapper_sig = _get_udt_wrapper(numba_func, ret_type, dtype) + unary_wrapper = numba.cfunc(wrapper_sig, nopython=True)(unary_wrapper) + new_unary = ffi_new("GrB_UnaryOp*") + check_status_carg( + lib.GrB_UnaryOp_new(new_unary, unary_wrapper.cffi, ret_type._carg, dtype._carg), + "UnaryOp", + new_unary[0], + ) + op = TypedUserUnaryOp(self, self.name, dtype, ret_type, new_unary[0]) + self._udt_types[dtype] = ret_type + self._udt_ops[dtype] = op + return op + + @classmethod + def register_anonymous(cls, func, name=None, *, parameterized=False, is_udt=False): + """Register a UnaryOp without registering it in the ``graphblas.unary`` namespace. + + Because it is not registered in the namespace, the name is optional. + + Parameters + ---------- + func : FunctionType + The function to compile. For all current backends, this must be able + to be compiled with ``numba.njit``. + ``func`` takes one input parameters of any dtype and returns any dtype. + name : str, optional + The name of the operator. This *does not* show up as ``gb.unary.{name}``. + parameterized : bool, default False + When True, create a parameterized user-defined operator, which means + additional parameters can be "baked into" the operator when used. + For example, ``gb.binary.isclose`` is a parameterized function that + optionally accepts ``rel_tol`` and ``abs_tol`` parameters, and it + can be used as: ``A.ewise_mult(B, gb.binary.isclose(rel_tol=1e-5))``. + When creating a parameterized user-defined operator, the ``func`` + parameter must be a callable that *returns* a function that will + then get compiled. See the ``user_isclose`` example below. + is_udt : bool, default False + Whether the operator is intended to operate on user-defined types. + If True, then the function will not be automatically compiled for + builtin types, and it will be compiled "just in time" when used. + + Returns + ------- + UnaryOp or ParameterizedUnaryOp + + """ + cls._check_supports_udf("register_anonymous") + if parameterized: + return ParameterizedUnaryOp(name, func, anonymous=True, is_udt=is_udt) + return cls._build(name, func, anonymous=True, is_udt=is_udt) + + @classmethod + def register_new(cls, name, func, *, parameterized=False, is_udt=False, lazy=False): + """Register a new UnaryOp and save it to ``graphblas.unary`` namespace. + + Parameters + ---------- + name : str + The name of the operator. This will show up as ``gb.unary.{name}``. + The name may contain periods, ".", which will result in nested objects + such as ``gb.unary.x.y.z`` for name ``"x.y.z"``. + func : FunctionType + The function to compile. For all current backends, this must be able + to be compiled with ``numba.njit``. + ``func`` takes one input parameters of any dtype and returns any dtype. + parameterized : bool, default False + When True, create a parameterized user-defined operator, which means + additional parameters can be "baked into" the operator when used. + For example, ``gb.binary.isclose`` is a parameterized function that + optionally accepts ``rel_tol`` and ``abs_tol`` parameters, and it + can be used as: ``A.ewise_mult(B, gb.binary.isclose(rel_tol=1e-5))``. + When creating a parameterized user-defined operator, the ``func`` + parameter must be a callable that *returns* a function that will + then get compiled. See the ``user_isclose`` example below. + is_udt : bool, default False + Whether the operator is intended to operate on user-defined types. + If True, then the function will not be automatically compiled for + builtin types, and it will be compiled "just in time" when used. + lazy : bool, default False + If False (the default), then the function will be automatically + compiled for builtin types (unless ``is_udt`` is True). + Compiling functions can be slow, however, so you may want to + delay compilation and only compile when the operator is used, + which is done by setting ``lazy=True``. + + Examples + -------- + >>> gb.core.operator.UnaryOp.register_new("plus_one", lambda x: x + 1) + >>> dir(gb.unary) + [..., 'plus_one', ...] + + """ + cls._check_supports_udf("register_new") + module, funcname = cls._remove_nesting(name) + if lazy: + module._delayed[funcname] = ( + cls.register_new, + {"name": name, "func": func, "parameterized": parameterized}, + ) + elif parameterized: + unary_op = ParameterizedUnaryOp(name, func, is_udt=is_udt) + setattr(module, funcname, unary_op) + else: + unary_op = cls._build(name, func, is_udt=is_udt) + setattr(module, funcname, unary_op) + # Also save it to `graphblas.op` if not yet defined + opmodule, funcname = cls._remove_nesting(name, module=op, modname="op", strict=False) + if not _hasop(opmodule, funcname): + if lazy: + opmodule._delayed[funcname] = module + else: + setattr(opmodule, funcname, unary_op) + if not cls._initialized: # pragma: no cover + _STANDARD_OPERATOR_NAMES.add(f"{cls._modname}.{name}") + if not lazy: + return unary_op + + @classmethod + def _initialize(cls): + if cls._initialized: + return + super()._initialize() + # Update type information with sane coercion + position_dtypes = [ + BOOL, + FP32, + FP64, + INT8, + INT16, + UINT8, + UINT16, + UINT32, + UINT64, + ] + if _supports_complex: + position_dtypes.extend([FC32, FC64]) + for names, *types in [ + # fmt: off + ( + ( + "erf", "erfc", "lgamma", "tgamma", "acos", "acosh", "asin", "asinh", + "atan", "atanh", "ceil", "cos", "cosh", "exp", "exp2", "expm1", "floor", + "log", "log10", "log1p", "log2", "round", "signum", "sin", "sinh", "sqrt", + "tan", "tanh", "trunc", "cbrt", + ), + ((BOOL, INT8, INT16, UINT8, UINT16), FP32), + ((INT32, INT64, UINT32, UINT64), FP64), + ), + ( + ("positioni", "positioni1", "positionj", "positionj1"), + ( + position_dtypes, + INT64, + ), + ), + # fmt: on + ]: + for name in names: + if name in _SS_OPERATORS: + op = unary._deprecated[name] + else: + op = getattr(unary, name) + for input_types, target_type in types: + typed_op = op._typed_ops[target_type] + output_type = op.types[target_type] + for dtype in input_types: + if dtype not in op.types: # pragma: no branch (safety) + op.types[dtype] = output_type + op._typed_ops[dtype] = typed_op + op.coercions[dtype] = target_type + # Allow some functions to work on UDTs + for unop, func in [ + (unary.identity, _identity), + (unary.one, _one), + ]: + unop.orig_func = func + if _has_numba: + unop._numba_func = numba.njit(func) + else: + unop._numba_func = None + unop._udt_types = {} + unop._udt_ops = {} + cls._initialized = True + + def __init__( + self, + name, + func=None, + *, + anonymous=False, + is_positional=False, + is_udt=False, + numba_func=None, + ): + super().__init__(name, anonymous=anonymous) + self.orig_func = func + self._numba_func = numba_func + self.is_positional = is_positional + self._is_udt = is_udt + if is_udt: + self._udt_types = {} # {dtype: DataType} + self._udt_ops = {} # {dtype: TypedUserUnaryOp} + + def __reduce__(self): + if self._anonymous: + if hasattr(self.orig_func, "_parameterized_info"): + return (_deserialize_parameterized, self.orig_func._parameterized_info) + return (self.register_anonymous, (self.orig_func, self.name)) + if (name := f"unary.{self.name}") in _STANDARD_OPERATOR_NAMES: + return name + return (self._deserialize, (self.name, self.orig_func)) + + __call__ = TypedBuiltinUnaryOp.__call__ diff --git a/graphblas/core/operator/utils.py b/graphblas/core/operator/utils.py new file mode 100644 index 000000000..1442a9b5e --- /dev/null +++ b/graphblas/core/operator/utils.py @@ -0,0 +1,476 @@ +from types import BuiltinFunctionType, FunctionType, ModuleType + +from ... import backend, binary, config, indexunary, monoid, op, select, semiring, unary +from ...dtypes import UINT64, lookup_dtype, unify +from ..expr import InfixExprBase +from .base import ( + _SS_OPERATORS, + OpBase, + OpPath, + ParameterizedUdf, + TypedOpBase, + _builtin_to_op, + _hasop, + find_opclass, +) +from .binary import BinaryOp +from .indexunary import IndexUnaryOp +from .monoid import Monoid +from .select import SelectOp +from .semiring import Semiring +from .unary import UnaryOp + +# Now initialize all the things! +try: + UnaryOp._initialize() + IndexUnaryOp._initialize() + SelectOp._initialize() + BinaryOp._initialize() + Monoid._initialize() + Semiring._initialize() +except Exception: # pragma: no cover (debug) + # Exceptions here can often get ignored by Python + import traceback + + traceback.print_exc() + raise + + +def get_typed_op(op, dtype, dtype2=None, *, is_left_scalar=False, is_right_scalar=False, kind=None): + if isinstance(op, OpBase): + # UDTs always get compiled + if op._is_udt: + return op._compile_udt(dtype, dtype2) + # Single dtype is simple lookup + if dtype2 is None: + return op[dtype] + # Handle special cases such as first and second (may have UDTs) + if op._custom_dtype is not None and (rv := op._custom_dtype(op, dtype, dtype2)) is not None: + return rv + # Generic case: try to unify the two dtypes + try: + return op[ + unify(dtype, dtype2, is_left_scalar=is_left_scalar, is_right_scalar=is_right_scalar) + ] + except (TypeError, AttributeError): + # Failure to unify implies a dtype is UDT; some builtin operators can handle UDTs + if op.is_positional: + return op[UINT64] + if op._udt_types is None: + raise + return op._compile_udt(dtype, dtype2) + if isinstance(op, ParameterizedUdf): + op = op() # Use default parameters of parameterized UDFs + return get_typed_op( + op, + dtype, + dtype2, + is_left_scalar=is_left_scalar, + is_right_scalar=is_right_scalar, + kind=kind, + ) + if isinstance(op, TypedOpBase): + return op + + from .agg import Aggregator, TypedAggregator + + if isinstance(op, Aggregator): + # agg._any_dtype basically serves the same purpose as op._custom_dtype + if op._any_dtype is not None and op._any_dtype is not True: + return op[op._any_dtype] + return op[dtype] + if isinstance(op, TypedAggregator): + return op + if isinstance(op, str): + if kind == "unary": + op = unary_from_string(op) + elif kind == "select": + op = select_from_string(op) + elif kind == "binary": + op = binary_from_string(op) + elif kind == "monoid": + op = monoid_from_string(op) + elif kind == "semiring": + op = semiring_from_string(op) + elif kind == "binary|aggregator": + try: + op = binary_from_string(op) + except ValueError: + try: + op = aggregator_from_string(op) + except ValueError: + raise ValueError( + f"Unknown binary or aggregator string: {op!r}. Example usage: '+[int]'" + ) from None + + else: + raise ValueError( + f"Unable to get op from string {op!r}. `kind=` argument must be provided as " + '"unary", "binary", "monoid", "semiring", "indexunary", "select", ' + 'or "binary|aggregator".' + ) + return get_typed_op( + op, + dtype, + dtype2, + is_left_scalar=is_left_scalar, + is_right_scalar=is_right_scalar, + kind=kind, + ) + if isinstance(op, FunctionType): + if kind == "unary": + op = UnaryOp.register_anonymous(op, is_udt=True) + return op._compile_udt(dtype, dtype2) + if kind.startswith("binary"): + op = BinaryOp.register_anonymous(op, is_udt=True) + return op._compile_udt(dtype, dtype2) + if isinstance(op, BuiltinFunctionType) and op in _builtin_to_op: + return get_typed_op( + _builtin_to_op[op], + dtype, + dtype2, + is_left_scalar=is_left_scalar, + is_right_scalar=is_right_scalar, + kind=kind, + ) + raise TypeError(f"Unable to get typed operator from object with type {type(op)}") + + +def _get_typed_op_from_exprs(op, left, right, *, kind=None): + if isinstance(left, InfixExprBase): + left_op = _get_typed_op_from_exprs(op, left.left, left.right, kind=kind) + left_dtype = left_op.type + else: + left_op = None + left_dtype = left.dtype + if isinstance(right, InfixExprBase): + right_op = _get_typed_op_from_exprs(op, right.left, right.right, kind=kind) + if right_op is left_op: + return right_op + right_dtype = right_op.type2 + else: + right_dtype = right.dtype + return get_typed_op( + op, + left_dtype, + right_dtype, + is_left_scalar=left._is_scalar, + is_right_scalar=right._is_scalar, + kind=kind, + ) + + +def get_semiring(monoid, binaryop, name=None): + """Get or create a Semiring object from a monoid and binaryop. + + If either are typed, then the returned semiring will also be typed. + + See Also + -------- + semiring.register_anonymous + semiring.register_new + semiring.from_string + + """ + monoid, opclass = find_opclass(monoid) + switched = False + if opclass == "BinaryOp" and monoid.monoid is not None: + switched = True + monoid = monoid.monoid + elif opclass != "Monoid": + raise TypeError(f"Expected a Monoid for the monoid argument. Got type: {type(monoid)}") + binaryop, opclass = find_opclass(binaryop) + if opclass == "Monoid": + if switched: + raise TypeError( + "Got a BinaryOp for the monoid argument and a Monoid for the binaryop argument. " + "Are the arguments switched? Hint: you can do `mymonoid.binaryop` to get the " + "binaryop from a monoid." + ) + binaryop = binaryop.binaryop + elif opclass != "BinaryOp": + raise TypeError( + f"Expected a BinaryOp for the binaryop argument. Got type: {type(binaryop)}" + ) + if isinstance(monoid, Monoid): + monoid_type = None + else: + monoid_type = monoid.type + monoid = monoid.parent + if isinstance(binaryop, BinaryOp): + binary_type = None + else: + binary_type = binaryop.type + binaryop = binaryop.parent + if monoid._anonymous or binaryop._anonymous: + rv = Semiring.register_anonymous(monoid, binaryop, name=name) + else: + *monoid_prefix, monoid_name = monoid.name.rsplit(".", 1) + *binary_prefix, binary_name = binaryop.name.rsplit(".", 1) + if ( + monoid_prefix + and binary_prefix + and monoid_prefix == binary_prefix + or config.get("mapnumpy") + and ( + monoid_prefix == ["numpy"] + and not binary_prefix + or binary_prefix == ["numpy"] + and not monoid_prefix + ) + or backend == "suitesparse" + and binary_name in _SS_OPERATORS + ): + canonical_name = ( + ".".join(monoid_prefix or binary_prefix) + f".{monoid_name}_{binary_name}" + ) + else: + canonical_name = f"{monoid.name}_{binaryop.name}".replace(".", "_") + if name is None: + name = canonical_name + + module, funcname = Semiring._remove_nesting(canonical_name, strict=False) + rv = ( + getattr(module, funcname) + if funcname in module.__dict__ or funcname in module._delayed + else getattr(module, "_deprecated", {}).get(funcname) + ) + if rv is None and name != canonical_name: + module, funcname = Semiring._remove_nesting(name, strict=False) + rv = ( + getattr(module, funcname) + if funcname in module.__dict__ or funcname in module._delayed + else getattr(module, "_deprecated", {}).get(funcname) + ) + if rv is None: + rv = Semiring.register_new(canonical_name, monoid, binaryop) + elif rv.monoid is not monoid or rv.binaryop is not binaryop: # pragma: no cover + # It's not the object we expect (can this happen?) + rv = Semiring.register_anonymous(monoid, binaryop, name=name) + if name != canonical_name: + module, funcname = Semiring._remove_nesting(name, strict=False) + if not _hasop(module, funcname): # pragma: no branch (safety) + setattr(module, funcname, rv) + + if binary_type is not None: + return rv[binary_type] + if monoid_type is not None: + return rv[monoid_type] + return rv + + +unary.register_new = UnaryOp.register_new +unary.register_anonymous = UnaryOp.register_anonymous +indexunary.register_new = IndexUnaryOp.register_new +indexunary.register_anonymous = IndexUnaryOp.register_anonymous +select.register_new = SelectOp.register_new +select.register_anonymous = SelectOp.register_anonymous +binary.register_new = BinaryOp.register_new +binary.register_anonymous = BinaryOp.register_anonymous +monoid.register_new = Monoid.register_new +monoid.register_anonymous = Monoid.register_anonymous +semiring.register_new = Semiring.register_new +semiring.register_anonymous = Semiring.register_anonymous +semiring.get_semiring = get_semiring + +select._binary_to_select.update( + { + binary.eq: select.valueeq, + binary.ne: select.valuene, + binary.le: select.valuele, + binary.lt: select.valuelt, + binary.ge: select.valuege, + binary.gt: select.valuegt, + binary.iseq: select.valueeq, + binary.isne: select.valuene, + binary.isle: select.valuele, + binary.islt: select.valuelt, + binary.isge: select.valuege, + binary.isgt: select.valuegt, + } +) + +_builtin_to_op.update( + { + abs: unary.abs, + max: binary.max, + min: binary.min, + # Maybe someday: all, any, pow, sum + } +) + +_str_to_unary = { + "-": unary.ainv, + "~": unary.lnot, +} +_str_to_select = { + "<": select.valuelt, + ">": select.valuegt, + "<=": select.valuele, + ">=": select.valuege, + "!=": select.valuene, + "==": select.valueeq, + "col<=": select.colle, + "col>": select.colgt, + "row<=": select.rowle, + "row>": select.rowgt, + "index<=": select.indexle, + "index>": select.indexgt, +} +_str_to_binary = { + "<": binary.lt, + ">": binary.gt, + "<=": binary.le, + ">=": binary.ge, + "!=": binary.ne, + "==": binary.eq, + "+": binary.plus, + "-": binary.minus, + "*": binary.times, + "/": binary.truediv, + "//": "floordiv", + "%": "numpy.mod", + "**": binary.pow, + "&": binary.land, + "|": binary.lor, + "^": binary.lxor, +} +_str_to_monoid = { + "==": monoid.eq, + "+": monoid.plus, + "*": monoid.times, + "&": monoid.land, + "|": monoid.lor, + "^": monoid.lxor, +} + + +def _from_string(string, module, mapping, example): + s = string.lower().strip() + base, *dtype = s.split("[") + if len(dtype) > 1: + name = module.__name__.split(".")[-1] + raise ValueError( + f'Bad {name} string: {string!r}. Contains too many "[". Example usage: {example!r}' + ) + if dtype: + dtype = dtype[0] + if not dtype.endswith("]"): + name = module.__name__.split(".")[-1] + raise ValueError( + f'Bad {name} string: {string!r}. Datatype specification does not end with "]". ' + f"Example usage: {example!r}" + ) + dtype = lookup_dtype(dtype[:-1]) + if "]" in base: + name = module.__name__.split(".")[-1] + raise ValueError( + f'Bad {name} string: {string!r}. "]" not matched by "[". Example usage: {example!r}' + ) + if base in mapping: + op = mapping[base] + if isinstance(op, str): + op = mapping[base] = module.from_string(op) + elif hasattr(module, base): + op = getattr(module, base) + elif hasattr(module, "numpy") and hasattr(module.numpy, base): + op = getattr(module.numpy, base) + else: + *paths, attr = base.split(".") + op = None + cur = module + for path in paths: + cur = getattr(cur, path, None) + if not isinstance(cur, (OpPath, ModuleType)): + cur = None + break + op = getattr(cur, attr, None) + if op is None: + name = module.__name__.split(".")[-1] + raise ValueError(f"Unknown {name} string: {string!r}. Example usage: {example!r}") + if dtype: + op = op[dtype] + return op + + +def unary_from_string(string): + return _from_string(string, unary, _str_to_unary, "abs[int]") + + +def indexunary_from_string(string): + # "select" is a variant of IndexUnary, so the string abbreviations in + # _str_to_select are appropriate to reuse here + return _from_string(string, indexunary, _str_to_select, "row_index") + + +def select_from_string(string): + return _from_string(string, select, _str_to_select, "tril") + + +def binary_from_string(string): + return _from_string(string, binary, _str_to_binary, "+[int]") + + +def monoid_from_string(string): + return _from_string(string, monoid, _str_to_monoid, "+[int]") + + +def semiring_from_string(string): + split = string.split(".") + if len(split) == 1: + try: + return _from_string(string, semiring, {}, "min.+[int]") + except Exception: + pass + if len(split) != 2: + raise ValueError( + f"Bad semiring string: {string!r}. " + 'The monoid and binaryop should be separated by exactly one period, ".". ' + "Example usage: min.+[int]" + ) + cur_monoid = monoid_from_string(split[0]) + cur_binary = binary_from_string(split[1]) + return get_semiring(cur_monoid, cur_binary) + + +def op_from_string(string): + for func in [ + # Note: order matters here + unary_from_string, + binary_from_string, + monoid_from_string, + semiring_from_string, + indexunary_from_string, + select_from_string, + aggregator_from_string, + ]: + try: + return func(string) + except Exception: + pass + raise ValueError(f"Unknown op string: {string!r}. Example usage: 'abs[int]'") + + +unary.from_string = unary_from_string +indexunary.from_string = indexunary_from_string +select.from_string = select_from_string +binary.from_string = binary_from_string +monoid.from_string = monoid_from_string +semiring.from_string = semiring_from_string +op.from_string = op_from_string + +_str_to_agg = { + "+": "sum", + "*": "prod", + "&": "all", + "|": "any", +} + + +def aggregator_from_string(string): + return _from_string(string, agg, _str_to_agg, "sum[int]") + + +from ... import agg # noqa: E402 isort:skip + +agg.from_string = aggregator_from_string diff --git a/graphblas/core/recorder.py b/graphblas/core/recorder.py index 455166544..ca776f697 100644 --- a/graphblas/core/recorder.py +++ b/graphblas/core/recorder.py @@ -3,7 +3,6 @@ from ..dtypes import DataType from . import base, lib from .base import _recorder -from .formatting import CSS_STYLE from .mask import Mask from .matrix import TransposedMatrix from .operator import TypedOpBase @@ -35,7 +34,7 @@ def gbstr(arg): class Recorder: """Record GraphBLAS C calls. - The recorder can use `.start()` and `.stop()` to enable/disable recording, + The recorder can use ``.start()`` and ``.stop()`` to enable/disable recording, or it can be used as a context manager. For example, @@ -103,6 +102,8 @@ def is_recording(self): return self._token is not None and _recorder.get(base._prev_recorder) is self def _repr_base_(self): + from .formatting import CSS_STYLE + status = ( '
`__. + """ how = how.lower() if how == "materialize": @@ -477,6 +501,7 @@ def get(self, default=None): Returns ------- Python scalar + """ return default if self._is_empty else self.value @@ -500,6 +525,7 @@ def from_value(cls, value, dtype=None, *, is_cscalar=False, name=None): Returns ------- Scalar + """ typ = output_type(value) if dtype is None: @@ -609,8 +635,25 @@ def ewise_add(self, other, op=monoid.plus): # Functional syntax c << monoid.max(a | b) + """ + return self._ewise_add(other, op) + + def _ewise_add(self, other, op=monoid.plus, is_infix=False): method_name = "ewise_add" + if is_infix: + from .infix import ScalarEwiseAddExpr + + # This is a little different than how we handle ewise_add for Vector and + # Matrix where we are super-careful to handle dtypes well to support UDTs. + # For Scalar, we're going to let dtypes in expressions resolve themselves. + # Scalars are more challenging, because they may be literal scalars. + # Also, we have not yet resolved `op` here, so errors may be different. + if isinstance(self, ScalarEwiseAddExpr): + self = op(self).new() + if isinstance(other, ScalarEwiseAddExpr): + other = op(other).new() + if type(other) is not Scalar: dtype = self.dtype if self.dtype._is_udt else None try: @@ -663,8 +706,25 @@ def ewise_mult(self, other, op=binary.times): # Functional syntax c << binary.gt(a & b) + """ + return self._ewise_mult(other, op) + + def _ewise_mult(self, other, op=binary.times, is_infix=False): method_name = "ewise_mult" + if is_infix: + from .infix import ScalarEwiseMultExpr + + # This is a little different than how we handle ewise_mult for Vector and + # Matrix where we are super-careful to handle dtypes well to support UDTs. + # For Scalar, we're going to let dtypes in expressions resolve themselves. + # Scalars are more challenging, because they may be literal scalars. + # Also, we have not yet resolved `op` here, so errors may be different. + if isinstance(self, ScalarEwiseMultExpr): + self = op(self).new() + if isinstance(other, ScalarEwiseMultExpr): + other = op(other).new() + if type(other) is not Scalar: dtype = self.dtype if self.dtype._is_udt else None try: @@ -721,9 +781,27 @@ def ewise_union(self, other, op, left_default, right_default): # Functional syntax c << binary.div(a | b, left_default=1, right_default=1) + """ + return self._ewise_union(other, op, left_default, right_default) + + def _ewise_union(self, other, op, left_default, right_default, is_infix=False): method_name = "ewise_union" - dtype = self.dtype if self.dtype._is_udt else None + if is_infix: + from .infix import ScalarEwiseAddExpr + + # This is a little different than how we handle ewise_union for Vector and + # Matrix where we are super-careful to handle dtypes well to support UDTs. + # For Scalar, we're going to let dtypes in expressions resolve themselves. + # Scalars are more challenging, because they may be literal scalars. + # Also, we have not yet resolved `op` here, so errors may be different. + if isinstance(self, ScalarEwiseAddExpr): + self = op(self, left_default=left_default, right_default=right_default).new() + if isinstance(other, ScalarEwiseAddExpr): + other = op(other, left_default=left_default, right_default=right_default).new() + + right_dtype = self.dtype + dtype = right_dtype if right_dtype._is_udt else None if type(other) is not Scalar: try: other = Scalar.from_value(other, dtype, is_cscalar=False, name="") @@ -736,6 +814,13 @@ def ewise_union(self, other, op, left_default, right_default): extra_message="Literal scalars also accepted.", op=op, ) + else: + other = _as_scalar(other, dtype, is_cscalar=False) # pragma: is_grbscalar + + temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + + left_dtype = temp_op.type + dtype = left_dtype if left_dtype._is_udt else None if type(left_default) is not Scalar: try: left = Scalar.from_value( @@ -752,6 +837,8 @@ def ewise_union(self, other, op, left_default, right_default): ) else: left = _as_scalar(left_default, dtype, is_cscalar=False) # pragma: is_grbscalar + right_dtype = temp_op.type2 + dtype = right_dtype if right_dtype._is_udt else None if type(right_default) is not Scalar: try: right = Scalar.from_value( @@ -768,9 +855,15 @@ def ewise_union(self, other, op, left_default, right_default): ) else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - defaults_dtype = unify(left.dtype, right.dtype) - args_dtype = unify(self.dtype, other.dtype) - op = get_typed_op(op, defaults_dtype, args_dtype, kind="binary") + + op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") + op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") + if op1 is not op2: + left_dtype = unify(op1.type, op2.type, is_right_scalar=True) + right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) + op = get_typed_op(op, left_dtype, right_dtype, kind="binary") + else: + op = op1 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop @@ -786,11 +879,10 @@ def ewise_union(self, other, op, left_default, right_default): scalar_as_vector=True, ) else: - dtype = unify(defaults_dtype, args_dtype) expr = ScalarExpression( method_name, None, - [self, left, other, right, _s_union_s, (self, other, left, right, op, dtype)], + [self, left, other, right, _s_union_s, (self, other, left, right, op)], op=op, expr_repr=expr_repr, is_cscalar=False, @@ -835,6 +927,7 @@ def apply(self, op, right=None, *, left=None): # Functional syntax b << op.abs(a) + """ expr = self._as_vector().apply(op, right, left=left) return ScalarExpression( @@ -1041,7 +1134,7 @@ def _as_scalar(scalar, dtype=None, *, is_cscalar): def _dict_to_record(np_type, d): - """Converts e.g. `{"x": 1, "y": 2.3}` to `(1, 2.3)`.""" + """Converts e.g. ``{"x": 1, "y": 2.3}`` to ``(1, 2.3)``.""" rv = [] for name, (dtype, _) in np_type.fields.items(): val = d[name] diff --git a/graphblas/core/ss/__init__.py b/graphblas/core/ss/__init__.py index e69de29bb..10a6fed94 100644 --- a/graphblas/core/ss/__init__.py +++ b/graphblas/core/ss/__init__.py @@ -0,0 +1,5 @@ +import suitesparse_graphblas as _ssgb + +(version_major, version_minor, version_bug) = map(int, _ssgb.__version__.split(".")[:3]) + +_IS_SSGB7 = version_major == 7 diff --git a/graphblas/core/ss/binary.py b/graphblas/core/ss/binary.py new file mode 100644 index 000000000..d53608818 --- /dev/null +++ b/graphblas/core/ss/binary.py @@ -0,0 +1,128 @@ +from ... import backend +from ...dtypes import lookup_dtype +from ...exceptions import check_status_carg +from .. import NULL, ffi, lib +from ..operator.base import TypedOpBase +from ..operator.binary import BinaryOp, TypedUserBinaryOp +from . import _IS_SSGB7 + +ffi_new = ffi.new + + +class TypedJitBinaryOp(TypedOpBase): + __slots__ = "_monoid", "_jit_c_definition" + opclass = "BinaryOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, jit_c_definition, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, name, dtype2=dtype2) + self._monoid = None + self._jit_c_definition = jit_c_definition + + @property + def jit_c_definition(self): + return self._jit_c_definition + + monoid = TypedUserBinaryOp.monoid + commutes_to = TypedUserBinaryOp.commutes_to + _semiring_commutes_to = TypedUserBinaryOp._semiring_commutes_to + is_commutative = TypedUserBinaryOp.is_commutative + type2 = TypedUserBinaryOp.type2 + __call__ = TypedUserBinaryOp.__call__ + + +def register_new(name, jit_c_definition, left_type, right_type, ret_type): + """Register a new BinaryOp using the SuiteSparse:GraphBLAS JIT compiler. + + This creates a BinaryOp by compiling the C string definition of the function. + It requires a shell call to a C compiler. The resulting operator will be as + fast as if it were built-in to SuiteSparse:GraphBLAS and does not have the + overhead of additional function calls as when using ``gb.binary.register_new``. + + This is an advanced feature that requires a C compiler and proper configuration. + Configuration is handled by ``gb.ss.config``; see its docstring for details. + By default, the JIT caches results in ``~/.SuiteSparse/``. For more information, + see the SuiteSparse:GraphBLAS user guide. + + Only one type signature may be registered at a time, but repeated calls using + the same name with different input types is allowed. + + Parameters + ---------- + name : str + The name of the operator. This will show up as ``gb.binary.ss.{name}``. + The name may contain periods, ".", which will result in nested objects + such as ``gb.binary.ss.x.y.z`` for name ``"x.y.z"``. + jit_c_definition : str + The C definition as a string of the user-defined function. For example: + ``"void absdiff (double *z, double *x, double *y) { (*z) = fabs ((*x) - (*y)) ; }"``. + left_type : dtype + The dtype of the left operand of the binary operator. + right_type : dtype + The dtype of the right operand of the binary operator. + ret_type : dtype + The dtype of the result of the binary operator. + + Returns + ------- + BinaryOp + + See Also + -------- + gb.binary.register_new + gb.binary.register_anonymous + gb.unary.ss.register_new + + """ + if backend != "suitesparse": # pragma: no cover (safety) + raise RuntimeError( + "`gb.binary.ss.register_new` invalid when not using 'suitesparse' backend" + ) + if _IS_SSGB7: + # JIT was introduced in SuiteSparse:GraphBLAS 8.0 + import suitesparse_graphblas as ssgb + + raise RuntimeError( + "JIT was added to SuiteSparse:GraphBLAS in version 8; " + f"current version is {ssgb.__version__}" + ) + left_type = lookup_dtype(left_type) + right_type = lookup_dtype(right_type) + ret_type = lookup_dtype(ret_type) + name = name if name.startswith("ss.") else f"ss.{name}" + module, funcname = BinaryOp._remove_nesting(name, strict=False) + if hasattr(module, funcname): + rv = getattr(module, funcname) + if not isinstance(rv, BinaryOp): + BinaryOp._remove_nesting(name) + if ( + (left_type, right_type) in rv.types + or rv._udt_types is not None + and (left_type, right_type) in rv._udt_types + ): + raise TypeError( + f"BinaryOp gb.binary.{name} already defined for " + f"({left_type}, {right_type}) input types" + ) + else: + # We use `is_udt=True` to make dtype handling flexible and explicit. + rv = BinaryOp(name, is_udt=True) + gb_obj = ffi_new("GrB_BinaryOp*") + check_status_carg( + lib.GxB_BinaryOp_new( + gb_obj, + NULL, + ret_type._carg, + left_type._carg, + right_type._carg, + ffi_new("char[]", funcname.encode()), + ffi_new("char[]", jit_c_definition.encode()), + ), + "BinaryOp", + gb_obj[0], + ) + op = TypedJitBinaryOp( + rv, funcname, left_type, ret_type, gb_obj[0], jit_c_definition, dtype2=right_type + ) + rv._add(op, is_jit=True) + setattr(module, funcname, rv) + return rv diff --git a/graphblas/core/ss/config.py b/graphblas/core/ss/config.py index ca91cc198..70a7dd196 100644 --- a/graphblas/core/ss/config.py +++ b/graphblas/core/ss/config.py @@ -1,10 +1,9 @@ from collections.abc import MutableMapping -from numbers import Integral from ...dtypes import lookup_dtype from ...exceptions import _error_code_lookup, check_status from .. import NULL, ffi, lib -from ..utils import values_to_numpy_buffer +from ..utils import maybe_integral, values_to_numpy_buffer class BaseConfig(MutableMapping): @@ -12,6 +11,9 @@ class BaseConfig(MutableMapping): # Subclasses should redefine these _get_function = None _set_function = None + _context_get_function = "GxB_Context_get" + _context_set_function = "GxB_Context_set" + _context_keys = set() _null_valid = {} _options = {} _defaults = {} @@ -28,7 +30,7 @@ class BaseConfig(MutableMapping): "GxB_Format_Value", } - def __init__(self, parent=None): + def __init__(self, parent=None, context=None): cls = type(self) if not cls._initialized: cls._reverse_enumerations = {} @@ -51,6 +53,7 @@ def __init__(self, parent=None): rd[k] = k cls._initialized = True self._parent = parent + self._context = context def __delitem__(self, key): raise TypeError("Configuration options can't be deleted.") @@ -61,19 +64,27 @@ def __getitem__(self, key): raise KeyError(key) key_obj, ctype = self._options[key] is_bool = ctype == "bool" + if is_context := (key in self._context_keys): + get_function_base = self._context_get_function + else: + get_function_base = self._get_function if ctype in self._int32_ctypes: ctype = "int32_t" - get_function_name = f"{self._get_function}_INT32" + get_function_name = f"{get_function_base}_INT32" elif ctype.startswith("int64_t"): - get_function_name = f"{self._get_function}_INT64" + get_function_name = f"{get_function_base}_INT64" elif ctype.startswith("double"): - get_function_name = f"{self._get_function}_FP64" + get_function_name = f"{get_function_base}_FP64" + elif ctype.startswith("char"): + get_function_name = f"{get_function_base}_CHAR" else: # pragma: no cover (sanity) raise ValueError(ctype) get_function = getattr(lib, get_function_name) is_array = "[" in ctype val_ptr = ffi.new(ctype if is_array else f"{ctype}*") - if self._parent is None: + if is_context: + info = get_function(self._context._carg, key_obj, val_ptr) + elif self._parent is None: info = get_function(key_obj, val_ptr) else: info = get_function(self._parent._carg, key_obj, val_ptr) @@ -88,11 +99,13 @@ def __getitem__(self, key): return {reverse_bitwise[val]} rv = set() for k, v in self._bitwise[key].items(): - if isinstance(k, str) and val & v and bin(v).count("1") == 1: + if isinstance(k, str) and val & v and (v).bit_count() == 1: rv.add(k) return rv if is_bool: return bool(val_ptr[0]) + if ctype.startswith("char"): + return ffi.string(val_ptr[0]).decode() return val_ptr[0] raise _error_code_lookup[info](f"Failed to get info for {key!r}") # pragma: no cover @@ -103,15 +116,21 @@ def __setitem__(self, key, val): if key in self._read_only: raise ValueError(f"Config option {key!r} is read-only") key_obj, ctype = self._options[key] + if is_context := (key in self._context_keys): + set_function_base = self._context_set_function + else: + set_function_base = self._set_function if ctype in self._int32_ctypes: ctype = "int32_t" - set_function_name = f"{self._set_function}_INT32" + set_function_name = f"{set_function_base}_INT32" elif ctype == "double": - set_function_name = f"{self._set_function}_FP64" + set_function_name = f"{set_function_base}_FP64" elif ctype.startswith("int64_t["): - set_function_name = f"{self._set_function}_INT64_ARRAY" + set_function_name = f"{set_function_base}_INT64_ARRAY" elif ctype.startswith("double["): - set_function_name = f"{self._set_function}_FP64_ARRAY" + set_function_name = f"{set_function_base}_FP64_ARRAY" + elif ctype.startswith("char"): + set_function_name = f"{set_function_base}_CHAR" else: # pragma: no cover (sanity) raise ValueError(ctype) set_function = getattr(lib, set_function_name) @@ -127,8 +146,8 @@ def __setitem__(self, key, val): bitwise = self._bitwise[key] if isinstance(val, str): val = bitwise[val.lower()] - elif isinstance(val, Integral): - val = bitwise.get(val, val) + elif (x := maybe_integral(val)) is not None: + val = bitwise.get(x, x) else: bits = 0 for x in val: @@ -154,9 +173,19 @@ def __setitem__(self, key, val): f"expected {size}, got {vals.size}: {val}" ) val_obj = ffi.from_buffer(ctype, vals) + elif ctype.startswith("char"): + val_obj = ffi.new("char[]", val.encode()) else: val_obj = ffi.cast(ctype, val) - if self._parent is None: + if is_context: + if self._context is None: + from .context import Context + + self._context = Context(engage=False) + self._context._engage() # Disengage when context goes out of scope + self._parent._context = self._context # Set context to descriptor + info = set_function(self._context._carg, key_obj, val_obj) + elif self._parent is None: info = set_function(key_obj, val_obj) else: info = set_function(self._parent._carg, key_obj, val_obj) @@ -174,7 +203,12 @@ def __len__(self): return len(self._options) def __repr__(self): - return "{" + ",\n ".join(f"{k!r}: {v!r}" for k, v in self.items()) + "}" + return ( + type(self).__name__ + + "({" + + ",\n ".join(f"{k!r}: {v!r}" for k, v in self.items()) + + "})" + ) def _ipython_key_completions_(self): # pragma: no cover (ipython) return list(self) diff --git a/graphblas/core/ss/context.py b/graphblas/core/ss/context.py new file mode 100644 index 000000000..f93d1ec1c --- /dev/null +++ b/graphblas/core/ss/context.py @@ -0,0 +1,147 @@ +import threading + +from ...exceptions import InvalidValue, check_status, check_status_carg +from .. import ffi, lib +from . import _IS_SSGB7 +from .config import BaseConfig + +ffi_new = ffi.new +if _IS_SSGB7: + # Context was introduced in SuiteSparse:GraphBLAS 8.0 + import suitesparse_graphblas as ssgb + + raise ImportError( + "Context was added to SuiteSparse:GraphBLAS in version 8; " + f"current version is {ssgb.__version__}" + ) + + +class Context(BaseConfig): + _context_keys = {"chunk", "gpu_id", "nthreads"} + _options = { + "chunk": (lib.GxB_CONTEXT_CHUNK, "double"), + "gpu_id": (lib.GxB_CONTEXT_GPU_ID, "int"), + "nthreads": (lib.GxB_CONTEXT_NTHREADS, "int"), + } + _defaults = { + "nthreads": 0, + "chunk": 0, + "gpu_id": -1, # -1 means no GPU + } + + def __init__(self, engage=True, *, stack=True, nthreads=None, chunk=None, gpu_id=None): + super().__init__() + self.gb_obj = ffi_new("GxB_Context*") + check_status_carg(lib.GxB_Context_new(self.gb_obj), "Context", self.gb_obj[0]) + if stack: + context = threadlocal.context + self["nthreads"] = context["nthreads"] if nthreads is None else nthreads + self["chunk"] = context["chunk"] if chunk is None else chunk + self["gpu_id"] = context["gpu_id"] if gpu_id is None else gpu_id + else: + if nthreads is not None: + self["nthreads"] = nthreads + if chunk is not None: + self["chunk"] = chunk + if gpu_id is not None: + self["gpu_id"] = gpu_id + self._prev_context = None + if engage: + self.engage() + + @classmethod + def _from_obj(cls, gb_obj=None): + self = object.__new__(cls) + self.gb_obj = gb_obj + self._prev_context = None + super().__init__(self) + return self + + @property + def _carg(self): + return self.gb_obj[0] + + def dup(self, engage=True, *, nthreads=None, chunk=None, gpu_id=None): + if nthreads is None: + nthreads = self["nthreads"] + if chunk is None: + chunk = self["chunk"] + if gpu_id is None: + gpu_id = self["gpu_id"] + return type(self)(engage, stack=False, nthreads=nthreads, chunk=chunk, gpu_id=gpu_id) + + def __del__(self): + gb_obj = getattr(self, "gb_obj", None) + if gb_obj is not None and lib is not None: # pragma: no branch (safety) + try: + self.disengage() + except InvalidValue: + pass + lib.GxB_Context_free(gb_obj) + + def engage(self): + if self._prev_context is None and (context := threadlocal.context) is not self: + self._prev_context = context + check_status(lib.GxB_Context_engage(self._carg), self) + threadlocal.context = self + + def _engage(self): + """Like engage, but don't set to threadlocal.context. + + This is useful if you want to disengage when the object is deleted by going out of scope. + """ + if self._prev_context is None and (context := threadlocal.context) is not self: + self._prev_context = context + check_status(lib.GxB_Context_engage(self._carg), self) + + def disengage(self): + prev_context = self._prev_context + self._prev_context = None + if threadlocal.context is self: + if prev_context is not None: + threadlocal.context = prev_context + prev_context.engage() + else: + threadlocal.context = global_context + check_status(lib.GxB_Context_disengage(self._carg), self) + elif prev_context is not None and threadlocal.context is prev_context: + prev_context.engage() + else: + check_status(lib.GxB_Context_disengage(self._carg), self) + + def __enter__(self): + self.engage() + return self + + def __exit__(self, exc_type, exc, exc_tb): + self.disengage() + + @property + def _context(self): + return self + + @_context.setter + def _context(self, val): + if val is not None and val is not self: + raise AttributeError("'_context' attribute is read-only") + + +class GlobalContext(Context): + @property + def _carg(self): + return self.gb_obj + + def __del__(self): # pragma: no cover (safety) + pass + + +global_context = GlobalContext._from_obj(lib.GxB_CONTEXT_WORLD) + + +class ThreadLocal(threading.local): + """Hold the active context for the current thread.""" + + context = global_context + + +threadlocal = ThreadLocal() diff --git a/graphblas/core/ss/descriptor.py b/graphblas/core/ss/descriptor.py index dffc4dec1..781661b7b 100644 --- a/graphblas/core/ss/descriptor.py +++ b/graphblas/core/ss/descriptor.py @@ -1,6 +1,7 @@ from ...exceptions import check_status, check_status_carg from .. import ffi, lib from ..descriptor import Descriptor +from . import _IS_SSGB7 from .config import BaseConfig ffi_new = ffi.new @@ -18,6 +19,8 @@ class _DescriptorConfig(BaseConfig): _get_function = "GxB_Desc_get" _set_function = "GxB_Desc_set" + if not _IS_SSGB7: + _context_keys = {"chunk", "gpu_id", "nthreads"} _options = { # GrB "output_replace": (lib.GrB_OUTP, "GrB_Desc_Value"), @@ -26,13 +29,25 @@ class _DescriptorConfig(BaseConfig): "transpose_first": (lib.GrB_INP0, "GrB_Desc_Value"), "transpose_second": (lib.GrB_INP1, "GrB_Desc_Value"), # GxB - "nthreads": (lib.GxB_DESCRIPTOR_NTHREADS, "int"), - "chunk": (lib.GxB_DESCRIPTOR_CHUNK, "double"), "axb_method": (lib.GxB_AxB_METHOD, "GrB_Desc_Value"), "sort": (lib.GxB_SORT, "int"), "secure_import": (lib.GxB_IMPORT, "int"), - # "gpu_control": (GxB_DESCRIPTOR_GPU_CONTROL, "GrB_Desc_Value"), # Coming soon... } + if _IS_SSGB7: + _options.update( + { + "nthreads": (lib.GxB_DESCRIPTOR_NTHREADS, "int"), + "chunk": (lib.GxB_DESCRIPTOR_CHUNK, "double"), + } + ) + else: + _options.update( + { + "chunk": (lib.GxB_CONTEXT_CHUNK, "double"), + "gpu_id": (lib.GxB_CONTEXT_GPU_ID, "int"), + "nthreads": (lib.GxB_CONTEXT_NTHREADS, "int"), + } + ) _enumerations = { # GrB "output_replace": { @@ -71,10 +86,6 @@ class _DescriptorConfig(BaseConfig): False: False, True: lib.GxB_SORT, }, - # "gpu_control": { # Coming soon... - # "always": lib.GxB_GPU_ALWAYS, - # "never": lib.GxB_GPU_NEVER, - # }, } _defaults = { # GrB @@ -90,7 +101,8 @@ class _DescriptorConfig(BaseConfig): "sort": False, "secure_import": False, } - _count = 0 + if not _IS_SSGB7: + _defaults["gpu_id"] = -1 def __init__(self): gb_obj = ffi_new("GrB_Descriptor*") @@ -132,7 +144,7 @@ def get_descriptor(**opts): sort : bool, default False A hint for whether methods may return a "jumbled" matrix secure_import : bool, default False - Whether to trust the data for `import` and `pack` functions. + Whether to trust the data for ``import`` and ``pack`` functions. When True, checks are performed to ensure input data is valid. compression : str, {"none", "default", "lz4", "lz4hc", "zstd"} Whether and how to compress the data for serialization. @@ -145,6 +157,7 @@ def get_descriptor(**opts): Returns ------- Descriptor or None + """ if not opts or all(val is False or val is None for val in opts.values()): return diff --git a/graphblas/core/ss/dtypes.py b/graphblas/core/ss/dtypes.py new file mode 100644 index 000000000..d2eb5b416 --- /dev/null +++ b/graphblas/core/ss/dtypes.py @@ -0,0 +1,88 @@ +import numpy as np + +from ... import backend, core, dtypes +from ...exceptions import check_status_carg +from .. import _has_numba, ffi, lib +from . import _IS_SSGB7 + +ffi_new = ffi.new +if _has_numba: + import numba + from cffi import FFI + from numba.core.typing import cffi_utils + + jit_ffi = FFI() + + +def register_new(name, jit_c_definition, *, np_type=None): + if backend != "suitesparse": # pragma: no cover (safety) + raise RuntimeError( + "`gb.dtypes.ss.register_new` invalid when not using 'suitesparse' backend" + ) + if _IS_SSGB7: + # JIT was introduced in SuiteSparse:GraphBLAS 8.0 + import suitesparse_graphblas as ssgb + + raise RuntimeError( + "JIT was added to SuiteSparse:GraphBLAS in version 8; " + f"current version is {ssgb.__version__}" + ) + if not name.isidentifier(): + raise ValueError(f"`name` argument must be a valid Python identifier; got: {name!r}") + if name in core.dtypes._registry or hasattr(dtypes.ss, name): + raise ValueError(f"{name!r} name for dtype is unavailable") + if len(name) > lib.GxB_MAX_NAME_LEN: + raise ValueError( + f"`name` argument is too large. Max size is {lib.GxB_MAX_NAME_LEN}; got {len(name)}" + ) + if name not in jit_c_definition: + raise ValueError("`name` argument must be same name as the typedef in `jit_c_definition`") + if "struct" not in jit_c_definition: + raise ValueError("Only struct typedefs are currently allowed for JIT dtypes") + + gb_obj = ffi.new("GrB_Type*") + status = lib.GxB_Type_new( + gb_obj, 0, ffi_new("char[]", name.encode()), ffi_new("char[]", jit_c_definition.encode()) + ) + check_status_carg(status, "Type", gb_obj[0]) + + # Let SuiteSparse:GraphBLAS determine the size (we gave 0 as size above) + size_ptr = ffi_new("size_t*") + check_status_carg(lib.GxB_Type_size(size_ptr, gb_obj[0]), "Type", gb_obj[0]) + size = size_ptr[0] + + save_np_type = True + if np_type is None and _has_numba and numba.__version__[:5] > "0.56.": + jit_ffi.cdef(jit_c_definition) + numba_type = cffi_utils.map_type(jit_ffi.typeof(name), use_record_dtype=True) + np_type = numba_type.dtype + if np_type.itemsize != size: # pragma: no cover + raise RuntimeError( + "Size of compiled user-defined type does not match size of inferred numpy type: " + f"{size} != {np_type.itemsize} != {size}.\n\n" + f"UDT C definition: {jit_c_definition}\n" + f"numpy dtype: {np_type}\n\n" + "To get around this, you may pass `np_type=` keyword argument." + ) + else: + if np_type is not None: + np_type = np.dtype(np_type) + else: + # Not an ideal numpy type, but minimally useful + np_type = np.dtype((np.uint8, size)) + save_np_type = False + if _has_numba: + numba_type = numba.typeof(np_type).dtype + else: + numba_type = None + + # For now, let's use "opaque" unsigned bytes for the c type. + rv = core.dtypes.DataType(name, gb_obj, None, f"uint8_t[{size}]", numba_type, np_type) + core.dtypes._registry[gb_obj] = rv + if save_np_type or np_type not in core.dtypes._registry: + core.dtypes._registry[np_type] = rv + if numba_type is not None and (save_np_type or numba_type not in core.dtypes._registry): + core.dtypes._registry[numba_type] = rv + core.dtypes._registry[numba_type.name] = rv + setattr(dtypes.ss, name, rv) + return rv diff --git a/graphblas/core/ss/indexunary.py b/graphblas/core/ss/indexunary.py new file mode 100644 index 000000000..b60837acf --- /dev/null +++ b/graphblas/core/ss/indexunary.py @@ -0,0 +1,153 @@ +from ... import backend +from ...dtypes import BOOL, lookup_dtype +from ...exceptions import check_status_carg +from .. import NULL, ffi, lib +from ..operator.base import TypedOpBase +from ..operator.indexunary import IndexUnaryOp, TypedUserIndexUnaryOp +from . import _IS_SSGB7 + +ffi_new = ffi.new + + +class TypedJitIndexUnaryOp(TypedOpBase): + __slots__ = "_jit_c_definition" + opclass = "IndexUnaryOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, jit_c_definition, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, name, dtype2=dtype2) + self._jit_c_definition = jit_c_definition + + @property + def jit_c_definition(self): + return self._jit_c_definition + + thunk_type = TypedUserIndexUnaryOp.thunk_type + __call__ = TypedUserIndexUnaryOp.__call__ + + +def register_new(name, jit_c_definition, input_type, thunk_type, ret_type): + """Register a new IndexUnaryOp using the SuiteSparse:GraphBLAS JIT compiler. + + This creates a IndexUnaryOp by compiling the C string definition of the function. + It requires a shell call to a C compiler. The resulting operator will be as + fast as if it were built-in to SuiteSparse:GraphBLAS and does not have the + overhead of additional function calls as when using ``gb.indexunary.register_new``. + + This is an advanced feature that requires a C compiler and proper configuration. + Configuration is handled by ``gb.ss.config``; see its docstring for details. + By default, the JIT caches results in ``~/.SuiteSparse/``. For more information, + see the SuiteSparse:GraphBLAS user guide. + + Only one type signature may be registered at a time, but repeated calls using + the same name with different input types is allowed. + + This will also create a SelectOp operator under ``gb.select.ss`` if the return + type is boolean. + + Parameters + ---------- + name : str + The name of the operator. This will show up as ``gb.indexunary.ss.{name}``. + The name may contain periods, ".", which will result in nested objects + such as ``gb.indexunary.ss.x.y.z`` for name ``"x.y.z"``. + jit_c_definition : str + The C definition as a string of the user-defined function. For example: + ``"void diffy (double *z, double *x, GrB_Index i, GrB_Index j, double *y) "`` + ``"{ (*z) = (i + j) * fabs ((*x) - (*y)) ; }"`` + input_type : dtype + The dtype of the operand of the indexunary operator. + thunk_type : dtype + The dtype of the thunk of the indexunary operator. + ret_type : dtype + The dtype of the result of the indexunary operator. + + Returns + ------- + IndexUnaryOp + + See Also + -------- + gb.indexunary.register_new + gb.indexunary.register_anonymous + gb.select.ss.register_new + + """ + if backend != "suitesparse": # pragma: no cover (safety) + raise RuntimeError( + "`gb.indexunary.ss.register_new` invalid when not using 'suitesparse' backend" + ) + if _IS_SSGB7: + # JIT was introduced in SuiteSparse:GraphBLAS 8.0 + import suitesparse_graphblas as ssgb + + raise RuntimeError( + "JIT was added to SuiteSparse:GraphBLAS in version 8; " + f"current version is {ssgb.__version__}" + ) + input_type = lookup_dtype(input_type) + thunk_type = lookup_dtype(thunk_type) + ret_type = lookup_dtype(ret_type) + name = name if name.startswith("ss.") else f"ss.{name}" + module, funcname = IndexUnaryOp._remove_nesting(name, strict=False) + if hasattr(module, funcname): + rv = getattr(module, funcname) + if not isinstance(rv, IndexUnaryOp): + IndexUnaryOp._remove_nesting(name) + if ( + (input_type, thunk_type) in rv.types + or rv._udt_types is not None + and (input_type, thunk_type) in rv._udt_types + ): + raise TypeError( + f"IndexUnaryOp gb.indexunary.{name} already defined for " + f"({input_type}, {thunk_type}) input types" + ) + else: + # We use `is_udt=True` to make dtype handling flexible and explicit. + rv = IndexUnaryOp(name, is_udt=True) + gb_obj = ffi_new("GrB_IndexUnaryOp*") + check_status_carg( + lib.GxB_IndexUnaryOp_new( + gb_obj, + NULL, + ret_type._carg, + input_type._carg, + thunk_type._carg, + ffi_new("char[]", funcname.encode()), + ffi_new("char[]", jit_c_definition.encode()), + ), + "IndexUnaryOp", + gb_obj[0], + ) + op = TypedJitIndexUnaryOp( + rv, funcname, input_type, ret_type, gb_obj[0], jit_c_definition, dtype2=thunk_type + ) + rv._add(op, is_jit=True) + if ret_type == BOOL: + from ..operator.select import SelectOp + from .select import TypedJitSelectOp + + select_module, funcname = SelectOp._remove_nesting(name, strict=False) + if hasattr(select_module, funcname): + selectop = getattr(select_module, funcname) + if not isinstance(selectop, SelectOp): + SelectOp._remove_nesting(name) + if ( + (input_type, thunk_type) in selectop.types + or selectop._udt_types is not None + and (input_type, thunk_type) in selectop._udt_types + ): + raise TypeError( + f"SelectOp gb.select.{name} already defined for " + f"({input_type}, {thunk_type}) input types" + ) + else: + # We use `is_udt=True` to make dtype handling flexible and explicit. + selectop = SelectOp(name, is_udt=True) + op2 = TypedJitSelectOp( + selectop, funcname, input_type, ret_type, gb_obj[0], jit_c_definition, dtype2=thunk_type + ) + selectop._add(op2, is_jit=True) + setattr(select_module, funcname, selectop) + setattr(module, funcname, rv) + return rv diff --git a/graphblas/core/ss/matrix.py b/graphblas/core/ss/matrix.py index b455d760e..509c56113 100644 --- a/graphblas/core/ss/matrix.py +++ b/graphblas/core/ss/matrix.py @@ -1,18 +1,16 @@ import itertools -import warnings -import numba import numpy as np -from numba import njit from suitesparse_graphblas.utils import claim_buffer, claim_buffer_2d, unclaim_buffer import graphblas as gb from ... import binary, monoid -from ...dtypes import _INDEX, BOOL, INT64, UINT64, _string_to_dtype, lookup_dtype +from ...dtypes import _INDEX, BOOL, INT64, UINT64, lookup_dtype from ...exceptions import _error_code_lookup, check_status, check_status_carg -from .. import NULL, ffi, lib +from .. import NULL, _has_numba, ffi, lib from ..base import call +from ..dtypes import _string_to_dtype from ..operator import get_typed_op from ..scalar import Scalar, _as_scalar, _scalar_index from ..utils import ( @@ -30,6 +28,16 @@ from .config import BaseConfig from .descriptor import get_descriptor +if _has_numba: + from numba import njit, prange +else: + + def njit(func=None, **kwargs): + if func is not None: + return func + return njit + + prange = range ffi_new = ffi.new @@ -50,12 +58,12 @@ def head(matrix, n=10, dtype=None, *, sort=False): dtype = matrix.dtype else: dtype = lookup_dtype(dtype) - rows, cols, vals = zip(*itertools.islice(matrix.ss.iteritems(), n)) + rows, cols, vals = zip(*itertools.islice(matrix.ss.iteritems(), n), strict=True) return np.array(rows, np.uint64), np.array(cols, np.uint64), np.array(vals, dtype.np_type) def _concat_mn(tiles, *, is_matrix=None): - """Argument checking for `Matrix.ss.concat` and returns number of tiles in each dimension.""" + """Argument checking for ``Matrix.ss.concat`` and returns number of tiles in each dimension.""" from ..matrix import Matrix, TransposedMatrix from ..vector import Vector @@ -242,8 +250,7 @@ def orientation(self): return "rowwise" def build_diag(self, vector, k=0, **opts): - """ - GxB_Matrix_diag. + """GxB_Matrix_diag. Construct a diagonal Matrix from the given vector. Existing entries in the Matrix are discarded. @@ -253,8 +260,8 @@ def build_diag(self, vector, k=0, **opts): vector : Vector Create a diagonal from this Vector. k : int, default 0 - Diagonal in question. Use `k>0` for diagonals above the main diagonal, - and `k<0` for diagonals below the main diagonal. + Diagonal in question. Use ``k>0`` for diagonals above the main diagonal, + and ``k<0`` for diagonals below the main diagonal. See Also -------- @@ -271,15 +278,14 @@ def build_diag(self, vector, k=0, **opts): ) def split(self, chunks, *, name=None, **opts): - """ - GxB_Matrix_split. + """GxB_Matrix_split. - Split a Matrix into a 2D array of sub-matrices according to `chunks`. + Split a Matrix into a 2D array of sub-matrices according to ``chunks``. This performs the opposite operation as ``concat``. - `chunks` is short for "chunksizes" and indicates the chunk sizes for each dimension. - `chunks` may be a single integer, or a length 2 tuple or list. Example chunks: + ``chunks`` is short for "chunksizes" and indicates the chunk sizes for each dimension. + ``chunks`` may be a single integer, or a length 2 tuple or list. Example chunks: - ``chunks=10`` - Split each dimension into chunks of size 10 (the last chunk may be smaller). @@ -287,13 +293,14 @@ def split(self, chunks, *, name=None, **opts): - Split rows into chunks of size 10 and columns into chunks of size 20. - ``chunks=(None, [5, 10])`` - Don't split rows into chunks, and split columns into two chunks of size 5 and 10. - ` ``chunks=(10, [20, None])`` + - ``chunks=(10, [20, None])`` - Split columns into two chunks of size 20 and ``ncols - 20`` See Also -------- Matrix.ss.concat graphblas.ss.concat + """ from ..matrix import Matrix @@ -353,14 +360,13 @@ def _concat(self, tiles, m, n, opts): ) def concat(self, tiles, **opts): - """ - GxB_Matrix_concat. + """GxB_Matrix_concat. Concatenate a 2D list of Matrix objects into the current Matrix. Any existing values in the current Matrix will be discarded. - To concatenate into a new Matrix, use `graphblas.ss.concat`. + To concatenate into a new Matrix, use ``graphblas.ss.concat``. - Vectors may be used as `Nx1` Matrix objects. + Vectors may be used as ``Nx1`` Matrix objects. This performs the opposite operation as ``split``. @@ -368,13 +374,13 @@ def concat(self, tiles, **opts): -------- Matrix.ss.split graphblas.ss.concat + """ tiles, m, n, is_matrix = _concat_mn(tiles, is_matrix=True) self._concat(tiles, m, n, opts) def build_scalar(self, rows, columns, value): - """ - GxB_Matrix_build_Scalar. + """GxB_Matrix_build_Scalar. Like ``build``, but uses a scalar for all the values. @@ -382,6 +388,7 @@ def build_scalar(self, rows, columns, value): -------- Matrix.build Matrix.from_coo + """ rows = ints_to_numpy_buffer(rows, np.uint64, name="row indices") columns = ints_to_numpy_buffer(columns, np.uint64, name="column indices") @@ -528,14 +535,13 @@ def iteritems(self, seek=0): lib.GxB_Iterator_free(it_ptr) def export(self, format=None, *, sort=False, give_ownership=False, raw=False, **opts): - """ - GxB_Matrix_export_xxx. + """GxB_Matrix_export_xxx. Parameters ---------- format : str, optional - If `format` is not specified, this method exports in the currently stored format. - To control the export format, set `format` to one of: + If ``format`` is not specified, this method exports in the currently stored format. + To control the export format, set ``format`` to one of: - "csr" - "csc" - "hypercsr" @@ -570,7 +576,7 @@ def export(self, format=None, *, sort=False, give_ownership=False, raw=False, ** Returns ------- - dict; keys depend on `format` and `raw` arguments (see below). + dict; keys depend on ``format`` and ``raw`` arguments (see below). See Also -------- @@ -710,6 +716,7 @@ def export(self, format=None, *, sort=False, give_ownership=False, raw=False, ** >>> pieces = A.ss.export() >>> A2 = Matrix.ss.import_any(**pieces) + """ return self._export( format, @@ -721,13 +728,12 @@ def export(self, format=None, *, sort=False, give_ownership=False, raw=False, ** ) def unpack(self, format=None, *, sort=False, raw=False, **opts): - """ - GxB_Matrix_unpack_xxx. + """GxB_Matrix_unpack_xxx. - `unpack` is like `export`, except that the Matrix remains valid but empty. - `pack_*` methods are the opposite of `unpack`. + ``unpack`` is like ``export``, except that the Matrix remains valid but empty. + ``pack_*`` methods are the opposite of ``unpack``. - See `Matrix.ss.export` documentation for more details. + See ``Matrix.ss.export`` documentation for more details. """ return self._export( format, sort=sort, raw=raw, give_ownership=True, method="unpack", opts=opts @@ -888,16 +894,15 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m col_indices = claim_buffer(ffi, Aj[0], Aj_size[0] // index_dtype.itemsize, index_dtype) values = claim_buffer(ffi, Ax[0], Ax_size[0] // dtype.itemsize, dtype) if not raw: - if indptr.size > nrows + 1: + if indptr.size > nrows + 1: # pragma: no cover (suitesparse) indptr = indptr[: nrows + 1] if col_indices.size > nvals: col_indices = col_indices[:nvals] if is_iso: if values.size > 1: # pragma: no branch (suitesparse) values = values[:1] - else: - if values.size > nvals: # pragma: no branch (suitesparse) - values = values[:nvals] + elif values.size > nvals: # pragma: no branch (suitesparse) + values = values[:nvals] # Note: nvals is also at `indptr[nrows]` rv = { "indptr": indptr, @@ -930,16 +935,15 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m row_indices = claim_buffer(ffi, Ai[0], Ai_size[0] // index_dtype.itemsize, index_dtype) values = claim_buffer(ffi, Ax[0], Ax_size[0] // dtype.itemsize, dtype) if not raw: - if indptr.size > ncols + 1: + if indptr.size > ncols + 1: # pragma: no cover (suitesparse) indptr = indptr[: ncols + 1] if row_indices.size > nvals: row_indices = row_indices[:nvals] if is_iso: if values.size > 1: # pragma: no cover (suitesparse) values = values[:1] - else: - if values.size > nvals: - values = values[:nvals] + elif values.size > nvals: + values = values[:nvals] # Note: nvals is also at `indptr[ncols]` rv = { "indptr": indptr, @@ -989,9 +993,8 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m if is_iso: if values.size > 1: # pragma: no cover (suitesparse) values = values[:1] - else: - if values.size > nvals: - values = values[:nvals] + elif values.size > nvals: + values = values[:nvals] # Note: nvals is also at `indptr[nvec]` rv = { "indptr": indptr, @@ -1044,9 +1047,8 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m if is_iso: if values.size > 1: # pragma: no cover (suitesparse) values = values[:1] - else: - if values.size > nvals: - values = values[:nvals] + elif values.size > nvals: + values = values[:nvals] # Note: nvals is also at `indptr[nvec]` rv = { "indptr": indptr, @@ -1175,8 +1177,7 @@ def import_csr( name=None, **opts, ): - """ - GxB_Matrix_import_CSR. + """GxB_Matrix_import_CSR. Create a new Matrix from standard CSR format. @@ -1189,7 +1190,7 @@ def import_csr( col_indices : array-like is_iso : bool, default False Is the Matrix iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. sorted_cols : bool, default False Indicate whether the values in "col_indices" are sorted. take_ownership : bool, default False @@ -1206,7 +1207,7 @@ def import_csr( read-only and will no longer own the data. dtype : dtype, optional dtype of the new Matrix. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "csr" or None. This is included to be compatible with the dict returned from exporting. @@ -1216,6 +1217,7 @@ def import_csr( Returns ------- Matrix + """ return cls._import_csr( nrows=nrows, @@ -1252,13 +1254,12 @@ def pack_csr( name=None, **opts, ): - """ - GxB_Matrix_pack_CSR. + """GxB_Matrix_pack_CSR. - `pack_csr` is like `import_csr` except it "packs" data into an + ``pack_csr`` is like ``import_csr`` except it "packs" data into an existing Matrix. This is the opposite of ``unpack("csr")`` - See `Matrix.ss.import_csr` documentation for more details. + See ``Matrix.ss.import_csr`` documentation for more details. """ return self._import_csr( indptr=indptr, @@ -1365,8 +1366,7 @@ def import_csc( name=None, **opts, ): - """ - GxB_Matrix_import_CSC. + """GxB_Matrix_import_CSC. Create a new Matrix from standard CSC format. @@ -1379,7 +1379,7 @@ def import_csc( row_indices : array-like is_iso : bool, default False Is the Matrix iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. sorted_rows : bool, default False Indicate whether the values in "row_indices" are sorted. take_ownership : bool, default False @@ -1396,7 +1396,7 @@ def import_csc( read-only and will no longer own the data. dtype : dtype, optional dtype of the new Matrix. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "csc" or None. This is included to be compatible with the dict returned from exporting. @@ -1406,6 +1406,7 @@ def import_csc( Returns ------- Matrix + """ return cls._import_csc( nrows=nrows, @@ -1442,13 +1443,12 @@ def pack_csc( name=None, **opts, ): - """ - GxB_Matrix_pack_CSC. + """GxB_Matrix_pack_CSC. - `pack_csc` is like `import_csc` except it "packs" data into an + ``pack_csc`` is like ``import_csc`` except it "packs" data into an existing Matrix. This is the opposite of ``unpack("csc")`` - See `Matrix.ss.import_csc` documentation for more details. + See ``Matrix.ss.import_csc`` documentation for more details. """ return self._import_csc( indptr=indptr, @@ -1557,8 +1557,7 @@ def import_hypercsr( name=None, **opts, ): - """ - GxB_Matrix_import_HyperCSR. + """GxB_Matrix_import_HyperCSR. Create a new Matrix from standard HyperCSR format. @@ -1575,7 +1574,7 @@ def import_hypercsr( If not specified, will be set to ``len(rows)``. is_iso : bool, default False Is the Matrix iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. sorted_cols : bool, default False Indicate whether the values in "col_indices" are sorted. take_ownership : bool, default False @@ -1592,7 +1591,7 @@ def import_hypercsr( read-only and will no longer own the data. dtype : dtype, optional dtype of the new Matrix. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "hypercsr" or None. This is included to be compatible with the dict returned from exporting. @@ -1602,6 +1601,7 @@ def import_hypercsr( Returns ------- Matrix + """ return cls._import_hypercsr( nrows=nrows, @@ -1642,13 +1642,12 @@ def pack_hypercsr( name=None, **opts, ): - """ - GxB_Matrix_pack_HyperCSR. + """GxB_Matrix_pack_HyperCSR. - `pack_hypercsr` is like `import_hypercsr` except it "packs" data into an + ``pack_hypercsr`` is like ``import_hypercsr`` except it "packs" data into an existing Matrix. This is the opposite of ``unpack("hypercsr")`` - See `Matrix.ss.import_hypercsr` documentation for more details. + See ``Matrix.ss.import_hypercsr`` documentation for more details. """ return self._import_hypercsr( rows=rows, @@ -1781,8 +1780,7 @@ def import_hypercsc( name=None, **opts, ): - """ - GxB_Matrix_import_HyperCSC. + """GxB_Matrix_import_HyperCSC. Create a new Matrix from standard HyperCSC format. @@ -1790,6 +1788,7 @@ def import_hypercsc( ---------- nrows : int ncols : int + cols : array-like indptr : array-like values : array-like row_indices : array-like @@ -1798,7 +1797,7 @@ def import_hypercsc( If not specified, will be set to ``len(cols)``. is_iso : bool, default False Is the Matrix iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. sorted_rows : bool, default False Indicate whether the values in "row_indices" are sorted. take_ownership : bool, default False @@ -1815,7 +1814,7 @@ def import_hypercsc( read-only and will no longer own the data. dtype : dtype, optional dtype of the new Matrix. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "hypercsc" or None. This is included to be compatible with the dict returned from exporting. @@ -1825,6 +1824,7 @@ def import_hypercsc( Returns ------- Matrix + """ return cls._import_hypercsc( nrows=nrows, @@ -1865,13 +1865,12 @@ def pack_hypercsc( name=None, **opts, ): - """ - GxB_Matrix_pack_HyperCSC. + """GxB_Matrix_pack_HyperCSC. - `pack_hypercsc` is like `import_hypercsc` except it "packs" data into an + ``pack_hypercsc`` is like ``import_hypercsc`` except it "packs" data into an existing Matrix. This is the opposite of ``unpack("hypercsc")`` - See `Matrix.ss.import_hypercsc` documentation for more details. + See ``Matrix.ss.import_hypercsc`` documentation for more details. """ return self._import_hypercsc( cols=cols, @@ -2001,8 +2000,7 @@ def import_bitmapr( name=None, **opts, ): - """ - GxB_Matrix_import_BitmapR. + """GxB_Matrix_import_BitmapR. Create a new Matrix from values and bitmap (as mask) arrays. @@ -2023,7 +2021,7 @@ def import_bitmapr( If not provided, will be inferred from values or bitmap if either is 2d. is_iso : bool, default False Is the Matrix iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. take_ownership : bool, default False If True, perform a zero-copy data transfer from input numpy arrays to GraphBLAS if possible. To give ownership of the underlying @@ -2038,7 +2036,7 @@ def import_bitmapr( read-only and will no longer own the data. dtype : dtype, optional dtype of the new Matrix. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "bitmapr" or None. This is included to be compatible with the dict returned from exporting. @@ -2048,6 +2046,7 @@ def import_bitmapr( Returns ------- Matrix + """ return cls._import_bitmapr( bitmap=bitmap, @@ -2082,13 +2081,12 @@ def pack_bitmapr( name=None, **opts, ): - """ - GxB_Matrix_pack_BitmapR. + """GxB_Matrix_pack_BitmapR. - `pack_bitmapr` is like `import_bitmapr` except it "packs" data into an + ``pack_bitmapr`` is like ``import_bitmapr`` except it "packs" data into an existing Matrix. This is the opposite of ``unpack("bitmapr")`` - See `Matrix.ss.import_bitmapr` documentation for more details. + See ``Matrix.ss.import_bitmapr`` documentation for more details. """ return self._import_bitmapr( bitmap=bitmap, @@ -2194,8 +2192,7 @@ def import_bitmapc( name=None, **opts, ): - """ - GxB_Matrix_import_BitmapC. + """GxB_Matrix_import_BitmapC. Create a new Matrix from values and bitmap (as mask) arrays. @@ -2216,7 +2213,7 @@ def import_bitmapc( If not provided, will be inferred from values or bitmap if either is 2d. is_iso : bool, default False Is the Matrix iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. take_ownership : bool, default False If True, perform a zero-copy data transfer from input numpy arrays to GraphBLAS if possible. To give ownership of the underlying @@ -2231,7 +2228,7 @@ def import_bitmapc( read-only and will no longer own the data. dtype : dtype, optional dtype of the new Matrix. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "bitmapc" or None. This is included to be compatible with the dict returned from exporting. @@ -2241,6 +2238,7 @@ def import_bitmapc( Returns ------- Matrix + """ return cls._import_bitmapc( bitmap=bitmap, @@ -2275,13 +2273,12 @@ def pack_bitmapc( name=None, **opts, ): - """ - GxB_Matrix_pack_BitmapC. + """GxB_Matrix_pack_BitmapC. - `pack_bitmapc` is like `import_bitmapc` except it "packs" data into an + ``pack_bitmapc`` is like ``import_bitmapc`` except it "packs" data into an existing Matrix. This is the opposite of ``unpack("bitmapc")`` - See `Matrix.ss.import_bitmapc` documentation for more details. + See ``Matrix.ss.import_bitmapc`` documentation for more details. """ return self._import_bitmapc( bitmap=bitmap, @@ -2385,8 +2382,7 @@ def import_fullr( name=None, **opts, ): - """ - GxB_Matrix_import_FullR. + """GxB_Matrix_import_FullR. Create a new Matrix from values. @@ -2402,7 +2398,7 @@ def import_fullr( If not provided, will be inferred from values if it is 2d. is_iso : bool, default False Is the Matrix iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. take_ownership : bool, default False If True, perform a zero-copy data transfer from input numpy arrays to GraphBLAS if possible. To give ownership of the underlying @@ -2417,7 +2413,7 @@ def import_fullr( read-only and will no longer own the data. dtype : dtype, optional dtype of the new Matrix. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "fullr" or None. This is included to be compatible with the dict returned from exporting. @@ -2427,6 +2423,7 @@ def import_fullr( Returns ------- Matrix + """ return cls._import_fullr( values=values, @@ -2457,13 +2454,12 @@ def pack_fullr( name=None, **opts, ): - """ - GxB_Matrix_pack_FullR. + """GxB_Matrix_pack_FullR. - `pack_fullr` is like `import_fullr` except it "packs" data into an + ``pack_fullr`` is like ``import_fullr`` except it "packs" data into an existing Matrix. This is the opposite of ``unpack("fullr")`` - See `Matrix.ss.import_fullr` documentation for more details. + See ``Matrix.ss.import_fullr`` documentation for more details. """ return self._import_fullr( values=values, @@ -2544,8 +2540,7 @@ def import_fullc( name=None, **opts, ): - """ - GxB_Matrix_import_FullC. + """GxB_Matrix_import_FullC. Create a new Matrix from values. @@ -2561,7 +2556,7 @@ def import_fullc( If not provided, will be inferred from values if it is 2d. is_iso : bool, default False Is the Matrix iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. take_ownership : bool, default False If True, perform a zero-copy data transfer from input numpy arrays to GraphBLAS if possible. To give ownership of the underlying @@ -2576,7 +2571,7 @@ def import_fullc( read-only and will no longer own the data. dtype : dtype, optional dtype of the new Matrix. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "fullc" or None. This is included to be compatible with the dict returned from exporting. @@ -2586,6 +2581,7 @@ def import_fullc( Returns ------- Matrix + """ return cls._import_fullc( values=values, @@ -2616,13 +2612,12 @@ def pack_fullc( name=None, **opts, ): - """ - GxB_Matrix_pack_FullC. + """GxB_Matrix_pack_FullC. - `pack_fullc` is like `import_fullc` except it "packs" data into an + ``pack_fullc`` is like ``import_fullc`` except it "packs" data into an existing Matrix. This is the opposite of ``unpack("fullc")`` - See `Matrix.ss.import_fullc` documentation for more details. + See ``Matrix.ss.import_fullc`` documentation for more details. """ return self._import_fullc( values=values, @@ -2706,8 +2701,7 @@ def import_coo( name=None, **opts, ): - """ - GrB_Matrix_build_XXX and GxB_Matrix_build_Scalar. + """GrB_Matrix_build_XXX and GxB_Matrix_build_Scalar. Create a new Matrix from indices and values in coordinate format. @@ -2722,7 +2716,7 @@ def import_coo( The number of columns for the Matrix. is_iso : bool, default False Is the Matrix iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. sorted_rows : bool, default False True if rows are sorted or when (cols, rows) are sorted lexicographically sorted_cols : bool, default False @@ -2731,7 +2725,7 @@ def import_coo( Ignored. Zero-copy is not possible for "coo" format. dtype : dtype, optional dtype of the new Matrix. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "coo" or None. This is included to be compatible with the dict returned from exporting. @@ -2741,6 +2735,7 @@ def import_coo( Returns ------- Matrix + """ return cls._import_coo( rows=rows, @@ -2779,13 +2774,12 @@ def pack_coo( name=None, **opts, ): - """ - GrB_Matrix_build_XXX and GxB_Matrix_build_Scalar. + """GrB_Matrix_build_XXX and GxB_Matrix_build_Scalar. - `pack_coo` is like `import_coo` except it "packs" data into an + ``pack_coo`` is like ``import_coo`` except it "packs" data into an existing Matrix. This is the opposite of ``unpack("coo")`` - See `Matrix.ss.import_coo` documentation for more details. + See ``Matrix.ss.import_coo`` documentation for more details. """ return self._import_coo( nrows=self._parent._nrows, @@ -2892,8 +2886,7 @@ def import_coor( name=None, **opts, ): - """ - GxB_Matrix_import_CSR. + """GxB_Matrix_import_CSR. Create a new Matrix from indices and values in coordinate format. Rows must be sorted. @@ -2909,7 +2902,7 @@ def import_coor( The number of columns for the Matrix. is_iso : bool, default False Is the Matrix iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. sorted_cols : bool, default False True indicates indices are sorted by column, then row. take_ownership : bool, default False @@ -2927,7 +2920,7 @@ def import_coor( For "coor", ownership of "rows" will never change. dtype : dtype, optional dtype of the new Matrix. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "coor" or None. This is included to be compatible with the dict returned from exporting. @@ -2937,6 +2930,7 @@ def import_coor( Returns ------- Matrix + """ return cls._import_coor( rows=rows, @@ -2975,13 +2969,12 @@ def pack_coor( name=None, **opts, ): - """ - GxB_Matrix_pack_CSR. + """GxB_Matrix_pack_CSR. - `pack_coor` is like `import_coor` except it "packs" data into an + ``pack_coor`` is like ``import_coor`` except it "packs" data into an existing Matrix. This is the opposite of ``unpack("coor")`` - See `Matrix.ss.import_coor` documentation for more details. + See ``Matrix.ss.import_coor`` documentation for more details. """ return self._import_coor( rows=rows, @@ -3061,8 +3054,7 @@ def import_cooc( name=None, **opts, ): - """ - GxB_Matrix_import_CSC. + """GxB_Matrix_import_CSC. Create a new Matrix from indices and values in coordinate format. Rows must be sorted. @@ -3078,7 +3070,7 @@ def import_cooc( The number of columns for the Matrix. is_iso : bool, default False Is the Matrix iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. sorted_rows : bool, default False True indicates indices are sorted by column, then row. take_ownership : bool, default False @@ -3096,7 +3088,7 @@ def import_cooc( For "cooc", ownership of "cols" will never change. dtype : dtype, optional dtype of the new Matrix. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "cooc" or None. This is included to be compatible with the dict returned from exporting. @@ -3106,6 +3098,7 @@ def import_cooc( Returns ------- Matrix + """ return cls._import_cooc( rows=rows, @@ -3144,13 +3137,12 @@ def pack_cooc( name=None, **opts, ): - """ - GxB_Matrix_pack_CSC. + """GxB_Matrix_pack_CSC. - `pack_cooc` is like `import_cooc` except it "packs" data into an + ``pack_cooc`` is like ``import_cooc`` except it "packs" data into an existing Matrix. This is the opposite of ``unpack("cooc")`` - See `Matrix.ss.import_cooc` documentation for more details. + See ``Matrix.ss.import_cooc`` documentation for more details. """ return self._import_cooc( ncols=self._parent._ncols, @@ -3246,11 +3238,10 @@ def import_any( nvals=None, # optional **opts, ): - """ - GxB_Matrix_import_xxx. + """GxB_Matrix_import_xxx. Dispatch to appropriate import method inferred from inputs. - See the other import functions and `Matrix.ss.export`` for details. + See the other import functions and ``Matrix.ss.export`` for details. Returns ------- @@ -3275,6 +3266,7 @@ def import_any( >>> pieces = A.ss.export() >>> A2 = Matrix.ss.import_any(**pieces) + """ return cls._import_any( values=values, @@ -3344,13 +3336,12 @@ def pack_any( name=None, **opts, ): - """ - GxB_Matrix_pack_xxx. + """GxB_Matrix_pack_xxx. - `pack_any` is like `import_any` except it "packs" data into an + ``pack_any`` is like ``import_any`` except it "packs" data into an existing Matrix. This is the opposite of ``unpack()`` - See `Matrix.ss.import_any` documentation for more details. + See ``Matrix.ss.import_any`` documentation for more details. """ return self._import_any( values=values, @@ -3480,15 +3471,10 @@ def _import_any( format = "cooc" else: format = "coo" + elif isinstance(values, np.ndarray) and values.ndim == 2 and values.flags.f_contiguous: + format = "fullc" else: - if ( - isinstance(values, np.ndarray) - and values.ndim == 2 - and values.flags.f_contiguous - ): - format = "fullc" - else: - format = "fullr" + format = "fullr" else: format = format.lower() if method == "pack": @@ -3664,8 +3650,10 @@ def _import_any( def unpack_hyperhash(self, *, compute=False, name=None, **opts): """Unpacks the hyper_hash of a hypersparse matrix if possible. - Will return None if the matrix is not hypersparse or if the hash is not computed. - Use ``compute=True`` to compute the hyper_hash if the input is hypersparse. + Will return None if the matrix is not hypersparse, if the hash is not computed, + or if the hash is not needed. Use ``compute=True`` to try to compute the hyper_hash + if the input is hypersparse. The hyper_hash is optional in SuiteSparse:GraphBLAS, + so it may not be computed even with ``compute=True``. Use ``pack_hyperhash`` to move a hyper_hash matrix that was previously unpacked back into a matrix. @@ -3701,12 +3689,13 @@ def head(self, n=10, dtype=None, *, sort=False): def scan(self, op=monoid.plus, order="rowwise", *, name=None, **opts): """Perform a prefix scan across rows (default) or columns with the given monoid. - For example, use `monoid.plus` (the default) to perform a cumulative sum, - and `monoid.times` for cumulative product. Works with any monoid. + For example, use ``monoid.plus`` (the default) to perform a cumulative sum, + and ``monoid.times`` for cumulative product. Works with any monoid. Returns ------- Matrix + """ order = get_order(order) parent = self._parent @@ -3714,51 +3703,6 @@ def scan(self, op=monoid.plus, order="rowwise", *, name=None, **opts): parent = parent.T return prefix_scan(parent, op, name=name, within="scan", **opts) - def scan_columnwise(self, op=monoid.plus, *, name=None, **opts): - """Perform a prefix scan across columns with the given monoid. - - .. deprecated:: 2022.11.1 - `Matrix.ss.scan_columnwise` will be removed in a future release. - Use `Matrix.ss.scan(order="columnwise")` instead. - Will be removed in version 2023.7.0 or later - - For example, use `monoid.plus` (the default) to perform a cumulative sum, - and `monoid.times` for cumulative product. Works with any monoid. - - Returns - ------- - Matrix - """ - warnings.warn( - "`Matrix.ss.scan_columnwise` is deprecated; " - 'please use `Matrix.ss.scan(order="columnwise")` instead.', - DeprecationWarning, - stacklevel=2, - ) - return prefix_scan(self._parent.T, op, name=name, within="scan_columnwise", **opts) - - def scan_rowwise(self, op=monoid.plus, *, name=None, **opts): - """Perform a prefix scan across rows with the given monoid. - - .. deprecated:: 2022.11.1 - `Matrix.ss.scan_rowwise` will be removed in a future release. - Use `Matrix.ss.scan` instead. - Will be removed in version 2023.7.0 or later - - For example, use `monoid.plus` (the default) to perform a cumulative sum, - and `monoid.times` for cumulative product. Works with any monoid. - - Returns - ------- - Matrix - """ - warnings.warn( - "`Matrix.ss.scan_rowwise` is deprecated; please use `Matrix.ss.scan` instead.", - DeprecationWarning, - stacklevel=2, - ) - return prefix_scan(self._parent, op, name=name, within="scan_rowwise", **opts) - def flatten(self, order="rowwise", *, name=None, **opts): """Return a copy of the Matrix collapsed into a Vector. @@ -3780,6 +3724,7 @@ def flatten(self, order="rowwise", *, name=None, **opts): See Also -------- Vector.ss.reshape : copy a Vector to a Matrix. + """ rv = self.reshape(-1, 1, order=order, name=name, **opts) return rv._as_vector() @@ -3816,6 +3761,7 @@ def reshape(self, nrows, ncols=None, order="rowwise", *, inplace=False, name=Non -------- Matrix.ss.flatten : flatten a Matrix into a Vector. Vector.ss.reshape : copy a Vector to a Matrix. + """ from ..matrix import Matrix @@ -3870,6 +3816,7 @@ def selectk(self, how, k, order="rowwise", *, name=None): The number of elements to choose from each row **THIS API IS EXPERIMENTAL AND MAY CHANGE** + """ # TODO: largest, smallest, random_weighted order = get_order(order) @@ -3900,99 +3847,6 @@ def selectk(self, how, k, order="rowwise", *, name=None): k, fmt, indices, sort_axis, choose_func, is_random, do_sort, name ) - def selectk_rowwise(self, how, k, *, name=None): # pragma: no cover (deprecated) - """Select (up to) k elements from each row. - - .. deprecated:: 2022.11.1 - `Matrix.ss.selectk_rowwise` will be removed in a future release. - Use `Matrix.ss.selectk` instead. - Will be removed in version 2023.7.0 or later - - Parameters - ---------- - how : str - "random": choose k elements with equal probability - "first": choose the first k elements - "last": choose the last k elements - k : int - The number of elements to choose from each row - - **THIS API IS EXPERIMENTAL AND MAY CHANGE** - """ - warnings.warn( - "`Matrix.ss.selectk_rowwise` is deprecated; please use `Matrix.ss.selectk` instead.", - DeprecationWarning, - stacklevel=2, - ) - how = how.lower() - fmt = "hypercsr" - indices = "col_indices" - sort_axis = "sorted_cols" - if how == "random": - choose_func = choose_random - is_random = True - do_sort = False - elif how == "first": - choose_func = choose_first - is_random = False - do_sort = True - elif how == "last": - choose_func = choose_last - is_random = False - do_sort = True - else: - raise ValueError('`how` argument must be one of: "random", "first", "last"') - return self._select_random( - k, fmt, indices, sort_axis, choose_func, is_random, do_sort, name - ) - - def selectk_columnwise(self, how, k, *, name=None): # pragma: no cover (deprecated) - """Select (up to) k elements from each column. - - .. deprecated:: 2022.11.1 - `Matrix.ss.selectk_columnwise` will be removed in a future release. - Use `Matrix.ss.selectk(order="columnwise")` instead. - Will be removed in version 2023.7.0 or later - - Parameters - ---------- - how : str - - "random": choose elements with equal probability - - "first": choose the first k elements - - "last": choose the last k elements - k : int - The number of elements to choose from each column - - **THIS API IS EXPERIMENTAL AND MAY CHANGE** - """ - warnings.warn( - "`Matrix.ss.selectk_columnwise` is deprecated; " - 'please use `Matrix.ss.selectk(order="columnwise")` instead.', - DeprecationWarning, - stacklevel=2, - ) - how = how.lower() - fmt = "hypercsc" - indices = "row_indices" - sort_axis = "sorted_rows" - if how == "random": - choose_func = choose_random - is_random = True - do_sort = False - elif how == "first": - choose_func = choose_first - is_random = False - do_sort = True - elif how == "last": - choose_func = choose_last - is_random = False - do_sort = True - else: - raise ValueError('`how` argument must be one of: "random", "first", "last"') - return self._select_random( - k, fmt, indices, sort_axis, choose_func, is_random, do_sort, name - ) - def _select_random(self, k, fmt, indices, sort_axis, choose_func, is_random, do_sort, name): if k < 0: raise ValueError("negative k is not allowed") @@ -4057,92 +3911,6 @@ def compactify( indices = "row_indices" return self._compactify(how, reverse, asindex, dimname, k, fmt, indices, name) - def compactify_rowwise( - self, how="first", ncols=None, *, reverse=False, asindex=False, name=None - ): - """Shift all values to the left so all values in a row are contiguous. - - This returns a new Matrix. - - Parameters - ---------- - how : {"first", "last", "smallest", "largest", "random"}, optional - How to compress the values: - - first : take the values furthest to the left - - last : take the values furthest to the right - - smallest : take the smallest values (if tied, may take any) - - largest : take the largest values (if tied, may take any) - - random : take values randomly with equal probability and without replacement - Chosen values may not be ordered randomly - reverse : bool, default False - Reverse the values in each row when True - asindex : bool, default False - Return the column index of the value when True. If there are ties for - "smallest" and "largest", then any valid index may be returned. - ncols : int, optional - The number of columns of the returned Matrix. If not specified, then - the Matrix will be "compacted" to the smallest ncols that doesn't lose - values. - - **THIS API IS EXPERIMENTAL AND MAY CHANGE** - - See Also - -------- - Matrix.ss.sort - """ - warnings.warn( - "`Matrix.ss.compactify_rowwise` is deprecated; " - "please use `Matrix.ss.compactify` instead.", - DeprecationWarning, - stacklevel=2, - ) - return self._compactify( - how, reverse, asindex, "ncols", ncols, "hypercsr", "col_indices", name - ) - - def compactify_columnwise( - self, how="first", nrows=None, *, reverse=False, asindex=False, name=None - ): - """Shift all values to the top so all values in a column are contiguous. - - This returns a new Matrix. - - Parameters - ---------- - how : {"first", "last", "smallest", "largest", "random"}, optional - How to compress the values: - - first : take the values furthest to the top - - last : take the values furthest to the bottom - - smallest : take the smallest values (if tied, may take any) - - largest : take the largest values (if tied, may take any) - - random : take values randomly with equal probability and without replacement - Chosen values may not be ordered randomly - reverse : bool, default False - Reverse the values in each column when True - asindex : bool, default False - Return the row index of the value when True. If there are ties for - "smallest" and "largest", then any valid index may be returned. - nrows : int, optional - The number of rows of the returned Matrix. If not specified, then - the Matrix will be "compacted" to the smallest nrows that doesn't lose - values. - - **THIS API IS EXPERIMENTAL AND MAY CHANGE** - - See Also - -------- - Matrix.ss.sort - """ - warnings.warn( - "`Matrix.ss.compactify_columnwise` is deprecated; " - 'please use `Matrix.ss.compactify(order="columnwise")` instead.', - DeprecationWarning, - stacklevel=2, - ) - return self._compactify( - how, reverse, asindex, "nrows", nrows, "hypercsc", "row_indices", name - ) - def _compactify(self, how, reverse, asindex, nkey, nval, fmt, indices_name, name): how = how.lower() if how not in {"first", "last", "smallest", "largest", "random"}: @@ -4216,23 +3984,23 @@ def sort(self, op=binary.lt, order="rowwise", *, values=True, permutation=True, """GxB_Matrix_sort to sort values along the rows (default) or columns of the Matrix. Sorting moves all the elements to the left (if rowwise) or top (if columnwise) just - like `compactify`. The returned matrices will be the same shape as the input Matrix. + like ``compactify``. The returned matrices will be the same shape as the input Matrix. Parameters ---------- op : :class:`~graphblas.core.operator.BinaryOp`, optional Binary operator with a bool return type used to sort the values. - For example, `binary.lt` (the default) sorts the smallest elements first. + For example, ``binary.lt`` (the default) sorts the smallest elements first. Ties are broken according to indices (smaller first). order : {"rowwise", "columnwise"}, optional Whether to sort rowwise or columnwise. Rowwise shifts all values to the left, and columnwise shifts all values to the top. The default is "rowwise". values : bool, default=True - Whether to return values; will return `None` for values if `False`. + Whether to return values; will return ``None`` for values if ``False``. permutation : bool, default=True Whether to compute the permutation Matrix that has the original column indices (if rowwise) or row indices (if columnwise) of the sorted values. - Will return None if `False`. + Will return None if ``False``. nthreads : int, optional The maximum number of threads to use for this operation. None, 0 or negative nthreads means to use the default number of threads. @@ -4245,6 +4013,7 @@ def sort(self, op=binary.lt, order="rowwise", *, values=True, permutation=True, See Also -------- Matrix.ss.compactify + """ from ..matrix import Matrix @@ -4301,16 +4070,32 @@ def serialize(self, compression="default", level=None, **opts): None, 0 or negative nthreads means to use the default number of threads. For best performance, this function returns a numpy array with uint8 dtype. - Use `Matrix.ss.deserialize(blob)` to create a Matrix from the result of serialization + Use ``Matrix.ss.deserialize(blob)`` to create a Matrix from the result of serialization This method is intended to support all serialization options from SuiteSparse:GraphBLAS. *Warning*: Behavior of serializing UDTs is experimental and may change in a future release. + """ desc = get_descriptor(compression=compression, compression_level=level, **opts) blob_handle = ffi_new("void**") blob_size_handle = ffi_new("GrB_Index*") parent = self._parent + if parent.dtype._is_udt and hasattr(lib, "GrB_Type_get_String"): + # Get the name from the dtype and set it to the name of the matrix so we can + # recreate the UDT. This is a bit hacky and we should restore the original name. + # First get the size of name. + dtype_size = ffi_new("size_t*") + status = lib.GrB_Type_get_SIZE(parent.dtype.gb_obj[0], dtype_size, lib.GrB_NAME) + check_status_carg(status, "Type", parent.dtype.gb_obj[0]) + # Then get the name + dtype_char = ffi_new(f"char[{dtype_size[0]}]") + status = lib.GrB_Type_get_String(parent.dtype.gb_obj[0], dtype_char, lib.GrB_NAME) + check_status_carg(status, "Type", parent.dtype.gb_obj[0]) + # Then set the name + status = lib.GrB_Matrix_set_String(parent._carg, dtype_char, lib.GrB_NAME) + check_status_carg(status, "Matrix", parent._carg) + check_status( lib.GxB_Matrix_serialize( blob_handle, @@ -4327,7 +4112,7 @@ def deserialize(cls, data, dtype=None, *, name=None, **opts): """Deserialize a Matrix from bytes, buffer, or numpy array using GxB_Matrix_deserialize. The data should have been previously serialized with a compatible version of - SuiteSparse:GraphBLAS. For example, from the result of `data = matrix.ss.serialize()`. + SuiteSparse:GraphBLAS. For example, from the result of ``data = matrix.ss.serialize()``. Examples -------- @@ -4345,14 +4130,15 @@ def deserialize(cls, data, dtype=None, *, name=None, **opts): nthreads : int, optional The maximum number of threads to use when deserializing. None, 0 or negative nthreads means to use the default number of threads. + """ if isinstance(data, np.ndarray): data = ints_to_numpy_buffer(data, np.uint8) else: data = np.frombuffer(data, np.uint8) data_obj = ffi.from_buffer("void*", data) - # Get the dtype name first if dtype is None: + # Get the dtype name first (for non-UDTs) cname = ffi_new(f"char[{lib.GxB_MAX_NAME_LEN}]") info = lib.GxB_deserialize_type_name( cname, @@ -4362,6 +4148,22 @@ def deserialize(cls, data, dtype=None, *, name=None, **opts): if info != lib.GrB_SUCCESS: raise _error_code_lookup[info]("Matrix deserialize failed to get the dtype name") dtype_name = b"".join(itertools.takewhile(b"\x00".__ne__, cname)).decode() + if not dtype_name and hasattr(lib, "GxB_Serialized_get_String"): + # Handle UDTs. First get the size of name + dtype_size = ffi_new("size_t*") + info = lib.GxB_Serialized_get_SIZE(data_obj, dtype_size, lib.GrB_NAME, data.nbytes) + if info != lib.GrB_SUCCESS: + raise _error_code_lookup[info]( + "Matrix deserialize failed to get the size of name" + ) + # Then get the name + dtype_char = ffi_new(f"char[{dtype_size[0]}]") + info = lib.GxB_Serialized_get_String( + data_obj, dtype_char, lib.GrB_NAME, data.nbytes + ) + if info != lib.GrB_SUCCESS: + raise _error_code_lookup[info]("Matrix deserialize failed to get the name") + dtype_name = ffi.string(dtype_char).decode() dtype = _string_to_dtype(dtype_name) else: dtype = lookup_dtype(dtype) @@ -4380,28 +4182,28 @@ def deserialize(cls, data, dtype=None, *, name=None, **opts): return rv -@numba.njit(parallel=True) +@njit(parallel=True) def argsort_values(indptr, indices, values): # pragma: no cover (numba) rv = np.empty(indptr[-1], dtype=np.uint64) - for i in numba.prange(indptr.size - 1): + for i in prange(indptr.size - 1): rv[indptr[i] : indptr[i + 1]] = indices[ np.int64(indptr[i]) + np.argsort(values[indptr[i] : indptr[i + 1]]) ] return rv -@numba.njit(parallel=True) +@njit(parallel=True) def sort_values(indptr, values): # pragma: no cover (numba) rv = np.empty(indptr[-1], dtype=values.dtype) - for i in numba.prange(indptr.size - 1): + for i in prange(indptr.size - 1): rv[indptr[i] : indptr[i + 1]] = np.sort(values[indptr[i] : indptr[i + 1]]) return rv -@numba.njit(parallel=True) +@njit(parallel=True) def compact_values(old_indptr, new_indptr, values): # pragma: no cover (numba) rv = np.empty(new_indptr[-1], dtype=values.dtype) - for i in numba.prange(new_indptr.size - 1): + for i in prange(new_indptr.size - 1): start = np.int64(new_indptr[i]) offset = np.int64(old_indptr[i]) - start for j in range(start, new_indptr[i + 1]): @@ -4409,17 +4211,17 @@ def compact_values(old_indptr, new_indptr, values): # pragma: no cover (numba) return rv -@numba.njit(parallel=True) +@njit(parallel=True) def reverse_values(indptr, values): # pragma: no cover (numba) rv = np.empty(indptr[-1], dtype=values.dtype) - for i in numba.prange(indptr.size - 1): + for i in prange(indptr.size - 1): offset = np.int64(indptr[i]) + np.int64(indptr[i + 1]) - 1 for j in range(indptr[i], indptr[i + 1]): rv[j] = values[offset - j] return rv -@numba.njit(parallel=True) +@njit(parallel=True) def compact_indices(indptr, k): # pragma: no cover (numba) """Given indptr from hypercsr, create a new col_indices array that is compact. @@ -4429,7 +4231,7 @@ def compact_indices(indptr, k): # pragma: no cover (numba) indptr = create_indptr(indptr, k) col_indices = np.empty(indptr[-1], dtype=np.uint64) N = np.int64(0) - for i in numba.prange(indptr.size - 1): + for i in prange(indptr.size - 1): start = np.int64(indptr[i]) deg = np.int64(indptr[i + 1]) - start N = max(N, deg) @@ -4442,7 +4244,7 @@ def compact_indices(indptr, k): # pragma: no cover (numba) def choose_random1(indptr): # pragma: no cover (numba) choices = np.empty(indptr.size - 1, dtype=indptr.dtype) new_indptr = np.arange(indptr.size, dtype=indptr.dtype) - for i in numba.prange(indptr.size - 1): + for i in prange(indptr.size - 1): idx = np.int64(indptr[i]) deg = np.int64(indptr[i + 1]) - idx if deg == 1: @@ -4479,7 +4281,7 @@ def choose_random(indptr, k): # pragma: no cover (numba) # be nice to have them sorted if convenient to do so. new_indptr = create_indptr(indptr, k) choices = np.empty(new_indptr[-1], dtype=indptr.dtype) - for i in numba.prange(indptr.size - 1): + for i in prange(indptr.size - 1): idx = np.int64(indptr[i]) deg = np.int64(indptr[i + 1]) - idx if k < deg: @@ -4560,7 +4362,7 @@ def choose_first(indptr, k): # pragma: no cover (numba) new_indptr = create_indptr(indptr, k) choices = np.empty(new_indptr[-1], dtype=indptr.dtype) - for i in numba.prange(indptr.size - 1): + for i in prange(indptr.size - 1): idx = np.int64(indptr[i]) deg = np.int64(indptr[i + 1]) - idx if k < deg: @@ -4584,7 +4386,7 @@ def choose_last(indptr, k): # pragma: no cover (numba) new_indptr = create_indptr(indptr, k) choices = np.empty(new_indptr[-1], dtype=indptr.dtype) - for i in numba.prange(indptr.size - 1): + for i in prange(indptr.size - 1): idx = np.int64(indptr[i]) deg = np.int64(indptr[i + 1]) - idx if k < deg: @@ -4617,19 +4419,20 @@ def indices_to_indptr(indices, size): # pragma: no cover (numba) """Calculate the indptr for e.g. CSR from sorted COO rows.""" indptr = np.zeros(size, dtype=indices.dtype) index = np.uint64(0) + one = np.uint64(1) for i in range(indices.size): row = indices[i] if row != index: - indptr[index + 1] = i + indptr[index + one] = i index = row - indptr[index + 1] = indices.size + indptr[index + one] = indices.size return indptr @njit(parallel=True) def indptr_to_indices(indptr): # pragma: no cover (numba) indices = np.empty(indptr[-1], dtype=indptr.dtype) - for i in numba.prange(indptr.size - 1): + for i in prange(indptr.size - 1): for j in range(indptr[i], indptr[i + 1]): indices[j] = i return indices diff --git a/graphblas/core/ss/select.py b/graphblas/core/ss/select.py new file mode 100644 index 000000000..3ba135eee --- /dev/null +++ b/graphblas/core/ss/select.py @@ -0,0 +1,89 @@ +from ... import backend, indexunary +from ...dtypes import BOOL, lookup_dtype +from .. import ffi +from ..operator.base import TypedOpBase +from ..operator.select import SelectOp, TypedUserSelectOp +from . import _IS_SSGB7 + +ffi_new = ffi.new + + +class TypedJitSelectOp(TypedOpBase): + __slots__ = "_jit_c_definition" + opclass = "SelectOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, jit_c_definition, dtype2=None): + super().__init__(parent, name, type_, return_type, gb_obj, name, dtype2=dtype2) + self._jit_c_definition = jit_c_definition + + @property + def jit_c_definition(self): + return self._jit_c_definition + + thunk_type = TypedUserSelectOp.thunk_type + __call__ = TypedUserSelectOp.__call__ + + +def register_new(name, jit_c_definition, input_type, thunk_type): + """Register a new SelectOp using the SuiteSparse:GraphBLAS JIT compiler. + + This creates a SelectOp by compiling the C string definition of the function. + It requires a shell call to a C compiler. The resulting operator will be as + fast as if it were built-in to SuiteSparse:GraphBLAS and does not have the + overhead of additional function calls as when using ``gb.select.register_new``. + + This is an advanced feature that requires a C compiler and proper configuration. + Configuration is handled by ``gb.ss.config``; see its docstring for details. + By default, the JIT caches results in ``~/.SuiteSparse/``. For more information, + see the SuiteSparse:GraphBLAS user guide. + + Only one type signature may be registered at a time, but repeated calls using + the same name with different input types is allowed. + + This will also create an IndexUnary operator under ``gb.indexunary.ss`` + + Parameters + ---------- + name : str + The name of the operator. This will show up as ``gb.select.ss.{name}``. + The name may contain periods, ".", which will result in nested objects + such as ``gb.select.ss.x.y.z`` for name ``"x.y.z"``. + jit_c_definition : str + The C definition as a string of the user-defined function. For example: + ``"void woot (bool *z, const int32_t *x, GrB_Index i, GrB_Index j, int32_t *y) "`` + ``"{ (*z) = ((*x) + i + j == (*y)) ; }"`` + input_type : dtype + The dtype of the operand of the select operator. + thunk_type : dtype + The dtype of the thunk of the select operator. + + Returns + ------- + SelectOp + + See Also + -------- + gb.select.register_new + gb.select.register_anonymous + gb.indexunary.ss.register_new + + """ + if backend != "suitesparse": # pragma: no cover (safety) + raise RuntimeError( + "`gb.select.ss.register_new` invalid when not using 'suitesparse' backend" + ) + if _IS_SSGB7: + # JIT was introduced in SuiteSparse:GraphBLAS 8.0 + import suitesparse_graphblas as ssgb + + raise RuntimeError( + "JIT was added to SuiteSparse:GraphBLAS in version 8; " + f"current version is {ssgb.__version__}" + ) + input_type = lookup_dtype(input_type) + thunk_type = lookup_dtype(thunk_type) + name = name if name.startswith("ss.") else f"ss.{name}" + # Register to both `gb.indexunary.ss` and `gb.select.ss.` + indexunary.ss.register_new(name, jit_c_definition, input_type, thunk_type, BOOL) + module, funcname = SelectOp._remove_nesting(name, strict=False) + return getattr(module, funcname) diff --git a/graphblas/core/ss/unary.py b/graphblas/core/ss/unary.py new file mode 100644 index 000000000..0b7ced3c8 --- /dev/null +++ b/graphblas/core/ss/unary.py @@ -0,0 +1,109 @@ +from ... import backend +from ...dtypes import lookup_dtype +from ...exceptions import check_status_carg +from .. import NULL, ffi, lib +from ..operator.base import TypedOpBase +from ..operator.unary import TypedUserUnaryOp, UnaryOp +from . import _IS_SSGB7 + +ffi_new = ffi.new + + +class TypedJitUnaryOp(TypedOpBase): + __slots__ = "_jit_c_definition" + opclass = "UnaryOp" + + def __init__(self, parent, name, type_, return_type, gb_obj, jit_c_definition): + super().__init__(parent, name, type_, return_type, gb_obj, name) + self._jit_c_definition = jit_c_definition + + @property + def jit_c_definition(self): + return self._jit_c_definition + + __call__ = TypedUserUnaryOp.__call__ + + +def register_new(name, jit_c_definition, input_type, ret_type): + """Register a new UnaryOp using the SuiteSparse:GraphBLAS JIT compiler. + + This creates a UnaryOp by compiling the C string definition of the function. + It requires a shell call to a C compiler. The resulting operator will be as + fast as if it were built-in to SuiteSparse:GraphBLAS and does not have the + overhead of additional function calls as when using ``gb.unary.register_new``. + + This is an advanced feature that requires a C compiler and proper configuration. + Configuration is handled by ``gb.ss.config``; see its docstring for details. + By default, the JIT caches results in ``~/.SuiteSparse/``. For more information, + see the SuiteSparse:GraphBLAS user guide. + + Only one type signature may be registered at a time, but repeated calls using + the same name with different input types is allowed. + + Parameters + ---------- + name : str + The name of the operator. This will show up as ``gb.unary.ss.{name}``. + The name may contain periods, ".", which will result in nested objects + such as ``gb.unary.ss.x.y.z`` for name ``"x.y.z"``. + jit_c_definition : str + The C definition as a string of the user-defined function. For example: + ``"void square (float *z, float *x) { (*z) = (*x) * (*x) ; } ;"`` + input_type : dtype + The dtype of the operand of the unary operator. + ret_type : dtype + The dtype of the result of the unary operator. + + Returns + ------- + UnaryOp + + See Also + -------- + gb.unary.register_new + gb.unary.register_anonymous + gb.binary.ss.register_new + + """ + if backend != "suitesparse": # pragma: no cover (safety) + raise RuntimeError( + "`gb.unary.ss.register_new` invalid when not using 'suitesparse' backend" + ) + if _IS_SSGB7: + # JIT was introduced in SuiteSparse:GraphBLAS 8.0 + import suitesparse_graphblas as ssgb + + raise RuntimeError( + "JIT was added to SuiteSparse:GraphBLAS in version 8; " + f"current version is {ssgb.__version__}" + ) + input_type = lookup_dtype(input_type) + ret_type = lookup_dtype(ret_type) + name = name if name.startswith("ss.") else f"ss.{name}" + module, funcname = UnaryOp._remove_nesting(name, strict=False) + if hasattr(module, funcname): + rv = getattr(module, funcname) + if not isinstance(rv, UnaryOp): + UnaryOp._remove_nesting(name) + if input_type in rv.types or rv._udt_types is not None and input_type in rv._udt_types: + raise TypeError(f"UnaryOp gb.unary.{name} already defined for {input_type} input type") + else: + # We use `is_udt=True` to make dtype handling flexible and explicit. + rv = UnaryOp(name, is_udt=True) + gb_obj = ffi_new("GrB_UnaryOp*") + check_status_carg( + lib.GxB_UnaryOp_new( + gb_obj, + NULL, + ret_type._carg, + input_type._carg, + ffi_new("char[]", funcname.encode()), + ffi_new("char[]", jit_c_definition.encode()), + ), + "UnaryOp", + gb_obj[0], + ) + op = TypedJitUnaryOp(rv, funcname, input_type, ret_type, gb_obj[0], jit_c_definition) + rv._add(op, is_jit=True) + setattr(module, funcname, rv) + return rv diff --git a/graphblas/core/ss/vector.py b/graphblas/core/ss/vector.py index d13d78ac3..fdde7eb92 100644 --- a/graphblas/core/ss/vector.py +++ b/graphblas/core/ss/vector.py @@ -1,16 +1,16 @@ import itertools import numpy as np -from numba import njit from suitesparse_graphblas.utils import claim_buffer, unclaim_buffer import graphblas as gb from ... import binary, monoid -from ...dtypes import _INDEX, INT64, UINT64, _string_to_dtype, lookup_dtype +from ...dtypes import _INDEX, INT64, UINT64, lookup_dtype from ...exceptions import _error_code_lookup, check_status, check_status_carg from .. import NULL, ffi, lib from ..base import call +from ..dtypes import _string_to_dtype from ..operator import get_typed_op from ..scalar import Scalar, _as_scalar from ..utils import ( @@ -23,7 +23,7 @@ ) from .config import BaseConfig from .descriptor import get_descriptor -from .matrix import _concat_mn +from .matrix import _concat_mn, njit from .prefix_scan import prefix_scan ffi_new = ffi.new @@ -43,7 +43,7 @@ def head(vector, n=10, dtype=None, *, sort=False): dtype = vector.dtype else: dtype = lookup_dtype(dtype) - indices, vals = zip(*itertools.islice(vector.ss.iteritems(), n)) + indices, vals = zip(*itertools.islice(vector.ss.iteritems(), n), strict=True) return np.array(indices, np.uint64), np.array(vals, dtype.np_type) @@ -145,8 +145,7 @@ def format(self): return format def build_diag(self, matrix, k=0, **opts): - """ - GxB_Vector_diag. + """GxB_Vector_diag. Extract a diagonal from a Matrix or TransposedMatrix into a Vector. Existing entries in the Vector are discarded. @@ -156,8 +155,8 @@ def build_diag(self, matrix, k=0, **opts): matrix : Matrix or TransposedMatrix Extract a diagonal from this matrix. k : int, default 0 - Diagonal in question. Use `k>0` for diagonals above the main diagonal, - and `k<0` for diagonals below the main diagonal. + Diagonal in question. Use ``k>0`` for diagonals above the main diagonal, + and ``k<0`` for diagonals below the main diagonal. See Also -------- @@ -183,15 +182,14 @@ def build_diag(self, matrix, k=0, **opts): ) def split(self, chunks, *, name=None, **opts): - """ - GxB_Matrix_split. + """GxB_Matrix_split. - Split a Vector into a 1D array of sub-vectors according to `chunks`. + Split a Vector into a 1D array of sub-vectors according to ``chunks``. This performs the opposite operation as ``concat``. - `chunks` is short for "chunksizes" and indicates the chunk sizes. - `chunks` may be a single integer, or a tuple or list. Example chunks: + ``chunks`` is short for "chunksizes" and indicates the chunk sizes. + ``chunks`` may be a single integer, or a tuple or list. Example chunks: - ``chunks=10`` - Split vector into chunks of size 10 (the last chunk may be smaller). @@ -202,6 +200,7 @@ def split(self, chunks, *, name=None, **opts): -------- Vector.ss.concat graphblas.ss.concat + """ from ..vector import Vector @@ -249,12 +248,11 @@ def _concat(self, tiles, m, opts): ) def concat(self, tiles, **opts): - """ - GxB_Matrix_concat. + """GxB_Matrix_concat. Concatenate a 1D list of Vector objects into the current Vector. Any existing values in the current Vector will be discarded. - To concatenate into a new Vector, use `graphblas.ss.concat`. + To concatenate into a new Vector, use ``graphblas.ss.concat``. This performs the opposite operation as ``split``. @@ -262,13 +260,13 @@ def concat(self, tiles, **opts): -------- Vector.ss.split graphblas.ss.concat + """ tiles, m, n, is_matrix = _concat_mn(tiles, is_matrix=False) self._concat(tiles, m, opts) def build_scalar(self, indices, value): - """ - GxB_Vector_build_Scalar. + """GxB_Vector_build_Scalar. Like ``build``, but uses a scalar for all the values. @@ -276,6 +274,7 @@ def build_scalar(self, indices, value): -------- Vector.build Vector.from_coo + """ indices = ints_to_numpy_buffer(indices, np.uint64, name="indices") scalar = _as_scalar(value, self._parent.dtype, is_cscalar=False) # pragma: is_grbscalar @@ -410,14 +409,13 @@ def iteritems(self, seek=0): lib.GxB_Iterator_free(it_ptr) def export(self, format=None, *, sort=False, give_ownership=False, raw=False, **opts): - """ - GxB_Vextor_export_xxx. + """GxB_Vextor_export_xxx. Parameters ---------- format : str or None, default None - If `format` is not specified, this method exports in the currently stored format. - To control the export format, set `format` to one of: + If ``format`` is not specified, this method exports in the currently stored format. + To control the export format, set ``format`` to one of: - "sparse" - "bitmap" - "full" @@ -435,7 +433,7 @@ def export(self, format=None, *, sort=False, give_ownership=False, raw=False, ** Returns ------- - dict; keys depend on `format` and `raw` arguments (see below). + dict; keys depend on ``format`` and ``raw`` arguments (see below). See Also -------- @@ -443,7 +441,7 @@ def export(self, format=None, *, sort=False, give_ownership=False, raw=False, ** Vector.ss.import_any Return values - - Note: for `raw=True`, arrays may be larger than specified. + - Note: for ``raw=True``, arrays may be larger than specified. - "sparse" format - indices : ndarray(dtype=uint64, size=nvals) - values : ndarray(size=nvals) @@ -468,6 +466,7 @@ def export(self, format=None, *, sort=False, give_ownership=False, raw=False, ** >>> pieces = v.ss.export() >>> v2 = Vector.ss.import_any(**pieces) + """ return self._export( format=format, @@ -479,13 +478,12 @@ def export(self, format=None, *, sort=False, give_ownership=False, raw=False, ** ) def unpack(self, format=None, *, sort=False, raw=False, **opts): - """ - GxB_Vector_unpack_xxx. + """GxB_Vector_unpack_xxx. - `unpack` is like `export`, except that the Vector remains valid but empty. - `pack_*` methods are the opposite of `unpack`. + ``unpack`` is like ``export``, except that the Vector remains valid but empty. + ``pack_*`` methods are the opposite of ``unpack``. - See `Vector.ss.export` documentation for more details. + See ``Vector.ss.export`` documentation for more details. """ return self._export( format=format, sort=sort, give_ownership=True, raw=raw, method="unpack", opts=opts @@ -551,9 +549,8 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m if is_iso: if values.size > 1: # pragma: no cover (suitesparse) values = values[:1] - else: - if values.size > nvals: - values = values[:nvals] + elif values.size > nvals: + values = values[:nvals] rv = { "size": size, "indices": indices, @@ -589,9 +586,8 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m if is_iso: if values.size > 1: # pragma: no cover (suitesparse) values = values[:1] - else: - if values.size > size: # pragma: no branch (suitesparse) - values = values[:size] + elif values.size > size: # pragma: no cover (suitesparse) + values = values[:size] rv = { "bitmap": bitmap, "nvals": nvals[0], @@ -616,9 +612,8 @@ def _export(self, format=None, *, sort=False, give_ownership=False, raw=False, m if is_iso: if values.size > 1: values = values[:1] - else: - if values.size > size: # pragma: no branch (suitesparse) - values = values[:size] + elif values.size > size: # pragma: no branch (suitesparse) + values = values[:size] rv = {} if raw or is_iso: rv["size"] = size @@ -658,11 +653,10 @@ def import_any( nvals=None, # optional **opts, ): - """ - GxB_Vector_import_xxx. + """GxB_Vector_import_xxx. Dispatch to appropriate import method inferred from inputs. - See the other import functions and `Vector.ss.export`` for details. + See the other import functions and ``Vector.ss.export`` for details. Returns ------- @@ -682,6 +676,7 @@ def import_any( >>> pieces = v.ss.export() >>> v2 = Vector.ss.import_any(**pieces) + """ return cls._import_any( values=values, @@ -725,13 +720,12 @@ def pack_any( name=None, **opts, ): - """ - GxB_Vector_pack_xxx. + """GxB_Vector_pack_xxx. - `pack_any` is like `import_any` except it "packs" data into an + ``pack_any`` is like ``import_any`` except it "packs" data into an existing Vector. This is the opposite of ``unpack()`` - See `Vector.ss.import_any` documentation for more details. + See ``Vector.ss.import_any`` documentation for more details. """ return self._import_any( values=values, @@ -847,8 +841,7 @@ def import_sparse( name=None, **opts, ): - """ - GxB_Vector_import_CSC. + """GxB_Vector_import_CSC. Create a new Vector from sparse input. @@ -862,7 +855,7 @@ def import_sparse( If not specified, will be set to ``len(values)``. is_iso : bool, default False Is the Vector iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. sorted_index : bool, default False Indicate whether the values in "col_indices" are sorted. take_ownership : bool, default False @@ -879,7 +872,7 @@ def import_sparse( read-only and will no longer own the data. dtype : dtype, optional dtype of the new Vector. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "sparse" or None. This is included to be compatible with the dict returned from exporting. @@ -889,6 +882,7 @@ def import_sparse( Returns ------- Vector + """ return cls._import_sparse( size=size, @@ -923,13 +917,12 @@ def pack_sparse( name=None, **opts, ): - """ - GxB_Vector_pack_CSC. + """GxB_Vector_pack_CSC. - `pack_sparse` is like `import_sparse` except it "packs" data into an + ``pack_sparse`` is like ``import_sparse`` except it "packs" data into an existing Vector. This is the opposite of ``unpack("sparse")`` - See `Vector.ss.import_sparse` documentation for more details. + See ``Vector.ss.import_sparse`` documentation for more details. """ return self._import_sparse( indices=indices, @@ -1032,8 +1025,7 @@ def import_bitmap( name=None, **opts, ): - """ - GxB_Vector_import_Bitmap. + """GxB_Vector_import_Bitmap. Create a new Vector from values and bitmap (as mask) arrays. @@ -1049,7 +1041,7 @@ def import_bitmap( If not specified, it will be set to the size of values. is_iso : bool, default False Is the Vector iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. take_ownership : bool, default False If True, perform a zero-copy data transfer from input numpy arrays to GraphBLAS if possible. To give ownership of the underlying @@ -1064,7 +1056,7 @@ def import_bitmap( read-only and will no longer own the data. dtype : dtype, optional dtype of the new Vector. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "bitmap" or None. This is included to be compatible with the dict returned from exporting. @@ -1074,6 +1066,7 @@ def import_bitmap( Returns ------- Vector + """ return cls._import_bitmap( bitmap=bitmap, @@ -1106,13 +1099,12 @@ def pack_bitmap( name=None, **opts, ): - """ - GxB_Vector_pack_Bitmap. + """GxB_Vector_pack_Bitmap. - `pack_bitmap` is like `import_bitmap` except it "packs" data into an + ``pack_bitmap`` is like ``import_bitmap`` except it "packs" data into an existing Vector. This is the opposite of ``unpack("bitmap")`` - See `Vector.ss.import_bitmap` documentation for more details. + See ``Vector.ss.import_bitmap`` documentation for more details. """ return self._import_bitmap( bitmap=bitmap, @@ -1217,8 +1209,7 @@ def import_full( name=None, **opts, ): - """ - GxB_Vector_import_Full. + """GxB_Vector_import_Full. Create a new Vector from values. @@ -1230,7 +1221,7 @@ def import_full( If not specified, it will be set to the size of values. is_iso : bool, default False Is the Vector iso-valued (meaning all the same value)? - If true, then `values` should be a length 1 array. + If true, then ``values`` should be a length 1 array. take_ownership : bool, default False If True, perform a zero-copy data transfer from input numpy arrays to GraphBLAS if possible. To give ownership of the underlying @@ -1245,7 +1236,7 @@ def import_full( read-only and will no longer own the data. dtype : dtype, optional dtype of the new Vector. - If not specified, this will be inferred from `values`. + If not specified, this will be inferred from ``values``. format : str, optional Must be "full" or None. This is included to be compatible with the dict returned from exporting. @@ -1255,6 +1246,7 @@ def import_full( Returns ------- Vector + """ return cls._import_full( values=values, @@ -1283,13 +1275,12 @@ def pack_full( name=None, **opts, ): - """ - GxB_Vector_pack_Full. + """GxB_Vector_pack_Full. - `pack_full` is like `import_full` except it "packs" data into an + ``pack_full`` is like ``import_full`` except it "packs" data into an existing Vector. This is the opposite of ``unpack("full")`` - See `Vector.ss.import_full` documentation for more details. + See ``Vector.ss.import_full`` documentation for more details. """ return self._import_full( values=values, @@ -1368,12 +1359,13 @@ def head(self, n=10, dtype=None, *, sort=False): def scan(self, op=monoid.plus, *, name=None, **opts): """Perform a prefix scan with the given monoid. - For example, use `monoid.plus` (the default) to perform a cumulative sum, - and `monoid.times` for cumulative product. Works with any monoid. + For example, use ``monoid.plus`` (the default) to perform a cumulative sum, + and ``monoid.times`` for cumulative product. Works with any monoid. Returns ------- Scalar + """ return prefix_scan(self._parent, op, name=name, within="scan", **opts) @@ -1404,6 +1396,7 @@ def reshape(self, nrows, ncols=None, order="rowwise", *, name=None, **opts): See Also -------- Matrix.ss.flatten : flatten a Matrix into a Vector. + """ return self._parent._as_matrix().ss.reshape(nrows, ncols, order, name=name, **opts) @@ -1423,6 +1416,7 @@ def selectk(self, how, k, *, name=None): The number of elements to choose **THIS API IS EXPERIMENTAL AND MAY CHANGE** + """ how = how.lower() if k < 0: @@ -1565,20 +1559,20 @@ def compactify(self, how="first", size=None, *, reverse=False, asindex=False, na def sort(self, op=binary.lt, *, values=True, permutation=True, **opts): """GxB_Vector_sort to sort values of the Vector. - Sorting moves all the elements to the left just like `compactify`. + Sorting moves all the elements to the left just like ``compactify``. The returned vectors will be the same size as the input Vector. Parameters ---------- op : :class:`~graphblas.core.operator.BinaryOp`, optional Binary operator with a bool return type used to sort the values. - For example, `binary.lt` (the default) sorts the smallest elements first. + For example, ``binary.lt`` (the default) sorts the smallest elements first. Ties are broken according to indices (smaller first). values : bool, default=True - Whether to return values; will return `None` for values if `False`. + Whether to return values; will return ``None`` for values if ``False``. permutation : bool, default=True Whether to compute the permutation Vector that has the original indices of the - sorted values. Will return None if `False`. + sorted values. Will return None if ``False``. nthreads : int, optional The maximum number of threads to use for this operation. None, 0 or negative nthreads means to use the default number of threads. @@ -1591,6 +1585,7 @@ def sort(self, op=binary.lt, *, values=True, permutation=True, **opts): See Also -------- Vector.ss.compactify + """ from ..vector import Vector @@ -1646,16 +1641,32 @@ def serialize(self, compression="default", level=None, **opts): None, 0 or negative nthreads means to use the default number of threads. For best performance, this function returns a numpy array with uint8 dtype. - Use `Vector.ss.deserialize(blob)` to create a Vector from the result of serialization· + Use ``Vector.ss.deserialize(blob)`` to create a Vector from the result of serialization· This method is intended to support all serialization options from SuiteSparse:GraphBLAS. *Warning*: Behavior of serializing UDTs is experimental and may change in a future release. + """ desc = get_descriptor(compression=compression, compression_level=level, **opts) blob_handle = ffi_new("void**") blob_size_handle = ffi_new("GrB_Index*") parent = self._parent + if parent.dtype._is_udt and hasattr(lib, "GrB_Type_get_String"): + # Get the name from the dtype and set it to the name of the vector so we can + # recreate the UDT. This is a bit hacky and we should restore the original name. + # First get the size of name. + dtype_size = ffi_new("size_t*") + status = lib.GrB_Type_get_SIZE(parent.dtype.gb_obj[0], dtype_size, lib.GrB_NAME) + check_status_carg(status, "Type", parent.dtype.gb_obj[0]) + # Then get the name + dtype_char = ffi_new(f"char[{dtype_size[0]}]") + status = lib.GrB_Type_get_String(parent.dtype.gb_obj[0], dtype_char, lib.GrB_NAME) + check_status_carg(status, "Type", parent.dtype.gb_obj[0]) + # Then set the name + status = lib.GrB_Vector_set_String(parent._carg, dtype_char, lib.GrB_NAME) + check_status_carg(status, "Vector", parent._carg) + check_status( lib.GxB_Vector_serialize( blob_handle, @@ -1672,7 +1683,7 @@ def deserialize(cls, data, dtype=None, *, name=None, **opts): """Deserialize a Vector from bytes, buffer, or numpy array using GxB_Vector_deserialize. The data should have been previously serialized with a compatible version of - SuiteSparse:GraphBLAS. For example, from the result of `data = vector.ss.serialize()`. + SuiteSparse:GraphBLAS. For example, from the result of ``data = vector.ss.serialize()``. Examples -------- @@ -1690,6 +1701,7 @@ def deserialize(cls, data, dtype=None, *, name=None, **opts): nthreads : int, optional The maximum number of threads to use when deserializing. None, 0 or negative nthreads means to use the default number of threads. + """ if isinstance(data, np.ndarray): data = ints_to_numpy_buffer(data, np.uint8) @@ -1697,7 +1709,7 @@ def deserialize(cls, data, dtype=None, *, name=None, **opts): data = np.frombuffer(data, np.uint8) data_obj = ffi.from_buffer("void*", data) if dtype is None: - # Get the dtype name first + # Get the dtype name first (for non-UDTs) cname = ffi_new(f"char[{lib.GxB_MAX_NAME_LEN}]") info = lib.GxB_deserialize_type_name( cname, @@ -1707,6 +1719,22 @@ def deserialize(cls, data, dtype=None, *, name=None, **opts): if info != lib.GrB_SUCCESS: raise _error_code_lookup[info]("Vector deserialize failed to get the dtype name") dtype_name = b"".join(itertools.takewhile(b"\x00".__ne__, cname)).decode() + if not dtype_name and hasattr(lib, "GxB_Serialized_get_String"): + # Handle UDTs. First get the size of name + dtype_size = ffi_new("size_t*") + info = lib.GxB_Serialized_get_SIZE(data_obj, dtype_size, lib.GrB_NAME, data.nbytes) + if info != lib.GrB_SUCCESS: + raise _error_code_lookup[info]( + "Vector deserialize failed to get the size of name" + ) + # Then get the name + dtype_char = ffi_new(f"char[{dtype_size[0]}]") + info = lib.GxB_Serialized_get_String( + data_obj, dtype_char, lib.GrB_NAME, data.nbytes + ) + if info != lib.GrB_SUCCESS: + raise _error_code_lookup[info]("Vector deserialize failed to get the name") + dtype_name = ffi.string(dtype_char).decode() dtype = _string_to_dtype(dtype_name) else: dtype = lookup_dtype(dtype) diff --git a/graphblas/core/utils.py b/graphblas/core/utils.py index 0beeb4a2a..e9a29b3a9 100644 --- a/graphblas/core/utils.py +++ b/graphblas/core/utils.py @@ -1,17 +1,19 @@ -from numbers import Integral, Number +from operator import index import numpy as np from ..dtypes import _INDEX, lookup_dtype from . import ffi, lib +_NP2 = np.__version__.startswith("2.") + def libget(name): """Helper to get items from GraphBLAS which might be GrB or GxB.""" try: return getattr(lib, name) except AttributeError: - if name[-4:] not in {"FC32", "FC64", "error"}: + if name[-4:] not in {"FC32", "FC64", "rror"}: raise ext_name = f"GxB_{name[4:]}" try: @@ -22,7 +24,7 @@ def libget(name): def wrapdoc(func_with_doc): - """Decorator to copy `__doc__` from a function onto the wrapped function.""" + """Decorator to copy ``__doc__`` from a function onto the wrapped function.""" def inner(func_wo_doc): func_wo_doc.__doc__ = func_with_doc.__doc__ @@ -43,7 +45,7 @@ def inner(func_wo_doc): object: object, type: type, } -_output_types.update((k, k) for k in np.cast) +_output_types.update((k, k) for k in set(np.sctypeDict.values())) def output_type(val): @@ -60,7 +62,8 @@ def ints_to_numpy_buffer(array, dtype, *, name="array", copy=False, ownable=Fals and not np.issubdtype(array.dtype, np.bool_) ): raise ValueError(f"{name} must be integers, not {array.dtype.name}") - array = np.array(array, dtype, copy=copy, order=order) + # https://numpy.org/doc/stable/release/2.0.0-notes.html#new-copy-keyword-meaning-for-array-and-asarray-constructors + array = np.array(array, dtype, copy=copy or _NP2 and None, order=order) if ownable and (not array.flags.owndata or not array.flags.writeable): array = array.copy(order) return array @@ -86,13 +89,18 @@ def values_to_numpy_buffer( ------- np.ndarray dtype + """ if dtype is not None: dtype = lookup_dtype(dtype) - array = np.array(array, _get_subdtype(dtype.np_type), copy=copy, order=order) + # https://numpy.org/doc/stable/release/2.0.0-notes.html#new-copy-keyword-meaning-for-array-and-asarray-constructors + array = np.array( + array, _get_subdtype(dtype.np_type), copy=copy or _NP2 and None, order=order + ) else: is_input_np = isinstance(array, np.ndarray) - array = np.array(array, copy=copy, order=order) + # https://numpy.org/doc/stable/release/2.0.0-notes.html#new-copy-keyword-meaning-for-array-and-asarray-constructors + array = np.array(array, copy=copy or _NP2 and None, order=order) if array.dtype.hasobject: raise ValueError("object dtype for values is not allowed") if not is_input_np and array.dtype == np.int32: # pragma: no cover @@ -131,6 +139,7 @@ def get_shape(nrows, ncols, dtype=None, **arrays): # We could be smarter and determine the shape of the dtype sub-arrays if arr.ndim >= 3: break + # BRANCH NOT COVERED elif arr.ndim == 2: break else: @@ -157,8 +166,19 @@ def get_order(order): ) +def maybe_integral(val): + """Ensure ``val`` is an integer or return None if it's not.""" + try: + return index(val) + except TypeError: + pass + if isinstance(val, float) and val.is_integer(): + return int(val) + return None + + def normalize_chunks(chunks, shape): - """Normalize chunks argument for use by `Matrix.ss.split`. + """Normalize chunks argument for use by ``Matrix.ss.split``. Examples -------- @@ -171,11 +191,12 @@ def normalize_chunks(chunks, shape): [(10,), (5, 15)] >>> normalize_chunks((5, (5, None)), shape) [(5, 5), (5, 15)] + """ if isinstance(chunks, (list, tuple)): pass - elif isinstance(chunks, Number): - chunks = (chunks,) * len(shape) + elif (chunk := maybe_integral(chunks)) is not None: + chunks = (chunk,) * len(shape) elif isinstance(chunks, np.ndarray): chunks = chunks.tolist() else: @@ -188,25 +209,24 @@ def normalize_chunks(chunks, shape): f"chunks argument must be of length {len(shape)} (one for each dimension of a {typ})" ) chunksizes = [] - for size, chunk in zip(shape, chunks): + for size, chunk in zip(shape, chunks, strict=True): if chunk is None: cur_chunks = [size] - elif isinstance(chunk, Integral) or isinstance(chunk, float) and chunk.is_integer(): - chunk = int(chunk) - if chunk < 0: - raise ValueError(f"Chunksize must be greater than 0; got: {chunk}") - div, mod = divmod(size, chunk) - cur_chunks = [chunk] * div + elif (c := maybe_integral(chunk)) is not None: + if c < 0: + raise ValueError(f"Chunksize must be greater than 0; got: {c}") + div, mod = divmod(size, c) + cur_chunks = [c] * div if mod: cur_chunks.append(mod) elif isinstance(chunk, (list, tuple)): cur_chunks = [] none_index = None for c in chunk: - if isinstance(c, Integral) or isinstance(c, float) and c.is_integer(): - c = int(c) - if c < 0: - raise ValueError(f"Chunksize must be greater than 0; got: {c}") + if (val := maybe_integral(c)) is not None: + if val < 0: + raise ValueError(f"Chunksize must be greater than 0; got: {val}") + c = val elif c is None: if none_index is not None: raise TypeError( @@ -248,17 +268,17 @@ def normalize_chunks(chunks, shape): def ensure_type(x, types): - """Try to ensure `x` is one of the given types, computing if necessary. + """Try to ensure ``x`` is one of the given types, computing if necessary. - `types` must be a type or a tuple of types as used in `isinstance`. + ``types`` must be a type or a tuple of types as used in ``isinstance``. - For example, if `types` is a Vector, then a Vector input will be returned, - and a `VectorExpression` input will be computed and returned as a Vector. + For example, if ``types`` is a Vector, then a Vector input will be returned, + and a ``VectorExpression`` input will be computed and returned as a Vector. TypeError will be raised if the input is not or can't be converted to types. - This function ignores `graphblas.config["autocompute"]`; it always computes - if the return type will match `types`. + This function ignores ``graphblas.config["autocompute"]``; it always computes + if the return type will match ``types``. """ if isinstance(x, types): return x @@ -299,7 +319,10 @@ def __init__(self, array=None, dtype=_INDEX, *, size=None, name=None): if size is not None: self.array = np.empty(size, dtype=dtype.np_type) else: - self.array = np.array(array, dtype=_get_subdtype(dtype.np_type), copy=False, order="C") + # https://numpy.org/doc/stable/release/2.0.0-notes.html#new-copy-keyword-meaning-for-array-and-asarray-constructors + self.array = np.array( + array, dtype=_get_subdtype(dtype.np_type), copy=_NP2 and None, order="C" + ) c_type = dtype.c_type if dtype._is_udt else f"{dtype.c_type}*" self._carg = ffi.cast(c_type, ffi.from_buffer(self.array)) self.dtype = dtype @@ -357,6 +380,7 @@ def _autogenerate_code( specializer=None, begin="# Begin auto-generated code", end="# End auto-generated code", + callblack=True, ): """Super low-tech auto-code generation used by automethods.py and infixmethods.py.""" with filepath.open() as f: # pragma: no branch (flaky) @@ -383,7 +407,8 @@ def _autogenerate_code( f.write(new_text) import subprocess - try: - subprocess.check_call(["black", filepath]) - except FileNotFoundError: # pragma: no cover (safety) - pass # It's okay if `black` isn't installed; pre-commit hooks will do linting + if callblack: + try: + subprocess.check_call(["black", filepath]) + except FileNotFoundError: # pragma: no cover (safety) + pass # It's okay if `black` isn't installed; pre-commit hooks will do linting diff --git a/graphblas/core/vector.py b/graphblas/core/vector.py index dd183d856..8bac4198e 100644 --- a/graphblas/core/vector.py +++ b/graphblas/core/vector.py @@ -1,17 +1,23 @@ import itertools -import warnings import numpy as np -from .. import backend, binary, monoid, select, semiring +from .. import backend, binary, monoid, select, semiring, unary from ..dtypes import _INDEX, FP64, INT64, lookup_dtype, unify from ..exceptions import DimensionMismatch, NoValue, check_status -from . import automethods, ffi, lib, utils +from . import _supports_udfs, automethods, ffi, lib, utils from .base import BaseExpression, BaseType, _check_mask, call from .descriptor import lookup as descriptor_lookup -from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, Updater +from .expr import _ALL_INDICES, AmbiguousAssignOrExtract, IndexerResolver, InfixExprBase, Updater from .mask import Mask, StructuralMask, ValueMask -from .operator import UNKNOWN_OPCLASS, find_opclass, get_semiring, get_typed_op, op_from_string +from .operator import ( + UNKNOWN_OPCLASS, + _get_typed_op_from_exprs, + find_opclass, + get_semiring, + get_typed_op, + op_from_string, +) from .scalar import ( _COMPLETE, _MATERIALIZE, @@ -61,13 +67,13 @@ def _v_union_m(updater, left, right, left_default, right_default, op): updater << temp.ewise_union(right, op, left_default=left_default, right_default=right_default) -def _v_union_v(updater, left, right, left_default, right_default, op, dtype): +def _v_union_v(updater, left, right, left_default, right_default, op): mask = updater.kwargs.get("mask") opts = updater.opts - new_left = left.dup(dtype, clear=True) + new_left = left.dup(op.type, clear=True) new_left(mask=mask, **opts) << binary.second(right, left_default) new_left(mask=mask, **opts) << binary.first(left | new_left) - new_right = right.dup(dtype, clear=True) + new_right = right.dup(op.type2, clear=True) new_right(mask=mask, **opts) << binary.second(left, right_default) new_right(mask=mask, **opts) << binary.first(right | new_right) updater << op(new_left & new_right) @@ -93,6 +99,45 @@ def _select_mask(updater, obj, mask): updater << obj.dup(mask=mask) +def _isclose_recipe(self, other, rel_tol, abs_tol, **opts): + # x == y or abs(x - y) <= max(rel_tol * max(abs(x), abs(y)), abs_tol) + isequal = self.ewise_mult(other, binary.eq).new(bool, name="isclose", **opts) + if isequal._nvals != self._nvals: + return False + if type(isequal) is Vector: + val = isequal.reduce(monoid.land, allow_empty=False).new(**opts).value + else: + val = isequal.reduce_scalar(monoid.land, allow_empty=False).new(**opts).value + if val: + return True + # So we can use structural mask below + isequal(**opts) << select.value(isequal == True) # noqa: E712 + + # abs(x) + x = self.apply(unary.abs).new(FP64, mask=~isequal.S, **opts) + # abs(y) + y = other.apply(unary.abs).new(FP64, mask=~isequal.S, **opts) + # max(abs(x), abs(y)) + x(**opts) << x.ewise_mult(y, binary.max) + max_x_y = x + # rel_tol * max(abs(x), abs(y)) + max_x_y(**opts) << max_x_y.apply(binary.times, rel_tol) + # max(rel_tol * max(abs(x), abs(y)), abs_tol) + max_x_y(**opts) << max_x_y.apply(binary.max, abs_tol) + + # x - y + y(~isequal.S, replace=True, **opts) << self.ewise_mult(other, binary.minus) + abs_x_y = y + # abs(x - y) + abs_x_y(**opts) << abs_x_y.apply(unary.abs) + + # abs(x - y) <= max(rel_tol * max(abs(x), abs(y)), abs_tol) + isequal(**opts) << abs_x_y.ewise_mult(max_x_y, binary.le) + if isequal.ndim == 1: + return isequal.reduce(monoid.land, allow_empty=False).new(**opts).value + return isequal.reduce_scalar(monoid.land, allow_empty=False).new(**opts).value + + class Vector(BaseType): """Create a new GraphBLAS Sparse Vector. @@ -104,6 +149,7 @@ class Vector(BaseType): Size of the Vector. name : str, optional Name to give the Vector. This will be displayed in the ``__repr__``. + """ __slots__ = "_size", "_parent", "ss" @@ -220,6 +266,7 @@ def __delitem__(self, keys, **opts): Examples -------- >>> del v[1:-1] + """ del Updater(self, opts=opts)[keys] @@ -234,6 +281,7 @@ def __getitem__(self, keys): .. code-block:: python sub_v = v[[1, 3, 5]].new() + """ resolved_indexes = IndexerResolver(self, keys) shape = resolved_indexes.shape @@ -253,6 +301,7 @@ def __setitem__(self, keys, expr, **opts): # This makes a dense iso-value vector v[:] = 1 + """ Updater(self, opts=opts)[keys] = expr @@ -265,6 +314,7 @@ def __contains__(self, index): # Check if v[15] is non-empty 15 in v + """ extractor = self[index] if not extractor._is_scalar: @@ -304,6 +354,7 @@ def isequal(self, other, *, check_dtype=False, **opts): See Also -------- :meth:`isclose` : For equality check of floating point dtypes + """ other = self._expect_type(other, Vector, within="isequal", argname="other") if check_dtype and self.dtype != other.dtype: @@ -346,6 +397,7 @@ def isclose(self, other, *, rel_tol=1e-7, abs_tol=0.0, check_dtype=False, **opts Returns ------- bool + """ other = self._expect_type(other, Vector, within="isclose", argname="other") if check_dtype and self.dtype != other.dtype: @@ -354,6 +406,8 @@ def isclose(self, other, *, rel_tol=1e-7, abs_tol=0.0, check_dtype=False, **opts return False if self._nvals != other._nvals: return False + if not _supports_udfs: + return _isclose_recipe(self, other, rel_tol, abs_tol, **opts) matches = self.ewise_mult(other, binary.isclose(rel_tol, abs_tol)).new( bool, name="M_isclose", **opts @@ -408,36 +462,6 @@ def resize(self, size): call("GrB_Vector_resize", [self, size]) self._size = size.value - def to_values(self, dtype=None, *, indices=True, values=True, sort=True): - """Extract the indices and values as a 2-tuple of numpy arrays. - - .. deprecated:: 2022.11.0 - `Vector.to_values` will be removed in a future release. - Use `Vector.to_coo` instead. Will be removed in version 2023.9.0 or later - - Parameters - ---------- - dtype : - Requested dtype for the output values array. - indices :bool, default=True - Whether to return indices; will return `None` for indices if `False` - values : bool, default=True - Whether to return values; will return `None` for values if `False` - sort : bool, default=True - Whether to require sorted indices. - - Returns - ------- - np.ndarray[dtype=uint64] : Indices - np.ndarray : Values - """ - warnings.warn( - "`Vector.to_values(...)` is deprecated; please use `Vector.to_coo(...)` instead.", - DeprecationWarning, - stacklevel=2, - ) - return self.to_coo(dtype, indices=indices, values=values, sort=sort) - def to_coo(self, dtype=None, *, indices=True, values=True, sort=True): """Extract the indices and values as a 2-tuple of numpy arrays. @@ -446,9 +470,9 @@ def to_coo(self, dtype=None, *, indices=True, values=True, sort=True): dtype : Requested dtype for the output values array. indices :bool, default=True - Whether to return indices; will return `None` for indices if `False` + Whether to return indices; will return ``None`` for indices if ``False`` values : bool, default=True - Whether to return values; will return `None` for values if `False` + Whether to return values; will return ``None`` for values if ``False`` sort : bool, default=True Whether to require sorted indices. @@ -462,6 +486,7 @@ def to_coo(self, dtype=None, *, indices=True, values=True, sort=True): ------- np.ndarray[dtype=uint64] : Indices np.ndarray : Values + """ if sort and backend == "suitesparse": self.wait() # sort in SS @@ -498,7 +523,7 @@ def build(self, indices, values, *, dup_op=None, clear=False, size=None): """Rarely used method to insert values into an existing Vector. The typical use case is to create a new Vector and insert values at the same time using :meth:`from_coo`. - All the arguments are used identically in :meth:`from_coo`, except for `clear`, which + All the arguments are used identically in :meth:`from_coo`, except for ``clear``, which indicates whether to clear the Vector prior to adding the new values. """ # TODO: accept `dtype` keyword to match the dtype of `values`? @@ -520,14 +545,15 @@ def build(self, indices, values, *, dup_op=None, clear=False, size=None): if not dup_op_given: if not self.dtype._is_udt: dup_op = binary.plus - else: + elif backend != "suitesparse": dup_op = binary.any - # SS:SuiteSparse-specific: we could use NULL for dup_op - dup_op = get_typed_op(dup_op, self.dtype, kind="binary") - if dup_op.opclass == "Monoid": - dup_op = dup_op.binaryop - else: - self._expect_op(dup_op, "BinaryOp", within="build", argname="dup_op") + # SS:SuiteSparse-specific: we use NULL for dup_op + if dup_op is not None: + dup_op = get_typed_op(dup_op, self.dtype, kind="binary") + if dup_op.opclass == "Monoid": + dup_op = dup_op.binaryop + else: + self._expect_op(dup_op, "BinaryOp", within="build", argname="dup_op") indices = _CArray(indices) values = _CArray(values, self.dtype) @@ -560,6 +586,7 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): Returns ------- Vector + """ if dtype is not None or mask is not None or clear: if dtype is None: @@ -570,7 +597,7 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): else: if opts: # Ignore opts for now - descriptor_lookup(**opts) + desc = descriptor_lookup(**opts) # noqa: F841 (keep desc in scope for context) rv = Vector._from_obj(ffi_new("GrB_Vector*"), self.dtype, self._size, name=name) call("GrB_Vector_dup", [_Pointer(rv), self]) return rv @@ -590,6 +617,7 @@ def diag(self, k=0, *, name=None): Returns ------- :class:`~graphblas.Matrix` + """ from .matrix import Matrix @@ -614,6 +642,7 @@ def wait(self, how="materialize"): Use wait to force completion of the Vector. Has no effect in `blocking mode <../user_guide/init.html#graphblas-modes>`__. + """ how = how.lower() if how == "materialize": @@ -638,6 +667,7 @@ def get(self, index, default=None): Returns ------- Python scalar + """ expr = self[index] if expr._is_scalar: @@ -648,43 +678,6 @@ def get(self, index, default=None): "A single index should be given, and the result will be a Python scalar." ) - @classmethod - def from_values(cls, indices, values, dtype=None, *, size=None, dup_op=None, name=None): - """Create a new Vector from indices and values. - - .. deprecated:: 2022.11.0 - `Vector.from_values` will be removed in a future release. - Use `Vector.from_coo` instead. Will be removed in version 2023.9.0 or later - - Parameters - ---------- - indices : list or np.ndarray - Vector indices. - values : list or np.ndarray or scalar - List of values. If a scalar is provided, all values will be set to this single value. - dtype : - Data type of the Vector. If not provided, the values will be inspected - to choose an appropriate dtype. - size : int, optional - Size of the Vector. If not provided, ``size`` is computed from - the maximum index found in ``indices``. - dup_op : BinaryOp, optional - Function used to combine values if duplicate indices are found. - Leaving ``dup_op=None`` will raise an error if duplicates are found. - name : str, optional - Name to give the Vector. - - Returns - ------- - Vector - """ - warnings.warn( - "`Vector.from_values(...)` is deprecated; please use `Vector.from_coo(...)` instead.", - DeprecationWarning, - stacklevel=2, - ) - return cls.from_coo(indices, values, dtype, size=size, dup_op=dup_op, name=name) - @classmethod def from_coo(cls, indices, values=1.0, dtype=None, *, size=None, dup_op=None, name=None): """Create a new Vector from indices and values. @@ -717,6 +710,7 @@ def from_coo(cls, indices, values=1.0, dtype=None, *, size=None, dup_op=None, na Returns ------- Vector + """ indices = ints_to_numpy_buffer(indices, np.uint64, name="indices") values, dtype = values_to_numpy_buffer(values, dtype, subarray_after=1) @@ -774,10 +768,11 @@ def from_pairs(cls, pairs, dtype=None, *, size=None, dup_op=None, name=None): Returns ------- Vector + """ if isinstance(pairs, np.ndarray): raise TypeError("pairs as NumPy array is not supported; use `Vector.from_coo` instead") - unzipped = list(zip(*pairs)) + unzipped = list(zip(*pairs, strict=True)) if len(unzipped) == 2: indices, values = unzipped elif not unzipped: @@ -825,6 +820,7 @@ def from_scalar(cls, value, size, dtype=None, *, name=None, **opts): Returns ------- Vector + """ if type(value) is not Scalar: try: @@ -877,6 +873,7 @@ def from_dense(cls, values, missing_value=None, *, dtype=None, name=None, **opts Returns ------- Vector + """ values, dtype = values_to_numpy_buffer(values, dtype, subarray_after=1) if values.ndim == 0: @@ -925,6 +922,7 @@ def to_dense(self, fill_value=None, dtype=None, **opts): Returns ------- np.ndarray + """ if fill_value is None or self._nvals == self._size: if self._nvals != self._size: @@ -995,16 +993,43 @@ def ewise_add(self, other, op=monoid.plus): # Functional syntax w << monoid.max(u | v) + """ + return self._ewise_add(other, op) + + def _ewise_add(self, other, op=monoid.plus, is_infix=False): from .matrix import Matrix, MatrixExpression, TransposedMatrix method_name = "ewise_add" - other = self._expect_type( - other, (Vector, Matrix, TransposedMatrix), within=method_name, argname="other", op=op - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="binary") - # Per the spec, op may be a semiring, but this is weird, so don't. - self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if is_infix: + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr + + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix, MatrixEwiseAddExpr, VectorEwiseAddExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if isinstance(self, VectorEwiseAddExpr): + self = op(self).new() + if isinstance(other, InfixExprBase): + other = op(other).new() + else: + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix), + within=method_name, + argname="other", + op=op, + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if other.ndim == 2: # Broadcast columnwise from the left if other._nrows != self._size: @@ -1060,16 +1085,42 @@ def ewise_mult(self, other, op=binary.times): # Functional syntax w << binary.gt(u & v) + """ + return self._ewise_mult(other, op) + + def _ewise_mult(self, other, op=binary.times, is_infix=False): from .matrix import Matrix, MatrixExpression, TransposedMatrix method_name = "ewise_mult" - other = self._expect_type( - other, (Vector, Matrix, TransposedMatrix), within=method_name, argname="other", op=op - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="binary") - # Per the spec, op may be a semiring, but this is weird, so don't. - self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if is_infix: + from .infix import MatrixEwiseMultExpr, VectorEwiseMultExpr + + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix, MatrixEwiseMultExpr, VectorEwiseMultExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") + if isinstance(self, VectorEwiseMultExpr): + self = op(self).new() + if isinstance(other, InfixExprBase): + other = op(other).new() + else: + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix), + within=method_name, + argname="other", + op=op, + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + # Per the spec, op may be a semiring, but this is weird, so don't. + self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if other.ndim == 2: # Broadcast columnwise from the left if other._nrows != self._size: @@ -1128,14 +1179,37 @@ def ewise_union(self, other, op, left_default, right_default): # Functional syntax w << binary.div(u | v, left_default=1, right_default=1) + """ + return self._ewise_union(other, op, left_default, right_default) + + def _ewise_union(self, other, op, left_default, right_default, is_infix=False): from .matrix import Matrix, MatrixExpression, TransposedMatrix method_name = "ewise_union" - other = self._expect_type( - other, (Vector, Matrix, TransposedMatrix), within=method_name, argname="other", op=op - ) - dtype = self.dtype if self.dtype._is_udt else None + if is_infix: + from .infix import MatrixEwiseAddExpr, VectorEwiseAddExpr + + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix, MatrixEwiseAddExpr, VectorEwiseAddExpr), + within=method_name, + argname="other", + op=op, + ) + temp_op = _get_typed_op_from_exprs(op, self, other, kind="binary") + else: + other = self._expect_type( + other, + (Vector, Matrix, TransposedMatrix), + within=method_name, + argname="other", + op=op, + ) + temp_op = get_typed_op(op, self.dtype, other.dtype, kind="binary") + + left_dtype = temp_op.type + dtype = left_dtype if left_dtype._is_udt else None if type(left_default) is not Scalar: try: left = Scalar.from_value( @@ -1152,6 +1226,8 @@ def ewise_union(self, other, op, left_default, right_default): ) else: left = _as_scalar(left_default, dtype, is_cscalar=False) # pragma: is_grbscalar + right_dtype = temp_op.type2 + dtype = right_dtype if right_dtype._is_udt else None if type(right_default) is not Scalar: try: right = Scalar.from_value( @@ -1168,12 +1244,29 @@ def ewise_union(self, other, op, left_default, right_default): ) else: right = _as_scalar(right_default, dtype, is_cscalar=False) # pragma: is_grbscalar - scalar_dtype = unify(left.dtype, right.dtype) - nonscalar_dtype = unify(self.dtype, other.dtype) - op = get_typed_op(op, scalar_dtype, nonscalar_dtype, is_left_scalar=True, kind="binary") + + if is_infix: + op1 = _get_typed_op_from_exprs(op, self, right, kind="binary") + op2 = _get_typed_op_from_exprs(op, left, other, kind="binary") + else: + op1 = get_typed_op(op, self.dtype, right.dtype, kind="binary") + op2 = get_typed_op(op, left.dtype, other.dtype, kind="binary") + if op1 is not op2: + left_dtype = unify(op1.type, op2.type, is_right_scalar=True) + right_dtype = unify(op1.type2, op2.type2, is_left_scalar=True) + op = get_typed_op(op, left_dtype, right_dtype, kind="binary") + else: + op = op1 self._expect_op(op, ("BinaryOp", "Monoid"), within=method_name, argname="op") if op.opclass == "Monoid": op = op.binaryop + + if is_infix: + if isinstance(self, VectorEwiseAddExpr): + self = op(self, left_default=left, right_default=right).new() + if isinstance(other, InfixExprBase): + other = op(other, left_default=left, right_default=right).new() + expr_repr = "{0.name}.{method_name}({2.name}, {op}, {1._expr_name}, {3._expr_name})" if other.ndim == 2: # Broadcast columnwise from the left @@ -1201,11 +1294,10 @@ def ewise_union(self, other, op, left_default, right_default): expr_repr=expr_repr, ) else: - dtype = unify(scalar_dtype, nonscalar_dtype, is_left_scalar=True) expr = VectorExpression( method_name, None, - [self, left, other, right, _v_union_v, (self, other, left, right, op, dtype)], + [self, left, other, right, _v_union_v, (self, other, left, right, op)], expr_repr=expr_repr, size=self._size, op=op, @@ -1242,15 +1334,37 @@ def vxm(self, other, op=semiring.plus_times): # Functional syntax C << semiring.min_plus(v @ A) + """ + return self._vxm(other, op) + + def _vxm(self, other, op=semiring.plus_times, is_infix=False): from .matrix import Matrix, TransposedMatrix method_name = "vxm" - other = self._expect_type( - other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op - ) - op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") - self._expect_op(op, "Semiring", within=method_name, argname="op") + if is_infix: + from .infix import MatrixMatMulExpr, VectorMatMulExpr + + other = self._expect_type( + other, + (Matrix, TransposedMatrix, MatrixMatMulExpr), + within=method_name, + argname="other", + op=op, + ) + op = _get_typed_op_from_exprs(op, self, other, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if isinstance(self, VectorMatMulExpr): + self = op(self).new() + if isinstance(other, MatrixMatMulExpr): + other = op(other).new() + else: + other = self._expect_type( + other, (Matrix, TransposedMatrix), within=method_name, argname="other", op=op + ) + op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + expr = VectorExpression( method_name, "GrB_vxm", @@ -1300,6 +1414,7 @@ def apply(self, op, right=None, *, left=None): # Functional syntax w << op.abs(v) + """ method_name = "apply" extra_message = ( @@ -1445,6 +1560,7 @@ def select(self, op, thunk=None): # Functional syntax w << select.value(v >= 1) + """ method_name = "select" if isinstance(op, str): @@ -1500,6 +1616,7 @@ def select(self, op, thunk=None): if thunk.dtype._is_udt: dtype_name = "UDT" thunk = _Pointer(thunk) + # NOT COVERED else: dtype_name = thunk.dtype.name cfunc_name = f"GrB_Vector_select_{dtype_name}" @@ -1538,6 +1655,7 @@ def reduce(self, op=monoid.plus, *, allow_empty=True): .. code-block:: python total << v.reduce(monoid.plus) + """ method_name = "reduce" op = get_typed_op(op, self.dtype, kind="binary|aggregator") @@ -1590,11 +1708,29 @@ def inner(self, other, op=semiring.plus_times): *Note*: This is not a standard GraphBLAS function, but fits with other functions in the `Matrix Multiplication <../user_guide/operations.html#matrix-multiply>`__ family of functions. + """ + return self._inner(other, op) + + def _inner(self, other, op=semiring.plus_times, is_infix=False): method_name = "inner" - other = self._expect_type(other, Vector, within=method_name, argname="other", op=op) - op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") - self._expect_op(op, "Semiring", within=method_name, argname="op") + if is_infix: + from .infix import VectorMatMulExpr + + other = self._expect_type( + other, (Vector, VectorMatMulExpr), within=method_name, argname="other", op=op + ) + op = _get_typed_op_from_exprs(op, self, other, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + if isinstance(self, VectorMatMulExpr): + self = op(self).new() + if isinstance(other, VectorMatMulExpr): + other = op(other).new() + else: + other = self._expect_type(other, Vector, within=method_name, argname="other", op=op) + op = get_typed_op(op, self.dtype, other.dtype, kind="semiring") + self._expect_op(op, "Semiring", within=method_name, argname="op") + expr = ScalarExpression( method_name, "GrB_vxm", @@ -1628,6 +1764,7 @@ def outer(self, other, op=binary.times): C << v.outer(w, op=binary.times) *Note*: This is not a standard GraphBLAS function. + """ from .matrix import MatrixExpression @@ -1676,6 +1813,7 @@ def reposition(self, offset, *, size=None): .. code-block:: python w = v.reposition(20).new() + """ if size is None: size = self._size @@ -1714,7 +1852,7 @@ def _extract_element( result = Scalar(dtype, is_cscalar=is_cscalar, name=name) if opts: # Ignore opts for now - descriptor_lookup(**opts) + desc = descriptor_lookup(**opts) # noqa: F841 (keep desc in scope for context) if is_cscalar: dtype_name = "UDT" if dtype._is_udt else dtype.name if ( @@ -1817,13 +1955,14 @@ def _prep_for_assign(self, resolved_indexes, value, mask, is_submask, replace, o shape = values.shape try: vals = Vector.from_dense(values, dtype=dtype) - except Exception: # pragma: no cover (safety) + except Exception: vals = None else: if dtype.np_type.subdtype is not None: shape = vals.shape if vals is None or shape != (size,): if dtype.np_type.subdtype is not None: + # NOT COVERED extra = ( " (this is assigning to a vector with sub-array dtype " f"({dtype}), so array shape should include dtype shape)" @@ -1868,12 +2007,11 @@ def _prep_for_assign(self, resolved_indexes, value, mask, is_submask, replace, o else: cfunc_name = f"GrB_Vector_assign_{dtype_name}" mask = _vanilla_subassign_mask(self, mask, idx, replace, opts) + elif backend == "suitesparse": + cfunc_name = "GxB_Vector_subassign_Scalar" else: - if backend == "suitesparse": - cfunc_name = "GxB_Vector_subassign_Scalar" - else: - cfunc_name = "GrB_Vector_assign_Scalar" - mask = _vanilla_subassign_mask(self, mask, idx, replace, opts) + cfunc_name = "GrB_Vector_assign_Scalar" + mask = _vanilla_subassign_mask(self, mask, idx, replace, opts) expr_repr = ( "[[{2._expr_name} elements]]" f"({mask.name})" # fmt: skip @@ -1936,6 +2074,7 @@ def from_dict(cls, d, dtype=None, *, size=None, name=None): Returns ------- Vector + """ indices = np.fromiter(d.keys(), np.uint64) if dtype is None: @@ -1944,7 +2083,7 @@ def from_dict(cls, d, dtype=None, *, size=None, name=None): # If we know the dtype, then using `np.fromiter` is much faster dtype = lookup_dtype(dtype) if dtype.np_type.subdtype is not None and np.__version__[:5] in {"1.21.", "1.22."}: - values, dtype = values_to_numpy_buffer(list(d.values()), dtype) + values, dtype = values_to_numpy_buffer(list(d.values()), dtype) # FLAKY COVERAGE else: values = np.fromiter(d.values(), dtype.np_type) if size is None and indices.size == 0: @@ -1963,9 +2102,10 @@ def to_dict(self): Returns ------- dict + """ indices, values = self.to_coo(sort=False) - return dict(zip(indices.tolist(), values.tolist())) + return dict(zip(indices.tolist(), values.tolist(), strict=True)) if backend == "suitesparse": @@ -2092,7 +2232,6 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): to_coo = wrapdoc(Vector.to_coo)(property(automethods.to_coo)) to_dense = wrapdoc(Vector.to_dense)(property(automethods.to_dense)) to_dict = wrapdoc(Vector.to_dict)(property(automethods.to_dict)) - to_values = wrapdoc(Vector.to_values)(property(automethods.to_values)) vxm = wrapdoc(Vector.vxm)(property(automethods.vxm)) wait = wrapdoc(Vector.wait)(property(automethods.wait)) # These raise exceptions @@ -2134,6 +2273,9 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): if clear: if dtype is None: dtype = self.dtype + if opts: + # Ignore opts for now + desc = descriptor_lookup(**opts) # noqa: F841 (keep desc in scope for context) return self.output_type(dtype, *self.shape, name=name) return self.new(dtype, mask=mask, name=name, **opts) @@ -2177,7 +2319,6 @@ def dup(self, dtype=None, *, clear=False, mask=None, name=None, **opts): to_coo = wrapdoc(Vector.to_coo)(property(automethods.to_coo)) to_dense = wrapdoc(Vector.to_dense)(property(automethods.to_dense)) to_dict = wrapdoc(Vector.to_dict)(property(automethods.to_dict)) - to_values = wrapdoc(Vector.to_values)(property(automethods.to_values)) vxm = wrapdoc(Vector.vxm)(property(automethods.vxm)) wait = wrapdoc(Vector.wait)(property(automethods.wait)) # These raise exceptions diff --git a/graphblas/dtypes/__init__.py b/graphblas/dtypes/__init__.py new file mode 100644 index 000000000..f9c144f13 --- /dev/null +++ b/graphblas/dtypes/__init__.py @@ -0,0 +1,46 @@ +from ..core.dtypes import ( + _INDEX, + BOOL, + FP32, + FP64, + INT8, + INT16, + INT32, + INT64, + UINT8, + UINT16, + UINT32, + UINT64, + DataType, + _supports_complex, + lookup_dtype, + register_anonymous, + register_new, + unify, +) + +if _supports_complex: + from ..core.dtypes import FC32, FC64 + + +def __dir__(): + return globals().keys() | {"ss"} + + +def __getattr__(key): + if key == "ss": + from .. import backend + + if backend != "suitesparse": + raise AttributeError( + f'module {__name__!r} only has attribute "ss" when backend is "suitesparse"' + ) + from importlib import import_module + + ss = import_module(".ss", __name__) + globals()["ss"] = ss + return ss + raise AttributeError(f"module {__name__!r} has no attribute {key!r}") + + +_index_dtypes = {BOOL, INT8, UINT8, INT16, UINT16, INT32, UINT32, INT64, UINT64, _INDEX} diff --git a/graphblas/dtypes/ss.py b/graphblas/dtypes/ss.py new file mode 100644 index 000000000..9f6083e01 --- /dev/null +++ b/graphblas/dtypes/ss.py @@ -0,0 +1 @@ +from ..core.ss.dtypes import register_new # noqa: F401 diff --git a/graphblas/exceptions.py b/graphblas/exceptions.py index 0acc9ed0b..05cac988a 100644 --- a/graphblas/exceptions.py +++ b/graphblas/exceptions.py @@ -1,4 +1,3 @@ -from . import backend as _backend from .core import ffi as _ffi from .core import lib as _lib from .core.utils import _Pointer @@ -85,9 +84,14 @@ class NotImplementedException(GraphblasException): """ +# SuiteSparse errors +class JitError(GraphblasException): + """SuiteSparse:GraphBLAS error using JIT.""" + + # Our errors class UdfParseError(GraphblasException): - """Unable to parse the user-defined function.""" + """SuiteSparse:GraphBLAS unable to parse the user-defined function.""" _error_code_lookup = { @@ -112,8 +116,12 @@ class UdfParseError(GraphblasException): } GrB_SUCCESS = _lib.GrB_SUCCESS GrB_NO_VALUE = _lib.GrB_NO_VALUE -if _backend == "suitesparse": + +# SuiteSparse-specific errors +if hasattr(_lib, "GxB_EXHAUSTED"): _error_code_lookup[_lib.GxB_EXHAUSTED] = StopIteration +if hasattr(_lib, "GxB_JIT_ERROR"): # Added in 9.4 + _error_code_lookup[_lib.GxB_JIT_ERROR] = JitError def check_status(response_code, args): @@ -121,7 +129,7 @@ def check_status(response_code, args): return if response_code == GrB_NO_VALUE: return NoValue - if type(args) is list: + if isinstance(args, list): arg = args[0] else: arg = args diff --git a/graphblas/indexunary/__init__.py b/graphblas/indexunary/__init__.py index 472231597..a3cb06608 100644 --- a/graphblas/indexunary/__init__.py +++ b/graphblas/indexunary/__init__.py @@ -4,7 +4,7 @@ def __dir__(): - return globals().keys() | _delayed.keys() + return globals().keys() | _delayed.keys() | {"ss"} def __getattr__(key): @@ -13,6 +13,18 @@ def __getattr__(key): rv = func(**kwargs) globals()[key] = rv return rv + if key == "ss": + from .. import backend + + if backend != "suitesparse": + raise AttributeError( + f'module {__name__!r} only has attribute "ss" when backend is "suitesparse"' + ) + from importlib import import_module + + ss = import_module(".ss", __name__) + globals()["ss"] = ss + return ss raise AttributeError(f"module {__name__!r} has no attribute {key!r}") diff --git a/graphblas/indexunary/ss.py b/graphblas/indexunary/ss.py new file mode 100644 index 000000000..58218df6f --- /dev/null +++ b/graphblas/indexunary/ss.py @@ -0,0 +1,6 @@ +from ..core import operator +from ..core.ss.indexunary import register_new # noqa: F401 + +_delayed = {} + +del operator diff --git a/graphblas/io.py b/graphblas/io.py deleted file mode 100644 index e9d8ccfe6..000000000 --- a/graphblas/io.py +++ /dev/null @@ -1,631 +0,0 @@ -from warnings import warn as _warn - -import numpy as _np - -from . import backend as _backend -from .core.matrix import Matrix as _Matrix -from .core.utils import normalize_values as _normalize_values -from .core.utils import output_type as _output_type -from .core.vector import Vector as _Vector -from .dtypes import lookup_dtype as _lookup_dtype -from .exceptions import GraphblasException as _GraphblasException - - -def draw(m): # pragma: no cover - """Draw a square adjacency Matrix as a graph. - - Requires `networkx `_ and - `matplotlib `_ to be installed. - - Example output: - - .. image:: /_static/img/draw-example.png - """ - from . import viz - - _warn( - "`graphblas.io.draw` is deprecated; it has been moved to `graphblas.viz.draw`", - DeprecationWarning, - ) - viz.draw(m) - - -def from_networkx(G, nodelist=None, dtype=None, weight="weight", name=None): - """Create a square adjacency Matrix from a networkx Graph. - - Parameters - ---------- - G : nx.Graph - Graph to convert - nodelist : list, optional - List of nodes in the nx.Graph. If not provided, all nodes will be used. - dtype : - Data type - weight : str, default="weight" - Weight attribute - name : str, optional - Name of resulting Matrix - - Returns - ------- - :class:`~graphblas.Matrix` - """ - import networkx as nx - - if dtype is not None: - dtype = _lookup_dtype(dtype).np_type - A = nx.to_scipy_sparse_array(G, nodelist=nodelist, dtype=dtype, weight=weight) - return from_scipy_sparse(A, name=name) - - -def from_numpy(m): # pragma: no cover (deprecated) - """Create a sparse Vector or Matrix from a dense numpy array. - - .. deprecated:: 2023.2.0 - `from_numpy` will be removed in a future release. - Use `Vector.from_dense` or `Matrix.from_dense` instead. - Will be removed in version 2023.10.0 or later - - A value of 0 is considered as "missing". - - - m.ndim == 1 returns a `Vector` - - m.ndim == 2 returns a `Matrix` - - m.ndim > 2 raises an error - - dtype is inferred from m.dtype - - Parameters - ---------- - m : np.ndarray - Input array - - See Also - -------- - Matrix.from_dense - Vector.from_dense - from_scipy_sparse - - Returns - ------- - Vector or Matrix - """ - _warn( - "`graphblas.io.from_numpy` is deprecated; " - "use `Matrix.from_dense` and `Vector.from_dense` instead.", - DeprecationWarning, - ) - if m.ndim > 2: - raise _GraphblasException("m.ndim must be <= 2") - - try: - from scipy.sparse import coo_array, csr_array - except ImportError: # pragma: no cover (import) - raise ImportError("scipy is required to import from numpy") from None - - if m.ndim == 1: - A = csr_array(m) - _, size = A.shape - dtype = _lookup_dtype(m.dtype) - return _Vector.from_coo(A.indices, A.data, size=size, dtype=dtype) - A = coo_array(m) - return from_scipy_sparse(A) - - -def from_scipy_sparse(A, *, dup_op=None, name=None): - """Create a Matrix from a scipy.sparse array or matrix. - - Input data in "csr" or "csc" format will be efficient when importing with SuiteSparse:GraphBLAS. - - Parameters - ---------- - A : scipy.sparse - Scipy sparse array or matrix - dup_op : BinaryOp, optional - Aggregation function for formats that allow duplicate entries (e.g. coo) - name : str, optional - Name of resulting Matrix - - Returns - ------- - :class:`~graphblas.Matrix` - """ - nrows, ncols = A.shape - dtype = _lookup_dtype(A.dtype) - if A.nnz == 0: - return _Matrix(dtype, nrows=nrows, ncols=ncols, name=name) - if _backend == "suitesparse" and A.format in {"csr", "csc"}: - data = A.data - is_iso = (data[[0]] == data).all() - if is_iso: - data = data[[0]] - if A.format == "csr": - return _Matrix.ss.import_csr( - nrows=nrows, - ncols=ncols, - indptr=A.indptr, - col_indices=A.indices, - values=data, - is_iso=is_iso, - sorted_cols=getattr(A, "_has_sorted_indices", False), - name=name, - ) - return _Matrix.ss.import_csc( - nrows=nrows, - ncols=ncols, - indptr=A.indptr, - row_indices=A.indices, - values=data, - is_iso=is_iso, - sorted_rows=getattr(A, "_has_sorted_indices", False), - name=name, - ) - if A.format == "csr": - return _Matrix.from_csr(A.indptr, A.indices, A.data, ncols=ncols, name=name) - if A.format == "csc": - return _Matrix.from_csc(A.indptr, A.indices, A.data, nrows=nrows, name=name) - if A.format != "coo": - A = A.tocoo() - return _Matrix.from_coo( - A.row, A.col, A.data, nrows=nrows, ncols=ncols, dtype=dtype, dup_op=dup_op, name=name - ) - - -def from_awkward(A, *, name=None): - """Create a Matrix or Vector from an Awkward Array. - - The Awkward Array must have top-level parameters: format, shape - - The Awkward Array must have top-level attributes based on format: - - vec/csr/csc: values, indices - - hypercsr/hypercsc: values, indices, offset_labels - - Parameters - ---------- - A : awkward.Array - Awkward Array with values and indices - name : str, optional - Name of resulting Matrix or Vector - - Returns - ------- - Vector or Matrix - """ - params = A.layout.parameters - if missing := {"format", "shape"} - params.keys(): - raise ValueError(f"Missing parameters: {missing}") - format = params["format"] - shape = params["shape"] - - if len(shape) == 1: - if format != "vec": - raise ValueError(f"Invalid format for Vector: {format}") - return _Vector.from_coo( - A.indices.layout.data, A.values.layout.data, size=shape[0], name=name - ) - nrows, ncols = shape - values = A.values.layout.content.data - indptr = A.values.layout.offsets.data - if format == "csr": - cols = A.indices.layout.content.data - return _Matrix.from_csr(indptr, cols, values, ncols=ncols, name=name) - if format == "csc": - rows = A.indices.layout.content.data - return _Matrix.from_csc(indptr, rows, values, nrows=nrows, name=name) - if format == "hypercsr": - rows = A.offset_labels.layout.data - cols = A.indices.layout.content.data - return _Matrix.from_dcsr(rows, indptr, cols, values, nrows=nrows, ncols=ncols, name=name) - if format == "hypercsc": - cols = A.offset_labels.layout.data - rows = A.indices.layout.content.data - return _Matrix.from_dcsc(cols, indptr, rows, values, nrows=nrows, ncols=ncols, name=name) - raise ValueError(f"Invalid format for Matrix: {format}") - - -def from_pydata_sparse(s, *, dup_op=None, name=None): - """Create a Vector or a Matrix from a pydata.sparse array or matrix. - - Input data in "gcxs" format will be efficient when importing with SuiteSparse:GraphBLAS. - - Parameters - ---------- - s : sparse - PyData sparse array or matrix (see https://sparse.pydata.org) - dup_op : BinaryOp, optional - Aggregation function for formats that allow duplicate entries (e.g. coo) - name : str, optional - Name of resulting Matrix - - Returns - ------- - :class:`~graphblas.Vector` - :class:`~graphblas.Matrix` - """ - try: - import sparse - except ImportError: # pragma: no cover (import) - raise ImportError("sparse is required to import from pydata sparse") from None - if not isinstance(s, sparse.SparseArray): - raise TypeError( - "from_pydata_sparse only accepts objects from the `sparse` library; " - "see https://sparse.pydata.org" - ) - if s.ndim > 2: - raise _GraphblasException("m.ndim must be <= 2") - - if s.ndim == 1: - # the .asformat('coo') makes it easier to convert dok/gcxs using a single approach - _s = s.asformat("coo") - return _Vector.from_coo( - _s.coords, _s.data, dtype=_s.dtype, size=_s.shape[0], dup_op=dup_op, name=name - ) - # handle two-dimensional arrays - if isinstance(s, sparse.GCXS): - return from_scipy_sparse(s.to_scipy_sparse(), dup_op=dup_op, name=name) - if isinstance(s, (sparse.DOK, sparse.COO)): - _s = s.asformat("coo") - return _Matrix.from_coo( - *_s.coords, - _s.data, - nrows=_s.shape[0], - ncols=_s.shape[1], - dtype=_s.dtype, - dup_op=dup_op, - name=name, - ) - raise ValueError(f"Unknown sparse array type: {type(s).__name__}") # pragma: no cover (safety) - - -# TODO: add parameters to allow different networkx classes and attribute names -def to_networkx(m, edge_attribute="weight"): - """Create a networkx DiGraph from a square adjacency Matrix. - - Parameters - ---------- - m : Matrix - Square adjacency Matrix - edge_attribute : str, optional - Name of edge attribute from values of Matrix. If None, values will be skipped. - Default is "weight". - - Returns - ------- - nx.DiGraph - """ - import networkx as nx - - rows, cols, vals = m.to_coo() - rows = rows.tolist() - cols = cols.tolist() - G = nx.DiGraph() - if edge_attribute is None: - G.add_edges_from(zip(rows, cols)) - else: - G.add_weighted_edges_from(zip(rows, cols, vals.tolist()), weight=edge_attribute) - return G - - -def to_numpy(m): # pragma: no cover (deprecated) - """Create a dense numpy array from a sparse Vector or Matrix. - - .. deprecated:: 2023.2.0 - `to_numpy` will be removed in a future release. - Use `Vector.to_dense` or `Matrix.to_dense` instead. - Will be removed in version 2023.10.0 or later - - Missing values will become 0 in the output. - - numpy dtype will match the GraphBLAS dtype - - Parameters - ---------- - m : Vector or Matrix - GraphBLAS Vector or Matrix - - See Also - -------- - to_scipy_sparse - Matrix.to_dense - Vector.to_dense - - Returns - ------- - np.ndarray - """ - _warn( - "`graphblas.io.to_numpy` is deprecated; " - "use `Matrix.to_dense` and `Vector.to_dense` instead.", - DeprecationWarning, - ) - try: - import scipy # noqa: F401 - except ImportError: # pragma: no cover (import) - raise ImportError("scipy is required to export to numpy") from None - if _output_type(m) is _Vector: - return to_scipy_sparse(m).toarray()[0] - sparse = to_scipy_sparse(m, "coo") - return sparse.toarray() - - -def to_scipy_sparse(A, format="csr"): - """Create a scipy.sparse array from a GraphBLAS Matrix or Vector. - - Parameters - ---------- - A : Matrix or Vector - GraphBLAS object to be converted - format : str - {'bsr', 'csr', 'csc', 'coo', 'lil', 'dia', 'dok'} - - Returns - ------- - scipy.sparse array - - """ - import scipy.sparse as ss - - format = format.lower() - if format not in {"bsr", "csr", "csc", "coo", "lil", "dia", "dok"}: - raise ValueError(f"Invalid format: {format}") - if _output_type(A) is _Vector: - indices, data = A.to_coo() - if format == "csc": - return ss.csc_array((data, indices, [0, len(data)]), shape=(A._size, 1)) - rv = ss.csr_array((data, indices, [0, len(data)]), shape=(1, A._size)) - if format == "csr": - return rv - elif _backend == "suitesparse" and format in {"csr", "csc"}: - if A._is_transposed: - info = A.T.ss.export("csc" if format == "csr" else "csr", sort=True) - if "col_indices" in info: - info["row_indices"] = info["col_indices"] - else: - info["col_indices"] = info["row_indices"] - else: - info = A.ss.export(format, sort=True) - values = _normalize_values(A, info["values"], None, (A._nvals,), info["is_iso"]) - if format == "csr": - return ss.csr_array((values, info["col_indices"], info["indptr"]), shape=A.shape) - return ss.csc_array((values, info["row_indices"], info["indptr"]), shape=A.shape) - elif format == "csr": - indptr, cols, vals = A.to_csr() - return ss.csr_array((vals, cols, indptr), shape=A.shape) - elif format == "csc": - indptr, rows, vals = A.to_csc() - return ss.csc_array((vals, rows, indptr), shape=A.shape) - else: - rows, cols, data = A.to_coo() - rv = ss.coo_array((data, (rows, cols)), shape=A.shape) - if format == "coo": - return rv - return rv.asformat(format) - - -_AwkwardDoublyCompressedMatrix = None - - -def to_awkward(A, format=None): - """Create an Awkward Array from a GraphBLAS Matrix. - - Parameters - ---------- - A : Matrix or Vector - GraphBLAS object to be converted - format : str {'csr', 'csc', 'hypercsr', 'hypercsc', 'vec} - Default format is csr for Matrix; vec for Vector - - The Awkward Array will have top-level attributes based on format: - - vec/csr/csc: values, indices - - hypercsr/hypercsc: values, indices, offset_labels - - Top-level parameters will also be set: format, shape - - Returns - ------- - awkward.Array - - """ - try: - # awkward version 1 - # MAINT: we can probably drop awkward v1 at the end of 2024 or 2025 - import awkward._v2 as ak - from awkward._v2.forms.listoffsetform import ListOffsetForm - from awkward._v2.forms.numpyform import NumpyForm - from awkward._v2.forms.recordform import RecordForm - except ImportError: - # awkward version 2 - import awkward as ak - from awkward.forms.listoffsetform import ListOffsetForm - from awkward.forms.numpyform import NumpyForm - from awkward.forms.recordform import RecordForm - - out_type = _output_type(A) - if format is None: - format = "vec" if out_type is _Vector else "csr" - format = format.lower() - classname = None - - if out_type is _Vector: - if format != "vec": - raise ValueError(f"Invalid format for Vector: {format}") - size = A.nvals - indices, values = A.to_coo() - form = RecordForm( - contents=[ - NumpyForm(A.dtype.numba_type.name, form_key="node1"), - NumpyForm("int64", form_key="node0"), - ], - fields=["values", "indices"], - ) - d = {"node0-data": indices, "node1-data": values} - - elif out_type is _Matrix: - if format == "csr": - indptr, cols, values = A.to_csr() - d = {"node3-data": cols} - size = A.nrows - elif format == "csc": - indptr, rows, values = A.to_csc() - d = {"node3-data": rows} - size = A.ncols - elif format == "hypercsr": - rows, indptr, cols, values = A.to_dcsr() - d = {"node3-data": cols, "node5-data": rows} - size = len(rows) - elif format == "hypercsc": - cols, indptr, rows, values = A.to_dcsc() - d = {"node3-data": rows, "node5-data": cols} - size = len(cols) - else: - raise ValueError(f"Invalid format for Matrix: {format}") - d["node1-offsets"] = indptr - d["node4-data"] = _np.ascontiguousarray(values) - - form = ListOffsetForm( - "i64", - RecordForm( - contents=[ - NumpyForm("int64", form_key="node3"), - NumpyForm(A.dtype.numba_type.name, form_key="node4"), - ], - fields=["indices", "values"], - ), - form_key="node1", - ) - if format.startswith("hyper"): - global _AwkwardDoublyCompressedMatrix - if _AwkwardDoublyCompressedMatrix is None: # pylint: disable=used-before-assignment - # Define behaviors to make all fields function at the top-level - @ak.behaviors.mixins.mixin_class(ak.behavior) - class _AwkwardDoublyCompressedMatrix: - @property - def values(self): - return self.data.values - - @property - def indices(self): - return self.data.indices - - form = RecordForm( - contents=[ - form, - NumpyForm("int64", form_key="node5"), - ], - fields=["data", "offset_labels"], - ) - classname = "_AwkwardDoublyCompressedMatrix" - - else: - raise TypeError(f"A must be a Matrix or Vector, found {type(A)}") - - ret = ak.from_buffers(form, size, d) - ret = ak.with_parameter(ret, "format", format) - ret = ak.with_parameter(ret, "shape", list(A.shape)) - if classname: - ret = ak.with_name(ret, classname) - return ret - - -def to_pydata_sparse(A, format="coo"): - """Create a pydata.sparse array from a GraphBLAS Matrix or Vector. - - Parameters - ---------- - A : Matrix or Vector - GraphBLAS object to be converted - format : str - {'coo', 'dok', 'gcxs'} - - Returns - ------- - sparse array (see https://sparse.pydata.org) - - """ - try: - from sparse import COO - except ImportError: # pragma: no cover (import) - raise ImportError("sparse is required to export to pydata sparse") from None - - format = format.lower() - if format not in {"coo", "dok", "gcxs"}: - raise ValueError(f"Invalid format: {format}") - - if _output_type(A) is _Vector: - indices, values = A.to_coo(sort=False) - s = COO(indices, values, shape=A.shape) - else: - if format == "gcxs": - B = to_scipy_sparse(A, format="csr") - else: - # obtain an intermediate conversion via hardcoded 'coo' intermediate object - B = to_scipy_sparse(A, format="coo") - # convert to pydata.sparse - s = COO.from_scipy_sparse(B) - - # express in the desired format - return s.asformat(format) - - -def mmread(source, *, dup_op=None, name=None): - """Create a GraphBLAS Matrix from the contents of a Matrix Market file. - - This uses `scipy.io.mmread - `_. - - Parameters - ---------- - filename : str or file - Filename (.mtx or .mtz.gz) or file-like object - dup_op : BinaryOp, optional - Aggregation function for duplicate coordinates (if found) - name : str, optional - Name of resulting Matrix - - Returns - ------- - :class:`~graphblas.Matrix` - """ - try: - from scipy.io import mmread - from scipy.sparse import isspmatrix_coo - except ImportError: # pragma: no cover (import) - raise ImportError("scipy is required to read Matrix Market files") from None - array = mmread(source) - if isspmatrix_coo(array): - nrows, ncols = array.shape - return _Matrix.from_coo( - array.row, array.col, array.data, nrows=nrows, ncols=ncols, dup_op=dup_op, name=name - ) - return _Matrix.from_dense(array, name=name) - - -def mmwrite(target, matrix, *, comment="", field=None, precision=None, symmetry=None): - """Write a Matrix Market file from the contents of a GraphBLAS Matrix. - - This uses `scipy.io.mmwrite - `_. - - Parameters - ---------- - filename : str or file target - Filename (.mtx) or file-like object opened for writing - matrix : Matrix - Matrix to be written - comment : str, optional - Comments to be prepended to the Matrix Market file - field : str - {"real", "complex", "pattern", "integer"} - precision : int, optional - Number of digits to write for real or complex values - symmetry : str, optional - {"general", "symmetric", "skew-symmetric", "hermetian"} - """ - try: - from scipy.io import mmwrite - except ImportError: # pragma: no cover (import) - raise ImportError("scipy is required to write Matrix Market files") from None - if _backend == "suitesparse" and matrix.ss.format in {"fullr", "fullc"}: - array = matrix.ss.export()["values"] - else: - array = to_scipy_sparse(matrix, format="coo") - mmwrite(target, array, comment=comment, field=field, precision=precision, symmetry=symmetry) diff --git a/graphblas/io/__init__.py b/graphblas/io/__init__.py new file mode 100644 index 000000000..a1b71db40 --- /dev/null +++ b/graphblas/io/__init__.py @@ -0,0 +1,5 @@ +from ._awkward import from_awkward, to_awkward +from ._matrixmarket import mmread, mmwrite +from ._networkx import from_networkx, to_networkx +from ._scipy import from_scipy_sparse, to_scipy_sparse +from ._sparse import from_pydata_sparse, to_pydata_sparse diff --git a/graphblas/io/_awkward.py b/graphblas/io/_awkward.py new file mode 100644 index 000000000..b30984251 --- /dev/null +++ b/graphblas/io/_awkward.py @@ -0,0 +1,188 @@ +import numpy as np + +from ..core.matrix import Matrix +from ..core.utils import output_type +from ..core.vector import Vector + +_AwkwardDoublyCompressedMatrix = None + + +def to_awkward(A, format=None): + """Create an Awkward Array from a GraphBLAS Matrix. + + Parameters + ---------- + A : Matrix or Vector + GraphBLAS object to be converted + format : str {'csr', 'csc', 'hypercsr', 'hypercsc', 'vec} + Default format is csr for Matrix; vec for Vector + + The Awkward Array will have top-level attributes based on format: + - vec/csr/csc: values, indices + - hypercsr/hypercsc: values, indices, offset_labels + + Top-level parameters will also be set: format, shape + + Returns + ------- + awkward.Array + + """ + try: + # awkward version 1 + # MAINT: we can probably drop awkward v1 at the end of 2024 or 2025 + import awkward._v2 as ak + from awkward._v2.forms.listoffsetform import ListOffsetForm + from awkward._v2.forms.numpyform import NumpyForm + from awkward._v2.forms.recordform import RecordForm + except ImportError: + # awkward version 2 + import awkward as ak + from awkward.forms.listoffsetform import ListOffsetForm + from awkward.forms.numpyform import NumpyForm + from awkward.forms.recordform import RecordForm + + out_type = output_type(A) + if format is None: + format = "vec" if out_type is Vector else "csr" + format = format.lower() + classname = None + + if out_type is Vector: + if format != "vec": + raise ValueError(f"Invalid format for Vector: {format}") + size = A.nvals + indices, values = A.to_coo() + form = RecordForm( + contents=[ + NumpyForm(A.dtype.np_type.name, form_key="node1"), + NumpyForm("int64", form_key="node0"), + ], + fields=["values", "indices"], + ) + d = {"node0-data": indices, "node1-data": values} + + elif out_type is Matrix: + if format == "csr": + indptr, cols, values = A.to_csr() + d = {"node3-data": cols} + size = A.nrows + elif format == "csc": + indptr, rows, values = A.to_csc() + d = {"node3-data": rows} + size = A.ncols + elif format == "hypercsr": + rows, indptr, cols, values = A.to_dcsr() + d = {"node3-data": cols, "node5-data": rows} + size = len(rows) + elif format == "hypercsc": + cols, indptr, rows, values = A.to_dcsc() + d = {"node3-data": rows, "node5-data": cols} + size = len(cols) + else: + raise ValueError(f"Invalid format for Matrix: {format}") + d["node1-offsets"] = indptr + d["node4-data"] = np.ascontiguousarray(values) + + form = ListOffsetForm( + "i64", + RecordForm( + contents=[ + NumpyForm("int64", form_key="node3"), + NumpyForm(A.dtype.np_type.name, form_key="node4"), + ], + fields=["indices", "values"], + ), + form_key="node1", + ) + if format.startswith("hyper"): + global _AwkwardDoublyCompressedMatrix + if _AwkwardDoublyCompressedMatrix is None: # pylint: disable=used-before-assignment + # Define behaviors to make all fields function at the top-level + @ak.behaviors.mixins.mixin_class(ak.behavior) + class _AwkwardDoublyCompressedMatrix: + @property + def values(self): # pragma: no branch (???) + return self.data.values + + @property + def indices(self): # pragma: no branch (???) + return self.data.indices + + form = RecordForm( + contents=[ + form, + NumpyForm("int64", form_key="node5"), + ], + fields=["data", "offset_labels"], + ) + classname = "_AwkwardDoublyCompressedMatrix" + + else: + raise TypeError(f"A must be a Matrix or Vector, found {type(A)}") + + ret = ak.from_buffers(form, size, d) + ret = ak.with_parameter(ret, "format", format) + ret = ak.with_parameter(ret, "shape", list(A.shape)) + if classname: + ret = ak.with_name(ret, classname) + return ret + + +def from_awkward(A, *, name=None): + """Create a Matrix or Vector from an Awkward Array. + + The Awkward Array must have top-level parameters: format, shape + + The Awkward Array must have top-level attributes based on format: + - vec/csr/csc: values, indices + - hypercsr/hypercsc: values, indices, offset_labels + + Parameters + ---------- + A : awkward.Array + Awkward Array with values and indices + name : str, optional + Name of resulting Matrix or Vector + + Returns + ------- + Vector or Matrix + + Note: the intended purpose of this function is to facilitate + conversion of an `awkward-array` that was created via `to_awkward` + function. If attempting to convert an arbitrary `awkward-array`, + make sure that the top-level attributes and parameters contain + the expected values. + + """ + params = A.layout.parameters + if missing := {"format", "shape"} - params.keys(): + raise ValueError(f"Missing parameters: {missing}") + format = params["format"] + shape = params["shape"] + + if len(shape) == 1: + if format != "vec": + raise ValueError(f"Invalid format for Vector: {format}") + return Vector.from_coo( + A.indices.layout.data, A.values.layout.data, size=shape[0], name=name + ) + nrows, ncols = shape + values = A.values.layout.content.data + indptr = A.values.layout.offsets.data + if format == "csr": + cols = A.indices.layout.content.data + return Matrix.from_csr(indptr, cols, values, ncols=ncols, name=name) + if format == "csc": + rows = A.indices.layout.content.data + return Matrix.from_csc(indptr, rows, values, nrows=nrows, name=name) + if format == "hypercsr": + rows = A.offset_labels.layout.data + cols = A.indices.layout.content.data + return Matrix.from_dcsr(rows, indptr, cols, values, nrows=nrows, ncols=ncols, name=name) + if format == "hypercsc": + cols = A.offset_labels.layout.data + rows = A.indices.layout.content.data + return Matrix.from_dcsc(cols, indptr, rows, values, nrows=nrows, ncols=ncols, name=name) + raise ValueError(f"Invalid format for Matrix: {format}") diff --git a/graphblas/io/_matrixmarket.py b/graphblas/io/_matrixmarket.py new file mode 100644 index 000000000..8cf8738a3 --- /dev/null +++ b/graphblas/io/_matrixmarket.py @@ -0,0 +1,142 @@ +from .. import backend +from ..core.matrix import Matrix +from ._scipy import to_scipy_sparse + + +def mmread(source, engine="auto", *, dup_op=None, name=None, **kwargs): + """Create a GraphBLAS Matrix from the contents of a Matrix Market file. + + This uses `scipy.io.mmread + `_ + or `fast_matrix_market.mmread + `_. + + By default, ``fast_matrix_market`` will be used if available, because it + is faster. Additional keyword arguments in ``**kwargs`` will be passed + to the engine's ``mmread``. For example, ``parallelism=8`` will set the + number of threads to use to 8 when using ``fast_matrix_market``. + + Parameters + ---------- + source : str or file + Filename (.mtx or .mtz.gz) or file-like object + engine : {"auto", "scipy", "fmm", "fast_matrix_market"}, default "auto" + How to read the matrix market file. "scipy" uses ``scipy.io.mmread``, + "fmm" and "fast_matrix_market" uses ``fast_matrix_market.mmread``, + and "auto" will use "fast_matrix_market" if available. + dup_op : BinaryOp, optional + Aggregation function for duplicate coordinates (if found) + name : str, optional + Name of resulting Matrix + + Returns + ------- + :class:`~graphblas.Matrix` + + """ + try: + # scipy is currently needed for *all* engines + from scipy.io import mmread + except ImportError: # pragma: no cover (import) + raise ImportError("scipy is required to read Matrix Market files") from None + engine = engine.lower() + if engine in {"auto", "fmm", "fast_matrix_market"}: + try: + from fast_matrix_market import mmread # noqa: F811 + except ImportError: # pragma: no cover (import) + if engine != "auto": + raise ImportError( + "fast_matrix_market is required to read Matrix Market files " + f'using the "{engine}" engine' + ) from None + elif engine != "scipy": + raise ValueError( + f'Bad engine value: {engine!r}. Must be "auto", "scipy", "fmm", or "fast_matrix_market"' + ) + array = mmread(source, **kwargs) + if getattr(array, "format", None) == "coo": + nrows, ncols = array.shape + return Matrix.from_coo( + array.row, array.col, array.data, nrows=nrows, ncols=ncols, dup_op=dup_op, name=name + ) + return Matrix.from_dense(array, name=name) + + +def mmwrite( + target, + matrix, + engine="auto", + *, + comment="", + field=None, + precision=None, + symmetry=None, + **kwargs, +): + """Write a Matrix Market file from the contents of a GraphBLAS Matrix. + + This uses `scipy.io.mmwrite + `_. + + Parameters + ---------- + target : str or file target + Filename (.mtx) or file-like object opened for writing + matrix : Matrix + Matrix to be written + engine : {"auto", "scipy", "fmm", "fast_matrix_market"}, default "auto" + How to read the matrix market file. "scipy" uses ``scipy.io.mmwrite``, + "fmm" and "fast_matrix_market" uses ``fast_matrix_market.mmwrite``, + and "auto" will use "fast_matrix_market" if available. + comment : str, optional + Comments to be prepended to the Matrix Market file + field : str + {"real", "complex", "pattern", "integer"} + precision : int, optional + Number of digits to write for real or complex values + symmetry : str, optional + {"general", "symmetric", "skew-symmetric", "hermetian"} + + """ + try: + # scipy is currently needed for *all* engines + from scipy.io import mmwrite + except ImportError: # pragma: no cover (import) + raise ImportError("scipy is required to write Matrix Market files") from None + engine = engine.lower() + if engine in {"auto", "fmm", "fast_matrix_market"}: + try: + from fast_matrix_market import __version__, mmwrite # noqa: F811 + except ImportError: # pragma: no cover (import) + if engine != "auto": + raise ImportError( + "fast_matrix_market is required to write Matrix Market files " + f'using the "{engine}" engine' + ) from None + else: + import scipy as sp + + engine = "fast_matrix_market" + elif engine != "scipy": + raise ValueError( + f'Bad engine value: {engine!r}. Must be "auto", "scipy", "fmm", or "fast_matrix_market"' + ) + if backend == "suitesparse" and matrix.ss.format in {"fullr", "fullc"}: + array = matrix.ss.export()["values"] + else: + array = to_scipy_sparse(matrix, format="coo") + if engine == "fast_matrix_market" and __version__ < "1.7." and sp.__version__ > "1.11.": + # 2023-06-25: scipy 1.11.0 added `sparray` and changed e.g. `ss.isspmatrix_coo`. + # fast_matrix_market updated to handle this in version 1.7.0 + # Also, it looks like fast_matrix_market has special writers for csr and csc; + # should we see if using those are faster? + array = sp.sparse.coo_matrix(array) # FLAKY COVERAGE + mmwrite( + target, + array, + comment=comment, + field=field, + precision=precision, + symmetry=symmetry, + **kwargs, + ) diff --git a/graphblas/io/_networkx.py b/graphblas/io/_networkx.py new file mode 100644 index 000000000..8cf84e576 --- /dev/null +++ b/graphblas/io/_networkx.py @@ -0,0 +1,63 @@ +from ..dtypes import lookup_dtype +from ._scipy import from_scipy_sparse + + +def from_networkx(G, nodelist=None, dtype=None, weight="weight", name=None): + """Create a square adjacency Matrix from a networkx Graph. + + Parameters + ---------- + G : nx.Graph + Graph to convert + nodelist : list, optional + List of nodes in the nx.Graph. If not provided, all nodes will be used. + dtype : + Data type + weight : str, default="weight" + Weight attribute + name : str, optional + Name of resulting Matrix + + Returns + ------- + :class:`~graphblas.Matrix` + + """ + import networkx as nx + + if dtype is not None: + dtype = lookup_dtype(dtype).np_type + A = nx.to_scipy_sparse_array(G, nodelist=nodelist, dtype=dtype, weight=weight) + return from_scipy_sparse(A, name=name) + + +# TODO: add parameters to allow different networkx classes and attribute names +def to_networkx(m, edge_attribute="weight"): + """Create a networkx DiGraph from a square adjacency Matrix. + + Parameters + ---------- + m : Matrix + Square adjacency Matrix + edge_attribute : str, optional + Name of edge attribute from values of Matrix. If None, values will be skipped. + Default is "weight". + + Returns + ------- + nx.DiGraph + + """ + import networkx as nx + + rows, cols, vals = m.to_coo() + rows = rows.tolist() + cols = cols.tolist() + G = nx.DiGraph() + if edge_attribute is None: + G.add_edges_from(zip(rows, cols, strict=True)) + else: + G.add_weighted_edges_from( + zip(rows, cols, vals.tolist(), strict=True), weight=edge_attribute + ) + return G diff --git a/graphblas/io/_scipy.py b/graphblas/io/_scipy.py new file mode 100644 index 000000000..228432eed --- /dev/null +++ b/graphblas/io/_scipy.py @@ -0,0 +1,119 @@ +from .. import backend +from ..core.matrix import Matrix +from ..core.utils import normalize_values, output_type +from ..core.vector import Vector +from ..dtypes import lookup_dtype + + +def from_scipy_sparse(A, *, dup_op=None, name=None): + """Create a Matrix from a scipy.sparse array or matrix. + + Input data in "csr" or "csc" format will be efficient when importing with SuiteSparse:GraphBLAS. + + Parameters + ---------- + A : scipy.sparse + Scipy sparse array or matrix + dup_op : BinaryOp, optional + Aggregation function for formats that allow duplicate entries (e.g. coo) + name : str, optional + Name of resulting Matrix + + Returns + ------- + :class:`~graphblas.Matrix` + + """ + nrows, ncols = A.shape + dtype = lookup_dtype(A.dtype) + if A.nnz == 0: + return Matrix(dtype, nrows=nrows, ncols=ncols, name=name) + if backend == "suitesparse" and A.format in {"csr", "csc"}: + data = A.data + is_iso = (data[[0]] == data).all() + if is_iso: + data = data[[0]] + if A.format == "csr": + return Matrix.ss.import_csr( + nrows=nrows, + ncols=ncols, + indptr=A.indptr, + col_indices=A.indices, + values=data, + is_iso=is_iso, + sorted_cols=getattr(A, "_has_sorted_indices", False), + name=name, + ) + return Matrix.ss.import_csc( + nrows=nrows, + ncols=ncols, + indptr=A.indptr, + row_indices=A.indices, + values=data, + is_iso=is_iso, + sorted_rows=getattr(A, "_has_sorted_indices", False), + name=name, + ) + if A.format == "csr": + return Matrix.from_csr(A.indptr, A.indices, A.data, ncols=ncols, name=name) + if A.format == "csc": + return Matrix.from_csc(A.indptr, A.indices, A.data, nrows=nrows, name=name) + if A.format != "coo": + A = A.tocoo() + return Matrix.from_coo( + A.row, A.col, A.data, nrows=nrows, ncols=ncols, dtype=dtype, dup_op=dup_op, name=name + ) + + +def to_scipy_sparse(A, format="csr"): + """Create a scipy.sparse array from a GraphBLAS Matrix or Vector. + + Parameters + ---------- + A : Matrix or Vector + GraphBLAS object to be converted + format : str + {'bsr', 'csr', 'csc', 'coo', 'lil', 'dia', 'dok'} + + Returns + ------- + scipy.sparse array + + """ + import scipy.sparse as ss + + format = format.lower() + if format not in {"bsr", "csr", "csc", "coo", "lil", "dia", "dok"}: + raise ValueError(f"Invalid format: {format}") + if output_type(A) is Vector: + indices, data = A.to_coo() + if format == "csc": + return ss.csc_array((data, indices, [0, len(data)]), shape=(A._size, 1)) + rv = ss.csr_array((data, indices, [0, len(data)]), shape=(1, A._size)) + if format == "csr": + return rv + elif backend == "suitesparse" and format in {"csr", "csc"}: + if A._is_transposed: + info = A.T.ss.export("csc" if format == "csr" else "csr", sort=True) + if "col_indices" in info: + info["row_indices"] = info["col_indices"] + else: + info["col_indices"] = info["row_indices"] + else: + info = A.ss.export(format, sort=True) + values = normalize_values(A, info["values"], None, (A._nvals,), info["is_iso"]) + if format == "csr": + return ss.csr_array((values, info["col_indices"], info["indptr"]), shape=A.shape) + return ss.csc_array((values, info["row_indices"], info["indptr"]), shape=A.shape) + elif format == "csr": + indptr, cols, vals = A.to_csr() + return ss.csr_array((vals, cols, indptr), shape=A.shape) + elif format == "csc": + indptr, rows, vals = A.to_csc() + return ss.csc_array((vals, rows, indptr), shape=A.shape) + else: + rows, cols, data = A.to_coo() + rv = ss.coo_array((data, (rows, cols)), shape=A.shape) + if format == "coo": + return rv + return rv.asformat(format) diff --git a/graphblas/io/_sparse.py b/graphblas/io/_sparse.py new file mode 100644 index 000000000..c0d4beabb --- /dev/null +++ b/graphblas/io/_sparse.py @@ -0,0 +1,100 @@ +from ..core.matrix import Matrix +from ..core.utils import output_type +from ..core.vector import Vector +from ..exceptions import GraphblasException +from ._scipy import from_scipy_sparse, to_scipy_sparse + + +def from_pydata_sparse(s, *, dup_op=None, name=None): + """Create a Vector or a Matrix from a pydata.sparse array or matrix. + + Input data in "gcxs" format will be efficient when importing with SuiteSparse:GraphBLAS. + + Parameters + ---------- + s : sparse + PyData sparse array or matrix (see https://sparse.pydata.org) + dup_op : BinaryOp, optional + Aggregation function for formats that allow duplicate entries (e.g. coo) + name : str, optional + Name of resulting Matrix + + Returns + ------- + :class:`~graphblas.Vector` + :class:`~graphblas.Matrix` + + """ + try: + import sparse + except ImportError: # pragma: no cover (import) + raise ImportError("sparse is required to import from pydata sparse") from None + if not isinstance(s, sparse.SparseArray): + raise TypeError( + "from_pydata_sparse only accepts objects from the `sparse` library; " + "see https://sparse.pydata.org" + ) + if s.ndim > 2: + raise GraphblasException("m.ndim must be <= 2") + + if s.ndim == 1: + # the .asformat('coo') makes it easier to convert dok/gcxs using a single approach + _s = s.asformat("coo") + return Vector.from_coo( + _s.coords, _s.data, dtype=_s.dtype, size=_s.shape[0], dup_op=dup_op, name=name + ) + # handle two-dimensional arrays + if isinstance(s, sparse.GCXS): + return from_scipy_sparse(s.to_scipy_sparse(), dup_op=dup_op, name=name) + if isinstance(s, (sparse.DOK, sparse.COO)): + _s = s.asformat("coo") + return Matrix.from_coo( + *_s.coords, + _s.data, + nrows=_s.shape[0], + ncols=_s.shape[1], + dtype=_s.dtype, + dup_op=dup_op, + name=name, + ) + raise ValueError(f"Unknown sparse array type: {type(s).__name__}") # pragma: no cover (safety) + + +def to_pydata_sparse(A, format="coo"): + """Create a pydata.sparse array from a GraphBLAS Matrix or Vector. + + Parameters + ---------- + A : Matrix or Vector + GraphBLAS object to be converted + format : str + {'coo', 'dok', 'gcxs'} + + Returns + ------- + sparse array (see https://sparse.pydata.org) + + """ + try: + from sparse import COO + except ImportError: # pragma: no cover (import) + raise ImportError("sparse is required to export to pydata sparse") from None + + format = format.lower() + if format not in {"coo", "dok", "gcxs"}: + raise ValueError(f"Invalid format: {format}") + + if output_type(A) is Vector: + indices, values = A.to_coo(sort=False) + s = COO(indices, values, shape=A.shape) + else: + if format == "gcxs": + B = to_scipy_sparse(A, format="csr") + else: + # obtain an intermediate conversion via hardcoded 'coo' intermediate object + B = to_scipy_sparse(A, format="coo") + # convert to pydata.sparse + s = COO.from_scipy_sparse(B) + + # express in the desired format + return s.asformat(format) diff --git a/graphblas/monoid/__init__.py b/graphblas/monoid/__init__.py index 007aba416..027fc0afe 100644 --- a/graphblas/monoid/__init__.py +++ b/graphblas/monoid/__init__.py @@ -4,19 +4,31 @@ def __dir__(): - return globals().keys() | _delayed.keys() + return globals().keys() | _delayed.keys() | {"ss"} def __getattr__(key): if key in _delayed: func, kwargs = _delayed.pop(key) - if type(kwargs["binaryop"]) is str: + if isinstance(kwargs["binaryop"], str): from ..binary import from_string kwargs["binaryop"] = from_string(kwargs["binaryop"]) rv = func(**kwargs) globals()[key] = rv return rv + if key == "ss": + from .. import backend + + if backend != "suitesparse": + raise AttributeError( + f'module {__name__!r} only has attribute "ss" when backend is "suitesparse"' + ) + from importlib import import_module + + ss = import_module(".ss", __name__) + globals()["ss"] = ss + return ss raise AttributeError(f"module {__name__!r} has no attribute {key!r}") diff --git a/graphblas/monoid/numpy.py b/graphblas/monoid/numpy.py index 475266d5c..b9ff2b502 100644 --- a/graphblas/monoid/numpy.py +++ b/graphblas/monoid/numpy.py @@ -5,15 +5,19 @@ https://numba.readthedocs.io/en/stable/reference/numpysupported.html#math-operations """ -import numba as _numba + import numpy as _np from .. import _STANDARD_OPERATOR_NAMES from .. import binary as _binary from .. import config as _config from .. import monoid as _monoid +from ..core import _has_numba, _supports_udfs from ..dtypes import _supports_complex +if _has_numba: + import numba as _numba + _delayed = {} _complex_dtypes = {"FC32", "FC64"} _float_dtypes = {"FP32", "FP64"} @@ -86,8 +90,8 @@ # To increase import speed, only call njit when `_config.get("mapnumpy")` is False if ( _config.get("mapnumpy") - or type(_numba.njit(lambda x, y: _np.fmax(x, y))(1, 2)) # pragma: no branch (numba) - is not float + or _has_numba + and not isinstance(_numba.njit(lambda x, y: _np.fmax(x, y))(1, 2), float) # pragma: no branch ): # Incorrect behavior was introduced in numba 0.56.2 and numpy 1.23 # See: https://github.com/numba/numba/issues/8478 @@ -140,15 +144,33 @@ # _graphblas_to_numpy = {val: key for key, val in _numpy_to_graphblas.items()} # Soon... # Not included: maximum, minimum, gcd, hypot, logaddexp, logaddexp2 +# True if ``monoid(x, x) == x`` for any x. +_idempotent = { + "bitwise_and", + "bitwise_or", + "fmax", + "fmin", + "gcd", + "logical_and", + "logical_or", + "maximum", + "minimum", +} + def __dir__(): - return globals().keys() | _delayed.keys() | _monoid_identities.keys() + if not _supports_udfs and not _config.get("mapnumpy"): + return globals().keys() # FLAKY COVERAGE + attrs = _delayed.keys() | _monoid_identities.keys() + if not _supports_udfs: + attrs &= _numpy_to_graphblas.keys() + return attrs | globals().keys() def __getattr__(name): if name in _delayed: func, kwargs = _delayed.pop(name) - if type(kwargs["binaryop"]) is str: + if isinstance(kwargs["binaryop"], str): from ..binary import from_string kwargs["binaryop"] = from_string(kwargs["binaryop"]) @@ -160,8 +182,8 @@ def __getattr__(name): if _config.get("mapnumpy") and name in _numpy_to_graphblas: globals()[name] = getattr(_monoid, _numpy_to_graphblas[name]) else: - from ..core import operator - func = getattr(_binary.numpy, name) - operator.Monoid.register_new(f"numpy.{name}", func, _monoid_identities[name]) + _monoid.register_new( + f"numpy.{name}", func, _monoid_identities[name], is_idempotent=name in _idempotent + ) return globals()[name] diff --git a/graphblas/monoid/ss.py b/graphblas/monoid/ss.py new file mode 100644 index 000000000..97852fc12 --- /dev/null +++ b/graphblas/monoid/ss.py @@ -0,0 +1,5 @@ +from ..core import operator + +_delayed = {} + +del operator diff --git a/graphblas/op/__init__.py b/graphblas/op/__init__.py index af05cbef4..1eb2b51d7 100644 --- a/graphblas/op/__init__.py +++ b/graphblas/op/__init__.py @@ -39,10 +39,18 @@ def __getattr__(key): ss = import_module(".ss", __name__) globals()["ss"] = ss return ss + if not _supports_udfs: + from .. import binary, semiring + + if key in binary._udfs or key in semiring._udfs: + raise AttributeError( + f"module {__name__!r} unable to compile UDF for {key!r}; " + "install numba for UDF support" + ) raise AttributeError(f"module {__name__!r} has no attribute {key!r}") -from ..core import operator # noqa: E402 isort:skip +from ..core import operator, _supports_udfs # noqa: E402 isort:skip from . import numpy # noqa: E402 isort:skip del operator diff --git a/graphblas/op/numpy.py b/graphblas/op/numpy.py index 497a6037c..cadba17eb 100644 --- a/graphblas/op/numpy.py +++ b/graphblas/op/numpy.py @@ -1,4 +1,5 @@ from ..binary import numpy as _np_binary +from ..core import _supports_udfs from ..semiring import numpy as _np_semiring from ..unary import numpy as _np_unary @@ -10,7 +11,10 @@ def __dir__(): - return globals().keys() | _delayed.keys() | _op_to_mod.keys() + attrs = _delayed.keys() | _op_to_mod.keys() + if not _supports_udfs: + attrs &= _np_unary.__dir__() | _np_binary.__dir__() | _np_semiring.__dir__() + return attrs | globals().keys() def __getattr__(name): diff --git a/graphblas/op/ss.py b/graphblas/op/ss.py index e45cbcda0..97852fc12 100644 --- a/graphblas/op/ss.py +++ b/graphblas/op/ss.py @@ -1,3 +1,5 @@ from ..core import operator +_delayed = {} + del operator diff --git a/graphblas/select/__init__.py b/graphblas/select/__init__.py index c7a1897f5..b55766ff8 100644 --- a/graphblas/select/__init__.py +++ b/graphblas/select/__init__.py @@ -8,7 +8,7 @@ def __dir__(): - return globals().keys() | _delayed.keys() + return globals().keys() | _delayed.keys() | {"ss"} def __getattr__(key): @@ -17,6 +17,18 @@ def __getattr__(key): rv = func(**kwargs) globals()[key] = rv return rv + if key == "ss": + from .. import backend + + if backend != "suitesparse": + raise AttributeError( + f'module {__name__!r} only has attribute "ss" when backend is "suitesparse"' + ) + from importlib import import_module + + ss = import_module(".ss", __name__) + globals()["ss"] = ss + return ss raise AttributeError(f"module {__name__!r} has no attribute {key!r}") @@ -57,9 +69,9 @@ def _resolve_expr(expr, callname, opname): def _match_expr(parent, expr): - """Match expressions to rewrite `A.select(A < 5)` into select expression. + """Match expressions to rewrite ``A.select(A < 5)`` into select expression. - The argument must match the parent, so this _won't_ be rewritten: `A.select(B < 5)` + The argument must match the parent, so this _won't_ be rewritten: ``A.select(B < 5)`` """ args = expr.args op = expr.op @@ -76,56 +88,49 @@ def _match_expr(parent, expr): def value(expr): - """ - An advanced select method which allows for easily expressing - value comparison logic. + """An advanced select method for easily expressing value comparison logic. Example usage: >>> gb.select.value(A > 0) - The example will dispatch to `gb.select.valuegt(A, 0)` + The example will dispatch to ``gb.select.valuegt(A, 0)`` while being nicer to read. """ return _resolve_expr(expr, "value", "value") def row(expr): - """ - An advanced select method which allows for easily expressing - Matrix row index comparison logic. + """An advanced select method for easily expressing Matrix row index comparison logic. Example usage: >>> gb.select.row(A <= 5) - The example will dispatch to `gb.select.rowle(A, 5)` + The example will dispatch to ``gb.select.rowle(A, 5)`` while being potentially nicer to read. """ return _resolve_expr(expr, "row", "row") def column(expr): - """ - An advanced select method which allows for easily expressing - Matrix column index comparison logic. + """An advanced select method for easily expressing Matrix column index comparison logic. Example usage: >>> gb.select.column(A <= 5) - The example will dispatch to `gb.select.colle(A, 5)` + The example will dispatch to ``gb.select.colle(A, 5)`` while being potentially nicer to read. """ return _resolve_expr(expr, "column", "col") def index(expr): - """ - An advanced select method which allows for easily expressing + """An advanced select method which allows for easily expressing Vector index comparison logic. Example usage: >>> gb.select.index(v <= 5) - The example will dispatch to `gb.select.indexle(v, 5)` + The example will dispatch to ``gb.select.indexle(v, 5)`` while being potentially nicer to read. """ return _resolve_expr(expr, "index", "index") diff --git a/graphblas/select/ss.py b/graphblas/select/ss.py new file mode 100644 index 000000000..173067382 --- /dev/null +++ b/graphblas/select/ss.py @@ -0,0 +1,6 @@ +from ..core import operator +from ..core.ss.select import register_new # noqa: F401 + +_delayed = {} + +del operator diff --git a/graphblas/semiring/__init__.py b/graphblas/semiring/__init__.py index 904ae192f..95a44261a 100644 --- a/graphblas/semiring/__init__.py +++ b/graphblas/semiring/__init__.py @@ -1,7 +1,29 @@ # All items are dynamically added by classes in operator.py # This module acts as a container of Semiring instances +from ..core import _supports_udfs + _delayed = {} _deprecated = {} +_udfs = { + # Used by aggregators + "max_absfirst", + "max_abssecond", + "plus_absfirst", + "plus_abssecond", + "plus_rpow", + # floordiv + "any_floordiv", + "max_floordiv", + "min_floordiv", + "plus_floordiv", + "times_floordiv", + # rfloordiv + "any_rfloordiv", + "max_rfloordiv", + "min_rfloordiv", + "plus_rfloordiv", + "times_rfloordiv", +} def __dir__(): @@ -24,11 +46,11 @@ def __getattr__(key): return rv if key in _delayed: func, kwargs = _delayed.pop(key) - if type(kwargs["binaryop"]) is str: + if isinstance(kwargs["binaryop"], str): from ..binary import from_string kwargs["binaryop"] = from_string(kwargs["binaryop"]) - if type(kwargs["monoid"]) is str: + if isinstance(kwargs["monoid"], str): from ..monoid import from_string kwargs["monoid"] = from_string(kwargs["monoid"]) @@ -47,6 +69,11 @@ def __getattr__(key): ss = import_module(".ss", __name__) globals()["ss"] = ss return ss + if not _supports_udfs and key in _udfs: + raise AttributeError( + f"module {__name__!r} unable to compile UDF for {key!r}; " + "install numba for UDF support" + ) raise AttributeError(f"module {__name__!r} has no attribute {key!r}") diff --git a/graphblas/semiring/numpy.py b/graphblas/semiring/numpy.py index 64169168a..10a680ea0 100644 --- a/graphblas/semiring/numpy.py +++ b/graphblas/semiring/numpy.py @@ -5,6 +5,7 @@ https://numba.readthedocs.io/en/stable/reference/numpysupported.html#math-operations """ + import itertools as _itertools from .. import _STANDARD_OPERATOR_NAMES @@ -12,6 +13,7 @@ from .. import config as _config from .. import monoid as _monoid from ..binary.numpy import _binary_names +from ..core import _supports_udfs from ..monoid.numpy import _fmin_is_float, _monoid_identities _delayed = {} @@ -132,19 +134,29 @@ def __dir__(): - return globals().keys() | _delayed.keys() | _semiring_names + if not _supports_udfs and not _config.get("mapnumpy"): + return globals().keys() # FLAKY COVERAGE + attrs = _delayed.keys() | _semiring_names + if not _supports_udfs: + attrs &= { + f"{monoid_name}_{binary_name}" + for monoid_name, binary_name in _itertools.product( + dir(_monoid.numpy), dir(_binary.numpy) + ) + } + return attrs | globals().keys() def __getattr__(name): - from ..core import operator + from ..core.operator import get_semiring if name in _delayed: func, kwargs = _delayed.pop(name) - if type(kwargs["binaryop"]) is str: + if isinstance(kwargs["binaryop"], str): from ..binary import from_string kwargs["binaryop"] = from_string(kwargs["binaryop"]) - if type(kwargs["monoid"]) is str: + if isinstance(kwargs["monoid"], str): from ..monoid import from_string kwargs["monoid"] = from_string(kwargs["monoid"]) @@ -161,7 +173,7 @@ def __getattr__(name): binary_name = "_".join(words[i:]) if hasattr(_binary.numpy, binary_name): # pragma: no branch break - operator.get_semiring( + get_semiring( getattr(_monoid.numpy, monoid_name), getattr(_binary.numpy, binary_name), name=f"numpy.{name}", diff --git a/graphblas/semiring/ss.py b/graphblas/semiring/ss.py index e45cbcda0..97852fc12 100644 --- a/graphblas/semiring/ss.py +++ b/graphblas/semiring/ss.py @@ -1,3 +1,5 @@ from ..core import operator +_delayed = {} + del operator diff --git a/graphblas/ss/__init__.py b/graphblas/ss/__init__.py index b36bc1bdc..1f059771b 100644 --- a/graphblas/ss/__init__.py +++ b/graphblas/ss/__init__.py @@ -1 +1,7 @@ -from ._core import about, concat, config, diag +from suitesparse_graphblas import burble + +from ._core import _IS_SSGB7, about, concat, config, diag + +if not _IS_SSGB7: + # Context was introduced in SuiteSparse:GraphBLAS 8.0 + from ..core.ss.context import Context, global_context diff --git a/graphblas/ss/_core.py b/graphblas/ss/_core.py index 441458a42..b42ea72b4 100644 --- a/graphblas/ss/_core.py +++ b/graphblas/ss/_core.py @@ -2,8 +2,10 @@ from ..core import ffi, lib from ..core.base import _expect_type +from ..core.descriptor import lookup as descriptor_lookup from ..core.matrix import Matrix, TransposedMatrix from ..core.scalar import _as_scalar +from ..core.ss import _IS_SSGB7 from ..core.ss.config import BaseConfig from ..core.ss.matrix import _concat_mn from ..core.vector import Vector @@ -12,7 +14,7 @@ class _graphblas_ss: - """Used in `_expect_type`.""" + """Used in ``_expect_type``.""" _graphblas_ss.__name__ = "graphblas.ss" @@ -20,8 +22,7 @@ class _graphblas_ss: def diag(x, k=0, dtype=None, *, name=None, **opts): - """ - GxB_Matrix_diag, GxB_Vector_diag. + """GxB_Matrix_diag, GxB_Vector_diag. Extract a diagonal Vector from a Matrix, or construct a diagonal Matrix from a Vector. Unlike ``Matrix.diag`` and ``Vector.diag``, this function @@ -33,8 +34,8 @@ def diag(x, k=0, dtype=None, *, name=None, **opts): The Vector to assign to the diagonal, or the Matrix from which to extract the diagonal. k : int, default 0 - Diagonal in question. Use `k>0` for diagonals above the main diagonal, - and `k<0` for diagonals below the main diagonal. + Diagonal in question. Use ``k>0`` for diagonals above the main diagonal, + and ``k<0`` for diagonals below the main diagonal. See Also -------- @@ -52,6 +53,9 @@ def diag(x, k=0, dtype=None, *, name=None, **opts): dtype = x.dtype typ = type(x) if typ is Vector: + if opts: + # Ignore opts for now + desc = descriptor_lookup(**opts) # noqa: F841 (keep desc in scope for context) size = x._size + abs(k.value) rv = Matrix(dtype, nrows=size, ncols=size, name=name) rv.ss.build_diag(x, k) @@ -66,14 +70,13 @@ def diag(x, k=0, dtype=None, *, name=None, **opts): def concat(tiles, dtype=None, *, name=None, **opts): - """ - GxB_Matrix_concat. + """GxB_Matrix_concat. Concatenate a 2D list of Matrix objects into a new Matrix, or a 1D list of Vector objects into a new Vector. To concatenate into existing objects, - use ``Matrix.ss.concat`` or `Vector.ss.concat`. + use ``Matrix.ss.concat`` or ``Vector.ss.concat``. - Vectors may be used as `Nx1` Matrix objects when creating a new Matrix. + Vectors may be used as ``Nx1`` Matrix objects when creating a new Matrix. This performs the opposite operation as ``split``. @@ -117,18 +120,65 @@ class GlobalConfig(BaseConfig): Threshold that determines when to switch to bitmap format nthreads : int Maximum number of OpenMP threads to use - memory_pool : List[int] + chunk : double + Control the number of threads used for small problems. + For example, ``nthreads = floor(work / chunk)``. burble : bool Enable diagnostic printing from SuiteSparse:GraphBLAS - print_1based: bool + print_1based : bool gpu_control : str, {"always", "never"} + Only available for SuiteSparse:GraphBLAS 7 + **GPU support is a work in progress--not recommended to use** gpu_chunk : double + Only available for SuiteSparse:GraphBLAS 7 + **GPU support is a work in progress--not recommended to use** + gpu_id : int + Which GPU to use; default is -1, which means do not run on the GPU. + Only available for SuiteSparse:GraphBLAS >=8 + **GPU support is a work in progress--not recommended to use** + jit_c_control : {"off", "pause", "run", "load", "on} + Control the CPU JIT: + "off" : do not use the JIT and free all JIT kernels if loaded + "pause" : do not run JIT kernels, but keep any loaded + "run" : run JIT kernels if already loaded, but don't load or compile + "load" : able to load and run JIT kernels; may not compile + "on" : full JIT: able to compile, load, and run + Only available for SuiteSparse:GraphBLAS >=8 + jit_use_cmake : bool + Whether to use cmake to compile the JIT kernels. + Only available for SuiteSparse:GraphBLAS >=8 + jit_c_compiler_name : str + C compiler for JIT kernels. + Only available for SuiteSparse:GraphBLAS >=8 + jit_c_compiler_flags : str + Flags for the C compiler. + Only available for SuiteSparse:GraphBLAS >=8 + jit_c_linker_flags : str + Link flags for the C compiler + Only available for SuiteSparse:GraphBLAS >=8 + jit_c_libraries : str + Libraries to link against. + Only available for SuiteSparse:GraphBLAS >=8 + jit_c_cmake_libs : str + Libraries to link against when cmake is used. + Only available for SuiteSparse:GraphBLAS >=8 + jit_c_preface : str + C code as preface to JIT kernels. + Only available for SuiteSparse:GraphBLAS >=8 + jit_error_log : str + Error log file. + Only available for SuiteSparse:GraphBLAS >=8 + jit_cache_path : str + The folder with the compiled kernels. + Only available for SuiteSparse:GraphBLAS >=8 Setting values to None restores the default value for most configurations. """ _get_function = "GxB_Global_Option_get" _set_function = "GxB_Global_Option_set" + if not _IS_SSGB7: + _context_keys = {"chunk", "gpu_id", "nthreads"} _null_valid = {"bitmap_switch"} _options = { # Matrix/Vector format @@ -139,14 +189,36 @@ class GlobalConfig(BaseConfig): "nthreads": (lib.GxB_GLOBAL_NTHREADS, "int"), "chunk": (lib.GxB_GLOBAL_CHUNK, "double"), # Memory pool control - "memory_pool": (lib.GxB_MEMORY_POOL, "int64_t[64]"), + # "memory_pool": (lib.GxB_MEMORY_POOL, "int64_t[64]"), # No longer used # Diagnostics (skipping "printf" and "flush" for now) "burble": (lib.GxB_BURBLE, "bool"), "print_1based": (lib.GxB_PRINT_1BASED, "bool"), - # CUDA GPU control - "gpu_control": (lib.GxB_GLOBAL_GPU_CONTROL, "GrB_Desc_Value"), - "gpu_chunk": (lib.GxB_GLOBAL_GPU_CHUNK, "double"), } + if _IS_SSGB7: + _options.update( + { + "gpu_control": (lib.GxB_GLOBAL_GPU_CONTROL, "GrB_Desc_Value"), + "gpu_chunk": (lib.GxB_GLOBAL_GPU_CHUNK, "double"), + } + ) + else: + _options.update( + { + # JIT control + "jit_c_control": (lib.GxB_JIT_C_CONTROL, "int"), + "jit_use_cmake": (lib.GxB_JIT_USE_CMAKE, "bool"), + "jit_c_compiler_name": (lib.GxB_JIT_C_COMPILER_NAME, "char*"), + "jit_c_compiler_flags": (lib.GxB_JIT_C_COMPILER_FLAGS, "char*"), + "jit_c_linker_flags": (lib.GxB_JIT_C_LINKER_FLAGS, "char*"), + "jit_c_libraries": (lib.GxB_JIT_C_LIBRARIES, "char*"), + "jit_c_cmake_libs": (lib.GxB_JIT_C_CMAKE_LIBS, "char*"), + "jit_c_preface": (lib.GxB_JIT_C_PREFACE, "char*"), + "jit_error_log": (lib.GxB_JIT_ERROR_LOG, "char*"), + "jit_cache_path": (lib.GxB_JIT_CACHE_PATH, "char*"), + # CUDA GPU control + "gpu_id": (lib.GxB_GLOBAL_GPU_ID, "int"), + } + ) # Values to restore defaults _defaults = { "hyper_switch": lib.GxB_HYPER_DEFAULT, @@ -157,17 +229,28 @@ class GlobalConfig(BaseConfig): "burble": 0, "print_1based": 0, } + if not _IS_SSGB7: + _defaults["gpu_id"] = -1 # -1 means no GPU _enumerations = { "format": { "by_row": lib.GxB_BY_ROW, "by_col": lib.GxB_BY_COL, # "no_format": lib.GxB_NO_FORMAT, # Used by iterators; not valid here }, - "gpu_control": { + } + if _IS_SSGB7: + _enumerations["gpu_control"] = { "always": lib.GxB_GPU_ALWAYS, "never": lib.GxB_GPU_NEVER, - }, - } + } + else: + _enumerations["jit_c_control"] = { + "off": lib.GxB_JIT_OFF, + "pause": lib.GxB_JIT_PAUSE, + "run": lib.GxB_JIT_RUN, + "load": lib.GxB_JIT_LOAD, + "on": lib.GxB_JIT_ON, + } class About(Mapping): @@ -254,4 +337,10 @@ def __len__(self): about = About() -config = GlobalConfig() +if _IS_SSGB7: + config = GlobalConfig() +else: + # Context was introduced in SuiteSparse:GraphBLAS 8.0 + from ..core.ss.context import global_context + + config = GlobalConfig(context=global_context) diff --git a/graphblas/tests/conftest.py b/graphblas/tests/conftest.py index 24aba085f..964325e0d 100644 --- a/graphblas/tests/conftest.py +++ b/graphblas/tests/conftest.py @@ -1,39 +1,50 @@ import atexit +import contextlib import functools import itertools +import platform +import sys from pathlib import Path import numpy as np import pytest import graphblas as gb +from graphblas.core import _supports_udfs as supports_udfs orig_binaryops = set() orig_semirings = set() +pypy = platform.python_implementation() == "PyPy" + def pytest_configure(config): rng = np.random.default_rng() - randomly = config.getoption("--randomly", False) + randomly = config.getoption("--randomly", None) + if randomly is None: # pragma: no cover + options_unavailable = True + randomly = True + config.addinivalue_line("markers", "slow: Skipped unless --runslow passed") + else: + options_unavailable = False backend = config.getoption("--backend", None) if backend is None: if randomly: backend = "suitesparse" if rng.random() < 0.5 else "suitesparse-vanilla" else: backend = "suitesparse" - blocking = config.getoption("--blocking", True) + blocking = config.getoption("--blocking", None) if blocking is None: # pragma: no branch blocking = rng.random() < 0.5 if randomly else True record = config.getoption("--record", False) if record is None: # pragma: no branch record = rng.random() < 0.5 if randomly else False - mapnumpy = config.getoption("--mapnumpy", False) + mapnumpy = config.getoption("--mapnumpy", None) if mapnumpy is None: mapnumpy = rng.random() < 0.5 if randomly else False - runslow = config.getoption("--runslow", False) + runslow = config.getoption("--runslow", None) if runslow is None: - # Add a small amount of randomization to be safer - runslow = rng.random() < 0.05 if randomly else False + runslow = options_unavailable config.runslow = runslow gb.config.set(autocompute=False, mapnumpy=mapnumpy) @@ -48,7 +59,7 @@ def pytest_configure(config): rec.start() def save_records(): - with Path("record.txt").open("w") as f: # pragma: no cover + with Path("record.txt").open("w") as f: # pragma: no cover (???) f.write("\n".join(rec.data)) # I'm sure there's a `pytest` way to do this... @@ -58,9 +69,11 @@ def save_records(): for key in dir(gb.semiring) if key != "ss" and isinstance( - getattr(gb.semiring, key) - if key not in gb.semiring._deprecated - else gb.semiring._deprecated[key], + ( + getattr(gb.semiring, key) + if key not in gb.semiring._deprecated + else gb.semiring._deprecated[key] + ), (gb.core.operator.Semiring, gb.core.operator.ParameterizedSemiring), ) ) @@ -69,9 +82,11 @@ def save_records(): for key in dir(gb.binary) if key != "ss" and isinstance( - getattr(gb.binary, key) - if key not in gb.binary._deprecated - else gb.binary._deprecated[key], + ( + getattr(gb.binary, key) + if key not in gb.binary._deprecated + else gb.binary._deprecated[key] + ), (gb.core.operator.BinaryOp, gb.core.operator.ParameterizedBinaryOp), ) ) @@ -105,6 +120,27 @@ def ic(): # pragma: no cover (debug) return icecream.ic +@contextlib.contextmanager +def burble(): # pragma: no cover (debug) + """Show the burble diagnostics within a context.""" + if gb.backend != "suitesparse": + yield + return + prev = gb.ss.config["burble"] + gb.ss.config["burble"] = True + try: + yield + finally: + gb.ss.config["burble"] = prev + + +@pytest.fixture(scope="session") +def burble_all(): # pragma: no cover (debug) + """Show the burble diagnostics for the entire test.""" + with burble(): + yield burble + + def autocompute(func): @functools.wraps(func) def inner(*args, **kwargs): @@ -116,3 +152,15 @@ def inner(*args, **kwargs): def compute(x): return x + + +def shouldhave(module, opname): + """Whether an "operator" module should have the given operator.""" + return supports_udfs or hasattr(module, opname) + + +def dprint(*args, **kwargs): # pragma: no cover (debug) + """Print to stderr for debugging purposes.""" + kwargs["file"] = sys.stderr + kwargs["flush"] = True + print(*args, **kwargs) diff --git a/graphblas/tests/pickle1-vanilla.pkl b/graphblas/tests/pickle1-vanilla.pkl index 36ea20760..a494e405a 100644 Binary files a/graphblas/tests/pickle1-vanilla.pkl and b/graphblas/tests/pickle1-vanilla.pkl differ diff --git a/graphblas/tests/pickle1.pkl b/graphblas/tests/pickle1.pkl index 98a1fdf05..273b49901 100644 Binary files a/graphblas/tests/pickle1.pkl and b/graphblas/tests/pickle1.pkl differ diff --git a/graphblas/tests/pickle2-vanilla.pkl b/graphblas/tests/pickle2-vanilla.pkl index 3c6e18ba4..dd091c823 100644 Binary files a/graphblas/tests/pickle2-vanilla.pkl and b/graphblas/tests/pickle2-vanilla.pkl differ diff --git a/graphblas/tests/pickle2.pkl b/graphblas/tests/pickle2.pkl index 3c6e18ba4..dd091c823 100644 Binary files a/graphblas/tests/pickle2.pkl and b/graphblas/tests/pickle2.pkl differ diff --git a/graphblas/tests/pickle3-vanilla.pkl b/graphblas/tests/pickle3-vanilla.pkl index 29e79d7db..7f8408c95 100644 Binary files a/graphblas/tests/pickle3-vanilla.pkl and b/graphblas/tests/pickle3-vanilla.pkl differ diff --git a/graphblas/tests/pickle3.pkl b/graphblas/tests/pickle3.pkl index d04a53cb9..28b308452 100644 Binary files a/graphblas/tests/pickle3.pkl and b/graphblas/tests/pickle3.pkl differ diff --git a/graphblas/tests/test_core.py b/graphblas/tests/test_core.py index c08ca416f..3586eb4a8 100644 --- a/graphblas/tests/test_core.py +++ b/graphblas/tests/test_core.py @@ -1,7 +1,18 @@ +import pathlib + import pytest import graphblas as gb +try: + import setuptools +except ImportError: # pragma: no cover (import) + setuptools = None +try: + import tomli +except ImportError: # pragma: no cover (import) + tomli = None + def test_import_special_attrs(): not_hidden = {x for x in dir(gb) if not x.startswith("__")} @@ -57,3 +68,29 @@ def test_version(): from packaging.version import parse assert parse(gb.__version__) > parse("2022.11.0") + + +@pytest.mark.skipif("not setuptools or not tomli or not gb.__file__") +def test_packages(): + """Ensure all packages are declared in pyproject.toml.""" + # Currently assume s`pyproject.toml` is at the same level as `graphblas` folder. + # This probably isn't always True, and we can probably do a better job of finding it. + path = pathlib.Path(gb.__file__).parent + pkgs = [f"graphblas.{x}" for x in setuptools.find_packages(str(path))] + pkgs.append("graphblas") + pkgs.sort() + pyproject = path.parent / "pyproject.toml" + if not pyproject.exists(): # pragma: no cover (safety) + pytest.skip("Did not find pyproject.toml") + with pyproject.open("rb") as f: + cfg = tomli.load(f) + if cfg.get("project", {}).get("name") != "python-graphblas": # pragma: no cover (safety) + pytest.skip("Did not find correct pyproject.toml") + pkgs2 = sorted(cfg["tool"]["setuptools"]["packages"]) + assert ( + pkgs == pkgs2 + ), "If there are extra items on the left, add them to pyproject.toml:tool.setuptools.packages" + + +def test_index_max(): + assert gb.MAX_SIZE == 2**60 # True for all current backends diff --git a/graphblas/tests/test_descriptor.py b/graphblas/tests/test_descriptor.py index 9209a8055..6ec9df36a 100644 --- a/graphblas/tests/test_descriptor.py +++ b/graphblas/tests/test_descriptor.py @@ -2,8 +2,7 @@ def test_caching(): - """ - Test that building a descriptor is actually caching rather than building + """Test that building a descriptor is actually caching rather than building a new object for each call. """ tocr = descriptor.lookup( diff --git a/graphblas/tests/test_dtype.py b/graphblas/tests/test_dtype.py index 64e6d69ab..ecbca707f 100644 --- a/graphblas/tests/test_dtype.py +++ b/graphblas/tests/test_dtype.py @@ -7,8 +7,9 @@ import pytest import graphblas as gb -from graphblas import dtypes +from graphblas import core, dtypes from graphblas.core import lib +from graphblas.core.utils import _NP2 from graphblas.dtypes import lookup_dtype suitesparse = gb.backend == "suitesparse" @@ -123,7 +124,7 @@ def test_dtype_bad_comparison(): def test_dtypes_match_numpy(): - for key, val in dtypes._registry.items(): + for key, val in core.dtypes._registry.items(): try: if key is int or (isinstance(key, str) and key == "int"): # For win64, numpy treats int as int32, not int64 @@ -137,7 +138,7 @@ def test_dtypes_match_numpy(): def test_pickle(): - for val in dtypes._registry.values(): + for val in core.dtypes._registry.values(): s = pickle.dumps(val) val2 = pickle.loads(s) if val._is_udt: # pragma: no cover @@ -205,7 +206,7 @@ def test_auto_register(): def test_default_names(): - from graphblas.dtypes import _default_name + from graphblas.core.dtypes import _default_name assert _default_name(np.dtype([("x", np.int32), ("y", np.float64)], align=True)) == ( "{'x': INT32, 'y': FP64}" @@ -224,15 +225,22 @@ def test_record_dtype_from_dict(): def test_dtype_to_from_string(): types = [dtypes.BOOL, dtypes.FP64] for c in string.ascii_letters: + if c == "T": + # See NEP 55 about StringDtype "T". Notably, this doesn't work: + # >>> np.dtype(np.dtype("T").str) + continue + if _NP2 and c == "a": + # Data type alias 'a' was deprecated in NumPy 2.0. Use the 'S' alias instead. + continue try: dtype = np.dtype(c) types.append(dtype) except Exception: pass for dtype in types: - s = dtypes._dtype_to_string(dtype) + s = core.dtypes._dtype_to_string(dtype) try: - dtype2 = dtypes._string_to_dtype(s) + dtype2 = core.dtypes._string_to_dtype(s) except Exception: with pytest.raises(ValueError, match="Unknown dtype"): lookup_dtype(dtype) @@ -241,7 +249,7 @@ def test_dtype_to_from_string(): def test_has_complex(): - """Only SuiteSparse has complex (with Windows support in Python after v7.4.3.1)""" + """Only SuiteSparse has complex (with Windows support in Python after v7.4.3.1).""" if not suitesparse: assert not dtypes._supports_complex return @@ -252,7 +260,21 @@ def test_has_complex(): import suitesparse_graphblas as ssgb from packaging.version import parse - if parse(ssgb.__version__) < parse("7.4.3.1"): - assert not dtypes._supports_complex + assert dtypes._supports_complex == (parse(ssgb.__version__) >= parse("7.4.3.1")) + + +def test_has_ss_attribute(): + if suitesparse: + assert dtypes.ss is not None else: - assert dtypes._supports_complex + with pytest.raises(AttributeError): + dtypes.ss + + +def test_dir(): + must_have = {"DataType", "lookup_dtype", "register_anonymous", "register_new", "ss", "unify"} + must_have.update({"FP32", "FP64", "INT8", "INT16", "INT32", "INT64"}) + must_have.update({"BOOL", "UINT8", "UINT16", "UINT32", "UINT64"}) + if dtypes._supports_complex: + must_have.update({"FC32", "FC64"}) + assert set(dir(dtypes)) & must_have == must_have diff --git a/graphblas/tests/test_formatting.py b/graphblas/tests/test_formatting.py index 3094aea91..faadc983b 100644 --- a/graphblas/tests/test_formatting.py +++ b/graphblas/tests/test_formatting.py @@ -40,9 +40,8 @@ def _printer(text, name, repr_name, indent): # line = f"f'{{CSS_STYLE}}'" in_style = False is_style = True - else: # pragma: no cover (???) - # This definitely gets covered, but why is it not picked up? - continue + else: + continue # FLAKY COVERAGE if repr_name == "repr_html" and line.startswith("\n" + '\n' + " \n" + ' \n' + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + " \n" + "
01
2
\n" + "
" + ) diff --git a/graphblas/tests/test_infix.py b/graphblas/tests/test_infix.py index f496ade15..601f282a7 100644 --- a/graphblas/tests/test_infix.py +++ b/graphblas/tests/test_infix.py @@ -1,6 +1,6 @@ import pytest -from graphblas import monoid, op +from graphblas import binary, monoid, op from graphblas.exceptions import DimensionMismatch from .conftest import autocompute @@ -342,3 +342,440 @@ def test_inplace_infix(s1, v1, v2, A1, A2): expr @= A with pytest.raises(TypeError, match="not supported"): s1 @= v1 + + +@autocompute +def test_infix_expr_value_types(): + """Test bug where `infix_expr._value` was used as MatrixExpression or Matrix.""" + from graphblas.core.matrix import MatrixExpression + + A = Matrix(int, 3, 3) + A << 1 + expr = A @ A.T + assert expr._expr is None + assert expr._value is None + assert type(expr._get_value()) is Matrix + assert type(expr._expr) is MatrixExpression + assert type(expr.new()) is Matrix + assert expr._expr is not None + assert expr._value is None + assert type(expr.new()) is Matrix + assert type(expr._get_value()) is Matrix + assert expr._expr is not None + assert expr._value is not None + assert expr._expr._value is not None + expr._value = None + assert expr._value is None + assert expr._expr._value is None + + +def test_multi_infix_vector(): + D0 = Vector.from_scalar(0, 3).diag() + v1 = Vector.from_coo([0, 1], [1, 2], size=3) # 1 2 . + v2 = Vector.from_coo([1, 2], [1, 2], size=3) # . 1 2 + v3 = Vector.from_coo([2, 0], [1, 2], size=3) # 2 . 1 + # ewise_add + result = binary.plus((v1 | v2) | v3).new() + expected = Vector.from_scalar(3, size=3) + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3)).new() + assert result.isequal(expected) + result = monoid.min(v1 | v2 | v3).new() + expected = Vector.from_scalar(1, size=3) + assert result.isequal(expected) + # ewise_mult + result = monoid.max((v1 & v2) & v3).new() + expected = Vector(int, size=3) + assert result.isequal(expected) + result = monoid.max(v1 & (v2 & v3)).new() + assert result.isequal(expected) + result = monoid.min((v1 & v2) & v1).new() + expected = Vector.from_coo([1], [1], size=3) + assert result.isequal(expected) + # ewise_union + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10).new() + expected = Vector.from_scalar(13, size=3) + assert result.isequal(expected) + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10.0).new() + expected = Vector.from_scalar(13.0, size=3) + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + # inner + assert op.plus_plus(v1 @ v1).new().value == 6 + assert op.plus_plus(v1 @ (v1 @ D0)).new().value == 6 + assert op.plus_plus((D0 @ v1) @ v1).new().value == 6 + # matrix-vector ewise_add + result = binary.plus((D0 | v1) | v2).new() + expected = binary.plus(binary.plus(D0 | v1).new() | v2).new() + assert result.isequal(expected) + result = binary.plus(D0 | (v1 | v2)).new() + assert result.isequal(expected) + result = binary.plus((v1 | v2) | D0).new() + assert result.isequal(expected.T) + result = binary.plus(v1 | (v2 | D0)).new() + assert result.isequal(expected.T) + # matrix-vector ewise_mult + result = binary.plus((D0 & v1) & v2).new() + expected = binary.plus(binary.plus(D0 & v1).new() & v2).new() + assert result.isequal(expected) + assert result.nvals > 0 + result = binary.plus(D0 & (v1 & v2)).new() + assert result.isequal(expected) + result = binary.plus((v1 & v2) & D0).new() + assert result.isequal(expected.T) + result = binary.plus(v1 & (v2 & D0)).new() + assert result.isequal(expected.T) + # matrix-vector ewise_union + kwargs = {"left_default": 10, "right_default": 20} + result = binary.plus((D0 | v1) | v2, **kwargs).new() + expected = binary.plus(binary.plus(D0 | v1, **kwargs).new() | v2, **kwargs).new() + assert result.isequal(expected) + result = binary.plus(D0 | (v1 | v2), **kwargs).new() + expected = binary.plus(D0 | binary.plus(v1 | v2, **kwargs).new(), **kwargs).new() + assert result.isequal(expected) + result = binary.plus((v1 | v2) | D0, **kwargs).new() + expected = binary.plus(binary.plus(v1 | v2, **kwargs).new() | D0, **kwargs).new() + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | D0), **kwargs).new() + expected = binary.plus(v1 | binary.plus(v2 | D0, **kwargs).new(), **kwargs).new() + assert result.isequal(expected) + # vxm, mxv + result = op.plus_plus((D0 @ v1) @ D0).new() + assert result.isequal(v1) + result = op.plus_plus(D0 @ (v1 @ D0)).new() + assert result.isequal(v1) + result = op.plus_plus(v1 @ (D0 @ D0)).new() + assert result.isequal(v1) + result = op.plus_plus((D0 @ D0) @ v1).new() + assert result.isequal(v1) + result = op.plus_plus((v1 @ D0) @ D0).new() + assert result.isequal(v1) + result = op.plus_plus(D0 @ (D0 @ v1)).new() + assert result.isequal(v1) + + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).__ror__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1 | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__ror__(v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) | (v2 & v3) + + with pytest.raises(TypeError, match="XXX"): # TODO + v1 & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__rand__(v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2).__rand__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 & v3) + + # We differentiate between infix and methods + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 | v2).ewise_mult(v3) + + +@autocompute +def test_multi_infix_vector_auto(): + v1 = Vector.from_coo([0, 1], [1, 2], size=3) # 1 2 . + v2 = Vector.from_coo([1, 2], [1, 2], size=3) # . 1 2 + v3 = Vector.from_coo([2, 0], [1, 2], size=3) # 2 . 1 + # We differentiate between infix and methods + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 | v2).ewise_mult(v3) + + +def test_multi_infix_matrix(): + # Adapted from test_multi_infix_vector + D0 = Vector.from_scalar(0, 3).diag() + v1 = Matrix.from_coo([0, 1], [0, 0], [1, 2], nrows=3) # 1 2 . + v2 = Matrix.from_coo([1, 2], [0, 0], [1, 2], nrows=3) # . 1 2 + v3 = Matrix.from_coo([2, 0], [0, 0], [1, 2], nrows=3) # 2 . 1 + # ewise_add + result = binary.plus((v1 | v2) | v3).new() + expected = Matrix.from_scalar(3, 3, 1) + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3)).new() + assert result.isequal(expected) + result = monoid.min(v1 | v2 | v3).new() + expected = Matrix.from_scalar(1, 3, 1) + assert result.isequal(expected) + result = binary.plus(v1 | v1 | v1 | v1 | v1).new() + expected = (5 * v1).new() + assert result.isequal(expected) + # ewise_mult + result = monoid.max((v1 & v2) & v3).new() + expected = Matrix(int, 3, 1) + assert result.isequal(expected) + result = monoid.max(v1 & (v2 & v3)).new() + assert result.isequal(expected) + result = monoid.min((v1 & v2) & v1).new() + expected = Matrix.from_coo([1], [0], [1], nrows=3) + assert result.isequal(expected) + result = binary.plus(v1 & v1 & v1 & v1 & v1).new() + expected = (5 * v1).new() + assert result.isequal(expected) + # ewise_union + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10).new() + expected = Matrix.from_scalar(13, 3, 1) + assert result.isequal(expected) + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10.0).new() + expected = Matrix.from_scalar(13.0, 3, 1) + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + # mxm + assert op.plus_plus(v1.T @ v1).new()[0, 0].new().value == 6 + assert op.plus_plus(v1 @ (v1.T @ D0)).new()[0, 0].new().value == 2 + assert op.plus_plus((v1.T @ D0) @ v1).new()[0, 0].new().value == 6 + assert op.plus_plus(D0 @ D0 @ D0 @ D0 @ D0).new().isequal(D0) + + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).__ror__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1 | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__ror__(v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) | (v2 & v3) + + with pytest.raises(TypeError, match="XXX"): # TODO + v1 & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__rand__(v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2).__rand__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 & v3) + + # We differentiate between infix and methods + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 | v2).ewise_mult(v3) + + +@autocompute +def test_multi_infix_matrix_auto(): + v1 = Matrix.from_coo([0, 1], [0, 0], [1, 2], nrows=3) # 1 2 . + v2 = Matrix.from_coo([1, 2], [0, 0], [1, 2], nrows=3) # . 1 2 + v3 = Matrix.from_coo([2, 0], [0, 0], [1, 2], nrows=3) # 2 . 1 + # We differentiate between infix and methods + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 | v2).ewise_mult(v3) + + +def test_multi_infix_scalar(): + # Adapted from test_multi_infix_vector + v1 = Scalar.from_value(1) + v2 = Scalar.from_value(2) + v3 = Scalar(int) + # ewise_add + result = binary.plus((v1 | v2) | v3).new() + expected = 3 + assert result.isequal(expected) + result = binary.plus((1 | v2) | v3).new() + assert result.isequal(expected) + result = binary.plus((1 | v2) | 0).new() + assert result.isequal(expected) + result = binary.plus((v1 | 2) | v3).new() + assert result.isequal(expected) + result = binary.plus((v1 | 2) | 0).new() + assert result.isequal(expected) + result = binary.plus((v1 | v2) | 0).new() + assert result.isequal(expected) + + result = binary.plus(v1 | (v2 | v3)).new() + assert result.isequal(expected) + result = binary.plus(1 | (v2 | v3)).new() + assert result.isequal(expected) + result = binary.plus(1 | (2 | v3)).new() + assert result.isequal(expected) + result = binary.plus(1 | (v2 | 0)).new() + assert result.isequal(expected) + result = binary.plus(v1 | (2 | v3)).new() + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | 0)).new() + assert result.isequal(expected) + + result = monoid.min(v1 | v2 | v3).new() + expected = 1 + assert result.isequal(expected) + # ewise_mult + result = monoid.max((v1 & v2) & v3).new() + expected = None + assert result.isequal(expected) + result = monoid.max(v1 & (v2 & v3)).new() + assert result.isequal(expected) + result = monoid.min((v1 & v2) & v1).new() + expected = 1 + assert result.isequal(expected) + + result = monoid.min((1 & v2) & v1).new() + assert result.isequal(expected) + result = monoid.min((1 & v2) & 1).new() + assert result.isequal(expected) + result = monoid.min((v1 & 2) & v1).new() + assert result.isequal(expected) + result = monoid.min((v1 & 2) & 1).new() + assert result.isequal(expected) + result = monoid.min((v1 & v2) & 1).new() + assert result.isequal(expected) + + result = monoid.min(1 & (v2 & v1)).new() + assert result.isequal(expected) + result = monoid.min(1 & (2 & v1)).new() + assert result.isequal(expected) + result = monoid.min(1 & (v2 & 1)).new() + assert result.isequal(expected) + result = monoid.min(v1 & (2 & v1)).new() + assert result.isequal(expected) + result = monoid.min(v1 & (v2 & 1)).new() + assert result.isequal(expected) + + # ewise_union + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10).new() + expected = 13 + assert result.isequal(expected) + result = binary.plus((1 | v2) | v3, left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus((v1 | 2) | v3, left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus((v1 | v2) | v3, left_default=10, right_default=10.0).new() + assert result.isequal(expected) + result = binary.plus(v1 | (v2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus(1 | (v2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus(1 | (2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + result = binary.plus(v1 | (2 | v3), left_default=10, right_default=10).new() + assert result.isequal(expected) + + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2).__ror__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) | (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1 | (v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__ror__(v2 & v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) | (v2 & v3) + + with pytest.raises(TypeError, match="XXX"): # TODO + v1 & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + v1.__rand__(v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 & v2) & (v2 | v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & v3 + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2).__rand__(v3) + with pytest.raises(TypeError, match="XXX"): # TODO + (v1 | v2) & (v2 & v3) + + # We differentiate between infix and methods + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="to automatically compute"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="Automatic computation"): + (v1 | v2).ewise_mult(v3) + + +@autocompute +def test_multi_infix_scalar_auto(): + v1 = Scalar.from_value(1) + v2 = Scalar.from_value(2) + v3 = Scalar(int) + # We differentiate between infix and methods + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_add(v2 & v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_add(v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_union(v2 & v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 & v2).ewise_union(v3, binary.plus, left_default=1, right_default=1) + with pytest.raises(TypeError, match="only valid for BOOL"): + v1.ewise_mult(v2 | v3) + with pytest.raises(TypeError, match="only valid for BOOL"): + (v1 | v2).ewise_mult(v3) diff --git a/graphblas/tests/test_io.py b/graphblas/tests/test_io.py index 6fa43ebbc..7e786f0da 100644 --- a/graphblas/tests/test_io.py +++ b/graphblas/tests/test_io.py @@ -30,21 +30,14 @@ except ImportError: # pragma: no cover (import) ak = None +try: + import fast_matrix_market as fmm +except ImportError: # pragma: no cover (import) + fmm = None suitesparse = gb.backend == "suitesparse" -@pytest.mark.skipif("not ss") -def test_deprecated(): - a = np.array([0.0, 2.0, 4.1]) - with pytest.warns(DeprecationWarning): - v = gb.io.from_numpy(a) - assert v.isequal(gb.Vector.from_coo([1, 2], [2.0, 4.1]), check_dtype=True) - with pytest.warns(DeprecationWarning): - a2 = gb.io.to_numpy(v) - np.testing.assert_array_equal(a, a2) - - @pytest.mark.skipif("not ss") def test_vector_to_from_numpy(): a = np.array([0.0, 2.0, 4.1]) @@ -55,18 +48,24 @@ def test_vector_to_from_numpy(): csr = gb.io.to_scipy_sparse(v, "csr") assert csr.nnz == 2 - assert ss.isspmatrix_csr(csr) + # 2023-06-25: scipy 1.11.0 added `sparray` and changed e.g. `ss.isspmatrix_csr` + assert isinstance(csr, getattr(ss, "sparray", ss.spmatrix)) + assert csr.format == "csr" np.testing.assert_array_equal(csr.toarray(), np.array([[0.0, 2.0, 4.1]])) csc = gb.io.to_scipy_sparse(v, "csc") assert csc.nnz == 2 - assert ss.isspmatrix_csc(csc) + # 2023-06-25: scipy 1.11.0 added `sparray` and changed e.g. `ss.isspmatrix_csc` + assert isinstance(csc, getattr(ss, "sparray", ss.spmatrix)) + assert csc.format == "csc" np.testing.assert_array_equal(csc.toarray(), np.array([[0.0, 2.0, 4.1]]).T) # default to csr-like coo = gb.io.to_scipy_sparse(v, "coo") assert coo.shape == csr.shape - assert ss.isspmatrix_coo(coo) + # 2023-06-25: scipy 1.11.0 added `sparray` and changed e.g. `ss.isspmatrix_coo` + assert isinstance(coo, getattr(ss, "sparray", ss.spmatrix)) + assert coo.format == "coo" assert coo.nnz == 2 np.testing.assert_array_equal(coo.toarray(), np.array([[0.0, 2.0, 4.1]])) @@ -95,7 +94,9 @@ def test_matrix_to_from_numpy(): for format in ["csr", "csc", "coo"]: sparse = gb.io.to_scipy_sparse(M, format) - assert getattr(ss, f"isspmatrix_{format}")(sparse) + # 2023-06-25: scipy 1.11.0 added `sparray` and changed e.g. `ss.isspmatrix_csr` + assert isinstance(sparse, getattr(ss, "sparray", ss.spmatrix)) + assert sparse.format == format assert sparse.nnz == 3 np.testing.assert_array_equal(sparse.toarray(), a) M2 = gb.io.from_scipy_sparse(sparse) @@ -145,7 +146,7 @@ def test_matrix_to_from_networkx(): M = gb.io.from_networkx(G, nodelist=range(7)) if suitesparse: assert M.ss.is_iso - rows, cols = zip(*edges) + rows, cols = zip(*edges, strict=True) expected = gb.Matrix.from_coo(rows, cols, 1) assert expected.isequal(M) # Test empty @@ -159,8 +160,15 @@ def test_matrix_to_from_networkx(): @pytest.mark.skipif("not ss") -def test_mmread_mmwrite(): - from scipy.io.tests import test_mmio +@pytest.mark.parametrize("engine", ["auto", "scipy", "fmm"]) +def test_mmread_mmwrite(engine): + if engine == "fmm" and fmm is None: # pragma: no cover (import) + pytest.skip("needs fast_matrix_market") + try: + from scipy.io.tests import test_mmio + except ImportError: + # Test files are mysteriously missing from some conda-forge builds + pytest.skip("scipy.io.tests.test_mmio unavailable :(") p31 = 2**31 p63 = 2**63 @@ -256,10 +264,19 @@ def test_mmread_mmwrite(): continue mm_in = StringIO(getattr(test_mmio, example)) if over64: - with pytest.raises(OverflowError): - M = gb.io.mmread(mm_in) + with pytest.raises((OverflowError, ValueError)): + # fast_matrix_market v1.4.5 raises ValueError instead of OverflowError + M = gb.io.mmread(mm_in, engine) else: - M = gb.io.mmread(mm_in) + if ( + example == "_empty_lines_example" + and engine in {"fmm", "auto"} + and fmm is not None + and fmm.__version__ in {"1.4.5"} + ): + # `fast_matrix_market` __version__ v1.4.5 does not handle this, but v1.5.0 does + continue + M = gb.io.mmread(mm_in, engine) if not M.isequal(expected): # pragma: no cover (debug) print(example) print("Expected:") @@ -268,12 +285,12 @@ def test_mmread_mmwrite(): print(M) raise AssertionError("Matrix M not as expected. See print output above") mm_out = BytesIO() - gb.io.mmwrite(mm_out, M) + gb.io.mmwrite(mm_out, M, engine) mm_out.flush() mm_out.seek(0) mm_out_str = b"".join(mm_out.readlines()).decode() mm_out.seek(0) - M2 = gb.io.mmread(mm_out) + M2 = gb.io.mmread(mm_out, engine) if not M2.isequal(expected): # pragma: no cover (debug) print(example) print("Expected:") @@ -299,23 +316,38 @@ def test_from_scipy_sparse_duplicates(): @pytest.mark.skipif("not ss") -def test_matrix_market_sparse_duplicates(): - mm = StringIO( - """%%MatrixMarket matrix coordinate real general +@pytest.mark.parametrize("engine", ["auto", "scipy", "fast_matrix_market"]) +def test_matrix_market_sparse_duplicates(engine): + if engine == "fast_matrix_market" and fmm is None: # pragma: no cover (import) + pytest.skip("needs fast_matrix_market") + string = """%%MatrixMarket matrix coordinate real general 3 3 4 1 3 1 2 2 2 3 1 3 3 1 4""" - ) + mm = StringIO(string) with pytest.raises(ValueError, match="Duplicate indices found"): - gb.io.mmread(mm) - mm.seek(0) - a = gb.io.mmread(mm, dup_op=gb.binary.plus) + gb.io.mmread(mm, engine) + # mm.seek(0) # Doesn't work with `fast_matrix_market` 1.4.5 + mm = StringIO(string) + a = gb.io.mmread(mm, engine, dup_op=gb.binary.plus) expected = gb.Matrix.from_coo([0, 1, 2], [2, 1, 0], [1, 2, 7]) assert a.isequal(expected) +@pytest.mark.skipif("not ss") +def test_matrix_market_bad_engine(): + A = gb.Matrix.from_coo([0, 0, 3, 5], [1, 4, 0, 2], [1, 0, 2, -1], nrows=7, ncols=6) + with pytest.raises(ValueError, match="Bad engine value"): + gb.io.mmwrite(BytesIO(), A, engine="bad_engine") + mm_out = BytesIO() + gb.io.mmwrite(mm_out, A) + mm_out.seek(0) + with pytest.raises(ValueError, match="Bad engine value"): + gb.io.mmread(mm_out, engine="bad_engine") + + @pytest.mark.skipif("not ss") def test_scipy_sparse(): a = np.arange(12).reshape(3, 4) @@ -334,6 +366,7 @@ def test_scipy_sparse(): @pytest.mark.skipif("not ak") +@pytest.mark.xfail(np.__version__[:5] in {"1.25.", "1.26."}, reason="awkward bug with numpy >=1.25") def test_awkward_roundtrip(): # Vector v = gb.Vector.from_coo([1, 3, 5], [20, 21, -5], size=22) @@ -355,6 +388,7 @@ def test_awkward_roundtrip(): @pytest.mark.skipif("not ak") +@pytest.mark.xfail(np.__version__[:5] in {"1.25.", "1.26."}, reason="awkward bug with numpy >=1.25") def test_awkward_iso_roundtrip(): # Vector v = gb.Vector.from_coo([1, 3, 5], [20, 20, 20], size=22) @@ -398,6 +432,7 @@ def test_awkward_errors(): @pytest.mark.skipif("not sparse") +@pytest.mark.slow def test_vector_to_from_pydata_sparse(): coords = np.array([0, 1, 2, 3, 4], dtype="int64") data = np.array([10, 20, 30, 40, 50], dtype="int64") @@ -411,6 +446,7 @@ def test_vector_to_from_pydata_sparse(): @pytest.mark.skipif("not sparse") +@pytest.mark.slow def test_matrix_to_from_pydata_sparse(): coords = np.array([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], dtype="int64") data = np.array([10, 20, 30, 40, 50], dtype="int64") diff --git a/graphblas/tests/test_matrix.py b/graphblas/tests/test_matrix.py index 40676f71a..24f0e73d7 100644 --- a/graphblas/tests/test_matrix.py +++ b/graphblas/tests/test_matrix.py @@ -11,6 +11,7 @@ import graphblas as gb from graphblas import agg, backend, binary, dtypes, indexunary, monoid, select, semiring, unary +from graphblas.core import _supports_udfs as supports_udfs from graphblas.core import lib from graphblas.exceptions import ( DimensionMismatch, @@ -23,7 +24,7 @@ OutputNotEmpty, ) -from .conftest import autocompute, compute +from .conftest import autocompute, compute, pypy, shouldhave from graphblas import Matrix, Scalar, Vector # isort:skip (for dask-graphblas) @@ -1230,6 +1231,8 @@ def test_apply_indexunary(A): assert w4.isequal(A3) with pytest.raises(TypeError, match="left"): A.apply(select.valueeq, left=s3) + assert pickle.loads(pickle.dumps(indexunary.tril)) is indexunary.tril + assert pickle.loads(pickle.dumps(indexunary.tril[int])) is indexunary.tril[int] def test_select(A): @@ -1259,6 +1262,16 @@ def test_select(A): with pytest.raises(TypeError, match="thunk"): A.select(select.valueeq, object()) + A3rows = Matrix.from_coo([0, 0, 1, 1, 2], [1, 3, 4, 6, 5], [2, 3, 8, 4, 1], nrows=7, ncols=7) + w8 = select.rowle(A, 2).new() + w9 = A.select("row<=", 2).new() + w10 = select.row(A < 3).new() + assert w8.isequal(A3rows) + assert w9.isequal(A3rows) + assert w10.isequal(A3rows) + assert pickle.loads(pickle.dumps(select.tril)) is select.tril + assert pickle.loads(pickle.dumps(select.tril[bool])) is select.tril[bool] + @autocompute def test_select_bools_and_masks(A): @@ -1283,16 +1296,27 @@ def test_select_bools_and_masks(A): A.select(A[0, :].new().S) +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_indexunary_udf(A): def threex_minusthunk(x, row, col, thunk): # pragma: no cover (numba) return 3 * x - thunk - indexunary.register_new("threex_minusthunk", threex_minusthunk) + assert indexunary.register_new("threex_minusthunk", threex_minusthunk) is not None assert hasattr(indexunary, "threex_minusthunk") assert not hasattr(select, "threex_minusthunk") with pytest.raises(ValueError, match="SelectOp must have BOOL return type"): select.register_anonymous(threex_minusthunk) + with pytest.raises(ValueError, match="SelectOp must have BOOL return type"): + select.register_new("bad_select", threex_minusthunk) + assert not hasattr(indexunary, "bad_select") + assert not hasattr(select, "bad_select") + assert select.register_new("bad_select", threex_minusthunk, lazy=True) is None + with pytest.raises(ValueError, match="SelectOp must have BOOL return type"): + select.bad_select + assert not hasattr(select, "bad_select") + assert hasattr(indexunary, "bad_select") # Keep it + expected = Matrix.from_coo( [3, 0, 3, 5, 6, 0, 6, 1, 6, 2, 4, 1], [0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6], @@ -1308,6 +1332,8 @@ def iii(x, row, col, thunk): # pragma: no cover (numba) select.register_new("iii", iii) assert hasattr(indexunary, "iii") assert hasattr(select, "iii") + assert indexunary.iii[int].orig_func is select.iii[int].orig_func is select.iii.orig_func + assert indexunary.iii[int]._numba_func is select.iii[int]._numba_func is select.iii._numba_func iii_apply = indexunary.register_anonymous(iii) expected = Matrix.from_coo( [3, 0, 3, 5, 6, 0, 6, 1, 6, 2, 4, 1], @@ -1353,15 +1379,17 @@ def test_reduce_agg(A): expected = unary.sqrt[float](squared).new() w5 = A.reduce_rowwise(agg.hypot).new() assert w5.isclose(expected) - w6 = A.reduce_rowwise(monoid.numpy.hypot[float]).new() - assert w6.isclose(expected) + if shouldhave(monoid.numpy, "hypot"): + w6 = A.reduce_rowwise(monoid.numpy.hypot[float]).new() + assert w6.isclose(expected) w7 = Vector(w5.dtype, size=w5.size) w7 << A.reduce_rowwise(agg.hypot) assert w7.isclose(expected) w8 = A.reduce_rowwise(agg.logaddexp).new() - expected = A.reduce_rowwise(monoid.numpy.logaddexp[float]).new() - assert w8.isclose(w8) + if shouldhave(monoid.numpy, "logaddexp"): + expected = A.reduce_rowwise(monoid.numpy.logaddexp[float]).new() + assert w8.isclose(w8) result = Vector.from_coo([0, 1, 2, 3, 4, 5, 6], [3, 2, 9, 10, 11, 8, 4]) w9 = A.reduce_columnwise(agg.sum).new() @@ -1598,6 +1626,7 @@ def test_reduce_agg_empty(): assert compute(s.value) is None +@pytest.mark.skipif("not supports_udfs") def test_reduce_row_udf(A): result = Vector.from_coo([0, 1, 2, 3, 4, 5, 6], [5, 12, 1, 6, 7, 1, 15]) @@ -2007,6 +2036,12 @@ def test_ss_import_export(A, do_iso, methods): B4 = Matrix.ss.import_any(**d) assert B4.isequal(A) assert B4.ss.is_iso is do_iso + if do_iso: + d["values"] = 1 + d["is_iso"] = False + B4b = Matrix.ss.import_any(**d) + assert B4b.isequal(A) + assert B4b.ss.is_iso is True else: A4.ss.pack_any(**d) assert A4.isequal(A) @@ -2173,15 +2208,14 @@ def test_ss_import_export(A, do_iso, methods): C1.ss.pack_any(**d) assert C1.isequal(C) assert C1.ss.is_iso is do_iso + elif in_method == "import": + D1 = Matrix.ss.import_any(**d) + assert D1.isequal(C) + assert D1.ss.is_iso is do_iso else: - if in_method == "import": - D1 = Matrix.ss.import_any(**d) - assert D1.isequal(C) - assert D1.ss.is_iso is do_iso - else: - C1.ss.pack_any(**d) - assert C1.isequal(C) - assert C1.ss.is_iso is do_iso + C1.ss.pack_any(**d) + assert C1.isequal(C) + assert C1.ss.is_iso is do_iso C2 = C.dup() d = getattr(C2.ss, out_method)("fullc") @@ -2263,6 +2297,11 @@ def test_ss_import_on_view(): A = Matrix.from_coo([0, 0, 1, 1], [0, 1, 0, 1], [1, 2, 3, 4]) B = Matrix.ss.import_any(nrows=2, ncols=2, values=np.array([1, 2, 3, 4, 99, 99, 99])[:4]) assert A.isequal(B) + values = np.arange(16).reshape(4, 4)[::2, ::2] + bitmap = np.ones((4, 4), dtype=bool)[::2, ::2] + C = Matrix.ss.import_any(values=values, bitmap=bitmap) + D = Matrix.ss.import_any(values=values.copy(), bitmap=bitmap.copy()) + assert C.isequal(D) @pytest.mark.skipif("not suitesparse") @@ -2564,12 +2603,14 @@ def test_iter(A): zip( [3, 0, 3, 5, 6, 0, 6, 1, 6, 2, 4, 1], [0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6], + strict=True, ) ) assert set(A.T) == set( zip( [0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6], [3, 0, 3, 5, 6, 0, 6, 1, 6, 2, 4, 1], + strict=True, ) ) @@ -2692,8 +2733,8 @@ def test_ss_split(A): for results in [A.ss.split([4, 3]), A.ss.split([[4, None], 3], name="split")]: row_boundaries = [0, 4, 7] col_boundaries = [0, 3, 6, 7] - for i, (i1, i2) in enumerate(zip(row_boundaries[:-1], row_boundaries[1:])): - for j, (j1, j2) in enumerate(zip(col_boundaries[:-1], col_boundaries[1:])): + for i, (i1, i2) in enumerate(itertools.pairwise(row_boundaries)): + for j, (j1, j2) in enumerate(itertools.pairwise(col_boundaries)): expected = A[i1:i2, j1:j2].new() assert expected.isequal(results[i][j]) with pytest.raises(DimensionMismatch): @@ -2766,6 +2807,8 @@ def test_ss_nbytes(A): @autocompute def test_auto(A, v): + from graphblas.core.infix import MatrixEwiseMultExpr + expected = binary.land[bool](A & A).new() B = A.dup(dtype=bool) for expr in [(B & B), binary.land[bool](A & A)]: @@ -2788,14 +2831,26 @@ def test_auto(A, v): "__and__", "__or__", # "kronecker", + "__rand__", + "__ror__", ]: + # print(type(expr).__name__, method) val1 = getattr(expected, method)(expected).new() - val2 = getattr(expected, method)(expr) - val3 = getattr(expr, method)(expected) - val4 = getattr(expr, method)(expr) - assert val1.isequal(val2) - assert val1.isequal(val3) - assert val1.isequal(val4) + if method in {"__or__", "__ror__"} and type(expr) is MatrixEwiseMultExpr: + # Doing e.g. `plus(A & B | C)` isn't allowed--make user be explicit + with pytest.raises(TypeError): + val2 = getattr(expected, method)(expr) + with pytest.raises(TypeError): + val3 = getattr(expr, method)(expected) + with pytest.raises(TypeError): + val4 = getattr(expr, method)(expr) + else: + val2 = getattr(expected, method)(expr) + assert val1.isequal(val2) + val3 = getattr(expr, method)(expected) + assert val1.isequal(val3) + val4 = getattr(expr, method)(expr) + assert val1.isequal(val4) for method in ["reduce_rowwise", "reduce_columnwise", "reduce_scalar"]: s1 = getattr(expected, method)(monoid.lor).new() s2 = getattr(expr, method)(monoid.lor) @@ -2899,22 +2954,23 @@ def test_expr_is_like_matrix(A): "from_dicts", "from_edgelist", "from_scalar", - "from_values", "resize", + "setdiag", "update", } - assert attrs - expr_attrs == expected, ( + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union", "_mxm", "_mxv"} + assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Matrix. You may need to " "add an entry to `matrix` or `matrix_vector` set in `graphblas.core.automethods` " "and then run `python -m graphblas.core.automethods`. If you're messing with infix " "methods, then you may need to run `python -m graphblas.core.infixmethods`." ) - assert attrs - infix_attrs == expected + assert attrs - infix_attrs - ignore == expected # TransposedMatrix is used differently than other expressions, # so maybe it shouldn't support everything. if suitesparse: expected.add("ss") - assert attrs - transposed_attrs == (expected | {"_as_vector", "S", "V"}) - { + assert attrs - transposed_attrs - ignore == (expected | {"_as_vector", "S", "V"}) - { "_prep_for_extract", "_extract_element", } @@ -2962,11 +3018,12 @@ def test_index_expr_is_like_matrix(A): "from_dense", "from_dicts", "from_edgelist", - "from_values", "from_scalar", "resize", + "setdiag", } - assert attrs - expr_attrs == expected, ( + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union", "_mxm", "_mxv"} + assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Matrix. You may need to " "add an entry to `matrix` or `matrix_vector` set in `graphblas.core.automethods` " "and then run `python -m graphblas.core.automethods`. If you're messing with infix " @@ -3013,7 +3070,7 @@ def test_ss_flatten(A): [3, 2, 3, 1, 5, 3, 7, 8, 3, 1, 7, 4], ] # row-wise - indices = [row * A.ncols + col for row, col in zip(data[0], data[1])] + indices = [row * A.ncols + col for row, col in zip(data[0], data[1], strict=True)] expected = Vector.from_coo(indices, data[2], size=A.nrows * A.ncols) for fmt in ["csr", "hypercsr", "bitmapr"]: B = Matrix.ss.import_any(**A.ss.export(format=fmt)) @@ -3032,7 +3089,7 @@ def test_ss_flatten(A): assert C.isequal(B) # column-wise - indices = [col * A.nrows + row for row, col in zip(data[0], data[1])] + indices = [col * A.nrows + row for row, col in zip(data[0], data[1], strict=True)] expected = Vector.from_coo(indices, data[2], size=A.nrows * A.ncols) for fmt in ["csc", "hypercsc", "bitmapc"]: B = Matrix.ss.import_any(**A.ss.export(format=fmt)) @@ -3095,6 +3152,10 @@ def test_ss_reshape(A): def test_autocompute_argument_messages(A, v): with pytest.raises(TypeError, match="autocompute"): A.ewise_mult(A & A) + with pytest.raises(TypeError, match="autocompute"): + A.ewise_mult(binary.plus(A & A)) + with pytest.raises(TypeError, match="autocompute"): + A.ewise_mult(A + A) with pytest.raises(TypeError, match="autocompute"): A.mxv(A @ v) @@ -3111,10 +3172,12 @@ def test_infix_sugar(A): assert binary.times(2, A).isequal(2 * A) assert binary.truediv(A, 2).isequal(A / 2) assert binary.truediv(5, A).isequal(5 / A) - assert binary.floordiv(A, 2).isequal(A // 2) - assert binary.floordiv(5, A).isequal(5 // A) - assert binary.numpy.mod(A, 2).isequal(A % 2) - assert binary.numpy.mod(5, A).isequal(5 % A) + if shouldhave(binary, "floordiv"): + assert binary.floordiv(A, 2).isequal(A // 2) + assert binary.floordiv(5, A).isequal(5 // A) + if shouldhave(binary.numpy, "mod"): + assert binary.numpy.mod(A, 2).isequal(A % 2) + assert binary.numpy.mod(5, A).isequal(5 % A) assert binary.pow(A, 2).isequal(A**2) assert binary.pow(2, A).isequal(2**A) assert binary.pow(A, 2).isequal(pow(A, 2)) @@ -3141,26 +3204,27 @@ def test_infix_sugar(A): assert binary.ge(A, 4).isequal(A >= 4) assert binary.eq(A, 4).isequal(A == 4) assert binary.ne(A, 4).isequal(A != 4) - x, y = divmod(A, 3) - assert binary.floordiv(A, 3).isequal(x) - assert binary.numpy.mod(A, 3).isequal(y) - assert binary.fmod(A, 3).isequal(y) - assert A.isequal(binary.plus((3 * x) & y)) - x, y = divmod(-A, 3) - assert binary.floordiv(-A, 3).isequal(x) - assert binary.numpy.mod(-A, 3).isequal(y) - # assert binary.fmod(-A, 3).isequal(y) # The reason we use numpy.mod - assert (-A).isequal(binary.plus((3 * x) & y)) - x, y = divmod(3, A) - assert binary.floordiv(3, A).isequal(x) - assert binary.numpy.mod(3, A).isequal(y) - assert binary.fmod(3, A).isequal(y) - assert binary.plus(binary.times(A & x) & y).isequal(3 * unary.one(A)) - x, y = divmod(-3, A) - assert binary.floordiv(-3, A).isequal(x) - assert binary.numpy.mod(-3, A).isequal(y) - # assert binary.fmod(-3, A).isequal(y) # The reason we use numpy.mod - assert binary.plus(binary.times(A & x) & y).isequal(-3 * unary.one(A)) + if shouldhave(binary, "floordiv") and shouldhave(binary.numpy, "mod"): + x, y = divmod(A, 3) + assert binary.floordiv(A, 3).isequal(x) + assert binary.numpy.mod(A, 3).isequal(y) + assert binary.fmod(A, 3).isequal(y) + assert A.isequal(binary.plus((3 * x) & y)) + x, y = divmod(-A, 3) + assert binary.floordiv(-A, 3).isequal(x) + assert binary.numpy.mod(-A, 3).isequal(y) + # assert binary.fmod(-A, 3).isequal(y) # The reason we use numpy.mod + assert (-A).isequal(binary.plus((3 * x) & y)) + x, y = divmod(3, A) + assert binary.floordiv(3, A).isequal(x) + assert binary.numpy.mod(3, A).isequal(y) + assert binary.fmod(3, A).isequal(y) + assert binary.plus(binary.times(A & x) & y).isequal(3 * unary.one(A)) + x, y = divmod(-3, A) + assert binary.floordiv(-3, A).isequal(x) + assert binary.numpy.mod(-3, A).isequal(y) + # assert binary.fmod(-3, A).isequal(y) # The reason we use numpy.mod + assert binary.plus(binary.times(A & x) & y).isequal(-3 * unary.one(A)) assert binary.eq(A & A).isequal(A == A) assert binary.ne(A.T & A.T).isequal(A.T != A.T) @@ -3183,14 +3247,16 @@ def test_infix_sugar(A): B /= 2 assert type(B) is Matrix assert binary.truediv(A, 2).isequal(B) - B = A.dup() - B //= 2 - assert type(B) is Matrix - assert binary.floordiv(A, 2).isequal(B) - B = A.dup() - B %= 2 - assert type(B) is Matrix - assert binary.numpy.mod(A, 2).isequal(B) + if shouldhave(binary, "floordiv"): + B = A.dup() + B //= 2 + assert type(B) is Matrix + assert binary.floordiv(A, 2).isequal(B) + if shouldhave(binary.numpy, "mod"): + B = A.dup() + B %= 2 + assert type(B) is Matrix + assert binary.numpy.mod(A, 2).isequal(B) B = A.dup() B **= 2 assert type(B) is Matrix @@ -3491,28 +3557,6 @@ def compare(A, expected, isequal=True, **kwargs): A.ss.compactify("bad_how") -def test_deprecated(A): - if suitesparse: - with pytest.warns(DeprecationWarning): - A.ss.compactify_rowwise() - with pytest.warns(DeprecationWarning): - A.ss.compactify_columnwise() - with pytest.warns(DeprecationWarning): - A.ss.scan_rowwise() - with pytest.warns(DeprecationWarning): - A.ss.scan_columnwise() - with pytest.warns(DeprecationWarning): - A.ss.selectk_rowwise("first", 3) - with pytest.warns(DeprecationWarning): - A.ss.selectk_columnwise("first", 3) - with pytest.warns(DeprecationWarning): - A.to_values() - with pytest.warns(DeprecationWarning): - A.T.to_values() - with pytest.warns(DeprecationWarning): - A.from_values([1], [2], [3]) - - def test_ndim(A): assert A.ndim == 2 assert A.ewise_mult(A).ndim == 2 @@ -3521,7 +3565,7 @@ def test_ndim(A): def test_sizeof(A): - if suitesparse: + if suitesparse and not pypy: assert sys.getsizeof(A) > A.nvals * 16 else: with pytest.raises(TypeError): @@ -3584,9 +3628,9 @@ def test_ss_iteration(A): assert not list(B.ss.itervalues()) assert not list(B.ss.iteritems()) rows, columns, values = A.to_coo() - assert sorted(zip(rows, columns)) == sorted(A.ss.iterkeys()) + assert sorted(zip(rows, columns, strict=True)) == sorted(A.ss.iterkeys()) assert sorted(values) == sorted(A.ss.itervalues()) - assert sorted(zip(rows, columns, values)) == sorted(A.ss.iteritems()) + assert sorted(zip(rows, columns, values, strict=True)) == sorted(A.ss.iteritems()) N = rows.size A = Matrix.ss.import_bitmapr(**A.ss.export("bitmapr")) @@ -3608,6 +3652,7 @@ def test_ss_iteration(A): assert next(A.ss.iteritems()) is not None +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_udt(): record_dtype = np.dtype([("x", np.bool_), ("y", np.float64)], align=True) @@ -3844,7 +3889,7 @@ def test_get(A): assert compute(A.T.get(0, 1)) is None assert A.T.get(1, 0) == 2 assert A.get(0, 1, "mittens") == 2 - assert type(compute(A.get(0, 1))) is int + assert isinstance(compute(A.get(0, 1)), int) with pytest.raises(ValueError, match="Bad row, col"): # Not yet supported A.get(0, [0, 1]) @@ -3918,7 +3963,7 @@ def test_ss_config(A): def test_to_csr_from_csc(A): - assert Matrix.from_csr(*A.to_csr(dtype=int)).isequal(A, check_dtype=True) + assert Matrix.from_csr(*A.to_csr(sort=False, dtype=int)).isequal(A, check_dtype=True) assert Matrix.from_csr(*A.T.to_csc()).isequal(A, check_dtype=True) assert Matrix.from_csc(*A.to_csc()).isequal(A) assert Matrix.from_csc(*A.T.to_csr()).isequal(A) @@ -4029,10 +4074,11 @@ def test_ss_pack_hyperhash(A): Y = C.ss.unpack_hyperhash() Y = C.ss.unpack_hyperhash(compute=True) assert C.ss.unpack_hyperhash() is None - assert Y.nrows == C.nrows - C.ss.pack_hyperhash(Y) - assert Y.gb_obj[0] == gb.core.NULL - assert C.ss.unpack_hyperhash() is not None + if Y is not None: # hyperhash may or may not be computed + assert Y.nrows == C.nrows + C.ss.pack_hyperhash(Y) + assert Y.gb_obj[0] == gb.core.NULL + assert C.ss.unpack_hyperhash() is not None # May or may not be computed def test_to_dicts_from_dicts(A): @@ -4127,7 +4173,11 @@ def test_from_scalar(): A = Matrix.from_scalar(1, dtype="INT64[2]", nrows=3, ncols=4) B = Matrix("INT64[2]", nrows=3, ncols=4) B << [1, 1] - assert A.isequal(B, check_dtype=True) + if supports_udfs: + assert A.isequal(B, check_dtype=True) + else: + with pytest.raises(KeyError, match="eq does not work with"): + assert A.isequal(B, check_dtype=True) def test_to_dense_from_dense(): @@ -4247,13 +4297,13 @@ def test_ss_descriptors(A): A(nthreads=4, axb_method="dot", sort=True) << A @ A assert A.isequal(C2) # Bad option should show list of valid options - with pytest.raises(ValueError, match="nthreads"): + with pytest.raises(ValueError, match="axb_method"): C1(bad_opt=True) << A with pytest.raises(ValueError, match="Duplicate descriptor"): (A @ A).new(nthreads=4, Nthreads=5) with pytest.raises(ValueError, match="escriptor"): A[0, 0].new(bad_opt=True) - A[0, 0].new(nthreads=4) # ignored, but okay + A[0, 0].new(nthreads=4, sort=None) # ignored, but okay with pytest.raises(ValueError, match="escriptor"): A.__setitem__((0, 0), 1, bad_opt=True) A.__setitem__((0, 0), 1, nthreads=4) # ignored, but okay @@ -4287,6 +4337,7 @@ def test_wait_chains(A): assert result == 47 +@pytest.mark.skipif("not supports_udfs") def test_subarray_dtypes(): a = np.arange(3 * 4, dtype=np.int64).reshape(3, 4) A = Matrix.from_coo([1, 3, 5], [0, 1, 3], a) @@ -4323,3 +4374,174 @@ def test_subarray_dtypes(): if suitesparse: Full2 = Matrix.ss.import_fullr(b2) assert Full1.isequal(Full2, check_dtype=True) + + +def test_power(A): + expected = A.dup() + for i in range(1, 50): + result = A.power(i).new() + assert result.isequal(expected) + expected << A @ expected + # Test transpose + expected = A.T.new() + for i in range(1, 10): + result = A.T.power(i).new() + assert result.isequal(expected) + expected << A.T @ expected + # Test other semiring + expected = A.dup() + for i in range(1, 10): + result = A.power(i, semiring.min_plus).new() + assert result.isequal(expected) + expected << semiring.min_plus(A @ expected) + # n == 0 + result = A.power(0).new() + expected = Vector.from_scalar(1, A.nrows, A.dtype).diag() + assert result.isequal(expected) + result = A.power(0, semiring.plus_min).new() + identity = semiring.plus_min[A.dtype].binaryop.monoid.identity + assert identity != 1 + expected = Vector.from_scalar(identity, A.nrows, A.dtype).diag() + assert result.isequal(expected) + # Exceptional + with pytest.raises(TypeError, match="must be a nonnegative integer"): + A.power(1.5) + with pytest.raises(ValueError, match="must be a nonnegative integer"): + A.power(-1) + with pytest.raises(ValueError, match="binaryop must be associated with a monoid"): + A.power(0, semiring.min_first) + B = A[:2, :3].new() + with pytest.raises(DimensionMismatch): + B.power(2) + + +def test_setdiag(): + A = Matrix(int, 2, 3) + A.setdiag(1) + expected = Matrix(int, 2, 3) + expected[0, 0] = 1 + expected[1, 1] = 1 + assert A.isequal(expected) + A.setdiag(Scalar.from_value(2), 2) + expected[0, 2] = 2 + assert A.isequal(expected) + A.setdiag(3, k=-1) + expected[1, 0] = 3 + assert A.isequal(expected) + # List (or array) is treated as dense + A.setdiag([10, 20], 1) + expected[0, 1] = 10 + expected[1, 2] = 20 + assert A.isequal(expected) + # Size 0 diagonals, which does not set anything. + # This could be valid (esp. given a size 0 vector), but let's raise for now. + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag(-1, 3) + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag(-1, -2) + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag([], 3) + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag(Vector(int, 0), -2) + # Now we're definitely out of bounds + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag(-1, 4) + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag(-1, -3) + with pytest.raises(TypeError, match="k must be an integer"): + A.setdiag(-1, 0.5) + with pytest.raises(TypeError, match="Bad type for argument `values` in Matrix.setdiag"): + A.setdiag(object()) + with pytest.raises(DimensionMismatch, match="Dimensions not compatible"): + A.setdiag([10, 20, 30], 1) + with pytest.raises(DimensionMismatch, match="Dimensions not compatible"): + A.setdiag([10], 1) + + # Special care for dimensions of length 0 + A = Matrix(int, 0, 2, name="A") + A.setdiag(0, 0) + A.setdiag(0, 1) + A.setdiag([], 0) + A.setdiag([], 1) + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag(0, -1) + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag([], -1) + A = Matrix(int, 2, 0, name="A") + A.setdiag(0, 0) + A.setdiag(0, -1) + A.setdiag([], 0) + A.setdiag([], -1) + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag(0, 1) + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag([], 1) + A = Matrix(int, 0, 0, name="A") + A.setdiag(0, 0) + A.setdiag([], 0) + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag(0, 1) + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag([], 1) + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag(0, -1) + with pytest.raises(IndexError, match="diagonal is out of range"): + A.setdiag([], -1) + + A = Matrix(int, 2, 2, name="A") + expected = Matrix(int, 2, 2, name="expected") + v = Vector(int, 2, name="v") + Vector(int, 2) + v[0] = 1 + A.setdiag(v) + expected[0, 0] = 1 + assert A.isequal(expected) + A.setdiag(v, accum=binary.plus) + expected[0, 0] = 2 + assert A.isequal(expected) + A.setdiag(10, mask=v.S) + expected[0, 0] = 10 + assert A.isequal(expected) + A.setdiag(10, mask=v.S, accum="+") + expected[0, 0] = 20 + assert A.isequal(expected) + # Allow mask to be a matrix + A.setdiag(10, mask=A.S, accum="+") + expected[0, 0] = 30 + assert A.isequal(expected) + # Test how to clear or not clear missing elements + A.clear() + A.setdiag(99) + A.setdiag(v) + expected[0, 0] = 1 + assert A.isequal(expected) + A.setdiag(99) + A.setdiag(v, accum="second") + expected[1, 1] = 99 + assert A.isequal(expected) + A.setdiag(99) + A.setdiag(v, mask=v.S) + assert A.isequal(expected) + + # We handle complemented masks! + A.clear() + expected.clear() + A.setdiag(42, mask=~v.S) + expected[1, 1] = 42 + assert A.isequal(expected) + A.setdiag(7, mask=~A.V) + expected[0, 0] = 7 + assert A.isequal(expected) + + with pytest.raises(DimensionMismatch, match="Matrix mask in setdiag is the wrong "): + A.setdiag(9, mask=Matrix(int, 3, 3).S) + with pytest.raises(DimensionMismatch, match="Vector mask in setdiag is the wrong "): + A.setdiag(10, mask=Vector(int, 3).S) + + A.clear() + A.resize(2, 3) + expected.clear() + expected.resize(2, 3) + A.setdiag(30, mask=v.S) + expected[0, 0] = 30 + assert A.isequal(expected) diff --git a/graphblas/tests/test_numpyops.py b/graphblas/tests/test_numpyops.py index c528d4051..999c6d5e0 100644 --- a/graphblas/tests/test_numpyops.py +++ b/graphblas/tests/test_numpyops.py @@ -5,28 +5,32 @@ import numpy as np import pytest +from packaging.version import parse import graphblas as gb import graphblas.binary.numpy as npbinary import graphblas.monoid.numpy as npmonoid import graphblas.semiring.numpy as npsemiring import graphblas.unary.numpy as npunary -from graphblas import Vector, backend +from graphblas import Vector, backend, config +from graphblas.core import _supports_udfs as supports_udfs from graphblas.dtypes import _supports_complex -from .conftest import compute +from .conftest import compute, shouldhave is_win = sys.platform.startswith("win") suitesparse = backend == "suitesparse" def test_numpyops_dir(): - assert "exp2" in dir(npunary) - assert "logical_and" in dir(npbinary) - assert "logaddexp" in dir(npmonoid) - assert "add_add" in dir(npsemiring) + udf_or_mapped = supports_udfs or config["mapnumpy"] + assert ("exp2" in dir(npunary)) == udf_or_mapped + assert ("logical_and" in dir(npbinary)) == udf_or_mapped + assert ("logaddexp" in dir(npmonoid)) == supports_udfs + assert ("add_add" in dir(npsemiring)) == udf_or_mapped +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_bool_doesnt_get_too_large(): a = Vector.from_coo([0, 1, 2, 3], [True, False, True, False]) @@ -70,9 +74,12 @@ def test_npunary(): # due to limitation of MSVC with complex blocklist["FC64"].update({"arcsin", "arcsinh"}) blocklist["FC32"] = {"arcsin", "arcsinh"} - isclose = gb.binary.isclose(1e-6, 0) + if shouldhave(gb.binary, "isclose"): + isclose = gb.binary.isclose(1e-6, 0) + else: + isclose = None for gb_input, np_input in data: - for unary_name in sorted(npunary._unary_names): + for unary_name in sorted(npunary._unary_names & npunary.__dir__()): op = getattr(npunary, unary_name) if gb_input.dtype not in op.types or unary_name in blocklist.get( gb_input.dtype.name, () @@ -99,11 +106,22 @@ def test_npunary(): list(range(np_input.size)), list(np_result), dtype=gb_result.dtype ) assert gb_result.nvals == np_result.size + if compare_op is None: + continue # FLAKY COVERAGE match = gb_result.ewise_mult(np_result, compare_op).new() if gb_result.dtype.name.startswith("F"): match(accum=gb.binary.lor) << gb_result.apply(npunary.isnan) compare = match.reduce(gb.monoid.land).new() if not compare: # pragma: no cover (debug) + import numba + + if ( + unary_name in {"sign"} + and np.__version__.startswith("2.") + and parse(numba.__version__) < parse("0.61.0") + ): + # numba <0.61.0 does not match numpy 2.0 + continue print(unary_name, gb_input.dtype) print(compute(gb_result)) print(np_result) @@ -149,9 +167,24 @@ def test_npbinary(): "FP64": {"floor_divide"}, # numba/numpy difference for 1.0 / 0.0 "BOOL": {"gcd", "lcm", "subtract"}, # not supported by numpy } - isclose = gb.binary.isclose(1e-7, 0) + if shouldhave(gb.binary, "isclose"): + isclose = gb.binary.isclose(1e-7, 0) + else: + isclose = None + if shouldhave(npbinary, "equal"): + equal = npbinary.equal + else: + equal = gb.binary.eq + if shouldhave(npbinary, "isnan"): + isnan = npunary.isnan + else: + isnan = gb.unary.isnan + if shouldhave(npbinary, "isinf"): + isinf = npunary.isinf + else: + isinf = gb.unary.isinf for (gb_left, gb_right), (np_left, np_right) in data: - for binary_name in sorted(npbinary._binary_names): + for binary_name in sorted(npbinary._binary_names & npbinary.__dir__()): op = getattr(npbinary, binary_name) if gb_left.dtype not in op.types or binary_name in blocklist.get( gb_left.dtype.name, () @@ -168,7 +201,10 @@ def test_npbinary(): compare_op = isclose else: np_result = getattr(np, binary_name)(np_left, np_right) - compare_op = npbinary.equal + if binary_name in {"arctan2"}: + compare_op = isclose + else: + compare_op = equal except Exception: # pragma: no cover (debug) print(f"Error computing numpy result for {binary_name}") print(f"dtypes: ({gb_left.dtype}, {gb_right.dtype}) -> {gb_result.dtype}") @@ -176,19 +212,23 @@ def test_npbinary(): np_result = Vector.from_coo(np.arange(np_left.size), np_result, dtype=gb_result.dtype) assert gb_result.nvals == np_result.size + if compare_op is None: + continue # FLAKY COVERAGE match = gb_result.ewise_mult(np_result, compare_op).new() if gb_result.dtype.name.startswith("F"): - match(accum=gb.binary.lor) << gb_result.apply(npunary.isnan) + match(accum=gb.binary.lor) << gb_result.apply(isnan) if gb_result.dtype.name.startswith("FC"): # Divide by 0j sometimes result in different behavior, such as `nan` or `(inf+0j)` - match(accum=gb.binary.lor) << gb_result.apply(npunary.isinf) + match(accum=gb.binary.lor) << gb_result.apply(isinf) compare = match.reduce(gb.monoid.land).new() if not compare: # pragma: no cover (debug) + print(compare_op) print(binary_name) print(compute(gb_left)) print(compute(gb_right)) print(compute(gb_result)) print(np_result) + print((np_result - compute(gb_result)).new().to_coo()[1]) assert compare @@ -218,7 +258,7 @@ def test_npmonoid(): ], ] # Complex monoids not working yet (they segfault upon creation in gb.core.operators) - if _supports_complex: # pragma: no branch + if _supports_complex: data.append( [ [ @@ -236,13 +276,13 @@ def test_npmonoid(): "BOOL": {"add"}, } for (gb_left, gb_right), (np_left, np_right) in data: - for binary_name in sorted(npmonoid._monoid_identities): + for binary_name in sorted(npmonoid._monoid_identities.keys() & npmonoid.__dir__()): op = getattr(npmonoid, binary_name) assert len(op.types) > 0, op.name if gb_left.dtype not in op.types or binary_name in blocklist.get( gb_left.dtype.name, () - ): # pragma: no cover (flaky) - continue + ): + continue # FLAKY COVERAGE with np.errstate(divide="ignore", over="ignore", under="ignore", invalid="ignore"): gb_result = gb_left.ewise_mult(gb_right, op).new() np_result = getattr(np, binary_name)(np_left, np_right) @@ -274,7 +314,8 @@ def test_npmonoid(): @pytest.mark.slow def test_npsemiring(): for monoid_name, binary_name in itertools.product( - sorted(npmonoid._monoid_identities), sorted(npbinary._binary_names) + sorted(npmonoid._monoid_identities.keys() & npmonoid.__dir__()), + sorted(npbinary._binary_names & npbinary.__dir__()), ): monoid = getattr(npmonoid, monoid_name) binary = getattr(npbinary, binary_name) diff --git a/graphblas/tests/test_op.py b/graphblas/tests/test_op.py index e32606290..41fae80ae 100644 --- a/graphblas/tests/test_op.py +++ b/graphblas/tests/test_op.py @@ -4,9 +4,30 @@ import pytest import graphblas as gb -from graphblas import agg, backend, binary, dtypes, indexunary, monoid, op, select, semiring, unary +from graphblas import ( + agg, + backend, + binary, + config, + dtypes, + indexunary, + monoid, + op, + select, + semiring, + unary, +) +from graphblas.core import _supports_udfs as supports_udfs from graphblas.core import lib, operator -from graphblas.core.operator import BinaryOp, IndexUnaryOp, Monoid, Semiring, UnaryOp, get_semiring +from graphblas.core.operator import ( + BinaryOp, + IndexUnaryOp, + Monoid, + SelectOp, + Semiring, + UnaryOp, + get_semiring, +) from graphblas.dtypes import ( BOOL, FP32, @@ -22,6 +43,8 @@ ) from graphblas.exceptions import DomainMismatch, UdfParseError +from .conftest import shouldhave + if dtypes._supports_complex: from graphblas.dtypes import FC32, FC64 @@ -142,6 +165,36 @@ def test_get_typed_op(): operator.get_typed_op(binary.plus, dtypes.INT64, "bad dtype") +@pytest.mark.skipif("supports_udfs") +def test_udf_mentions_numba(): + with pytest.raises(AttributeError, match="install numba"): + binary.rfloordiv + assert "rfloordiv" not in dir(binary) + with pytest.raises(AttributeError, match="install numba"): + semiring.any_rfloordiv + assert "any_rfloordiv" not in dir(semiring) + with pytest.raises(AttributeError, match="install numba"): + op.absfirst + assert "absfirst" not in dir(op) + with pytest.raises(AttributeError, match="install numba"): + op.plus_rpow + assert "plus_rpow" not in dir(op) + with pytest.raises(AttributeError, match="install numba"): + binary.numpy.gcd + assert "gcd" not in dir(binary.numpy) + assert "gcd" not in dir(op.numpy) + + +@pytest.mark.skipif("supports_udfs") +def test_unaryop_udf_no_support(): + def plus_one(x): # pragma: no cover (numba) + return x + 1 + + with pytest.raises(RuntimeError, match="UnaryOp.register_new.* unavailable"): + unary.register_new("plus_one", plus_one) + + +@pytest.mark.skipif("not supports_udfs") def test_unaryop_udf(): def plus_one(x): return x + 1 # pragma: no cover (numba) @@ -150,6 +203,7 @@ def plus_one(x): assert hasattr(unary, "plus_one") assert unary.plus_one.orig_func is plus_one assert unary.plus_one[int].orig_func is plus_one + assert unary.plus_one[int]._numba_func(1) == 2 comp_set = { INT8, INT16, @@ -179,9 +233,10 @@ def plus_one(x): UnaryOp.register_new("bad", object()) assert not hasattr(unary, "bad") with pytest.raises(UdfParseError, match="Unable to parse function using Numba"): - UnaryOp.register_new("bad", lambda x: v) + UnaryOp.register_new("bad", lambda x: v) # pragma: no branch (numba) +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_unaryop_parameterized(): def plus_x(x=0): @@ -207,6 +262,7 @@ def inner(val): assert r10.isequal(v11, check_dtype=True) +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_binaryop_parameterized(): def plus_plus_x(x=0): @@ -268,6 +324,7 @@ def my_add(x, y): assert op.name == "my_add" +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_monoid_parameterized(): def plus_plus_x(x=0): @@ -363,6 +420,7 @@ def bad_identity(x=0): assert monoid.is_idempotent +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_semiring_parameterized(): def plus_plus_x(x=0): @@ -490,6 +548,7 @@ def inner(y): assert B.isequal(A.kronecker(A, binary.plus).new()) +@pytest.mark.skipif("not supports_udfs") def test_unaryop_udf_bool_result(): # numba has trouble compiling this, but we have a work-around def is_positive(x): @@ -516,12 +575,14 @@ def is_positive(x): assert w.isequal(result) +@pytest.mark.skipif("not supports_udfs") def test_binaryop_udf(): def times_minus_sum(x, y): return x * y - (x + y) # pragma: no cover (numba) BinaryOp.register_new("bin_test_func", times_minus_sum) assert hasattr(binary, "bin_test_func") + assert binary.bin_test_func[int].orig_func is times_minus_sum comp_set = { BOOL, # goes to INT64 INT8, @@ -545,6 +606,7 @@ def times_minus_sum(x, y): assert w.isequal(result) +@pytest.mark.skipif("not supports_udfs") def test_monoid_udf(): def plus_plus_one(x, y): return x + y + 1 # pragma: no cover (numba) @@ -579,6 +641,7 @@ def plus_plus_one(x, y): Monoid.register_anonymous(binary.plus_plus_one, {"BOOL": -1}) +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_semiring_udf(): def plus_plus_two(x, y): @@ -608,10 +671,12 @@ def test_binary_updates(): vec4 = Vector.from_coo([0], [-3], dtype=dtypes.INT64) result2 = vec4.ewise_mult(vec2, binary.cdiv).new() assert result2.isequal(Vector.from_coo([0], [-1], dtype=dtypes.INT64), check_dtype=True) - result3 = vec4.ewise_mult(vec2, binary.floordiv).new() - assert result3.isequal(Vector.from_coo([0], [-2], dtype=dtypes.INT64), check_dtype=True) + if shouldhave(binary, "floordiv"): + result3 = vec4.ewise_mult(vec2, binary.floordiv).new() + assert result3.isequal(Vector.from_coo([0], [-2], dtype=dtypes.INT64), check_dtype=True) +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_nested_names(): def plus_three(x): @@ -671,12 +736,17 @@ def test_op_namespace(): assert op.plus is binary.plus assert op.plus_times is semiring.plus_times - assert op.numpy.fabs is unary.numpy.fabs - assert op.numpy.subtract is binary.numpy.subtract - assert op.numpy.add is binary.numpy.add - assert op.numpy.add_add is semiring.numpy.add_add + if shouldhave(unary.numpy, "fabs"): + assert op.numpy.fabs is unary.numpy.fabs + if shouldhave(binary.numpy, "subtract"): + assert op.numpy.subtract is binary.numpy.subtract + if shouldhave(binary.numpy, "add"): + assert op.numpy.add is binary.numpy.add + if shouldhave(semiring.numpy, "add_add"): + assert op.numpy.add_add is semiring.numpy.add_add assert len(dir(op)) > 300 - assert len(dir(op.numpy)) > 500 + if supports_udfs: + assert len(dir(op.numpy)) > 500 with pytest.raises( AttributeError, match="module 'graphblas.op.numpy' has no attribute 'bad_attr'" @@ -740,10 +810,18 @@ def test_op_namespace(): @pytest.mark.slow def test_binaryop_attributes_numpy(): # Some coverage from this test depends on order of tests - assert binary.numpy.add[int].monoid is monoid.numpy.add[int] - assert binary.numpy.subtract[int].monoid is None - assert binary.numpy.add.monoid is monoid.numpy.add - assert binary.numpy.subtract.monoid is None + if shouldhave(monoid.numpy, "add"): + assert binary.numpy.add[int].monoid is monoid.numpy.add[int] + assert binary.numpy.add.monoid is monoid.numpy.add + if shouldhave(binary.numpy, "subtract"): + assert binary.numpy.subtract[int].monoid is None + assert binary.numpy.subtract.monoid is None + + +@pytest.mark.skipif("not supports_udfs") +@pytest.mark.slow +def test_binaryop_monoid_numpy(): + assert gb.binary.numpy.minimum[int].monoid is gb.monoid.numpy.minimum[int] @pytest.mark.slow @@ -756,18 +834,21 @@ def test_binaryop_attributes(): def plus(x, y): return x + y # pragma: no cover (numba) - op = BinaryOp.register_anonymous(plus, name="plus") - assert op.monoid is None - assert op[int].monoid is None + if supports_udfs: + op = BinaryOp.register_anonymous(plus, name="plus") + assert op.monoid is None + assert op[int].monoid is None + assert op[int].parent is op assert binary.plus[int].parent is binary.plus - assert binary.numpy.add[int].parent is binary.numpy.add - assert op[int].parent is op + if shouldhave(binary.numpy, "add"): + assert binary.numpy.add[int].parent is binary.numpy.add # bad type assert binary.plus[bool].monoid is None - assert binary.numpy.equal[int].monoid is None - assert binary.numpy.equal[bool].monoid is monoid.numpy.equal[bool] # sanity + if shouldhave(binary.numpy, "equal"): + assert binary.numpy.equal[int].monoid is None + assert binary.numpy.equal[bool].monoid is monoid.numpy.equal[bool] # sanity for attr, val in vars(binary).items(): if not isinstance(val, BinaryOp): @@ -790,22 +871,25 @@ def test_monoid_attributes(): assert monoid.plus.binaryop is binary.plus assert monoid.plus.identities == {typ: 0 for typ in monoid.plus.types} - assert monoid.numpy.add[int].binaryop is binary.numpy.add[int] - assert monoid.numpy.add[int].identity == 0 - assert monoid.numpy.add.binaryop is binary.numpy.add - assert monoid.numpy.add.identities == {typ: 0 for typ in monoid.numpy.add.types} + if shouldhave(monoid.numpy, "add"): + assert monoid.numpy.add[int].binaryop is binary.numpy.add[int] + assert monoid.numpy.add[int].identity == 0 + assert monoid.numpy.add.binaryop is binary.numpy.add + assert monoid.numpy.add.identities == {typ: 0 for typ in monoid.numpy.add.types} def plus(x, y): # pragma: no cover (numba) return x + y - binop = BinaryOp.register_anonymous(plus, name="plus") - op = Monoid.register_anonymous(binop, 0, name="plus") - assert op.binaryop is binop - assert op[int].binaryop is binop[int] + if supports_udfs: + binop = BinaryOp.register_anonymous(plus, name="plus") + op = Monoid.register_anonymous(binop, 0, name="plus") + assert op.binaryop is binop + assert op[int].binaryop is binop[int] + assert op[int].parent is op assert monoid.plus[int].parent is monoid.plus - assert monoid.numpy.add[int].parent is monoid.numpy.add - assert op[int].parent is op + if shouldhave(monoid.numpy, "add"): + assert monoid.numpy.add[int].parent is monoid.numpy.add for attr, val in vars(monoid).items(): if not isinstance(val, Monoid): @@ -826,25 +910,27 @@ def test_semiring_attributes(): assert semiring.min_plus.monoid is monoid.min assert semiring.min_plus.binaryop is binary.plus - assert semiring.numpy.add_subtract[int].monoid is monoid.numpy.add[int] - assert semiring.numpy.add_subtract[int].binaryop is binary.numpy.subtract[int] - assert semiring.numpy.add_subtract.monoid is monoid.numpy.add - assert semiring.numpy.add_subtract.binaryop is binary.numpy.subtract + if shouldhave(semiring.numpy, "add_subtract"): + assert semiring.numpy.add_subtract[int].monoid is monoid.numpy.add[int] + assert semiring.numpy.add_subtract[int].binaryop is binary.numpy.subtract[int] + assert semiring.numpy.add_subtract.monoid is monoid.numpy.add + assert semiring.numpy.add_subtract.binaryop is binary.numpy.subtract + assert semiring.numpy.add_subtract[int].parent is semiring.numpy.add_subtract def plus(x, y): return x + y # pragma: no cover (numba) - binop = BinaryOp.register_anonymous(plus, name="plus") - mymonoid = Monoid.register_anonymous(binop, 0, name="plus") - op = Semiring.register_anonymous(mymonoid, binop, name="plus_plus") - assert op.binaryop is binop - assert op.binaryop[int] is binop[int] - assert op.monoid is mymonoid - assert op.monoid[int] is mymonoid[int] + if supports_udfs: + binop = BinaryOp.register_anonymous(plus, name="plus") + mymonoid = Monoid.register_anonymous(binop, 0, name="plus") + op = Semiring.register_anonymous(mymonoid, binop, name="plus_plus") + assert op.binaryop is binop + assert op.binaryop[int] is binop[int] + assert op.monoid is mymonoid + assert op.monoid[int] is mymonoid[int] + assert op[int].parent is op assert semiring.min_plus[int].parent is semiring.min_plus - assert semiring.numpy.add_subtract[int].parent is semiring.numpy.add_subtract - assert op[int].parent is op for attr, val in vars(semiring).items(): if not isinstance(val, Semiring): @@ -881,9 +967,10 @@ def test_div_semirings(): assert result[0, 0].new() == -2 assert result.dtype == dtypes.FP64 - result = A1.T.mxm(A2, semiring.plus_floordiv).new() - assert result[0, 0].new() == -3 - assert result.dtype == dtypes.INT64 + if shouldhave(semiring, "plus_floordiv"): + result = A1.T.mxm(A2, semiring.plus_floordiv).new() + assert result[0, 0].new() == -3 + assert result.dtype == dtypes.INT64 @pytest.mark.slow @@ -902,30 +989,32 @@ def test_get_semiring(): def myplus(x, y): return x + y # pragma: no cover (numba) - binop = BinaryOp.register_anonymous(myplus, name="myplus") - st = get_semiring(monoid.plus, binop) - assert st.monoid is monoid.plus - assert st.binaryop is binop + if supports_udfs: + binop = BinaryOp.register_anonymous(myplus, name="myplus") + st = get_semiring(monoid.plus, binop) + assert st.monoid is monoid.plus + assert st.binaryop is binop - binop = BinaryOp.register_new("myplus", myplus) - assert binop is binary.myplus - st = get_semiring(monoid.plus, binop) - assert st.monoid is monoid.plus - assert st.binaryop is binop + binop = BinaryOp.register_new("myplus", myplus) + assert binop is binary.myplus + st = get_semiring(monoid.plus, binop) + assert st.monoid is monoid.plus + assert st.binaryop is binop with pytest.raises(TypeError, match="Monoid"): get_semiring(None, binary.times) with pytest.raises(TypeError, match="Binary"): get_semiring(monoid.plus, None) - sr = get_semiring(monoid.plus, binary.numpy.copysign) - assert sr.monoid is monoid.plus - assert sr.binaryop is binary.numpy.copysign + if shouldhave(binary.numpy, "copysign"): + sr = get_semiring(monoid.plus, binary.numpy.copysign) + assert sr.monoid is monoid.plus + assert sr.binaryop is binary.numpy.copysign def test_create_semiring(): # stress test / sanity check - monoid_names = {x for x in dir(monoid) if not x.startswith("_")} + monoid_names = {x for x in dir(monoid) if not x.startswith("_") and x != "ss"} binary_names = {x for x in dir(binary) if not x.startswith("_") and x != "ss"} for monoid_name, binary_name in itertools.product(monoid_names, binary_names): cur_monoid = getattr(monoid, monoid_name) @@ -958,17 +1047,22 @@ def test_commutes(): assert semiring.plus_times.is_commutative if suitesparse: assert semiring.ss.min_secondi.commutes_to is semiring.ss.min_firstj - assert semiring.plus_pow.commutes_to is semiring.plus_rpow + if shouldhave(semiring, "plus_pow") and shouldhave(semiring, "plus_rpow"): + assert semiring.plus_pow.commutes_to is semiring.plus_rpow assert not semiring.plus_pow.is_commutative - assert binary.isclose.commutes_to is binary.isclose - assert binary.isclose.is_commutative - assert binary.isclose(0.1).commutes_to is binary.isclose(0.1) - assert binary.floordiv.commutes_to is binary.rfloordiv - assert not binary.floordiv.is_commutative - assert binary.numpy.add.commutes_to is binary.numpy.add - assert binary.numpy.add.is_commutative - assert binary.numpy.less.commutes_to is binary.numpy.greater - assert not binary.numpy.less.is_commutative + if shouldhave(binary, "isclose"): + assert binary.isclose.commutes_to is binary.isclose + assert binary.isclose.is_commutative + assert binary.isclose(0.1).commutes_to is binary.isclose(0.1) + if shouldhave(binary, "floordiv") and shouldhave(binary, "rfloordiv"): + assert binary.floordiv.commutes_to is binary.rfloordiv + assert not binary.floordiv.is_commutative + if shouldhave(binary.numpy, "add"): + assert binary.numpy.add.commutes_to is binary.numpy.add + assert binary.numpy.add.is_commutative + if shouldhave(binary.numpy, "less") and shouldhave(binary.numpy, "greater"): + assert binary.numpy.less.commutes_to is binary.numpy.greater + assert not binary.numpy.less.is_commutative # Typed assert binary.plus[int].commutes_to is binary.plus[int] @@ -985,15 +1079,20 @@ def test_commutes(): assert semiring.plus_times[int].is_commutative if suitesparse: assert semiring.ss.min_secondi[int].commutes_to is semiring.ss.min_firstj[int] - assert semiring.plus_pow[int].commutes_to is semiring.plus_rpow[int] + if shouldhave(semiring, "plus_rpow"): + assert semiring.plus_pow[int].commutes_to is semiring.plus_rpow[int] assert not semiring.plus_pow[int].is_commutative - assert binary.isclose(0.1)[int].commutes_to is binary.isclose(0.1)[int] - assert binary.floordiv[int].commutes_to is binary.rfloordiv[int] - assert not binary.floordiv[int].is_commutative - assert binary.numpy.add[int].commutes_to is binary.numpy.add[int] - assert binary.numpy.add[int].is_commutative - assert binary.numpy.less[int].commutes_to is binary.numpy.greater[int] - assert not binary.numpy.less[int].is_commutative + if shouldhave(binary, "isclose"): + assert binary.isclose(0.1)[int].commutes_to is binary.isclose(0.1)[int] + if shouldhave(binary, "floordiv") and shouldhave(binary, "rfloordiv"): + assert binary.floordiv[int].commutes_to is binary.rfloordiv[int] + assert not binary.floordiv[int].is_commutative + if shouldhave(binary.numpy, "add"): + assert binary.numpy.add[int].commutes_to is binary.numpy.add[int] + assert binary.numpy.add[int].is_commutative + if shouldhave(binary.numpy, "less") and shouldhave(binary.numpy, "greater"): + assert binary.numpy.less[int].commutes_to is binary.numpy.greater[int] + assert not binary.numpy.less[int].is_commutative # Stress test (this can create extra semirings) names = dir(semiring) @@ -1014,9 +1113,12 @@ def test_from_string(): assert unary.from_string("abs[float]") is unary.abs[float] assert binary.from_string("+") is binary.plus assert binary.from_string("-[int]") is binary.minus[int] - assert binary.from_string("true_divide") is binary.numpy.true_divide - assert binary.from_string("//") is binary.floordiv - assert binary.from_string("%") is binary.numpy.mod + if config["mapnumpy"] or shouldhave(binary.numpy, "true_divide"): + assert binary.from_string("true_divide") is binary.numpy.true_divide + if shouldhave(binary, "floordiv"): + assert binary.from_string("//") is binary.floordiv + if shouldhave(binary.numpy, "mod"): + assert binary.from_string("%") is binary.numpy.mod assert monoid.from_string("*[FP64]") is monoid.times["FP64"] assert semiring.from_string("min.plus") is semiring.min_plus assert semiring.from_string("min.+") is semiring.min_plus @@ -1053,6 +1155,7 @@ def test_from_string(): agg.from_string("bad_agg") +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_lazy_op(): UnaryOp.register_new("lazy", lambda x: x, lazy=True) # pragma: no branch (numba) @@ -1115,6 +1218,7 @@ def test_positional(): assert semiring.ss.any_secondj[int].is_positional +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_udt(): record_dtype = np.dtype([("x", np.bool_), ("y", np.float64)], align=True) @@ -1240,6 +1344,19 @@ def badfunc2(x, y): # pragma: no cover (numba) assert binary.first[udt, dtypes.INT8].type2 is dtypes.INT8 assert monoid.any[udt].type2 is udt + def _this_or_that(val, idx, _, thunk): # pragma: no cover (numba) + return val["x"] + + sel = SelectOp.register_anonymous(_this_or_that, is_udt=True) + sel[udt] + assert udt in sel + result = v.select(sel, 0).new() + assert result.nvals == 0 + assert result.dtype == v.dtype + result = w.select(sel, 0).new() + assert result.nvals == 3 + assert result.isequal(w) + def test_dir(): for mod in [unary, binary, monoid, semiring, op]: @@ -1280,6 +1397,7 @@ def test_binaryop_commute_exists(): raise AssertionError("Missing binaryops: " + ", ".join(sorted(missing))) +@pytest.mark.skipif("not supports_udfs") def test_binom(): v = Vector.from_coo([0, 1, 2], [3, 4, 5]) result = v.apply(binary.binom, 2).new() @@ -1334,14 +1452,28 @@ def test_deprecated(): gb.agg.argmin +@pytest.mark.slow def test_is_idempotent(): assert monoid.min.is_idempotent assert monoid.max[int].is_idempotent assert monoid.lor.is_idempotent assert monoid.band.is_idempotent - assert monoid.numpy.gcd.is_idempotent + if shouldhave(monoid.numpy, "gcd"): + assert monoid.numpy.gcd.is_idempotent assert not monoid.plus.is_idempotent assert not monoid.times[float].is_idempotent - assert not monoid.numpy.equal.is_idempotent + if config["mapnumpy"] or shouldhave(monoid.numpy, "equal"): + assert not monoid.numpy.equal.is_idempotent with pytest.raises(AttributeError): binary.min.is_idempotent + + +def test_ops_have_ss(): + modules = [unary, binary, monoid, semiring, indexunary, select, op] + if suitesparse: + for mod in modules: + assert mod.ss is not None + else: + for mod in modules: + with pytest.raises(AttributeError): + mod.ss diff --git a/graphblas/tests/test_operator_types.py b/graphblas/tests/test_operator_types.py index 522b42ad2..027f02fcc 100644 --- a/graphblas/tests/test_operator_types.py +++ b/graphblas/tests/test_operator_types.py @@ -2,6 +2,7 @@ from collections import defaultdict from graphblas import backend, binary, dtypes, monoid, semiring, unary +from graphblas.core import _supports_udfs as supports_udfs from graphblas.core import operator from graphblas.dtypes import ( BOOL, @@ -83,6 +84,11 @@ BINARY[(ALL, POS)] = { "firsti", "firsti1", "firstj", "firstj1", "secondi", "secondi1", "secondj", "secondj1", } +if not supports_udfs: + udfs = {"absfirst", "abssecond", "binom", "floordiv", "rfloordiv", "rpow"} + for funcnames in BINARY.values(): + funcnames -= udfs + BINARY = {key: val for key, val in BINARY.items() if val} MONOID = { (UINT, UINT): {"band", "bor", "bxnor", "bxor"}, diff --git a/graphblas/tests/test_pickle.py b/graphblas/tests/test_pickle.py index de2d9cfda..724f43d76 100644 --- a/graphblas/tests/test_pickle.py +++ b/graphblas/tests/test_pickle.py @@ -5,6 +5,7 @@ import pytest import graphblas as gb +from graphblas.core import _supports_udfs as supports_udfs # noqa: F401 suitesparse = gb.backend == "suitesparse" @@ -36,6 +37,7 @@ def extra(): return "" +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_deserialize(extra): path = Path(__file__).parent / f"pickle1{extra}.pkl" @@ -62,6 +64,7 @@ def test_deserialize(extra): assert d3["semiring_pickle"] is gb.semiring.semiring_pickle +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_serialize(): v = gb.Vector.from_coo([1], 2) @@ -232,6 +235,7 @@ def identity_par(z): return -z +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_serialize_parameterized(): # unary_pickle = gb.core.operator.UnaryOp.register_new( @@ -285,6 +289,7 @@ def test_serialize_parameterized(): pickle.loads(pkl) # TODO: check results +@pytest.mark.skipif("not supports_udfs") @pytest.mark.slow def test_deserialize_parameterized(extra): path = Path(__file__).parent / f"pickle2{extra}.pkl" @@ -295,6 +300,7 @@ def test_deserialize_parameterized(extra): pickle.load(f) # TODO: check results +@pytest.mark.skipif("not supports_udfs") def test_udt(extra): record_dtype = np.dtype([("x", np.bool_), ("y", np.int64)], align=True) udt = gb.dtypes.register_new("PickleUDT", record_dtype) diff --git a/graphblas/tests/test_scalar.py b/graphblas/tests/test_scalar.py index 6ee70311c..e93511914 100644 --- a/graphblas/tests/test_scalar.py +++ b/graphblas/tests/test_scalar.py @@ -12,7 +12,7 @@ from graphblas import backend, binary, dtypes, monoid, replace, select, unary from graphblas.exceptions import EmptyObject -from .conftest import autocompute, compute +from .conftest import autocompute, compute, pypy from graphblas import Matrix, Scalar, Vector # isort:skip (for dask-graphblas) @@ -50,7 +50,7 @@ def test_dup(s): s_empty = Scalar(dtypes.FP64) s_unempty = Scalar.from_value(0.0) if s_empty.is_cscalar: - # NumPy wraps around + # NumPy <2 wraps around; >=2 raises OverflowError uint_data = [ ("UINT8", 2**8 - 2), ("UINT16", 2**16 - 2), @@ -73,6 +73,10 @@ def test_dup(s): ("FP32", -2.5), *uint_data, ]: + if dtype.startswith("UINT") and s_empty.is_cscalar and not np.__version__.startswith("1."): + with pytest.raises(OverflowError, match="out of bounds for uint"): + s4.dup(dtype=dtype, name="s5") + continue s5 = s4.dup(dtype=dtype, name="s5") assert s5.dtype == dtype assert s5.value == val @@ -128,12 +132,14 @@ def test_equal(s): def test_casting(s): assert int(s) == 5 - assert type(int(s)) is int + assert isinstance(int(s), int) assert float(s) == 5.0 - assert type(float(s)) is float + assert isinstance(float(s), float) assert range(s) == range(5) + with pytest.raises(AttributeError, match="Scalar .* only .*__index__.*integral"): + range(s.dup(float)) assert complex(s) == complex(5) - assert type(complex(s)) is complex + assert isinstance(complex(s), complex) def test_truthy(s): @@ -209,12 +215,12 @@ def test_unsupported_ops(s): s[0] with pytest.raises(TypeError, match="does not support"): s[0] = 0 - with pytest.raises(TypeError, match="doesn't support"): + with pytest.raises(TypeError, match="doesn't support|does not support"): del s[0] def test_is_empty(s): - with pytest.raises(AttributeError, match="can't set attribute"): + with pytest.raises(AttributeError, match="can't set attribute|object has no setter"): s.is_empty = True @@ -226,7 +232,7 @@ def test_update(s): s << Scalar.from_value(3) assert s == 3 if s._is_cscalar: - with pytest.raises(TypeError, match="an integer is required"): + with pytest.raises(TypeError, match="an integer is required|expected integer"): s << Scalar.from_value(4.4) else: s << Scalar.from_value(4.4) @@ -248,7 +254,7 @@ def test_update(s): def test_not_hashable(s): with pytest.raises(TypeError, match="unhashable type"): - {s} + _ = {s} with pytest.raises(TypeError, match="unhashable type"): hash(s) @@ -358,14 +364,15 @@ def test_expr_is_like_scalar(s): } if s.is_cscalar: expected.add("_empty") - assert attrs - expr_attrs == expected, ( + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union"} + assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Scalar. You may need to " "add an entry to `scalar` set in `graphblas.core.automethods` " "and then run `python -m graphblas.core.automethods`. If you're messing with infix " "methods, then you may need to run `python -m graphblas.core.infixmethods`." ) - assert attrs - infix_attrs == expected - assert attrs - scalar_infix_attrs == expected + assert attrs - infix_attrs - ignore == expected + assert attrs - scalar_infix_attrs - ignore == expected # Make sure signatures actually match. `expr.dup` has `**opts` skip = {"__init__", "__repr__", "_repr_html_", "dup"} for expr in [v.inner(v), v @ v, t & t]: @@ -399,7 +406,8 @@ def test_index_expr_is_like_scalar(s): } if s.is_cscalar: expected.add("_empty") - assert attrs - expr_attrs == expected, ( + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union"} + assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Scalar. You may need to " "add an entry to `scalar` set in `graphblas.core.automethods` " "and then run `python -m graphblas.core.automethods`. If you're messing with infix " @@ -505,10 +513,10 @@ def test_scalar_expr(s): def test_sizeof(s): - if suitesparse or s._is_cscalar: + if (suitesparse or s._is_cscalar) and not pypy: assert 1 < sys.getsizeof(s) < 1000 else: - with pytest.raises(TypeError): + with pytest.raises(TypeError): # flakey coverage (why?!) sys.getsizeof(s) @@ -576,7 +584,7 @@ def test_record_from_dict(): def test_get(s): assert s.get() == 5 assert s.get("mittens") == 5 - assert type(compute(s.get())) is int + assert isinstance(compute(s.get()), int) s.clear() assert compute(s.get()) is None assert s.get("mittens") == "mittens" diff --git a/graphblas/tests/test_ss_utils.py b/graphblas/tests/test_ss_utils.py index d21f41f03..2df7ab939 100644 --- a/graphblas/tests/test_ss_utils.py +++ b/graphblas/tests/test_ss_utils.py @@ -4,6 +4,7 @@ import graphblas as gb from graphblas import Matrix, Vector, backend +from graphblas.exceptions import InvalidValue if backend != "suitesparse": pytest.skip("gb.ss and A.ss only available with suitesparse backend", allow_module_level=True) @@ -198,6 +199,11 @@ def test_about(): assert "library_name" in repr(about) +def test_openmp_enabled(): + # SuiteSparse:GraphBLAS without OpenMP enabled is very undesirable + assert gb.ss.about["openmp"] + + def test_global_config(): d = {} config = gb.ss.config @@ -226,6 +232,65 @@ def test_global_config(): else: with pytest.raises(ValueError, match="Unable to set default value for"): config[k] = None - with pytest.raises(ValueError, match="Wrong number"): - config["memory_pool"] = [1, 2] + # with pytest.raises(ValueError, match="Wrong number"): + # config["memory_pool"] = [1, 2] # No longer used assert "format" in repr(config) + + +@pytest.mark.skipif("gb.core.ss._IS_SSGB7") +def test_context(): + context = gb.ss.Context() + prev = dict(context) + context["chunk"] += 1 + context["nthreads"] += 1 + assert context["chunk"] == prev["chunk"] + 1 + assert context["nthreads"] == prev["nthreads"] + 1 + context2 = gb.ss.Context(stack=True) + assert context2 == context + context3 = gb.ss.Context(stack=False) + assert context3 == prev + context4 = gb.ss.Context( + chunk=context["chunk"] + 1, nthreads=context["nthreads"] + 1, stack=False + ) + assert context4["chunk"] == context["chunk"] + 1 + assert context4["nthreads"] == context["nthreads"] + 1 + assert context == context.dup() + assert context4 == context.dup(chunk=context["chunk"] + 1, nthreads=context["nthreads"] + 1) + assert context.dup(gpu_id=-1)["gpu_id"] == -1 + + context.engage() + assert gb.core.ss.context.threadlocal.context is context + with gb.ss.Context(nthreads=1) as ctx: + assert gb.core.ss.context.threadlocal.context is ctx + v = Vector(int, 5) + v(nthreads=2) << v + v + assert gb.core.ss.context.threadlocal.context is ctx + assert gb.core.ss.context.threadlocal.context is context + with pytest.raises(InvalidValue): + # Wait, why does this raise?! + ctx.disengage() + assert gb.core.ss.context.threadlocal.context is context + context.disengage() + assert gb.core.ss.context.threadlocal.context is gb.core.ss.context.global_context + assert context._prev_context is None + + # hackery + gb.core.ss.context.threadlocal.context = context + context.disengage() + context.disengage() + context.disengage() + assert gb.core.ss.context.threadlocal.context is gb.core.ss.context.global_context + + # Actually engaged, but not set in threadlocal + context._engage() + assert gb.core.ss.context.threadlocal.context is gb.core.ss.context.global_context + context.disengage() + + context.engage() + context._engage() + assert gb.core.ss.context.threadlocal.context is context + context.disengage() + + context._context = context # This is allowed to work with config + with pytest.raises(AttributeError, match="_context"): + context._context = ctx # This is not diff --git a/graphblas/tests/test_ssjit.py b/graphblas/tests/test_ssjit.py new file mode 100644 index 000000000..4cea0b563 --- /dev/null +++ b/graphblas/tests/test_ssjit.py @@ -0,0 +1,438 @@ +import os +import pathlib +import platform +import sys +import sysconfig + +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import graphblas as gb +from graphblas import backend, binary, dtypes, indexunary, select, unary +from graphblas.core import _supports_udfs as supports_udfs +from graphblas.core.ss import _IS_SSGB7 + +from .conftest import autocompute, burble + +from graphblas import Vector # isort:skip (for dask-graphblas) + +try: + import numba +except ImportError: + numba = None + +if backend != "suitesparse": + pytest.skip("not suitesparse backend", allow_module_level=True) + + +@pytest.fixture(scope="module", autouse=True) +def _setup_jit(): + """Set up the SuiteSparse:GraphBLAS JIT.""" + if _IS_SSGB7: + # SuiteSparse JIT was added in SSGB 8 + yield + return + + if not os.environ.get("GITHUB_ACTIONS"): + # Try to run the tests with defaults from sysconfig if not running in CI + prev = gb.ss.config["jit_c_control"] + cc = sysconfig.get_config_var("CC") + cflags = sysconfig.get_config_var("CFLAGS") + include = sysconfig.get_path("include") + libs = sysconfig.get_config_var("LIBS") + if not (cc is None or cflags is None or include is None or libs is None): + gb.ss.config["jit_c_control"] = "on" + gb.ss.config["jit_c_compiler_name"] = cc + gb.ss.config["jit_c_compiler_flags"] = f"{cflags} -I{include}" + gb.ss.config["jit_c_libraries"] = libs + else: + # Should we skip or try to run if sysconfig vars aren't set? + gb.ss.config["jit_c_control"] = "on" # "off" + try: + yield + finally: + gb.ss.config["jit_c_control"] = prev + return + + if ( + sys.platform == "darwin" + or sys.platform == "linux" + and "conda" not in gb.ss.config["jit_c_compiler_name"] + ): + # XXX TODO: tests for SuiteSparse JIT are not passing on linux when using wheels or on osx + # This should be understood and fixed! + gb.ss.config["jit_c_control"] = "off" + yield + return + + # Configuration values below were obtained from the output of the JIT config + # in CI, but with paths changed to use `{conda_prefix}` where appropriate. + conda_prefix = os.environ["CONDA_PREFIX"] + prev = gb.ss.config["jit_c_control"] + gb.ss.config["jit_c_control"] = "on" + if sys.platform == "linux": + gb.ss.config["jit_c_compiler_name"] = f"{conda_prefix}/bin/x86_64-conda-linux-gnu-cc" + gb.ss.config["jit_c_compiler_flags"] = ( + "-march=nocona -mtune=haswell -ftree-vectorize -fPIC -fstack-protector-strong " + f"-fno-plt -O2 -ffunction-sections -pipe -isystem {conda_prefix}/include -Wundef " + "-std=c11 -lm -Wno-pragmas -fexcess-precision=fast -fcx-limited-range " + "-fno-math-errno -fwrapv -O3 -DNDEBUG -fopenmp -fPIC" + ) + gb.ss.config["jit_c_linker_flags"] = ( + "-Wl,-O2 -Wl,--sort-common -Wl,--as-needed -Wl,-z,relro -Wl,-z,now " + "-Wl,--disable-new-dtags -Wl,--gc-sections -Wl,--allow-shlib-undefined " + f"-Wl,-rpath,{conda_prefix}/lib -Wl,-rpath-link,{conda_prefix}/lib " + f"-L{conda_prefix}/lib -shared" + ) + gb.ss.config["jit_c_libraries"] = ( + f"-lm -ldl {conda_prefix}/lib/libgomp.so " + f"{conda_prefix}/x86_64-conda-linux-gnu/sysroot/usr/lib/libpthread.so" + ) + gb.ss.config["jit_c_cmake_libs"] = ( + f"m;dl;{conda_prefix}/lib/libgomp.so;" + f"{conda_prefix}/x86_64-conda-linux-gnu/sysroot/usr/lib/libpthread.so" + ) + elif sys.platform == "darwin": + gb.ss.config["jit_c_compiler_name"] = f"{conda_prefix}/bin/clang" + gb.ss.config["jit_c_compiler_flags"] = ( + "-march=core2 -mtune=haswell -mssse3 -ftree-vectorize -fPIC -fPIE " + f"-fstack-protector-strong -O2 -pipe -isystem {conda_prefix}/include -DGBNCPUFEAT " + f"-Wno-pointer-sign -O3 -DNDEBUG -fopenmp=libomp -fPIC -arch {platform.machine()}" + ) + gb.ss.config["jit_c_linker_flags"] = ( + "-Wl,-pie -Wl,-headerpad_max_install_names -Wl,-dead_strip_dylibs " + f"-Wl,-rpath,{conda_prefix}/lib -L{conda_prefix}/lib -dynamiclib" + ) + gb.ss.config["jit_c_libraries"] = f"-lm -ldl {conda_prefix}/lib/libomp.dylib" + gb.ss.config["jit_c_cmake_libs"] = f"m;dl;{conda_prefix}/lib/libomp.dylib" + elif sys.platform == "win32": # pragma: no branch (sanity) + if "mingw" in gb.ss.config["jit_c_libraries"]: + # This probably means we're testing a `python-suitesparse-graphblas` wheel + # in a conda environment. This is not yet working. + gb.ss.config["jit_c_control"] = "off" + yield + return + + gb.ss.config["jit_c_compiler_name"] = f"{conda_prefix}/bin/cc" + gb.ss.config["jit_c_compiler_flags"] = ( + '/DWIN32 /D_WINDOWS -DGBNCPUFEAT /O2 -wd"4244" -wd"4146" -wd"4018" ' + '-wd"4996" -wd"4047" -wd"4554" /O2 /Ob2 /DNDEBUG -openmp' + ) + gb.ss.config["jit_c_linker_flags"] = "/machine:x64" + gb.ss.config["jit_c_libraries"] = "" + gb.ss.config["jit_c_cmake_libs"] = "" + + if not pathlib.Path(gb.ss.config["jit_c_compiler_name"]).exists(): + # Can't use the JIT if we don't have a compiler! + gb.ss.config["jit_c_control"] = "off" + yield + return + try: + yield + finally: + gb.ss.config["jit_c_control"] = prev + + +@pytest.fixture +def v(): + return Vector.from_coo([1, 3, 4, 6], [1, 1, 2, 0]) + + +@autocompute +def test_jit_udt(): + if _IS_SSGB7: + with pytest.raises(RuntimeError, match="JIT was added"): + dtypes.ss.register_new( + "myquaternion", "typedef struct { float x [4][4] ; int color ; } myquaternion ;" + ) + return + if gb.ss.config["jit_c_control"] == "off": + return + with burble(): + dtype = dtypes.ss.register_new( + "myquaternion", "typedef struct { float x [4][4] ; int color ; } myquaternion ;" + ) + assert not hasattr(dtypes, "myquaternion") + assert dtypes.ss.myquaternion is dtype + assert dtype.name == "myquaternion" + assert str(dtype) == "myquaternion" + assert dtype.gb_name is None + v = Vector(dtype, 2) + np_type = np.dtype([("x", "= thunk - select.register_new("ii", ii) - assert hasattr(indexunary, "ii") + def iin(n): + def inner(x, idx, _, thunk): # pragma: no cover (numba) + return idx // n >= thunk + + return inner + + select.register_new("ii", ii, lazy=True) + select.register_new("iin", iin, parameterized=True) + assert "ii" in dir(select) + assert "ii" in dir(indexunary) assert hasattr(select, "ii") + assert hasattr(indexunary, "ii") ii_apply = indexunary.register_anonymous(ii) expected = Vector.from_coo([1, 3, 4, 6], [False, False, True, True], size=7) result = ii_apply(v, 2).new() assert result.isequal(expected) + result = v.apply(indexunary.iin(2), 2).new() + assert result.isequal(expected) + result = v.apply(indexunary.register_anonymous(iin, parameterized=True)(2), 2).new() + assert result.isequal(expected) + ii_select = select.register_anonymous(ii) expected = Vector.from_coo([4, 6], [2, 0], size=7) result = ii_select(v, 2).new() assert result.isequal(expected) + result = v.select(select.iin(2), 2).new() + assert result.isequal(expected) + result = v.select(select.register_anonymous(iin, parameterized=True)(2), 2).new() + assert result.isequal(expected) delattr(indexunary, "ii") delattr(select, "ii") + delattr(indexunary, "iin") + delattr(select, "iin") + with pytest.raises(UdfParseError, match="Unable to parse function using Numba"): + indexunary.register_new("bad", lambda x, row, col, thunk: result) # pragma: no branch def test_reduce(v): @@ -920,6 +950,21 @@ def test_reduce_agg(v): assert s.is_empty +def test_reduce_agg_count_is_int64(v): + """Aggregators that count should default to INT64 return dtype.""" + assert v.dtype == dtypes.INT64 + res = v.reduce(agg.count).new() + assert res.dtype == dtypes.INT64 + assert res == 4 + res = v.dup(dtypes.INT8).reduce(agg.count).new() + assert res.dtype == dtypes.INT64 + assert res == 4 + # Allow return dtype to be specified + res = v.dup(dtypes.INT8).reduce(agg.count[dtypes.INT16]).new() + assert res.dtype == dtypes.INT16 + assert res == 4 + + @pytest.mark.skipif("not suitesparse") def test_reduce_agg_argminmax(v): assert v.reduce(agg.ss.argmin).new() == 6 @@ -971,10 +1016,10 @@ def test_reduce_agg_firstlast_index(v): def test_reduce_agg_empty(): v = Vector("UINT8", size=3) - for _attr, aggr in vars(agg).items(): + for attr, aggr in vars(agg).items(): if not isinstance(aggr, agg.Aggregator): continue - s = v.reduce(aggr).new() + s = v.reduce(aggr).new(name=attr) assert compute(s.value) is None @@ -1404,7 +1449,7 @@ def test_vector_index_with_scalar(): s0 = Scalar.from_value(0, dtype=dtype) w = v[[s1, s0]].new() assert w.isequal(expected) - for dtype in ["bool", "fp32", "fp64"] + ["fc32", "fc64"] if dtypes._supports_complex else []: + for dtype in ["bool", "fp32", "fp64"] + (["fc32", "fc64"] if dtypes._supports_complex else []): s = Scalar.from_value(1, dtype=dtype) with pytest.raises(TypeError, match="An integer is required for indexing"): v[s] @@ -1420,14 +1465,14 @@ def test_diag(v): expected = Matrix.from_coo(rows, cols, values, nrows=size, ncols=size, dtype=v.dtype) # Construct diagonal matrix A if suitesparse: - A = gb.ss.diag(v, k=k) + A = gb.ss.diag(v, k=k, nthreads=2) assert expected.isequal(A) A = v.diag(k) assert expected.isequal(A) # Extract diagonal from A if suitesparse: - w = gb.ss.diag(A, Scalar.from_value(k)) + w = gb.ss.diag(A, Scalar.from_value(k), nthreads=2) assert v.isequal(w) assert w.dtype == "INT64" @@ -1504,6 +1549,8 @@ def test_outer(v): @autocompute def test_auto(v): + from graphblas.core.infix import VectorEwiseMultExpr + v = v.dup(dtype=bool) expected = binary.land(v & v).new() assert 0 not in expected @@ -1551,16 +1598,26 @@ def test_auto(v): "__rand__", "__ror__", ]: + # print(type(expr).__name__, method) val1 = getattr(expected, method)(expected).new() - val2 = getattr(expected, method)(expr) - val3 = getattr(expr, method)(expected) - val4 = getattr(expr, method)(expr) - assert val1.isequal(val2) - assert val1.isequal(val3) - assert val1.isequal(val4) - assert val1.isequal(val2.new()) - assert val1.isequal(val3.new()) - assert val1.isequal(val4.new()) + if method in {"__or__", "__ror__"} and type(expr) is VectorEwiseMultExpr: + # Doing e.g. `plus(x & y | z)` isn't allowed--make user be explicit + with pytest.raises(TypeError): + val2 = getattr(expected, method)(expr) + with pytest.raises(TypeError): + val3 = getattr(expr, method)(expected) + with pytest.raises(TypeError): + val4 = getattr(expr, method)(expr) + else: + val2 = getattr(expected, method)(expr) + assert val1.isequal(val2) + assert val1.isequal(val2.new()) + val3 = getattr(expr, method)(expected) + assert val1.isequal(val3) + assert val1.isequal(val3.new()) + val4 = getattr(expr, method)(expr) + assert val1.isequal(val4) + assert val1.isequal(val4.new()) s1 = expected.reduce(monoid.lor).new() s2 = expr.reduce(monoid.lor) assert s1.isequal(s2.new()) @@ -1620,17 +1677,17 @@ def test_expr_is_like_vector(v): "from_dict", "from_pairs", "from_scalar", - "from_values", "resize", "update", } - assert attrs - expr_attrs == expected, ( + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union", "_inner", "_vxm"} + assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Vector. You may need to " "add an entry to `vector` or `matrix_vector` set in `graphblas.core.automethods` " "and then run `python -m graphblas.core.automethods`. If you're messing with infix " "methods, then you may need to run `python -m graphblas.core.infixmethods`." ) - assert attrs - infix_attrs == expected + assert attrs - infix_attrs - ignore == expected # Make sure signatures actually match skip = {"__init__", "__repr__", "_repr_html_"} for expr in [binary.times(w & w), w & w]: @@ -1669,10 +1726,10 @@ def test_index_expr_is_like_vector(v): "from_dict", "from_pairs", "from_scalar", - "from_values", "resize", } - assert attrs - expr_attrs == expected, ( + ignore = {"__sizeof__", "_ewise_add", "_ewise_mult", "_ewise_union", "_inner", "_vxm"} + assert attrs - expr_attrs - ignore == expected, ( "If you see this message, you probably added a method to Vector. You may need to " "add an entry to `vector` or `matrix_vector` set in `graphblas.core.automethods` " "and then run `python -m graphblas.core.automethods`. If you're messing with infix " @@ -1707,6 +1764,13 @@ def test_dup_expr(v): assert result.isequal(b) result = (b | b).dup(clear=True) assert result.isequal(b.dup(clear=True)) + result = v[:5].dup() + assert result.isequal(v[:5].new()) + if suitesparse: + result = v[:5].dup(nthreads=2) + assert result.isequal(v[:5].new()) + result = v[:5].dup(clear=True, nthreads=2) + assert result.isequal(Vector(v.dtype, size=5)) @pytest.mark.skipif("not suitesparse") @@ -1948,13 +2012,6 @@ def test_ss_split(v): assert x2.name == "split_1" -def test_deprecated(v): - with pytest.warns(DeprecationWarning): - v.to_values() - with pytest.warns(DeprecationWarning): - Vector.from_values([1], [2]) - - def test_ndim(A, v): assert v.ndim == 1 assert v.ewise_mult(v).ndim == 1 @@ -1963,7 +2020,7 @@ def test_ndim(A, v): def test_sizeof(v): - if suitesparse: + if suitesparse and not pypy: assert sys.getsizeof(v) > v.nvals * 16 else: with pytest.raises(TypeError): @@ -2006,6 +2063,7 @@ def test_delete_via_scalar(v): assert v.nvals == 0 +@pytest.mark.skipif("not supports_udfs") def test_udt(): record_dtype = np.dtype([("x", np.bool_), ("y", np.float64)], align=True) udt = dtypes.register_anonymous(record_dtype, "VectorUDT") @@ -2149,7 +2207,10 @@ def test_udt(): long_dtype = np.dtype([("x", np.bool_), ("y" * 1000, np.float64)], align=True) if suitesparse: - with pytest.warns(UserWarning, match="too large"): + if ss_version_major < 9: + with pytest.warns(UserWarning, match="too large"): + long_udt = dtypes.register_anonymous(long_dtype) + else: long_udt = dtypes.register_anonymous(long_dtype) else: # UDTs don't currently have a name in vanilla GraphBLAS @@ -2160,13 +2221,19 @@ def test_udt(): if suitesparse: vv = Vector.ss.deserialize(v.ss.serialize(), dtype=long_udt) assert v.isequal(vv, check_dtype=True) - with pytest.raises(SyntaxError): - # The size of the UDT name is limited + if ss_version_major < 9: + with pytest.raises(SyntaxError): + # The size of the UDT name is limited + Vector.ss.deserialize(v.ss.serialize()) + else: Vector.ss.deserialize(v.ss.serialize()) # May be able to look up non-anonymous dtypes by name if their names are too long named_long_dtype = np.dtype([("x", np.bool_), ("y" * 1000, np.float64)], align=False) if suitesparse: - with pytest.warns(UserWarning, match="too large"): + if ss_version_major < 9: + with pytest.warns(UserWarning, match="too large"): + named_long_udt = dtypes.register_new("LongUDT", named_long_dtype) + else: named_long_udt = dtypes.register_new("LongUDT", named_long_dtype) else: named_long_udt = dtypes.register_new("LongUDT", named_long_dtype) @@ -2214,7 +2281,7 @@ def test_ss_iteration(v): # This is what I would expect assert sorted(indices) == sorted(v.ss.iterkeys()) assert sorted(values) == sorted(v.ss.itervalues()) - assert sorted(zip(indices, values)) == sorted(v.ss.iteritems()) + assert sorted(zip(indices, values, strict=True)) == sorted(v.ss.iteritems()) N = indices.size v = Vector.ss.import_bitmap(**v.ss.export("bitmap")) @@ -2380,6 +2447,7 @@ def test_to_coo_subset(v): assert vals.dtype == np.int64 +@pytest.mark.skipif("not supports_udfs") def test_lambda_udfs(v): result = v.apply(lambda x: x + 1).new() # pragma: no branch (numba) expected = binary.plus(v, 1).new() @@ -2393,7 +2461,7 @@ def test_lambda_udfs(v): # with pytest.raises(TypeError): v.ewise_add(v, lambda x, y: x + y) # pragma: no branch (numba) with pytest.raises(TypeError): - v.inner(v, lambda x, y: x + y) + v.inner(v, lambda x, y: x + y) # pragma: no branch (numba) def test_get(v): @@ -2401,7 +2469,7 @@ def test_get(v): assert v.get(0, "mittens") == "mittens" assert v.get(1) == 1 assert v.get(1, "mittens") == 1 - assert type(compute(v.get(1))) is int + assert isinstance(compute(v.get(1)), int) with pytest.raises(ValueError, match="Bad index in Vector.get"): # Not yet supported v.get([0, 1]) @@ -2506,7 +2574,8 @@ def test_from_scalar(): v = Vector.from_scalar(1, dtype="INT64[2]", size=3) w = Vector("INT64[2]", size=3) w << [1, 1] - assert v.isequal(w, check_dtype=True) + if supports_udfs: + assert v.isequal(w, check_dtype=True) def test_to_dense_from_dense(): @@ -2559,9 +2628,10 @@ def test_ss_sort(v): v.ss.sort(binary.plus) # Like compactify - _, p = v.ss.sort(lambda x, y: False, values=False) # pragma: no branch (numba) - expected_p = Vector.from_coo([0, 1, 2, 3], [1, 3, 4, 6], size=7) - assert p.isequal(expected_p) + if supports_udfs: + _, p = v.ss.sort(lambda x, y: False, values=False) # pragma: no branch (numba) + expected_p = Vector.from_coo([0, 1, 2, 3], [1, 3, 4, 6], size=7) + assert p.isequal(expected_p) # reversed _, p = v.ss.sort(binary.pair[bool], values=False) expected_p = Vector.from_coo([0, 1, 2, 3], [6, 4, 3, 1], size=7) @@ -2569,6 +2639,7 @@ def test_ss_sort(v): w, p = v.ss.sort(monoid.lxor) # Weird, but user-defined monoids may not commute, so okay +@pytest.mark.skipif("not supports_udfs") def test_subarray_dtypes(): a = np.arange(3 * 4, dtype=np.int64).reshape(3, 4) v = Vector.from_coo([1, 3, 5], a) diff --git a/graphblas/unary/numpy.py b/graphblas/unary/numpy.py index 06086569d..0c36565ec 100644 --- a/graphblas/unary/numpy.py +++ b/graphblas/unary/numpy.py @@ -5,11 +5,13 @@ https://numba.readthedocs.io/en/stable/reference/numpysupported.html#math-operations """ + import numpy as _np from .. import _STANDARD_OPERATOR_NAMES from .. import config as _config from .. import unary as _unary +from ..core import _supports_udfs from ..dtypes import _supports_complex _delayed = {} @@ -119,7 +121,12 @@ def __dir__(): - return globals().keys() | _delayed.keys() | _unary_names + if not _supports_udfs and not _config.get("mapnumpy"): + return globals().keys() # FLAKY COVERAGE + attrs = _delayed.keys() | _unary_names + if not _supports_udfs: + attrs &= _numpy_to_graphblas.keys() + return attrs | globals().keys() def __getattr__(name): @@ -132,20 +139,23 @@ def __getattr__(name): raise AttributeError(f"module {__name__!r} has no attribute {name!r}") if _config.get("mapnumpy") and name in _numpy_to_graphblas: globals()[name] = getattr(_unary, _numpy_to_graphblas[name]) + elif not _supports_udfs: + raise AttributeError( + f"module {__name__!r} unable to compile UDF for {name!r}; " + "install numba for UDF support" + ) else: - from ..core import operator - numpy_func = getattr(_np, name) def func(x): # pragma: no cover (numba) return numpy_func(x) - operator.UnaryOp.register_new(f"numpy.{name}", func) + _unary.register_new(f"numpy.{name}", func) if name == "reciprocal": # numba doesn't match numpy here def reciprocal(x): # pragma: no cover (numba) return 1 if x else 0 - op = operator.UnaryOp.register_anonymous(reciprocal) + op = _unary.register_anonymous(reciprocal) globals()[name]._add(op["BOOL"]) return globals()[name] diff --git a/graphblas/unary/ss.py b/graphblas/unary/ss.py index e45cbcda0..e97784612 100644 --- a/graphblas/unary/ss.py +++ b/graphblas/unary/ss.py @@ -1,3 +1,6 @@ from ..core import operator +from ..core.ss.unary import register_new # noqa: F401 + +_delayed = {} del operator diff --git a/graphblas/viz.py b/graphblas/viz.py index 89010bc3d..b6d5f6ba7 100644 --- a/graphblas/viz.py +++ b/graphblas/viz.py @@ -35,8 +35,7 @@ def _get_imports(names, within): except ImportError: modname = _LAZY_IMPORTS[name].split(".")[0] raise ImportError(f"`{within}` requires {modname} to be installed") from None - finally: - globals()[name] = val + globals()[name] = val rv.append(val) if is_string: return rv[0] @@ -67,7 +66,7 @@ def draw(m): # pragma: no cover def spy(M, *, centered=False, show=True, figure=None, axes=None, figsize=None, **kwargs): - """Plot the sparsity pattern of a Matrix using `matplotlib.spy`. + """Plot the sparsity pattern of a Matrix using ``matplotlib.spy``. See: - https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.spy.html @@ -80,6 +79,7 @@ def spy(M, *, centered=False, show=True, figure=None, axes=None, figsize=None, * See Also -------- datashade + """ mpl, plt, ss = _get_imports(["mpl", "plt", "ss"], "spy") A = to_scipy_sparse(M, "coo") @@ -106,8 +106,8 @@ def spy(M, *, centered=False, show=True, figure=None, axes=None, figsize=None, * def datashade(M, agg="count", *, width=None, height=None, opts_kwargs=None, **kwargs): """Interactive plot of the sparsity pattern of a Matrix using hvplot and datashader. - The `datashader` library rasterizes large data into a 2d grid of pixels. Each pixel - may contain multiple data points, which are combined by an aggregator (`agg="count"`). + The ``datashader`` library rasterizes large data into a 2d grid of pixels. Each pixel + may contain multiple data points, which are combined by an aggregator (``agg="count"``). Common aggregators are "count", "sum", "mean", "min", and "max". See full list here: - https://datashader.org/api.html#reductions @@ -130,6 +130,7 @@ def datashade(M, agg="count", *, width=None, height=None, opts_kwargs=None, **kw See Also -------- spy + """ np, pd, bk, hv, hp, ds = _get_imports(["np", "pd", "bk", "hv", "hp", "ds"], "datashade") if "df" not in kwargs: @@ -182,30 +183,30 @@ def datashade(M, agg="count", *, width=None, height=None, opts_kwargs=None, **kw images.extend(image_row) return hv.Layout(images).cols(ncols) - kwds = dict( # noqa: C408 pylint: disable=use-dict-literal - x="col", - y="row", - c="val", - aggregator=agg, - frame_width=width, - frame_height=height, - cmap="fire", - cnorm="eq_hist", - xlim=(0, M.ncols), - ylim=(0, M.nrows), - rasterize=True, - flip_yaxis=True, - hover=True, - xlabel="", - ylabel="", - data_aspect=1, - x_sampling=1, - y_sampling=1, - xaxis="top", - xformatter="%d", - yformatter="%d", - rot=60, - ) + kwds = { + "x": "col", + "y": "row", + "c": "val", + "aggregator": agg, + "frame_width": width, + "frame_height": height, + "cmap": "fire", + "cnorm": "eq_hist", + "xlim": (0, M.ncols), + "ylim": (0, M.nrows), + "rasterize": True, + "flip_yaxis": True, + "hover": True, + "xlabel": "", + "ylabel": "", + "data_aspect": 1, + "x_sampling": 1, + "y_sampling": 1, + "xaxis": "top", + "xformatter": "%d", + "yformatter": "%d", + "rot": 60, + } # Only show axes on outer-most plots if kwargs.pop("_col", 0) != 0: kwds["yaxis"] = None diff --git a/notebooks/Example B.1 -- Level BFS.ipynb b/notebooks/Example B.1 -- Level BFS.ipynb index cdee2f2fc..e96d6d7d5 100644 --- a/notebooks/Example B.1 -- Level BFS.ipynb +++ b/notebooks/Example B.1 -- Level BFS.ipynb @@ -6,7 +6,7 @@ "source": [ "## Example B.1 Level Breadth-first Search\n", "\n", - "Examples come from http://people.eecs.berkeley.edu/~aydin/GraphBLAS_API_C_v13.pdf" + "Examples come from https://people.eecs.berkeley.edu/~aydin/GraphBLAS_API_C_v13.pdf" ] }, { diff --git a/notebooks/Example B.3 -- Parent BFS.ipynb b/notebooks/Example B.3 -- Parent BFS.ipynb index d1fbd82c5..d3c7c761f 100644 --- a/notebooks/Example B.3 -- Parent BFS.ipynb +++ b/notebooks/Example B.3 -- Parent BFS.ipynb @@ -6,7 +6,7 @@ "source": [ "## Example B.3 Parent Breadth-first Search\n", "\n", - "Examples come from http://people.eecs.berkeley.edu/~aydin/GraphBLAS_API_C_v13.pdf" + "Examples come from https://people.eecs.berkeley.edu/~aydin/GraphBLAS_API_C_v13.pdf" ] }, { diff --git a/notebooks/logos_and_colors.ipynb b/notebooks/logos_and_colors.ipynb new file mode 100644 index 000000000..7b64a2208 --- /dev/null +++ b/notebooks/logos_and_colors.ipynb @@ -0,0 +1,1467 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1ade2e62-38f4-4017-a0d3-e09f8587c376", + "metadata": {}, + "source": [ + "# Logos and Color Palette of Python-graphblas\n", + "\n", + "To create a minimal environment to run this notebook:\n", + "```bash\n", + "$ conda create -n drawsvg -c conda-forge drawsvg cairosvg scipy jupyter\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "bf42676c-190a-4803-a567-09e0ed260d6a", + "metadata": {}, + "outputs": [], + "source": [ + "import drawsvg as draw\n", + "import numpy as np\n", + "from scipy.spatial.transform import Rotation" + ] + }, + { + "cell_type": "markdown", + "id": "876a6128-94e4-4fb0-938d-0980a2033701", + "metadata": {}, + "source": [ + "## Define color palette" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "786f9c9e-d999-4286-bf79-009ca1681604", + "metadata": {}, + "outputs": [], + "source": [ + "# primary\n", + "blue = \"#409DC1\"\n", + "orange = \"#FF8552\"\n", + "dark_gray = \"#39393A\"\n", + "light_gray = \"#C3C3C7\"\n", + "\n", + "# Neutral, light/dark compatible\n", + "medium_gray = \"#848487\"\n", + "\n", + "# secondary\n", + "light_blue = \"#81B7CC\"\n", + "light_orange = \"#FFBB9E\"\n", + "red = \"#6D213C\"\n", + "light_red = \"#BA708A\"\n", + "green = \"#85FFC7\"\n", + "\n", + "french_rose = \"#FA4B88\" # ;)" + ] + }, + { + "cell_type": "markdown", + "id": "adb66550-f1e8-4846-a12a-e178fe801295", + "metadata": {}, + "source": [ + "## Display color palette" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "983b0cb8-db8b-4ad0-ad5a-36975d59289e", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "Primary\n", + "\n", + "#409DC1\n", + "\n", + "#FF8552\n", + "\n", + "#39393A\n", + "\n", + "#C3C3C7\n", + "\n", + "#848487\n", + "Secondary\n", + "\n", + "#81B7CC\n", + "\n", + "#FFBB9E\n", + "\n", + "#6D213C\n", + "\n", + "#BA708A\n", + "\n", + "#85FFC7\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d = draw.Drawing(750, 500, origin=\"center\")\n", + "d.append(\n", + " draw.Rectangle(-375, -250, 750, 500, fill=\"white\")\n", + ") # Add `stroke=\"black\"` border to see boundaries for testing\n", + "\n", + "dy = 25\n", + "dx = 0\n", + "w = h = 150\n", + "b = 25\n", + "x = -400 + 62.5 + dx\n", + "y = -200 + dy\n", + "\n", + "d.draw(\n", + " draw.Text(\n", + " \"Primary\",\n", + " x=x + 1.5 * (b + w) + w / 2,\n", + " y=y - b,\n", + " font_size=1.5 * b,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Arial\",\n", + " fill=\"black\",\n", + " )\n", + ")\n", + "d.draw(draw.Rectangle(x, y, w, h, fill=blue))\n", + "d.draw(\n", + " draw.Text(\n", + " blue.upper(),\n", + " x=x + w / 2,\n", + " y=y + h - b,\n", + " font_size=b,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Arial\",\n", + " fill=\"black\",\n", + " )\n", + ")\n", + "d.draw(draw.Rectangle(x + b + w, y, w, h, fill=orange))\n", + "d.draw(\n", + " draw.Text(\n", + " orange.upper(),\n", + " x=x + (b + w) + w / 2,\n", + " y=y + h - b,\n", + " font_size=b,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Arial\",\n", + " fill=\"black\",\n", + " )\n", + ")\n", + "d.draw(draw.Rectangle(x + 2 * (b + w), y, w, h, fill=dark_gray))\n", + "d.draw(\n", + " draw.Text(\n", + " dark_gray.upper(),\n", + " x=x + 2 * (b + w) + w / 2,\n", + " y=y + h - b,\n", + " font_size=b,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Arial\",\n", + " fill=\"white\",\n", + " )\n", + ")\n", + "d.draw(draw.Rectangle(x + 3 * (b + w), y, w, h, fill=light_gray))\n", + "d.draw(\n", + " draw.Text(\n", + " light_gray.upper(),\n", + " x=x + 3 * (b + w) + w / 2,\n", + " y=y + h - b,\n", + " font_size=b,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Arial\",\n", + " fill=\"black\",\n", + " )\n", + ")\n", + "\n", + "d.draw(draw.Rectangle(x, -25 + dy, 675, 45, fill=medium_gray))\n", + "d.draw(\n", + " draw.Text(\n", + " medium_gray.upper(),\n", + " x=x + 675 / 2,\n", + " y=-25 + 30 + dy,\n", + " font_size=22.5,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Arial\",\n", + " fill=\"black\",\n", + " )\n", + ")\n", + "\n", + "y = 40 + dy\n", + "w = h = 119\n", + "b = 20\n", + "d.draw(\n", + " draw.Text(\n", + " \"Secondary\",\n", + " x=x + 2 * (b + w) + w / 2,\n", + " y=y + h + 2 * b,\n", + " font_size=1.5 * b,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Arial\",\n", + " fill=\"black\",\n", + " )\n", + ")\n", + "d.draw(draw.Rectangle(x, y, w, h, fill=light_blue))\n", + "d.draw(\n", + " draw.Text(\n", + " light_blue.upper(),\n", + " x=x + w / 2,\n", + " y=y + h - b,\n", + " font_size=b,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Arial\",\n", + " fill=\"black\",\n", + " )\n", + ")\n", + "d.draw(draw.Rectangle(x + b + w, y, w, h, fill=light_orange))\n", + "d.draw(\n", + " draw.Text(\n", + " light_orange.upper(),\n", + " x=x + (b + w) + w / 2,\n", + " y=y + h - b,\n", + " font_size=b,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Arial\",\n", + " fill=\"black\",\n", + " )\n", + ")\n", + "d.draw(draw.Rectangle(x + 2 * (b + w), y, w, h, fill=red))\n", + "d.draw(\n", + " draw.Text(\n", + " red.upper(),\n", + " x=x + 2 * (b + w) + w / 2,\n", + " y=y + h - b,\n", + " font_size=b,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Arial\",\n", + " fill=\"white\",\n", + " )\n", + ")\n", + "d.draw(draw.Rectangle(x + 3 * (b + w), y, w, h, fill=light_red))\n", + "d.draw(\n", + " draw.Text(\n", + " light_red.upper(),\n", + " x=x + 3 * (b + w) + w / 2,\n", + " y=y + h - b,\n", + " font_size=b,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Arial\",\n", + " fill=\"black\",\n", + " )\n", + ")\n", + "d.draw(draw.Rectangle(x + 4 * (b + w), y, w, h, fill=green))\n", + "d.draw(\n", + " draw.Text(\n", + " green.upper(),\n", + " x=x + 4 * (b + w) + w / 2,\n", + " y=y + h - b,\n", + " font_size=b,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Arial\",\n", + " fill=\"black\",\n", + " )\n", + ")\n", + "\n", + "color_palette = d\n", + "d" + ] + }, + { + "cell_type": "markdown", + "id": "e59c3941-c73b-455e-88f2-4b3aae228421", + "metadata": {}, + "source": [ + "## Display color wheel" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c27e8ef2-04f2-4752-9c3b-cf297a0c87a5", + "metadata": {}, + "outputs": [], + "source": [ + "def create_color_wheel(color_wheel):\n", + " d = draw.Drawing(300, 300, origin=\"center\")\n", + " theta = np.pi / 3\n", + "\n", + " angle = 0\n", + " for i, color in enumerate(color_wheel):\n", + " angle = i * np.pi / 3\n", + " clip = draw.ClipPath()\n", + " if i == 5:\n", + " angle_offset = theta\n", + " else:\n", + " angle_offset = theta * 1.05\n", + " clip.append(\n", + " draw.Lines(\n", + " 0,\n", + " 0,\n", + " 300 * np.sin(angle),\n", + " 300 * np.cos(angle),\n", + " 300 * np.sin(angle + angle_offset),\n", + " 300 * np.cos(angle + angle_offset),\n", + " close=True,\n", + " )\n", + " )\n", + " if i == 0:\n", + " clip = None\n", + " d.append(draw.Circle(0, 0, 145, fill=color, clip_path=clip))\n", + "\n", + " angle = 3 * theta\n", + " for i, color in enumerate(color_wheel):\n", + " angle = ((i + 3) % 6) * np.pi / 3\n", + " clip = draw.ClipPath()\n", + " if i == 5:\n", + " angle_offset = theta\n", + " else:\n", + " angle_offset = theta * 1.05\n", + " clip.append(\n", + " draw.Lines(\n", + " 0,\n", + " 0,\n", + " 300 * np.sin(angle),\n", + " 300 * np.cos(angle),\n", + " 300 * np.sin(angle + angle_offset),\n", + " 300 * np.cos(angle + angle_offset),\n", + " close=True,\n", + " )\n", + " )\n", + " if i == 0:\n", + " clip = None\n", + " d.append(draw.Circle(0, 0, 105, fill=color, clip_path=clip))\n", + "\n", + " angle = theta\n", + " for i, color in enumerate(color_wheel):\n", + " angle = ((i + 1) % 6) * np.pi / 3\n", + " clip = draw.ClipPath()\n", + " if i == 5:\n", + " angle_offset = theta\n", + " else:\n", + " angle_offset = theta * 1.05\n", + " clip.append(\n", + " draw.Lines(\n", + " 0,\n", + " 0,\n", + " 300 * np.sin(angle),\n", + " 300 * np.cos(angle),\n", + " 300 * np.sin(angle + angle_offset),\n", + " 300 * np.cos(angle + angle_offset),\n", + " close=True,\n", + " )\n", + " )\n", + " if i == 0:\n", + " clip = None\n", + " d.append(draw.Circle(0, 0, 65, fill=color, clip_path=clip))\n", + "\n", + " d.append(draw.Circle(0, 0, 25, fill=medium_gray))\n", + " return d" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2564bf63-8293-4828-8e38-d00a3b96b067", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Standard\n", + "standard_wheel = create_color_wheel(\n", + " [\n", + " blue,\n", + " light_gray,\n", + " light_blue,\n", + " dark_gray,\n", + " orange,\n", + " light_orange,\n", + " ]\n", + ")\n", + "standard_wheel" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7a500a39-4114-49bb-aa19-912c6a8a8d95", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# High contrast\n", + "high_wheel = create_color_wheel(\n", + " [\n", + " light_gray,\n", + " blue,\n", + " green,\n", + " dark_gray,\n", + " orange,\n", + " red,\n", + " ]\n", + ")\n", + "high_wheel" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8f404efe-2b88-4bdf-9102-2e6ad9389ca3", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Low contrast\n", + "low_wheel = create_color_wheel(\n", + " [\n", + " green,\n", + " light_red,\n", + " orange,\n", + " light_blue,\n", + " light_orange,\n", + " blue,\n", + " ]\n", + ")\n", + "low_wheel" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "fd913698-ea45-4219-8003-0fd30124d091", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Warm :)\n", + "warm_wheel = create_color_wheel(\n", + " [\n", + " light_gray, # or dark_gray\n", + " light_red,\n", + " french_rose, # ;)\n", + " red,\n", + " orange,\n", + " light_orange,\n", + " ]\n", + ")\n", + "warm_wheel" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c7a3a5e6-4be4-4def-9687-00d1e3f80375", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Cool\n", + "cool_wheel = create_color_wheel(\n", + " [\n", + " light_blue,\n", + " light_gray,\n", + " blue,\n", + " light_red,\n", + " green,\n", + " dark_gray,\n", + " ]\n", + ")\n", + "cool_wheel" + ] + }, + { + "cell_type": "markdown", + "id": "343256c8-35a7-4c89-aa60-c6bf60930c09", + "metadata": {}, + "source": [ + "## Create logos" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "7855cd3f-8155-4d11-9730-b6041578e112", + "metadata": {}, + "outputs": [], + "source": [ + "default_angles = [\n", + " 180, # Don't modify this\n", + " 30, # How much of the \"left face\" to see\n", + " 22.5, # How much of the \"top face\" to see\n", + "]\n", + "R = Rotation.from_euler(\"ZYX\", default_angles, degrees=True).as_matrix()\n", + "\n", + "gcube = np.array(\n", + " [\n", + " [-1, 1, -1],\n", + " [-1, 1, 1],\n", + " [1, 1, 1],\n", + " [-1, -1, 1],\n", + " [1, -1, 1],\n", + " [1, 0, 1],\n", + " [0, 0, 1],\n", + " ]\n", + ")\n", + "gcube_major = gcube[:5] # Big circles\n", + "gcube_minor = gcube[5:] # Small circles\n", + "lines = np.array(\n", + " [\n", + " [gcube[1], gcube[0]],\n", + " ]\n", + ")\n", + "Gpath = np.array(\n", + " [\n", + " gcube[2],\n", + " gcube[1],\n", + " gcube[3],\n", + " gcube[4],\n", + " gcube[5],\n", + " gcube[6],\n", + " ]\n", + ")\n", + "\n", + "\n", + "def create_logo(\n", + " *,\n", + " bracket_color=None,\n", + " bg_color=None,\n", + " edge_color=None,\n", + " edge_width=8,\n", + " edge_border_color=\"white\",\n", + " edge_border_width=16,\n", + " node_color=None,\n", + " large_node_width=16,\n", + " small_node_width=8,\n", + " node_border_color=\"white\",\n", + " node_stroke_width=4,\n", + " large_border=True,\n", + " g_color=None,\n", + " angles=None,\n", + "):\n", + " if angles is None:\n", + " angles = default_angles\n", + " if edge_color is None:\n", + " edge_color = blue\n", + " if bracket_color is None:\n", + " bracket_color = edge_color\n", + " if node_color is None:\n", + " node_color = orange\n", + " if g_color is None:\n", + " g_color = edge_color\n", + "\n", + " d = draw.Drawing(190, 190, origin=\"center\")\n", + " if bg_color:\n", + " d.append(\n", + " draw.Rectangle(-95, -95, 190, 190, fill=bg_color)\n", + " ) # Add `stroke=\"black\"` border to see boundaries for testing\n", + "\n", + " scale = 40\n", + " dx = 0\n", + " dy = -2\n", + "\n", + " if edge_border_width:\n", + " # Add white border around lines\n", + " d.append(\n", + " draw.Lines(\n", + " *(((Gpath @ R) * scale)[:, :2] * [-1, 1]).ravel().tolist(),\n", + " fill=\"none\",\n", + " stroke=edge_border_color,\n", + " stroke_width=edge_border_width,\n", + " )\n", + " )\n", + " for (x0, y0, z0), (x1, y1, z1) in ((lines @ R) * scale).tolist():\n", + " x0 = -x0\n", + " x1 = -x1 # Just live with this\n", + " d.append(\n", + " draw.Line(\n", + " x0 + dx,\n", + " y0 + dy,\n", + " x1 + dx,\n", + " y1 + dy,\n", + " stroke=edge_border_color,\n", + " stroke_width=edge_border_width,\n", + " )\n", + " )\n", + "\n", + " # Add edges\n", + " d.append(\n", + " draw.Lines(\n", + " *(((Gpath @ R) * scale)[:, :2] * [-1, 1]).ravel().tolist(),\n", + " fill=\"none\",\n", + " stroke=g_color,\n", + " stroke_width=edge_width,\n", + " )\n", + " )\n", + " for (x0, y0, z0), (x1, y1, z1) in ((lines @ R) * scale).tolist():\n", + " x0 = -x0\n", + " x1 = -x1\n", + " d.append(\n", + " draw.Line(\n", + " x0 + dx, y0 + dy, x1 + dx, y1 + dy, stroke=edge_color, stroke_width=edge_width\n", + " )\n", + " )\n", + "\n", + " # Add vertices\n", + " for x, y, z in ((gcube_major @ R) * scale).tolist():\n", + " x = -x\n", + " d.append(\n", + " draw.Circle(\n", + " x + dx,\n", + " y + dy,\n", + " large_node_width,\n", + " fill=node_color,\n", + " stroke=node_border_color,\n", + " stroke_width=node_stroke_width if large_border else 0,\n", + " )\n", + " )\n", + " for x, y, z in ((gcube_minor @ R) * scale).tolist():\n", + " x = -x\n", + " d.append(\n", + " draw.Circle(\n", + " x + dx,\n", + " y + dy,\n", + " small_node_width,\n", + " fill=node_color,\n", + " stroke=node_border_color,\n", + " stroke_width=node_stroke_width,\n", + " )\n", + " )\n", + "\n", + " # Add brackets\n", + " d.append(\n", + " draw.Text(\n", + " \"[\",\n", + " x=-85,\n", + " y=52,\n", + " font_size=214,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Courier New\",\n", + " fill=bracket_color,\n", + " )\n", + " )\n", + " d.append(\n", + " draw.Text(\n", + " \"]\",\n", + " x=85,\n", + " y=52,\n", + " font_size=214,\n", + " text_anchor=\"middle\",\n", + " font_family=\"Courier New\",\n", + " fill=bracket_color,\n", + " )\n", + " )\n", + "\n", + " return d" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "4325e0b8-dbbb-4219-a2b3-4d9cdee2bdc8", + "metadata": {}, + "outputs": [], + "source": [ + "logo_defaults = dict(\n", + " bracket_color=blue,\n", + " edge_color=blue,\n", + " node_color=orange,\n", + " edge_border_width=0,\n", + " edge_width=12,\n", + " small_node_width=11,\n", + " large_node_width=17,\n", + " node_border_color=\"none\",\n", + " node_stroke_width=0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "f886df89-b3b5-4671-bcc0-98e8705feb5a", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "[\n", + "]\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "create_logo(bg_color=\"white\", **logo_defaults)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "68e01137-55e3-4973-bf97-4fcd36c8c662", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "[\n", + "]\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "create_logo(bg_color=\"black\", **logo_defaults)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b1d5e928-16c5-4377-aee1-1489ab45efc8", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "[\n", + "]\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Transparent background\n", + "logo = create_logo(**logo_defaults)\n", + "logo" + ] + }, + { + "cell_type": "markdown", + "id": "b187c131-d337-4a7b-ab54-80ebe0f48ab4", + "metadata": {}, + "source": [ + "## Alternatives with gray brackets\n", + "### Background-agnostic (works with light and dark mode)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "acca9b2e-2f54-4b86-9a33-2c57502f6160", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "[\n", + "]\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "medium_logo = create_logo(**{**logo_defaults, \"bracket_color\": medium_gray})\n", + "create_logo(bg_color=\"white\", **{**logo_defaults, \"bracket_color\": medium_gray})" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f5d0086d-b50e-49eb-9aae-b0953cdc0045", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "[\n", + "]\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "create_logo(bg_color=\"black\", **{**logo_defaults, \"bracket_color\": medium_gray})" + ] + }, + { + "cell_type": "markdown", + "id": "c4dce89d-e34c-4190-a068-7e78cdeea745", + "metadata": {}, + "source": [ + "### For light mode" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "27137343-141a-422e-abd6-123af3416ea4", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "[\n", + "]\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "light_logo = create_logo(**{**logo_defaults, \"bracket_color\": dark_gray})\n", + "create_logo(bg_color=\"white\", **{**logo_defaults, \"bracket_color\": dark_gray})" + ] + }, + { + "cell_type": "markdown", + "id": "8a70b0f7-c3c4-44ae-af09-8992400f362e", + "metadata": {}, + "source": [ + "### For dark mode" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "3ab9bb40-d7a8-4788-9971-54a5779d284d", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "[\n", + "]\n", + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dark_logo = create_logo(**{**logo_defaults, \"bracket_color\": light_gray})\n", + "create_logo(bg_color=\"black\", **{**logo_defaults, \"bracket_color\": light_gray})" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "d53046c1-8cbb-47fa-a88b-4d98958df26b", + "metadata": {}, + "outputs": [], + "source": [ + "if False:\n", + " logo.save_svg(\"python-graphblas-logo.svg\")\n", + " light_logo.save_svg(\"python-graphblas-logo-light.svg\")\n", + " medium_logo.save_svg(\"python-graphblas-logo-medium.svg\")\n", + " dark_logo.save_svg(\"python-graphblas-logo-dark.svg\")\n", + " color_palette.save_svg(\"color-palette.svg\")\n", + " standard_wheel.save_svg(\"color-wheel.svg\")\n", + " high_wheel.save_svg(\"color-wheel-high.svg\")\n", + " low_wheel.save_svg(\"color-wheel-low.svg\")\n", + " warm_wheel.save_svg(\"color-wheel-warm.svg\")\n", + " cool_wheel.save_svg(\"color-wheel-cool.svg\")" + ] + }, + { + "cell_type": "markdown", + "id": "51093fab-600b-47d7-9809-fa0f16e7246f", + "metadata": {}, + "source": [ + "### *NOTE: The font in the SVG files should be converted to paths, because not all systems have Courier New*\n", + "Also, SVG files can be minified here: https://vecta.io/nano" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index adbf2d5b0..1bad95118 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,65 +1,68 @@ [build-system] build-backend = "setuptools.build_meta" -requires = [ - "setuptools >=64", - "setuptools-git-versioning", -] +requires = ["setuptools >=64", "setuptools-git-versioning"] [project] name = "python-graphblas" dynamic = ["version"] description = "Python library for GraphBLAS: high-performance sparse linear algebra for scalable graph analytics" readme = "README.md" -requires-python = ">=3.8" -license = {file = "LICENSE"} +requires-python = ">=3.10" +license = { file = "LICENSE" } authors = [ - {name = "Erik Welch"}, - {name = "Jim Kitchen"}, + { name = "Erik Welch", email = "erik.n.welch@gmail.com" }, + { name = "Jim Kitchen" }, + { name = "Python-graphblas contributors" }, ] maintainers = [ - {name = "Erik Welch", email = "erik.n.welch@gmail.com"}, - {name = "Jim Kitchen", email = "jim22k@gmail.com"}, + { name = "Erik Welch", email = "erik.n.welch@gmail.com" }, + { name = "Jim Kitchen", email = "jim22k@gmail.com" }, + { name = "Sultan Orazbayev", email = "contact@econpoint.com" }, ] keywords = [ - "graphblas", - "graph", - "sparse", - "matrix", - "lagraph", - "suitesparse", - "Networks", - "Graph Theory", - "Mathematics", - "network", - "discrete mathematics", - "math", + "graphblas", + "graph", + "sparse", + "matrix", + "lagraph", + "suitesparse", + "Networks", + "Graph Theory", + "Mathematics", + "network", + "discrete mathematics", + "math", ] classifiers = [ - "Development Status :: 5 - Production/Stable", - "License :: OSI Approved :: Apache Software License", - "Operating System :: MacOS :: MacOS X", - "Operating System :: POSIX :: Linux", - "Operating System :: Microsoft :: Windows", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3 :: Only", - "Intended Audience :: Developers", - "Intended Audience :: Other Audience", - "Intended Audience :: Science/Research", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Information Analysis", - "Topic :: Scientific/Engineering :: Mathematics", - "Topic :: Software Development :: Libraries :: Python Modules", + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Operating System :: Microsoft :: Windows", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3 :: Only", + "Intended Audience :: Developers", + "Intended Audience :: Other Audience", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Information Analysis", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development :: Libraries :: Python Modules", ] dependencies = [ - "suitesparse-graphblas >=7.4.0.0, <7.5", - "numpy >=1.21", - "numba >=0.55", - "donfig >=0.6", - "pyyaml >=5.4", + "numpy >=1.23", + "donfig >=0.6", + "pyyaml >=5.4", + # These won't be installed by default after 2024.3.0 + # once pep-771 is supported: https://peps.python.org/pep-0771/ + # Use e.g. "python-graphblas[suitesparse]" or "python-graphblas[default]" instead + "suitesparse-graphblas >=7.4.0.0, <10", + "numba >=0.55; python_version<'3.14'", # make optional where numba is not supported ] [project.urls] @@ -69,33 +72,41 @@ repository = "https://github.com/python-graphblas/python-graphblas" changelog = "https://github.com/python-graphblas/python-graphblas/releases" [project.optional-dependencies] -repr = [ - "pandas >=1.2", +suitesparse = ["suitesparse-graphblas >=7.4.0.0, <10"] +networkx = ["networkx >=2.8"] +numba = ["numba >=0.55"] +pandas = ["pandas >=1.5"] +scipy = ["scipy >=1.9"] +suitesparse-udf = [ # udf requires numba + "python-graphblas[suitesparse,numba]", ] +repr = ["python-graphblas[pandas]"] io = [ - "networkx >=2.8", - "scipy >=1.8", - "awkward >=1.9", - "sparse >=0.12", + "python-graphblas[networkx,scipy]", + "python-graphblas[numba]; python_version<'3.14'", + "awkward >=2.0", + "sparse >=0.14; python_version<'3.13'", # make optional, b/c sparse needs numba + "fast-matrix-market >=1.4.5; python_version<'3.13'", # py3.13 not supported yet ] -viz = [ - "matplotlib >=3.5", +viz = ["python-graphblas[networkx,scipy]", "matplotlib >=3.6"] +datashade = [ # datashade requires numba + "python-graphblas[numba,pandas,scipy]", + "datashader >=0.14", + "hvplot >=0.8", ] test = [ - "pytest", - "packaging", - "pandas >=1.2", - "scipy >=1.8", + "python-graphblas[suitesparse,pandas,scipy]", + "packaging >=21", + "pytest >=6.2", + "tomli >=1", +] +default = [ + "python-graphblas[suitesparse,pandas,scipy]", + "python-graphblas[numba]; python_version<'3.14'", # make optional where numba is not supported ] -complete = [ - "pandas >=1.2", - "networkx >=2.8", - "scipy >=1.8", - "awkward >=1.9", - "sparse >=0.12", - "matplotlib >=3.5", - "pytest", - "packaging", +all = [ + "python-graphblas[default,io,viz,test]", + "python-graphblas[datashade]; python_version<'3.14'", # make optional, b/c datashade needs numba ] [tool.setuptools] @@ -104,19 +115,22 @@ complete = [ # $ find graphblas/ -name __init__.py -print | sort | sed -e 's/\/__init__.py//g' -e 's/\//./g' # $ python -c 'import tomli ; [print(x) for x in sorted(tomli.load(open("pyproject.toml", "rb"))["tool"]["setuptools"]["packages"])]' packages = [ - "graphblas", - "graphblas.agg", - "graphblas.binary", - "graphblas.core", - "graphblas.core.ss", - "graphblas.indexunary", - "graphblas.monoid", - "graphblas.op", - "graphblas.semiring", - "graphblas.select", - "graphblas.ss", - "graphblas.tests", - "graphblas.unary", + "graphblas", + "graphblas.agg", + "graphblas.binary", + "graphblas.core", + "graphblas.core.operator", + "graphblas.core.ss", + "graphblas.dtypes", + "graphblas.indexunary", + "graphblas.io", + "graphblas.monoid", + "graphblas.op", + "graphblas.semiring", + "graphblas.select", + "graphblas.ss", + "graphblas.tests", + "graphblas.unary", ] [tool.setuptools-git-versioning] @@ -126,7 +140,7 @@ dirty_template = "{tag}+{ccount}.g{sha}.dirty" [tool.black] line-length = 100 -target-version = ["py38", "py39", "py310", "py311"] +target-version = ["py310", "py311", "py312", "py313"] [tool.isort] sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] @@ -138,31 +152,56 @@ known_first_party = "graphblas" line_length = 100 [tool.pytest.ini_options] +minversion = "6.0" testpaths = "graphblas/tests" -xfail_strict = true -markers = [ - "slow: Skipped unless --runslow passed", +xfail_strict = false # 2023-07-23: awkward and numpy 1.25 sometimes conflict +addopts = [ + "--strict-config", # Force error if config is mispelled + "--strict-markers", # Force error if marker is mispelled (must be defined in config) + "-ra", # Print summary of all fails/errors ] +markers = ["slow: Skipped unless --runslow passed"] +log_cli_level = "info" filterwarnings = [ - # See: https://docs.python.org/3/library/warnings.html#describing-warning-filters - # and: https://docs.pytest.org/en/7.2.x/how-to/capture-warnings.html#controlling-warnings - "error", - # MAINT: we can drop support for sparse <0.13 at any time - "ignore:`np.bool` is a deprecated alias:DeprecationWarning:sparse._umath", # sparse <0.13 - # sparse 0.14.0 (2022-02-24) began raising this warning; it has been reported and fixed upstream. - "ignore:coords should be an ndarray. This will raise a ValueError:DeprecationWarning:sparse._coo.core", - - # setuptools v67.3.0 deprecated `pkg_resources.declare_namespace` on 13 Feb 2023. See: - # https://setuptools.pypa.io/en/latest/history.html#v67-3-0 - # MAINT: check if this is still necessary in 2025 - "ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning:pkg_resources", + # See: https://docs.python.org/3/library/warnings.html#describing-warning-filters + # and: https://docs.pytest.org/en/7.2.x/how-to/capture-warnings.html#controlling-warnings + "error", + + # sparse 0.14.0 (2022-02-24) began raising this warning; it has been reported and fixed upstream. + "ignore:coords should be an ndarray. This will raise a ValueError:DeprecationWarning:sparse._coo.core", + + # setuptools v67.3.0 deprecated `pkg_resources.declare_namespace` on 13 Feb 2023. See: + # https://setuptools.pypa.io/en/latest/history.html#v67-3-0 + # MAINT: check if this is still necessary in 2025 + "ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning:pkg_resources", + + # This deprecation warning was added in setuptools v67.5.0 (8 Mar 2023). See: + # https://setuptools.pypa.io/en/latest/history.html#v67-5-0 + "ignore:pkg_resources is deprecated as an API:DeprecationWarning:", + + # sre_parse deprecated in 3.11; this is triggered by awkward 0.10 + "ignore:module 'sre_parse' is deprecated:DeprecationWarning:", + "ignore:module 'sre_constants' is deprecated:DeprecationWarning:", + + # numpy 1.25.0 (2023-06-17) deprecated `np.find_common_type`; many other dependencies use it. + # See if we can remove this filter in 2025. + "ignore:np.find_common_type is deprecated:DeprecationWarning:", + + # pypy gives this warning + "ignore:can't resolve package from __spec__ or __package__:ImportWarning:", + + # Python 3.12 introduced this deprecation, which is triggered by pandas 2.1.1 + "ignore:datetime.datetime.utcfromtimestamp:DeprecationWarning:dateutil", + + # Pandas 2.2 warns that pyarrow will become a required dependency in pandas 3.0 + "ignore:\\nPyarrow will become a required dependency of pandas:DeprecationWarning:", ] [tool.coverage.run] branch = true source = ["graphblas"] omit = [ - "graphblas/viz.py", # TODO: test and get coverage for viz.py + "graphblas/viz.py", # TODO: test and get coverage for viz.py ] [tool.coverage.report] @@ -172,9 +211,9 @@ fail_under = 0 skip_covered = true skip_empty = true exclude_lines = [ - "pragma: no cover", - "raise AssertionError", - "raise NotImplementedError", + "pragma: no cover", + "raise AssertionError", + "raise NotImplementedError", ] [tool.codespell] @@ -183,224 +222,278 @@ ignore-words-list = "coo,ba" [tool.ruff] # https://github.com/charliermarsh/ruff/ line-length = 100 -target-version = "py38" +target-version = "py310" + +[tool.ruff.format] +exclude = ["*.ipynb"] # Consider enabling auto-formatting of notebooks + +[tool.ruff.lint] +exclude = ["*.ipynb"] # Consider enabling auto-formatting of notebooks +unfixable = [ + "F841", # unused-variable (Note: can leave useless expression) + "B905", # zip-without-explicit-strict (Note: prefer `zip(x, y, strict=True)`) +] select = [ - # Have we enabled too many checks that they'll become a nuisance? We'll see... - "F", # pyflakes - "E", # pycodestyle Error - "W", # pycodestyle Warning - # "C90", # mccabe (Too strict, but maybe we should make things less complex) - # "I", # isort (Should we replace `isort` with this?) - "N", # pep8-naming - "D", # pydocstyle - "UP", # pyupgrade - "YTT", # flake8-2020 - # "ANN", # flake8-annotations (We don't use annotations yet) - "S", # bandit - # "BLE", # flake8-blind-except (Maybe consider) - # "FBT", # flake8-boolean-trap (Why?) - "B", # flake8-bugbear - "A", # flake8-builtins - "COM", # flake8-commas - "C4", # flake8-comprehensions - "DTZ", # flake8-datetimez - "T10", # flake8-debugger - # "DJ", # flake8-django (We don't use django) - # "EM", # flake8-errmsg (Perhaps nicer, but too much work) - "EXE", # flake8-executable - "ISC", # flake8-implicit-str-concat - # "ICN", # flake8-import-conventions (Doesn't allow "_" prefix such as `_np`) - "G", # flake8-logging-format - "INP", # flake8-no-pep420 - "PIE", # flake8-pie - "T20", # flake8-print - # "PYI", # flake8-pyi (We don't have stub files yet) - "PT", # flake8-pytest-style - "Q", # flake8-quotes - "RSE", # flake8-raise - "RET", # flake8-return - # "SLF", # flake8-self (We can use our own private variables--sheesh!) - "SIM", # flake8-simplify - # "TID", # flake8-tidy-imports (Rely on isort and our own judgement) - # "TCH", # flake8-type-checking (Note: figure out type checking later) - # "ARG", # flake8-unused-arguments (Sometimes helpful, but too strict) - "PTH", # flake8-use-pathlib (Often better, but not always) - # "ERA", # eradicate (We like code in comments!) - # "PD", # pandas-vet (Intended for scripts that use pandas, not libraries) - "PGH", # pygrep-hooks - "PL", # pylint - "PLC", # pylint Convention - "PLE", # pylint Error - "PLR", # pylint Refactor - "PLW", # pylint Warning - "TRY", # tryceratops - "NPY", # NumPy-specific rules - "RUF", # ruff-specific rules - "ALL", # Try new categories by default (making the above list unnecessary) + # Have we enabled too many checks that they'll become a nuisance? We'll see... + "F", # pyflakes + "E", # pycodestyle Error + "W", # pycodestyle Warning + # "C90", # mccabe (Too strict, but maybe we should make things less complex) + # "I", # isort (Should we replace `isort` with this?) + "N", # pep8-naming + "D", # pydocstyle + "UP", # pyupgrade + "YTT", # flake8-2020 + # "ANN", # flake8-annotations (We don't use annotations yet) + "S", # bandit + # "BLE", # flake8-blind-except (Maybe consider) + # "FBT", # flake8-boolean-trap (Why?) + "B", # flake8-bugbear + "A", # flake8-builtins + "COM", # flake8-commas + "C4", # flake8-comprehensions + "DTZ", # flake8-datetimez + "T10", # flake8-debugger + # "DJ", # flake8-django (We don't use django) + # "EM", # flake8-errmsg (Perhaps nicer, but too much work) + "EXE", # flake8-executable + "ISC", # flake8-implicit-str-concat + # "ICN", # flake8-import-conventions (Doesn't allow "_" prefix such as `_np`) + "G", # flake8-logging-format + "INP", # flake8-no-pep420 + "PIE", # flake8-pie + "T20", # flake8-print + # "PYI", # flake8-pyi (We don't have stub files yet) + "PT", # flake8-pytest-style + "Q", # flake8-quotes + "RSE", # flake8-raise + "RET", # flake8-return + # "SLF", # flake8-self (We can use our own private variables--sheesh!) + "SIM", # flake8-simplify + # "TID", # flake8-tidy-imports (Rely on isort and our own judgement) + # "TCH", # flake8-type-checking (Note: figure out type checking later) + # "ARG", # flake8-unused-arguments (Sometimes helpful, but too strict) + "PTH", # flake8-use-pathlib (Often better, but not always) + # "ERA", # eradicate (We like code in comments!) + # "PD", # pandas-vet (Intended for scripts that use pandas, not libraries) + "PGH", # pygrep-hooks + "PL", # pylint + "PLC", # pylint Convention + "PLE", # pylint Error + "PLR", # pylint Refactor + "PLW", # pylint Warning + "TRY", # tryceratops + "NPY", # NumPy-specific rules + "RUF", # ruff-specific rules + "ALL", # Try new categories by default (making the above list unnecessary) ] external = [ - # noqa codes that ruff doesn't know about: https://github.com/charliermarsh/ruff#external + # noqa codes that ruff doesn't know about: https://github.com/charliermarsh/ruff#external + "F811", ] ignore = [ - # Would be nice to fix these - "D100", # Missing docstring in public module - "D101", # Missing docstring in public class - "D102", # Missing docstring in public method - "D103", # Missing docstring in public function - "D104", # Missing docstring in public package - "D105", # Missing docstring in magic method - # "D107", # Missing docstring in `__init__` - "D205", # 1 blank line required between summary line and description - "D401", # First line of docstring should be in imperative mood: - # "D417", # Missing argument description in the docstring: - "PLE0605", # Invalid format for `__all__`, must be `tuple` or `list` (Note: broken in v0.0.237) - - # Maybe consider - # "SIM300", # Yoda conditions are discouraged, use ... instead (Note: we're not this picky) - # "SIM401", # Use dict.get ... instead of if-else-block (Note: if-else better for coverage and sometimes clearer) - "TRY004", # Prefer `TypeError` exception for invalid type (Note: good advice, but not worth the nuisance) - "TRY200", # Use `raise from` to specify exception cause (Note: sometimes okay to raise original exception) - - # Intentionally ignored - "COM812", # Trailing comma missing - "D203", # 1 blank line required before class docstring (Note: conflicts with D211, which is preferred) - "D400", # First line should end with a period (Note: prefer D415, which also allows "?" and "!") - "N801", # Class name ... should use CapWords convention (Note:we have a few exceptions to this) - "N802", # Function name ... should be lowercase - "N803", # Argument name ... should be lowercase (Maybe okay--except in tests) - "N806", # Variable ... in function should be lowercase - "N807", # Function name should not start and end with `__` - "N818", # Exception name ... should be named with an Error suffix (Note: good advice) - "PLR0911", # Too many return statements - "PLR0912", # Too many branches - "PLR0913", # Too many arguments to function call - "PLR0915", # Too many statements - "PLR2004", # Magic number used in comparison, consider replacing magic with a constant variable - "PLW2901", # Outer for loop variable ... overwritten by inner assignment target (Note: good advice, but too strict) - "RET502", # Do not implicitly `return None` in function able to return non-`None` value - "RET503", # Missing explicit `return` at the end of function able to return non-`None` value - "RET504", # Unnecessary variable assignment before `return` statement - "S110", # `try`-`except`-`pass` detected, consider logging the exception (Note: good advice, but we don't log) - "S112", # `try`-`except`-`continue` detected, consider logging the exception (Note: good advice, but we don't log) - "SIM102", # Use a single `if` statement instead of nested `if` statements (Note: often necessary) - "SIM105", # Use contextlib.suppress(...) instead of try-except-pass (Note: try-except-pass is much faster) - "SIM108", # Use ternary operator ... instead of if-else-block (Note: if-else better for coverage and sometimes clearer) - "TRY003", # Avoid specifying long messages outside the exception class (Note: why?) - - # Ignored categories - "C90", # mccabe (Too strict, but maybe we should make things less complex) - "I", # isort (Should we replace `isort` with this?) - "ANN", # flake8-annotations (We don't use annotations yet) - "BLE", # flake8-blind-except (Maybe consider) - "FBT", # flake8-boolean-trap (Why?) - "DJ", # flake8-django (We don't use django) - "EM", # flake8-errmsg (Perhaps nicer, but too much work) - "ICN", # flake8-import-conventions (Doesn't allow "_" prefix such as `_np`) - "PYI", # flake8-pyi (We don't have stub files yet) - "SLF", # flake8-self (We can use our own private variables--sheesh!) - "TID", # flake8-tidy-imports (Rely on isort and our own judgement) - "TCH", # flake8-type-checking (Note: figure out type checking later) - "ARG", # flake8-unused-arguments (Sometimes helpful, but too strict) - "ERA", # eradicate (We like code in comments!) - "PD", # pandas-vet (Intended for scripts that use pandas, not libraries) + # Would be nice to fix these + "D100", # Missing docstring in public module + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + "D107", # Missing docstring in `__init__` + "D205", # 1 blank line required between summary line and description + "D401", # First line of docstring should be in imperative mood: + "D417", # D417 Missing argument description in the docstring for ...: ... + "PLE0605", # Invalid format for `__all__`, must be `tuple` or `list` (Note: broken in v0.0.237) + + # Maybe consider + # "SIM300", # Yoda conditions are discouraged, use ... instead (Note: we're not this picky) + # "SIM401", # Use dict.get ... instead of if-else-block (Note: if-else better for coverage and sometimes clearer) + "B904", # Use `raise from` to specify exception cause (Note: sometimes okay to raise original exception) + "TRY004", # Prefer `TypeError` exception for invalid type (Note: good advice, but not worth the nuisance) + "RUF012", # Mutable class attributes should be annotated with `typing.ClassVar` (Note: no annotations yet) + "RUF021", # parenthesize-chained-operators (Note: results don't look good yet) + "RUF023", # unsorted-dunder-slots (Note: maybe fine, but noisy changes) + "PERF401", # Use a list comprehension to create a transformed list (Note: poorly implemented atm) + + # Intentionally ignored + "COM812", # Trailing comma missing + "D203", # 1 blank line required before class docstring (Note: conflicts with D211, which is preferred) + "D213", # (Note: conflicts with D212, which is preferred) + "D400", # First line should end with a period (Note: prefer D415, which also allows "?" and "!") + "N801", # Class name ... should use CapWords convention (Note:we have a few exceptions to this) + "N802", # Function name ... should be lowercase + "N803", # Argument name ... should be lowercase (Maybe okay--except in tests) + "N806", # Variable ... in function should be lowercase + "N807", # Function name should not start and end with `__` + "N818", # Exception name ... should be named with an Error suffix (Note: good advice) + "PERF203", # `try`-`except` within a loop incurs performance overhead (Note: too strict) + "PLC0205", # Class `__slots__` should be a non-string iterable (Note: string is fine) + "PLR0124", # Name compared with itself, consider replacing `x == x` (Note: too strict) + "PLR0911", # Too many return statements + "PLR0912", # Too many branches + "PLR0913", # Too many arguments to function call + "PLR0915", # Too many statements + "PLR2004", # Magic number used in comparison, consider replacing magic with a constant variable + "PLW0603", # Using the global statement to update ... is discouraged (Note: yeah, discouraged, but too strict) + "PLW0642", # Reassigned `self` variable in instance method (Note: too strict for us) + "PLW2901", # Outer for loop variable ... overwritten by inner assignment target (Note: good advice, but too strict) + "RET502", # Do not implicitly `return None` in function able to return non-`None` value + "RET503", # Missing explicit `return` at the end of function able to return non-`None` value + "RET504", # Unnecessary variable assignment before `return` statement + "S110", # `try`-`except`-`pass` detected, consider logging the exception (Note: good advice, but we don't log) + "S112", # `try`-`except`-`continue` detected, consider logging the exception (Note: good advice, but we don't log) + "S603", # `subprocess` call: check for execution of untrusted input (Note: not important for us) + "S607", # Starting a process with a partial executable path (Note: not important for us) + "SIM102", # Use a single `if` statement instead of nested `if` statements (Note: often necessary) + "SIM105", # Use contextlib.suppress(...) instead of try-except-pass (Note: try-except-pass is much faster) + "SIM108", # Use ternary operator ... instead of if-else-block (Note: if-else better for coverage and sometimes clearer) + "TRY003", # Avoid specifying long messages outside the exception class (Note: why?) + "UP038", # Use `X | Y` in `isinstance` call instead of `(X, Y)` (Note: using `|` is slower atm) + + # Ignored categories + "C90", # mccabe (Too strict, but maybe we should make things less complex) + "I", # isort (Should we replace `isort` with this?) + "ANN", # flake8-annotations (We don't use annotations yet) + "BLE", # flake8-blind-except (Maybe consider) + "FBT", # flake8-boolean-trap (Why?) + "DJ", # flake8-django (We don't use django) + "EM", # flake8-errmsg (Perhaps nicer, but too much work) + "ICN", # flake8-import-conventions (Doesn't allow "_" prefix such as `_np`) + "PYI", # flake8-pyi (We don't have stub files yet) + "SLF", # flake8-self (We can use our own private variables--sheesh!) + "TID", # flake8-tidy-imports (Rely on isort and our own judgement) + "TCH", # flake8-type-checking (Note: figure out type checking later) + "ARG", # flake8-unused-arguments (Sometimes helpful, but too strict) + "TD", # flake8-todos (Maybe okay to add some of these) + "FIX", # flake8-fixme (like flake8-todos) + "ERA", # eradicate (We like code in comments!) + "PD", # pandas-vet (Intended for scripts that use pandas, not libraries) +] + +[tool.ruff.lint.per-file-ignores] +"graphblas/core/operator/__init__.py" = ["A005"] +"graphblas/io/__init__.py" = ["A005"] # shadows a standard-library module +"graphblas/core/operator/base.py" = ["S102"] # exec is used for UDF +"graphblas/core/ss/matrix.py" = [ + "NPY002", # numba doesn't support rng generator yet + "PLR1730", +] +"graphblas/core/ss/vector.py" = [ + "NPY002", # numba doesn't support rng generator yet +] +"graphblas/core/utils.py" = ["PLE0302"] # `__set__` is used as a property +"graphblas/ss/_core.py" = ["N999"] # We want _core.py to be underscopre +# Allow useless expressions, assert, pickle, RNG, print, no docstring, and yoda in tests +"graphblas/tests/*py" = [ + "B018", + "S101", + "S301", + "S311", + "T201", + "D103", + "D100", + "SIM300", +] +"graphblas/tests/test_formatting.py" = ["E501"] # Allow long lines +"graphblas/**/__init__.py" = [ + "F401", # Allow unused imports (w/o defining `__all__`) ] +"scripts/*.py" = ["INP001"] # Not a package +"scripts/create_pickle.py" = ["F403", "F405"] # Allow `from foo import *` +"docs/*.py" = ["INP001"] # Not a package -[tool.ruff.per-file-ignores] -"graphblas/core/operator.py" = ["S102"] # exec is used for UDF -"graphblas/core/ss/matrix.py" = ["NPY002"] # numba doesn't support rng generator yet -"graphblas/core/ss/vector.py" = ["NPY002"] # numba doesn't support rng generator yet -"graphblas/ss/_core.py" = ["N999"] # We want _core.py to be underscopre -"graphblas/tests/*py" = ["S101", "T201", "D103", "D100", "SIM300"] # Allow assert, print, no docstring, and yoda -"graphblas/tests/test_formatting.py" = ["E501"] # Allow long lines -"graphblas/**/__init__.py" = ["F401"] # Allow unused imports (w/o defining `__all__`) -"scripts/*.py" = ["INP001"] # Not a package -"scripts/create_pickle.py" = ["F403", "F405"] # Allow `from foo import *` -"docs/*.py" = ["INP001"] # Not a package - -[tool.ruff.flake8-builtins] + +[tool.ruff.lint.flake8-builtins] builtins-ignorelist = ["copyright", "format", "min", "max"] +builtins-allowed-modules = ["select"] -[tool.ruff.flake8-pytest-style] +[tool.ruff.lint.flake8-pytest-style] fixture-parentheses = false mark-parentheses = false -[tool.ruff.pydocstyle] +[tool.lint.ruff.pydocstyle] convention = "numpy" +[tool.bandit] +exclude_dirs = ["graphblas/tests", "scripts"] +skips = [ + "B110", # Try, Except, Pass detected. (Note: it would be nice to not have this pattern) +] + [tool.pylint.messages_control] # To run a single check, do: pylint graphblas --disable E,W,R,C,I --enable assignment-from-no-return max-line-length = 100 -py-version = "3.8" +py-version = "3.10" enable = ["I"] disable = [ - # Error - "assignment-from-no-return", - - # Warning - "arguments-differ", - "arguments-out-of-order", - "expression-not-assigned", - "fixme", - "global-statement", - "non-parent-init-called", - "redefined-builtin", - "redefined-outer-name", - "super-init-not-called", - "unbalanced-tuple-unpacking", - "unnecessary-lambda", - "unspecified-encoding", - "unused-argument", - "unused-variable", - - # Refactor - "cyclic-import", - "duplicate-code", - "inconsistent-return-statements", - "too-few-public-methods", - - # Convention - "missing-class-docstring", - "missing-function-docstring", - "missing-module-docstring", - "too-many-lines", - - # Intentionally turned off - # error - "class-variable-slots-conflict", - "invalid-unary-operand-type", - "no-member", - "no-name-in-module", - "not-an-iterable", - "too-many-function-args", - "unexpected-keyword-arg", - # warning - "broad-except", - "pointless-statement", - "protected-access", - "undefined-loop-variable", - "unused-import", - # refactor - "comparison-with-itself", - "too-many-arguments", - "too-many-boolean-expressions", - "too-many-branches", - "too-many-instance-attributes", - "too-many-locals", - "too-many-nested-blocks", - "too-many-public-methods", - "too-many-return-statements", - "too-many-statements", - # convention - "import-outside-toplevel", - "invalid-name", - "line-too-long", - "singleton-comparison", - "single-string-used-for-slots", - "unidiomatic-typecheck", - "unnecessary-dunder-call", - "wrong-import-order", - "wrong-import-position", - # informative - "locally-disabled", - "suppressed-message", + # Error + "assignment-from-no-return", + + # Warning + "arguments-differ", + "arguments-out-of-order", + "expression-not-assigned", + "fixme", + "global-statement", + "non-parent-init-called", + "redefined-builtin", + "redefined-outer-name", + "super-init-not-called", + "unbalanced-tuple-unpacking", + "unnecessary-lambda", + "unspecified-encoding", + "unused-argument", + "unused-variable", + + # Refactor + "cyclic-import", + "duplicate-code", + "inconsistent-return-statements", + "too-few-public-methods", + + # Convention + "missing-class-docstring", + "missing-function-docstring", + "missing-module-docstring", + "too-many-lines", + + # Intentionally turned off + # error + "class-variable-slots-conflict", + "invalid-unary-operand-type", + "no-member", + "no-name-in-module", + "not-an-iterable", + "too-many-function-args", + "unexpected-keyword-arg", + # warning + "broad-except", + "pointless-statement", + "protected-access", + "undefined-loop-variable", + "unused-import", + # refactor + "comparison-with-itself", + "too-many-arguments", + "too-many-boolean-expressions", + "too-many-branches", + "too-many-instance-attributes", + "too-many-locals", + "too-many-nested-blocks", + "too-many-public-methods", + "too-many-return-statements", + "too-many-statements", + # convention + "import-outside-toplevel", + "invalid-name", + "line-too-long", + "singleton-comparison", + "single-string-used-for-slots", + "unidiomatic-typecheck", + "unnecessary-dunder-call", + "wrong-import-order", + "wrong-import-position", + # informative + "locally-disabled", + "suppressed-message", ] diff --git a/scripts/check_versions.sh b/scripts/check_versions.sh index d42952cf0..5aa88e045 100755 --- a/scripts/check_versions.sh +++ b/scripts/check_versions.sh @@ -3,14 +3,15 @@ # Use, adjust, copy/paste, etc. as necessary to answer your questions. # This may be helpful when updating dependency versions in CI. # Tip: add `--json` for more information. -conda search 'numpy[channel=conda-forge]>=1.24.2' -conda search 'pandas[channel=conda-forge]>=1.5.3' -conda search 'scipy[channel=conda-forge]>=1.10.1' -conda search 'networkx[channel=conda-forge]>=3.0' -conda search 'awkward[channel=conda-forge]>=2.0.8' -conda search 'sparse[channel=conda-forge]>=0.14.0' -conda search 'numba[channel=conda-forge]>=0.56.4' -conda search 'pyyaml[channel=conda-forge]>=6.0' -conda search 'flake8-comprehensions[channel=conda-forge]>=3.10.1' -conda search 'flake8-bugbear[channel=conda-forge]>=23.2.13' -conda search 'flake8-simplify[channel=conda-forge]>=0.19.3' +conda search 'flake8-bugbear[channel=conda-forge]>=24.12.12' +conda search 'flake8-simplify[channel=conda-forge]>=0.21.0' +conda search 'numpy[channel=conda-forge]>=2.2.3' +conda search 'pandas[channel=conda-forge]>=2.2.3' +conda search 'scipy[channel=conda-forge]>=1.15.2' +conda search 'networkx[channel=conda-forge]>=3.4.2' +conda search 'awkward[channel=conda-forge]>=2.7.4' +conda search 'sparse[channel=conda-forge]>=0.15.5' +conda search 'fast_matrix_market[channel=conda-forge]>=1.7.6' +conda search 'numba[channel=conda-forge]>=0.61.0' +conda search 'pyyaml[channel=conda-forge]>=6.0.2' +# conda search 'python[channel=conda-forge]>=3.10 *pypy*' diff --git a/scripts/create_pickle.py b/scripts/create_pickle.py index 9ee672c41..10fe58630 100755 --- a/scripts/create_pickle.py +++ b/scripts/create_pickle.py @@ -6,7 +6,7 @@ """ import argparse import pickle -from pathlib import PurePath +from pathlib import Path import graphblas as gb from graphblas.tests.test_pickle import * @@ -158,7 +158,7 @@ def pickle3(filepath): extra = "-vanilla" else: extra = "" - path = PurePath(gb.tests.__file__).parent + path = Path(gb.tests.__file__).parent pickle1(path / f"pickle1{extra}.pkl") pickle2(path / f"pickle2{extra}.pkl") pickle3(path / f"pickle3{extra}.pkl") diff --git a/scripts/test_imports.sh b/scripts/test_imports.sh index c38e41d3e..6ce88c83e 100755 --- a/scripts/test_imports.sh +++ b/scripts/test_imports.sh @@ -3,7 +3,7 @@ # Make sure imports work. Also, this is a good way to measure import performance. if ! python -c "from graphblas import * ; Matrix" ; then exit 1 ; fi if ! python -c "from graphblas import agg" ; then exit 1 ; fi -if ! python -c "from graphblas.core import agg" ; then exit 1 ; fi +if ! python -c "from graphblas.core.operator import agg" ; then exit 1 ; fi if ! python -c "from graphblas.agg import count" ; then exit 1 ; fi if ! python -c "from graphblas.binary import plus" ; then exit 1 ; fi if ! python -c "from graphblas.indexunary import tril" ; then exit 1 ; fi @@ -13,14 +13,14 @@ if ! python -c "from graphblas.select import tril" ; then exit 1 ; fi if ! python -c "from graphblas.semiring import plus_times" ; then exit 1 ; fi if ! python -c "from graphblas.unary import exp" ; then exit 1 ; fi if ! (for attr in Matrix Scalar Vector Recorder agg binary dtypes exceptions \ - init io monoid op select semiring tests unary ss viz + init io monoid op select semiring tests unary ss viz MAX_SIZE do echo python -c \"from graphblas import $attr\" if ! python -c "from graphblas import $attr" then exit 1 fi done ) ; then exit 1 ; fi -if ! (for attr in agg base descriptor expr formatting ffi infix lib mask \ +if ! (for attr in base descriptor expr formatting ffi infix lib mask \ matrix operator scalar vector recorder automethods infixmethods slice ss do echo python -c \"from graphblas.core import $attr\" if ! python -c "from graphblas.core import $attr" @@ -44,7 +44,7 @@ if ! (for attr in agg binary binary.numpy dtypes exceptions io monoid monoid.num fi done ) ; then exit 1 ; fi -if ! (for attr in agg base descriptor expr formatting infix mask matrix \ +if ! (for attr in base descriptor expr formatting infix mask matrix \ operator scalar vector recorder automethods infixmethods slice ss do echo python -c \"import graphblas.core.$attr\" if ! python -c "import graphblas.core.$attr" @@ -60,3 +60,10 @@ if ! python -c "from graphblas import op ; op.plus" ; then exit 1 ; fi if ! python -c "from graphblas import select ; select.tril" ; then exit 1 ; fi if ! python -c "from graphblas import semiring ; semiring.plus_times" ; then exit 1 ; fi if ! python -c "from graphblas import unary ; unary.exp" ; then exit 1 ; fi +if ! (for attr in agg unary binary monoid semiring select indexunary base utils + do echo python -c \"import graphblas.core.operator.$attr\" + if ! python -c "import graphblas.core.operator.$attr" + then exit 1 + fi + done +) ; then exit 1 ; fi