diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 00000000..bde0bc1e --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,58 @@ +name: Build docs + +on: + pull_request: + branches: + - main + push: + branches: + - main + +jobs: + build-docs: + name: Build & Upload Artifact + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v3 + with: + python-version: "3.10" + + - uses: tlambert03/setup-qt-libs@v1 + + - name: Install Dependencies + run: | + python -m pip install --upgrade pip + python -m pip install ".[docs]" + sudo apt install graphviz --yes + + - name: Build Docs + uses: aganders3/headless-gui@v1 + with: + run: make html + working-directory: ./docs + + - name: Upload artifact + uses: actions/upload-artifact@v3 + with: + name: docs + path: docs/_build + + upload-docs: + name: Upload docs to GitHub pages + runs-on: ubuntu-latest + needs: build-docs + if: contains(github.ref, 'tags') + steps: + - uses: actions/checkout@v3 + - uses: actions/download-artifact@v3 + with: + name: docs + + - name: Push to GitHub pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: html + ssh-key: ${{ secrets.DEPLOY_KEY }} + repository-name: napari-matplotlib/napari-matplotlib.github.io diff --git a/.github/workflows/napari_hub_preview.yml b/.github/workflows/napari_hub_preview.yml index 416ceff5..c204ac45 100644 --- a/.github/workflows/napari_hub_preview.yml +++ b/.github/workflows/napari_hub_preview.yml @@ -13,9 +13,9 @@ jobs: steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: napari hub Preview Page Builder - uses: chanzuckerberg/napari-hub-preview-action@v0.1.5 + uses: chanzuckerberg/napari-hub-preview-action@v0.1 with: hub-ref: main diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index a8a31e4c..df170bdf 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -1,6 +1,3 @@ -# This workflows will upload a Python Package using Twine when a release is created -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries - name: tests on: @@ -13,14 +10,16 @@ on: branches: - main workflow_dispatch: + merge_group: jobs: test: name: ${{ matrix.platform }} py${{ matrix.python-version }} runs-on: ${{ matrix.platform }} strategy: + fail-fast: false matrix: - platform: [ubuntu-latest] + platform: [ubuntu-latest, macos-latest, windows-latest] python-version: ['3.8', '3.9', '3.10'] steps: @@ -50,8 +49,22 @@ jobs: - name: Test with tox run: python -m tox + - name: Upload pytest test results + uses: actions/upload-artifact@v3 + with: + name: pytest-results-${{ matrix.platform }} py${{ matrix.python-version }} + path: reports/ + # Use always() to always run this step to publish test results when there are test failures + if: ${{ always() }} + - name: Coverage - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v3 + # Don't run coverage on merge queue CI to avoid duplicating reports + # to codecov. See https://github.com/matplotlib/napari-matplotlib/issues/155 + if: github.event_name != 'merge_group' + with: + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: true deploy: # this will run when you have tagged a commit, starting with "v*" @@ -60,21 +73,24 @@ jobs: needs: [test] runs-on: ubuntu-latest if: contains(github.ref, 'tags') + environment: pypi-deploy + permissions: + id-token: write steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: python-version: "3.x" - - name: Install dependencies + - name: Install build run: | python -m pip install --upgrade pip - pip install -U setuptools setuptools_scm wheel twine build - - name: Build and publish - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} + python -m pip install --upgrade build + + - name: Build package run: | git tag python -m build . - twine upload dist/* + + - name: Publish package + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index 708dea14..90b59392 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,7 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ +report/ .coverage .coverage.* .cache @@ -84,3 +85,6 @@ venv/ # written by setuptools_scm **/_version.py + +# Generated by test runs +reports diff --git a/.napari/DESCRIPTION.md b/.napari/DESCRIPTION.md deleted file mode 100644 index 45776cd1..00000000 --- a/.napari/DESCRIPTION.md +++ /dev/null @@ -1,92 +0,0 @@ - - - - -The developer has not yet provided a napari-hub specific description. diff --git a/.napari/config.yml b/.napari/config.yml new file mode 100644 index 00000000..d24d3acb --- /dev/null +++ b/.napari/config.yml @@ -0,0 +1,8 @@ +labels: + ontology: EDAM-BIOIMAGING:alpha06 + terms: + - 2D Image + - 3D Image + - Plotting + - Slice rendering + - Scientific visualisation diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 216f555a..767aedf6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,43 +1,37 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v4.4.0 hooks: - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace + - repo: https://github.com/asottile/setup-cfg-fmt - rev: v1.20.1 + rev: v2.3.0 hooks: - id: setup-cfg-fmt - - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 - hooks: - - id: flake8 - additional_dependencies: [flake8-typing-imports>=1.9.0] - - repo: https://github.com/myint/autoflake - rev: v1.4 - hooks: - - id: autoflake - args: ["--in-place", "--remove-all-unused-imports"] - - repo: https://github.com/PyCQA/isort - rev: 5.10.1 - hooks: - - id: isort + - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.3.0 hooks: - id: black - - repo: https://github.com/asottile/pyupgrade - rev: v2.32.0 - hooks: - - id: pyupgrade - args: [--py38-plus, --keep-runtime-typing] + - repo: https://github.com/tlambert03/napari-plugin-checks - rev: v0.2.0 + rev: v0.3.0 hooks: - id: napari-plugin-checks + - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.910-1 + rev: v1.4.1 hooks: - id: mypy - args: ["--disallow-incomplete-defs", "--ignore-missing-imports"] + additional_dependencies: [numpy, matplotlib] + + - repo: https://github.com/charliermarsh/ruff-pre-commit + # Ruff version. + rev: 'v0.0.275' + hooks: + - id: ruff + +ci: + autofix_prs: false diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..797cee13 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,15 @@ +version: 2 + +python: + install: + - method: pip + path: . + extra_requirements: + - docs + +build: + os: ubuntu-22.04 + tools: + python: "3.10" + apt_packages: + - xvfb diff --git a/CHANGELOG.rst b/CHANGELOG.rst deleted file mode 100644 index fafd2e0b..00000000 --- a/CHANGELOG.rst +++ /dev/null @@ -1,12 +0,0 @@ -0.0.2 -===== - -New features ------------- -- `HistogramWidget` now shows individual histograms for RGB channels when - present. - - -Bug fixes ---------- -- `HistogramWidget` now works properly with 2D images. diff --git a/MANIFEST.in b/MANIFEST.in index 7ce16f9b..d625d95e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,6 @@ include LICENSE include README.md +recursive-include * *.mplstyle recursive-exclude * __pycache__ recursive-exclude * *.py[co] diff --git a/README.md b/README.md index de86a434..855c4991 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ # napari-matplotlib -[![License](https://img.shields.io/pypi/l/napari-matplotlib.svg?color=green)](https://github.com/dstansby/napari-matplotlib/raw/main/LICENSE) +[![License](https://img.shields.io/pypi/l/napari-matplotlib.svg?color=green)](https://github.com/matplotlib/napari-matplotlib/raw/main/LICENSE) [![PyPI](https://img.shields.io/pypi/v/napari-matplotlib.svg?color=green)](https://pypi.org/project/napari-matplotlib) [![Python Version](https://img.shields.io/pypi/pyversions/napari-matplotlib.svg?color=green)](https://python.org) -[![tests](https://github.com/dstansby/napari-matplotlib/workflows/tests/badge.svg)](https://github.com/dstansby/napari-matplotlib/actions) -[![codecov](https://codecov.io/gh/dstansby/napari-matplotlib/branch/main/graph/badge.svg)](https://codecov.io/gh/dstansby/napari-matplotlib) +[![tests](https://github.com/matplotlib/napari-matplotlib/workflows/tests/badge.svg)](https://github.com/matplotlib/napari-matplotlib/actions) +[![codecov](https://codecov.io/gh/matplotlib/napari-matplotlib/branch/main/graph/badge.svg)](https://codecov.io/gh/matplotlib/napari-matplotlib) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/matplotlib/pytest-mpl/master.svg)](https://results.pre-commit.ci/latest/github/matplotlib/pytest-mpl/master) [![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/napari-matplotlib)](https://napari-hub.org/plugins/napari-matplotlib) @@ -19,15 +19,15 @@ A plugin to create Matplotlib plots from napari layers ### `Slice` Plots 1D slices of data along a specified axis. -![](https://raw.githubusercontent.com/dstansby/napari-matplotlib/main/examples/slice.png) +![](https://raw.githubusercontent.com/matplotlib/napari-matplotlib/main/examples/slice.png) ### `Histogram` Plots histograms of individual image layers, or RGB histograms of an RGB image -![](https://raw.githubusercontent.com/dstansby/napari-matplotlib/main/examples/hist.png) +![](https://raw.githubusercontent.com/matplotlib/napari-matplotlib/main/examples/hist.png) ### `Scatter` Scatters the values of two similarly sized images layers against each other. -![](https://raw.githubusercontent.com/dstansby/napari-matplotlib/main/examples/scatter.png) +![](https://raw.githubusercontent.com/matplotlib/napari-matplotlib/main/examples/scatter.png) ## Installation @@ -39,7 +39,7 @@ You can install `napari-matplotlib` via [pip]: To install latest development version : - pip install git+https://github.com/dstansby/napari-matplotlib.git + pip install git+https://github.com/matplotlib/napari-matplotlib.git ## Contributing diff --git a/docs/Makefile b/docs/Makefile index b50c24dc..7bb83626 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,7 +3,7 @@ # You can set these variables from the command line, and also # from the environment for the first two. -SPHINXOPTS ?= -W +SPHINXOPTS ?= -W --keep-going SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build @@ -12,6 +12,11 @@ BUILDDIR = _build help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) +clean: + @echo Removing files created by sphinx-build + rm -rf $(BUILDDIR) + rm -rf $(SOURCEDIR)/api/ + .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new diff --git a/docs/_static/hist.png b/docs/_static/hist.png new file mode 100644 index 00000000..95850b71 Binary files /dev/null and b/docs/_static/hist.png differ diff --git a/docs/_static/logo.png b/docs/_static/logo.png new file mode 100644 index 00000000..92232b6c Binary files /dev/null and b/docs/_static/logo.png differ diff --git a/docs/api/napari_matplotlib.HistogramWidget.rst b/docs/api/napari_matplotlib.HistogramWidget.rst deleted file mode 100644 index 1a8aba80..00000000 --- a/docs/api/napari_matplotlib.HistogramWidget.rst +++ /dev/null @@ -1,19 +0,0 @@ -HistogramWidget -=============== - -.. currentmodule:: napari_matplotlib - -.. autoclass:: HistogramWidget - :show-inheritance: - - .. rubric:: Methods Summary - - .. autosummary:: - - ~HistogramWidget.hist_current_layer - ~HistogramWidget.update_layer - - .. rubric:: Methods Documentation - - .. automethod:: hist_current_layer - .. automethod:: update_layer diff --git a/docs/changelog.rst b/docs/changelog.rst new file mode 100644 index 00000000..64771d30 --- /dev/null +++ b/docs/changelog.rst @@ -0,0 +1,43 @@ +Changelog +========= + +1.0.0 +----- + +New features +~~~~~~~~~~~~ +- Added ``MPLWidget`` as a widget containing just a Matplotlib canvas + without any association with a napari viewer. +- Added text to each widget indicating how many layers need to be selected + for the widget to plot something. + +Visual improvements +~~~~~~~~~~~~~~~~~~~ +- The background of ``napari-matplotlib`` figures and axes is now transparent, and the text and axis colour respects the ``napari`` theme. +- The icons in the Matplotlib toolbar are now the same size as icons in the napari window. +- Custom style sheets can now be set to customise plots. See the user guide + for more information. + +Changes +~~~~~~~ +- The scatter widgets no longer use a LogNorm() for 2D histogram scaling. + This is to move the widget in line with the philosophy of using Matplotlib default + settings throughout ``napari-matplotlib``. This still leaves open the option of + adding the option to change the normalization in the future. If this is something + you would be interested in please open an issue at https://github.com/matplotlib/napari-matplotlib. +- Labels plotting with the features scatter widget no longer have underscores + replaced with spaces. +- ``NapariMPLWidget.update_layers()`` has been removed as it is intended to be + private API. Use ``NapariMPLWidget.on_update_layers`` instead to implement + funcitonality when layer selection is changed. +- The slice widget now only plots x-ticks at integer locations. + +Bug fixes +~~~~~~~~~ +- Importing ``napari-matplotlib`` no longer affects how plots are rendered in + Jupyter notebooks. + +Other +~~~~~ +- ``napari-matplotlib`` is now tested on macOS and Windows. +- Type annotations have been completed throughout the code base. diff --git a/docs/conf.py b/docs/conf.py index a6e60fa3..2517a59c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,7 +17,7 @@ # -- Project information ----------------------------------------------------- -project = "matplotlib-napari" +project = "napari-matplotlib" copyright = "2022, David Stansby" author = "David Stansby" @@ -41,13 +41,36 @@ "reset_modules": (qtgallery.reset_qapp,), } +qtgallery_conf = { + "xvfb_size": (640, 480), + "xvfb_color_depth": 24, + "xfvb_use_xauth": False, + "xfvb_extra_args": [], +} + numpydoc_show_class_members = False +automodapi_inheritance_diagram = True +inheritance_graph_attrs = {"rankdir": "TR"} intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), "napari": ("https://napari.org/", None), + "numpy": ("https://numpy.org/doc/stable/", None), "matplotlib": ("https://matplotlib.org/", None), + "PyQT6": ("https://www.riverbankcomputing.com/static/Docs/PyQt6/", None), } +nitpicky = True +# Can't work out how to link this properly using intersphinx and the PyQT6 docs. +# TODO: fix at some point +nitpick_ignore = [ + ("py:class", "PyQt5.QtWidgets.QWidget"), + ("py:class", "PyQt5.QtCore.QObject"), + ("py:class", "PyQt5.QtGui.QPaintDevice"), + ("py:class", "sip.simplewrapper"), + ("py:class", "sip.wrapper"), +] + # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -64,8 +87,23 @@ # a list of builtin themes. # html_theme = "pydata_sphinx_theme" +html_logo = "_static/logo.png" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] + +html_theme_options = { + "icon_links": [ + { + "name": "GitHub", + "url": "https://github.com/matplotlib/napari-matplotlib", + "icon": "fa-brands fa-square-github", + "type": "fontawesome", + } + ], + "logo": { + "text": "napari-matplotlib", + }, +} diff --git a/docs/index.rst b/docs/index.rst index 2cf4d2b4..0eb92b04 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,27 +1,22 @@ -matplotlib-napari +napari-matplotlib ================= -``matplotlib-napari`` is a ``napari`` plugin for generating ``Matplotlib`` plots -from one or more ``napari`` Layers. +A `napari `_ plugin for generating +`Matplotlib `_ plots from one or more ``napari`` Layers. -Design ------- -``matplotlib-napari`` contains a number of different ``napari`` widgets. Each -widget is designed to map one or more ``napari`` Layers on to a ``Matplotlib`` -plot. As an example, the `~.HistogramWidget` is used to +This plugin provides a number of different ``napari`` Widgets. Each +widget is designed to map one or more ``napari`` Layers on to a ``Matplotlib`` plot. +As an example, the `~.HistogramWidget` is used to map one or more Image layers on to a 1D histogram plot. +.. image:: _static/hist.png + :alt: Screenshot of the napari viewer showing the napari-matplotlib histogram widget on the right hand side of the screen. + .. toctree:: :maxdepth: 1 - :caption: Contents: + :hidden: + user_guide auto_examples/index + third_party api - - - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` + changelog diff --git a/docs/make.bat b/docs/make.bat index 32bb2452..4b78115a 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -7,6 +7,9 @@ REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) +if "%SPHINXOPTS%" == "" ( + set SPHINXOPTS="--keep-going" +) set SOURCEDIR=. set BUILDDIR=_build diff --git a/docs/third_party.rst b/docs/third_party.rst new file mode 100644 index 00000000..c09eb532 --- /dev/null +++ b/docs/third_party.rst @@ -0,0 +1,51 @@ +Third-party plugins +=================== +``napari-matplotlib`` provides a ready-to-go widget with a Matplotlib toolbar and figure to third party plugin developers. +This widget is customised to match the theme of the main napari window. + +The widget can be found at `napari_matplotlib.base.NapariMPLWidget`. +This class inherits from `QWidget `_. + +The recommended way to use `~napari_matplotlib.base.NapariMPLWidget` is inside a new widget, adding it to the layout. +This means you can add additional elements to your plugin layout alongside the Matplotlib figure. +Here's a short example: + +.. code-block:: python + + from qtpy.QtWidgets import QWidget + from napari_matplotlib.base import NapariMPLWidget + + class MyPlugin(QWidget): + def __init__(self, napari_viewer: napari.viewer.Viewer, parent=None): + super().__init__(parent=parent) + + # Any custom setup for your custom widget + ... + + # Set up the plot widget + plot_widget = NapariMPLWidget(napari_viewer, parent=self) + self.layout().addWidget(plot_widget) + +The following properties and methods are useful for working with the figure and any axes within the widget: + +- `~.BaseNapariMPLWidget.figure` provides access to the figure +- :meth:`~.BaseNapariMPLWidget.add_single_axes` adds a single axes to the figure, which can be accessed using the ``.axes`` attribute. + +Working with napari layers +-------------------------- +When either the layer selection or z-step in the napari viewer is changed +:meth:`~.NapariMPLWidget.clear` and :meth:`~.NapariMPLWidget.draw` are called +in turn. By default these do nothing, and are designed to be overriden by +plugins to automatically re-draw any figures within the widget. Plugins can +also override :meth:`~.NapariMPLWidget.on_update_layers` to do something when +the layer selection changes. This can be used to do something without clearing +or re-drawing any plots. + +Validating layer numbers and types +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +By default :meth:`~.NapariMPLWidget.draw` will be called when any number of any +type of napari layers are selected. The `~.NapariMPLWidget.n_layers_input` +and `~.NapariMPLWidget.input_layer_types` class variables can be overriden to +specify the number of selected napari layers and valid layer +types that are taken as input. If the number of selected layers and their +types do not match up with these class variables, no re-draw is called. diff --git a/docs/user_guide.rst b/docs/user_guide.rst new file mode 100644 index 00000000..0872e540 --- /dev/null +++ b/docs/user_guide.rst @@ -0,0 +1,49 @@ +User guide +========== + +Overview +-------- +``napari-matplotlib`` contains a number of different ``napari`` Widgets. Each +widget is designed to map one or more ``napari`` Layers on to a ``matplotlib`` plot. +As an example, the `~.HistogramWidget` is used to +map one or more Image layers on to a 1D histogram plot. + +The widgets split into two categories: + +Layer plotting +~~~~~~~~~~~~~~ +These widgets plot the data stored directly in napari layers. +Currently available are widgets to plot: + +- 1D histograms +- 2D scatter plots (switching to 2D histograms for a large number of points) +- 1D slice line plots + +To use these: + +1. Open the desired widget using the ``Plugins > napari-matplotlib`` menu in napari. +2. Select the required number of layers using the napari layers list in the bottom left-hand side of the window. + +Features plotting +~~~~~~~~~~~~~~~~~ +These widgets plot the data stored in the ``.features`` attribute of individual napari layers. +Currently available are: + +- 2D scatter plots of two features against each other. + +To use these: + +1. Open the desired widget using the ``Plugins > napari-matplotlib`` menu in napari. +2. Select a single layer that has a features table using the napari layers list in the bottom left-hand side of the window. +3. Use the drop down menu(s) under the Matplotlib figure to select the feature(s) to plot. + +Customising plots +----------------- +`Matplotlib style sheets `__ can be used to customise +the plots generated by ``napari-matplotlib``. +To use a custom style sheet: + +1. Save it as ``napari-matplotlib.mplstyle`` +2. Put it in the Matplotlib configuration directory. + The location of this directory varies on different computers, + and can be found by calling :func:`matplotlib.get_configdir()`. diff --git a/examples/README.rst b/examples/README.rst index bac945d5..ae12bff3 100644 --- a/examples/README.rst +++ b/examples/README.rst @@ -1,2 +1,2 @@ -Examples -======== +Gallery +======= diff --git a/examples/features_scatter.py b/examples/features_scatter.py index ac8580d7..f4ea7f66 100644 --- a/examples/features_scatter.py +++ b/examples/features_scatter.py @@ -1,9 +1,15 @@ +""" +Scattering features +=================== +""" + import napari import numpy as np +import numpy.typing as npt from skimage.measure import regionprops_table # make a test label image -label_image = np.zeros((100, 100), dtype=np.uint16) +label_image: npt.NDArray[np.uint16] = np.zeros((100, 100), dtype=np.uint16) label_image[10:20, 10:20] = 1 label_image[50:70, 50:70] = 2 diff --git a/pyproject.toml b/pyproject.toml index b1a57c9b..7c7dbbdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,13 +2,69 @@ requires = ["setuptools", "wheel", "setuptools_scm"] build-backend = "setuptools.build_meta" - [tool.setuptools_scm] write_to = "src/napari_matplotlib/_version.py" +[tool.pytest.ini_options] +qt_api = "pyqt6" +addopts = "--mpl" +filterwarnings = [ + "error", + # Coming from vispy + "ignore:distutils Version classes are deprecated:DeprecationWarning", + "ignore:`np.bool8` is a deprecated alias for `np.bool_`:DeprecationWarning", +] + [tool.black] line-length = 79 [tool.isort] profile = "black" line_length = 79 + +[tool.ruff] +target-version = "py38" +select = ["I", "UP", "F", "E", "W", "D"] +ignore = [ + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D200", # One-line docstring should fit on one line + "D205", # 1 blank line required between summary line and description + "D400", # First line should end with a period + "D401", # First line of docstring should be in imperative mood + +] +fix = true + +[tool.ruff.per-file-ignores] +"docs/*" = ["D"] +"examples/*" = ["D"] +"src/napari_matplotlib/tests/*" = ["D"] + +[tool.ruff.pydocstyle] +convention = "numpy" + +[tool.mypy] +python_version = "3.8" +# Block below are checks that form part of mypy 'strict' mode +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true +strict_equality = true +strict_concatenate = true +check_untyped_defs = true +disallow_subclassing_any = false # TODO: fix +disallow_untyped_decorators = true +disallow_any_generics = true +disallow_untyped_calls = true +disallow_incomplete_defs = true +disallow_untyped_defs = true +no_implicit_reexport = true +warn_return_any = false # TODO: fix +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = [ + "napari_matplotlib/tests/*", +] +disallow_untyped_defs = false diff --git a/setup.cfg b/setup.cfg index 919e1387..b03f446c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ url = https://github.com/matplotlib/napari-matplotlib author = David Stansby author_email = d.stansby@ucl.ac.uk license = BSD-3-Clause -license_file = LICENSE +license_files = LICENSE classifiers = Development Status :: 3 - Alpha Framework :: napari @@ -17,13 +17,10 @@ classifiers = Programming Language :: Python Programming Language :: Python :: 3 Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 Topic :: Software Development :: Testing project_urls = Bug Tracker = https://github.com/matplotlib/napari-matplotlib/issues - Documentation = https://github.com/matplotlib/napari-matplotlib#README.md + Documentation = https://napari-matplotlib.github.io Source Code = https://github.com/matplotlib/napari-matplotlib User Support = https://github.com/matplotlib/napari-matplotlib/issues @@ -33,6 +30,7 @@ install_requires = matplotlib napari numpy + tinycss2 python_requires = >=3.8 include_package_data = True package_dir = @@ -49,6 +47,7 @@ napari.manifest = [options.extras_require] docs = + napari[all] numpydoc pydata-sphinx-theme qtgallery @@ -56,9 +55,13 @@ docs = sphinx-automodapi sphinx-gallery testing = - napari[pyqt5] + napari[pyqt6-experimental] + pooch + pyqt6 pytest pytest-cov + pytest-mock + pytest-mpl pytest-qt tox pytest-xvfb;sys_platform == 'linux' diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index b4d8c7cf..792b5aff 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -1,78 +1,203 @@ import os from pathlib import Path -from typing import List, Tuple +from typing import List, Optional, Tuple -import matplotlib as mpl +import matplotlib +import matplotlib.style as mplstyle import napari -from matplotlib.backends.backend_qt5agg import ( +from matplotlib.backends.backend_qtagg import ( FigureCanvas, NavigationToolbar2QT, ) +from matplotlib.figure import Figure from qtpy.QtGui import QIcon -from qtpy.QtWidgets import QVBoxLayout, QWidget +from qtpy.QtWidgets import QLabel, QVBoxLayout, QWidget -from .util import Interval +from .util import Interval, from_napari_css_get_size_of -mpl.rc("axes", edgecolor="white") -mpl.rc("axes", facecolor="#262930") -mpl.rc("axes", labelcolor="white") -mpl.rc("savefig", facecolor="#262930") -mpl.rc("text", color="white") +__all__ = ["BaseNapariMPLWidget", "NapariMPLWidget", "SingleAxesWidget"] + +_CUSTOM_STYLE_PATH = ( + Path(matplotlib.get_configdir()) / "napari-matplotlib.mplstyle" +) -mpl.rc("xtick", color="white") -mpl.rc("ytick", color="white") -# Icons modified from -# https://github.com/matplotlib/matplotlib/tree/main/lib/matplotlib/mpl-data/images -ICON_ROOT = Path(__file__).parent / "icons" -__all__ = ["NapariMPLWidget"] +class BaseNapariMPLWidget(QWidget): + """ + Widget containing Matplotlib canvas and toolbar themed to match napari. + This creates a single FigureCanvas, which contains a single + `~matplotlib.figure.Figure`, and an associated toolbar. Both of these + are customised to match the visual style of the main napari window. + It is not responsible for creating any Axes, because different + widgets may want to implement different subplot layouts. -class NapariMPLWidget(QWidget): + See Also + -------- + NapariMPLWidget : A child class that also contains helpful attributes and + methods for working with napari layers. """ - Base widget that can be embedded as a napari widget and contains a - Matplotlib canvas. - This creates a single FigureCanvas, which contains a single Figure. + def __init__( + self, + napari_viewer: napari.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(parent=parent) + self.viewer = napari_viewer + self._mpl_style_sheet_path: Optional[Path] = None + + # Sets figure.* style + with mplstyle.context(self.mpl_style_sheet_path): + self.canvas = FigureCanvas() + + self.canvas.figure.set_layout_engine("constrained") + self.toolbar = NapariNavigationToolbar( + self.canvas, parent=self + ) # type: ignore[no-untyped-call] + self._replace_toolbar_icons() + # callback to update when napari theme changed + # TODO: this isn't working completely (see issue #140) + # most of our styling respects the theme change but not all + self.viewer.events.theme.connect(self._on_napari_theme_changed) + + self.setLayout(QVBoxLayout()) + self.layout().addWidget(self.toolbar) + self.layout().addWidget(self.canvas) + + @property + def figure(self) -> Figure: + """Matplotlib figure.""" + return self.canvas.figure - This class also handles callbacks to automatically update figures when - the layer selection or z-step is changed in the napari viewer. To take - advantage of this sub-classes should implement the ``clear()`` and - ``draw()`` methods. + @property + def mpl_style_sheet_path(self) -> Path: + """ + Path to the set Matplotlib style sheet. + """ + if self._mpl_style_sheet_path is not None: + return self._mpl_style_sheet_path + elif (_CUSTOM_STYLE_PATH).exists(): + return _CUSTOM_STYLE_PATH + elif self._napari_theme_has_light_bg(): + return Path(__file__).parent / "styles" / "light.mplstyle" + else: + return Path(__file__).parent / "styles" / "dark.mplstyle" + + @mpl_style_sheet_path.setter + def mpl_style_sheet_path(self, path: Path) -> None: + self._mpl_style_sheet_path = Path(path) + + def add_single_axes(self) -> None: + """ + Add a single Axes to the figure. + + The Axes is saved on the ``.axes`` attribute for later access. + """ + # Sets axes.* style. + # Does not set any text styling set by axes.* keys + with mplstyle.context(self.mpl_style_sheet_path): + self.axes = self.figure.subplots() + + def _on_napari_theme_changed(self) -> None: + """ + Called when the napari theme is changed. + """ + self._replace_toolbar_icons() + + def _napari_theme_has_light_bg(self) -> bool: + """ + Does this theme have a light background? + + Returns + ------- + bool + True if theme's background colour has hsl lighter than 50%, False if darker. + """ + theme = napari.utils.theme.get_theme(self.viewer.theme, as_dict=False) + _, _, bg_lightness = theme.background.as_hsl_tuple() + return bg_lightness > 0.5 + + def _get_path_to_icon(self) -> Path: + """ + Get the icons directory (which is theme-dependent). + + Icons modified from + https://github.com/matplotlib/matplotlib/tree/main/lib/matplotlib/mpl-data/images + """ + icon_root = Path(__file__).parent / "icons" + if self._napari_theme_has_light_bg(): + return icon_root / "black" + else: + return icon_root / "white" + + def _replace_toolbar_icons(self) -> None: + """ + Modifies toolbar icons to match the napari theme, and add some tooltips. + """ + icon_dir = self._get_path_to_icon() + for action in self.toolbar.actions(): + text = action.text() + if text == "Pan": + action.setToolTip( + "Pan/Zoom: Left button pans; Right button zooms; " + "Click once to activate; Click again to deactivate" + ) + if text == "Zoom": + action.setToolTip( + "Zoom to rectangle; Click once to activate; " + "Click again to deactivate" + ) + if len(text) > 0: # i.e. not a separator item + icon_path = os.path.join(icon_dir, text + ".png") + action.setIcon(QIcon(icon_path)) + + +class NapariMPLWidget(BaseNapariMPLWidget): + """ + Widget containing a Matplotlib canvas and toolbar. + + In addition to ``BaseNapariMPLWidget``, this class handles callbacks + to automatically update figures when the layer selection or z-step + is changed in the napari viewer. To take advantage of this sub-classes + should implement the ``clear()`` and ``draw()`` methods. + + When both the z-step and layer selection is changed, ``clear()`` is called + and if the number a type of selected layers are valid for the widget + ``draw()`` is then called. When layer selection is changed ``on_update_layers()`` + is also called, which can be useful e.g. for updating a layer list in a + selection widget. Attributes ---------- viewer : `napari.Viewer` Main napari viewer. - figure : `matplotlib.figure.Figure` - Matplotlib figure. - canvas : matplotlib.backends.backend_qt5agg.FigureCanvas - Matplotlib canvas. layers : `list` List of currently selected napari layers. - """ - - def __init__(self, napari_viewer: napari.viewer.Viewer): - super().__init__() - self.viewer = napari_viewer - self.canvas = FigureCanvas() - self.canvas.figure.set_tight_layout(True) - self.canvas.figure.patch.set_facecolor("#262930") - self.toolbar = NapariNavigationToolbar(self.canvas, self) - self._replace_toolbar_icons() + See Also + -------- + BaseNapariMPLWidget : The parent class of this widget. Contains helpful methods + for creating and working with the Matplotlib figure and any axes. + """ - self.setLayout(QVBoxLayout()) - self.layout().addWidget(self.toolbar) - self.layout().addWidget(self.canvas) + #: Number of layers taken as input + n_layers_input = Interval(None, None) + #: Type of layer taken as input + input_layer_types: Tuple[napari.layers.Layer, ...] = (napari.layers.Layer,) - self.setup_callbacks() + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(napari_viewer=napari_viewer, parent=parent) + self._setup_callbacks() self.layers: List[napari.layers.Layer] = [] - # Accept any number of input layers by default - n_layers_input = Interval(None, None) - # Accept any type of input layer by default - input_layer_types: Tuple[napari.layers.Layer, ...] = (napari.layers.Layer,) + helper_text = self.n_layers_input._helper_text + if helper_text is not None: + self.layout().insertWidget(0, QLabel(helper_text)) @property def n_selected_layers(self) -> int: @@ -84,27 +209,42 @@ def n_selected_layers(self) -> int: @property def current_z(self) -> int: """ - Current z-step of the viewer. + Current z-step of the napari viewer. """ return self.viewer.dims.current_step[0] - def setup_callbacks(self) -> None: + def _on_napari_theme_changed(self) -> None: + """Update MPL toolbar and axis styling when `napari.Viewer.theme` is changed. + + Note: + At the moment we only handle the default 'light' and 'dark' napari themes. + """ + super()._on_napari_theme_changed() + self.clear() + self.draw() + + def _setup_callbacks(self) -> None: """ - Setup callbacks for: - - Layer selection changing - - z-step changing + Sets up callbacks. + + Sets up callbacks for when: + - Layer selection is changed + - z-step is changed """ # z-step changed in viewer self.viewer.dims.events.current_step.connect(self._draw) # Layer selection changed in viewer - self.viewer.layers.selection.events.changed.connect(self.update_layers) + self.viewer.layers.selection.events.changed.connect( + self._update_layers + ) - def update_layers(self, event: napari.utils.events.Event) -> None: + def _update_layers(self, event: napari.utils.events.Event) -> None: """ - Update the layers attribute with currently selected layers and re-draw. + Update the ``layers`` attribute with currently selected layers and re-draw. """ self.layers = list(self.viewer.layers.selection) - self._on_update_layers() + self.layers = sorted(self.layers, key=lambda layer: layer.name) + self.on_update_layers() self._draw() def _draw(self) -> None: @@ -112,7 +252,10 @@ def _draw(self) -> None: Clear current figure, check selected layers are correct, and draw new figure if so. """ - self.clear() + # Clearing axes sets new defaults, so need to make sure style is applied when + # this happens + with mplstyle.context(self.mpl_style_sheet_path): + self.clear() if self.n_selected_layers in self.n_layers_input and all( isinstance(layer, self.input_layer_types) for layer in self.layers ): @@ -133,55 +276,68 @@ def draw(self) -> None: This is a no-op, and is intended for derived classes to override. """ - def _on_update_layers(self) -> None: + def on_update_layers(self) -> None: """ - This function is called when self.layers is updated via - ``self.update_layers()``. + Called when the selected layers are updated. This is a no-op, and is intended for derived classes to override. """ - def _replace_toolbar_icons(self): - # Modify toolbar icons and some tooltips - for action in self.toolbar.actions(): - text = action.text() - if text == "Pan": - action.setToolTip( - "Pan/Zoom: Left button pans; Right button zooms; " - "Click once to activate; Click again to deactivate" - ) - if text == "Zoom": - action.setToolTip( - "Zoom to rectangle; Click once to activate; " - "Click again to deactivate" - ) - if len(text) > 0: # i.e. not a separator item - icon_path = os.path.join(ICON_ROOT, text + ".png") - action.setIcon(QIcon(icon_path)) + +class SingleAxesWidget(NapariMPLWidget): + """ + In addition to `NapariMPLWidget`, this sets up a single axes and + the callback to clear it. + """ + + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(napari_viewer=napari_viewer, parent=parent) + self.add_single_axes() + + def clear(self) -> None: + """ + Clear the axes. + """ + with mplstyle.context(self.mpl_style_sheet_path): + self.axes.clear() class NapariNavigationToolbar(NavigationToolbar2QT): """Custom Toolbar style for Napari.""" - def _update_buttons_checked(self): + def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.setIconSize( + from_napari_css_get_size_of( + "QtViewerPushButton", fallback=(28, 28) + ) + ) + + def _update_buttons_checked(self) -> None: """Update toggle tool icons when selected/unselected.""" super()._update_buttons_checked() + icon_dir = self.parentWidget()._get_path_to_icon() + # changes pan/zoom icons depending on state (checked or not) if "pan" in self._actions: if self._actions["pan"].isChecked(): self._actions["pan"].setIcon( - QIcon(os.path.join(ICON_ROOT, "Pan_checked.png")) + QIcon(os.path.join(icon_dir, "Pan_checked.png")) ) else: self._actions["pan"].setIcon( - QIcon(os.path.join(ICON_ROOT, "Pan.png")) + QIcon(os.path.join(icon_dir, "Pan.png")) ) if "zoom" in self._actions: if self._actions["zoom"].isChecked(): self._actions["zoom"].setIcon( - QIcon(os.path.join(ICON_ROOT, "Zoom_checked.png")) + QIcon(os.path.join(icon_dir, "Zoom_checked.png")) ) else: self._actions["zoom"].setIcon( - QIcon(os.path.join(ICON_ROOT, "Zoom.png")) + QIcon(os.path.join(icon_dir, "Zoom.png")) ) diff --git a/src/napari_matplotlib/histogram.py b/src/napari_matplotlib/histogram.py index 1f479ad7..39ad41a3 100644 --- a/src/napari_matplotlib/histogram.py +++ b/src/napari_matplotlib/histogram.py @@ -1,17 +1,18 @@ -import numpy as np - -from .base import NapariMPLWidget - -__all__ = ["HistogramWidget"] +from typing import Optional import napari +import numpy as np +from qtpy.QtWidgets import QWidget +from .base import SingleAxesWidget from .util import Interval +__all__ = ["HistogramWidget"] + _COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"} -class HistogramWidget(NapariMPLWidget): +class HistogramWidget(SingleAxesWidget): """ Display a histogram of the currently selected layer. """ @@ -19,13 +20,13 @@ class HistogramWidget(NapariMPLWidget): n_layers_input = Interval(1, 1) input_layer_types = (napari.layers.Image,) - def __init__(self, napari_viewer: napari.viewer.Viewer): - super().__init__(napari_viewer) - self.axes = self.canvas.figure.subplots() - self.update_layers(None) - - def clear(self) -> None: - self.axes.clear() + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(napari_viewer, parent=parent) + self._update_layers(None) def draw(self) -> None: """ diff --git a/src/napari_matplotlib/icons/black/Back.png b/src/napari_matplotlib/icons/black/Back.png new file mode 100644 index 00000000..d7c65b43 Binary files /dev/null and b/src/napari_matplotlib/icons/black/Back.png differ diff --git a/src/napari_matplotlib/icons/black/Customize.png b/src/napari_matplotlib/icons/black/Customize.png new file mode 100644 index 00000000..9f56bb6d Binary files /dev/null and b/src/napari_matplotlib/icons/black/Customize.png differ diff --git a/src/napari_matplotlib/icons/black/Forward.png b/src/napari_matplotlib/icons/black/Forward.png new file mode 100644 index 00000000..52770f6f Binary files /dev/null and b/src/napari_matplotlib/icons/black/Forward.png differ diff --git a/src/napari_matplotlib/icons/black/Home.png b/src/napari_matplotlib/icons/black/Home.png new file mode 100644 index 00000000..9e527bfd Binary files /dev/null and b/src/napari_matplotlib/icons/black/Home.png differ diff --git a/src/napari_matplotlib/icons/black/Pan.png b/src/napari_matplotlib/icons/black/Pan.png new file mode 100644 index 00000000..36332c34 Binary files /dev/null and b/src/napari_matplotlib/icons/black/Pan.png differ diff --git a/src/napari_matplotlib/icons/black/Pan_checked.png b/src/napari_matplotlib/icons/black/Pan_checked.png new file mode 100644 index 00000000..eb0b908f Binary files /dev/null and b/src/napari_matplotlib/icons/black/Pan_checked.png differ diff --git a/src/napari_matplotlib/icons/black/Save.png b/src/napari_matplotlib/icons/black/Save.png new file mode 100644 index 00000000..79b0d030 Binary files /dev/null and b/src/napari_matplotlib/icons/black/Save.png differ diff --git a/src/napari_matplotlib/icons/black/Subplots.png b/src/napari_matplotlib/icons/black/Subplots.png new file mode 100644 index 00000000..aa15d760 Binary files /dev/null and b/src/napari_matplotlib/icons/black/Subplots.png differ diff --git a/src/napari_matplotlib/icons/black/Zoom.png b/src/napari_matplotlib/icons/black/Zoom.png new file mode 100644 index 00000000..4d2898b7 Binary files /dev/null and b/src/napari_matplotlib/icons/black/Zoom.png differ diff --git a/src/napari_matplotlib/icons/black/Zoom_checked.png b/src/napari_matplotlib/icons/black/Zoom_checked.png new file mode 100644 index 00000000..ad769e66 Binary files /dev/null and b/src/napari_matplotlib/icons/black/Zoom_checked.png differ diff --git a/src/napari_matplotlib/icons/Back.png b/src/napari_matplotlib/icons/white/Back.png similarity index 100% rename from src/napari_matplotlib/icons/Back.png rename to src/napari_matplotlib/icons/white/Back.png diff --git a/src/napari_matplotlib/icons/Customize.png b/src/napari_matplotlib/icons/white/Customize.png similarity index 100% rename from src/napari_matplotlib/icons/Customize.png rename to src/napari_matplotlib/icons/white/Customize.png diff --git a/src/napari_matplotlib/icons/Forward.png b/src/napari_matplotlib/icons/white/Forward.png similarity index 100% rename from src/napari_matplotlib/icons/Forward.png rename to src/napari_matplotlib/icons/white/Forward.png diff --git a/src/napari_matplotlib/icons/Home.png b/src/napari_matplotlib/icons/white/Home.png similarity index 100% rename from src/napari_matplotlib/icons/Home.png rename to src/napari_matplotlib/icons/white/Home.png diff --git a/src/napari_matplotlib/icons/Pan.png b/src/napari_matplotlib/icons/white/Pan.png similarity index 100% rename from src/napari_matplotlib/icons/Pan.png rename to src/napari_matplotlib/icons/white/Pan.png diff --git a/src/napari_matplotlib/icons/Pan_checked.png b/src/napari_matplotlib/icons/white/Pan_checked.png similarity index 100% rename from src/napari_matplotlib/icons/Pan_checked.png rename to src/napari_matplotlib/icons/white/Pan_checked.png diff --git a/src/napari_matplotlib/icons/Save.png b/src/napari_matplotlib/icons/white/Save.png similarity index 100% rename from src/napari_matplotlib/icons/Save.png rename to src/napari_matplotlib/icons/white/Save.png diff --git a/src/napari_matplotlib/icons/Subplots.png b/src/napari_matplotlib/icons/white/Subplots.png similarity index 100% rename from src/napari_matplotlib/icons/Subplots.png rename to src/napari_matplotlib/icons/white/Subplots.png diff --git a/src/napari_matplotlib/icons/Zoom.png b/src/napari_matplotlib/icons/white/Zoom.png similarity index 100% rename from src/napari_matplotlib/icons/Zoom.png rename to src/napari_matplotlib/icons/white/Zoom.png diff --git a/src/napari_matplotlib/icons/Zoom_checked.png b/src/napari_matplotlib/icons/white/Zoom_checked.png similarity index 100% rename from src/napari_matplotlib/icons/Zoom_checked.png rename to src/napari_matplotlib/icons/white/Zoom_checked.png diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 3b0f918c..db86c7f3 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -1,86 +1,63 @@ -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union -import matplotlib.colors as mcolor import napari -import numpy as np -from magicgui import magicgui -from magicgui.widgets import ComboBox +import numpy.typing as npt +from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget -from .base import NapariMPLWidget +from .base import SingleAxesWidget from .util import Interval -__all__ = ["ScatterWidget", "FeaturesScatterWidget"] +__all__ = ["ScatterBaseWidget", "ScatterWidget", "FeaturesScatterWidget"] -class ScatterBaseWidget(NapariMPLWidget): - # opacity value for the markers - _marker_alpha = 0.5 - - # flag set to True if histogram should be used - # for plotting large points - _histogram_for_large_data = True +class ScatterBaseWidget(SingleAxesWidget): + """ + Base class for widgets that scatter two datasets against each other. + """ # if the number of points is greater than this value, - # the scatter is plotted as a 2dhist + # the scatter is plotted as a 2D histogram _threshold_to_switch_to_histogram = 500 - def __init__(self, napari_viewer: napari.viewer.Viewer): - super().__init__(napari_viewer) - - self.axes = self.canvas.figure.subplots() - self.update_layers(None) - - def clear(self) -> None: - """ - Clear the axes. - """ - self.axes.clear() - def draw(self) -> None: """ Scatter the currently selected layers. """ - data, x_axis_name, y_axis_name = self._get_data() - - if len(data) == 0: - # don't plot if there isn't data + if len(self.layers) == 0: return + x, y, x_axis_name, y_axis_name = self._get_data() - if self._histogram_for_large_data and ( - data[0].size > self._threshold_to_switch_to_histogram - ): + if x.size > self._threshold_to_switch_to_histogram: self.axes.hist2d( - data[0].ravel(), - data[1].ravel(), + x.ravel(), + y.ravel(), bins=100, - norm=mcolor.LogNorm(), ) else: - self.axes.scatter(data[0], data[1], alpha=self._marker_alpha) + self.axes.scatter(x, y, alpha=0.5) self.axes.set_xlabel(x_axis_name) self.axes.set_ylabel(y_axis_name) - def _get_data(self) -> Tuple[List[np.ndarray], str, str]: - """Get the plot data. + def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: + """ + Get the plot data. This must be implemented on the subclass. Returns ------- - data : np.ndarray - The list containing the scatter plot data. - x_axis_name : str - The label to display on the x axis - y_axis_name: str - The label to display on the y axis + x, y : np.ndarray + x and y values of plot data. + x_axis_name, y_axis_name : str + Label to display on the x/y axis """ raise NotImplementedError class ScatterWidget(ScatterBaseWidget): """ - Widget to display scatter plot of two similarly shaped image layers. + Scatter data in two similarly shaped layers. If there are more than 500 data points, a 2D histogram is displayed instead of a scatter plot, to avoid too many scatter points. @@ -89,8 +66,9 @@ class ScatterWidget(ScatterBaseWidget): n_layers_input = Interval(2, 2) input_layer_types = (napari.layers.Image,) - def _get_data(self) -> Tuple[List[np.ndarray], str, str]: - """Get the plot data. + def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: + """ + Get the plot data. Returns ------- @@ -101,14 +79,19 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]: y_axis_name: str The title to display on the y axis """ - data = [layer.data[self.current_z] for layer in self.layers] + x = self.layers[0].data[self.current_z] + y = self.layers[1].data[self.current_z] x_axis_name = self.layers[0].name y_axis_name = self.layers[1].name - return data, x_axis_name, y_axis_name + return x, y, x_axis_name, y_axis_name class FeaturesScatterWidget(ScatterBaseWidget): + """ + Widget to scatter data stored in two layer feature attributes. + """ + n_layers_input = Interval(1, 1) # All layers that have a .features attributes input_layer_types = ( @@ -119,46 +102,57 @@ class FeaturesScatterWidget(ScatterBaseWidget): napari.layers.Vectors, ) - def __init__(self, napari_viewer: napari.viewer.Viewer): - super().__init__(napari_viewer) - self._key_selection_widget = magicgui( - self._set_axis_keys, - x_axis_key={"choices": self._get_valid_axis_keys}, - y_axis_key={"choices": self._get_valid_axis_keys}, - call_button="plot", - ) + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(napari_viewer, parent=parent) - self.layout().addWidget(self._key_selection_widget.native) + self.layout().addLayout(QVBoxLayout()) + + self._selectors: Dict[str, QComboBox] = {} + for dim in ["x", "y"]: + self._selectors[dim] = QComboBox() + # Re-draw when combo boxes are updated + self._selectors[dim].currentTextChanged.connect(self._draw) + + self.layout().addWidget(QLabel(f"{dim}-axis:")) + self.layout().addWidget(self._selectors[dim]) + + self._update_layers(None) @property - def x_axis_key(self) -> Optional[str]: - """Key to access x axis data from the FeaturesTable""" - return self._x_axis_key + def x_axis_key(self) -> Union[str, None]: + """ + Key for the x-axis data. + """ + if self._selectors["x"].count() == 0: + return None + else: + return self._selectors["x"].currentText() @x_axis_key.setter - def x_axis_key(self, key: Optional[str]) -> None: - self._x_axis_key = key + def x_axis_key(self, key: str) -> None: + self._selectors["x"].setCurrentText(key) self._draw() @property - def y_axis_key(self) -> Optional[str]: - """Key to access y axis data from the FeaturesTable""" - return self._y_axis_key + def y_axis_key(self) -> Union[str, None]: + """ + Key for the y-axis data. + """ + if self._selectors["y"].count() == 0: + return None + else: + return self._selectors["y"].currentText() @y_axis_key.setter - def y_axis_key(self, key: Optional[str]) -> None: - self._y_axis_key = key - self._draw() - - def _set_axis_keys(self, x_axis_key: str, y_axis_key: str) -> None: - """Set both axis keys and then redraw the plot""" - self._x_axis_key = x_axis_key - self._y_axis_key = y_axis_key + def y_axis_key(self, key: str) -> None: + self._selectors["y"].setCurrentText(key) self._draw() - def _get_valid_axis_keys( - self, combo_widget: Optional[ComboBox] = None - ) -> List[str]: + def _get_valid_axis_keys(self) -> List[str]: """ Get the valid axis keys from the layer FeatureTable. @@ -173,8 +167,34 @@ def _get_valid_axis_keys( else: return self.layers[0].features.keys() - def _get_data(self) -> Tuple[List[np.ndarray], str, str]: - """Get the plot data. + def _ready_to_scatter(self) -> bool: + """ + Return True if selected layer has a feature table we can scatter with, + and the two columns to be scatterd have been selected. + """ + if not hasattr(self.layers[0], "features"): + return False + + feature_table = self.layers[0].features + valid_keys = self._get_valid_axis_keys() + return ( + feature_table is not None + and len(feature_table) > 0 + and self.x_axis_key in valid_keys + and self.y_axis_key in valid_keys + ) + + def draw(self) -> None: + """ + Scatter two features from the currently selected layer. + """ + if self._ready_to_scatter(): + super().draw() + + def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: + """ + Get the plot data from the ``features`` attribute of the first + selected layer. Returns ------- @@ -188,37 +208,23 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]: The title to display on the y axis. Returns an empty string if nothing to plot. """ - if not hasattr(self.layers[0], "features"): - # if the selected layer doesn't have a featuretable, - # skip draw - return [], "", "" - feature_table = self.layers[0].features - if ( - (len(feature_table) == 0) - or (self.x_axis_key is None) - or (self.y_axis_key is None) - ): - return [], "", "" + x = feature_table[self.x_axis_key] + y = feature_table[self.y_axis_key] - data_x = feature_table[self.x_axis_key] - data_y = feature_table[self.y_axis_key] - data = [data_x, data_y] + x_axis_name = str(self.x_axis_key) + y_axis_name = str(self.y_axis_key) - x_axis_name = self.x_axis_key.replace("_", " ") - y_axis_name = self.y_axis_key.replace("_", " ") + return x, y, x_axis_name, y_axis_name - return data, x_axis_name, y_axis_name - - def _on_update_layers(self) -> None: + def on_update_layers(self) -> None: """ - This is called when the layer selection changes by - ``self.update_layers()``. + Called when the layer selection changes by ``self.update_layers()``. """ - if hasattr(self, "_key_selection_widget"): - self._key_selection_widget.reset_choices() - - # reset the axis keys - self._x_axis_key = None - self._y_axis_key = None + # Clear combobox + for dim in ["x", "y"]: + while self._selectors[dim].count() > 0: + self._selectors[dim].removeItem(0) + # Add keys for newly selected layer + self._selectors[dim].addItems(self._get_valid_axis_keys()) diff --git a/src/napari_matplotlib/slice.py b/src/napari_matplotlib/slice.py index 5ab122c0..e3aa80b2 100644 --- a/src/napari_matplotlib/slice.py +++ b/src/napari_matplotlib/slice.py @@ -1,10 +1,12 @@ -from typing import Dict, Tuple +from typing import Any, Dict, Optional, Tuple +import matplotlib.ticker as mticker import napari import numpy as np -from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox +import numpy.typing as npt +from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox, QWidget -from .base import NapariMPLWidget +from .base import SingleAxesWidget from .util import Interval __all__ = ["SliceWidget"] @@ -13,7 +15,7 @@ _dims = ["x", "y", "z"] -class SliceWidget(NapariMPLWidget): +class SliceWidget(SingleAxesWidget): """ Plot a 1D slice along a given dimension. """ @@ -21,10 +23,13 @@ class SliceWidget(NapariMPLWidget): n_layers_input = Interval(1, 1) input_layer_types = (napari.layers.Image,) - def __init__(self, napari_viewer: napari.viewer.Viewer): + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): # Setup figure/axes - super().__init__(napari_viewer) - self.axes = self.canvas.figure.subplots() + super().__init__(napari_viewer, parent=parent) button_layout = QHBoxLayout() self.layout().addLayout(button_layout) @@ -46,10 +51,13 @@ def __init__(self, napari_viewer: napari.viewer.Viewer): for d in _dims_sel: self.slice_selectors[d].textChanged.connect(self._draw) - self.update_layers(None) + self._update_layers(None) @property - def layer(self): + def _layer(self) -> napari.layers.Layer: + """ + Layer being plotted. + """ return self.layers[0] @property @@ -69,25 +77,19 @@ def current_dim_index(self) -> int: return _dims[::-1].index(self.current_dim) @property - def selector_values(self) -> Dict[str, int]: - return {d: self.slice_selectors[d].value() for d in _dims_sel} - - def update_slice_selectors(self) -> None: + def _selector_values(self) -> Dict[str, int]: """ - Update range and enabled status of the slice selectors, and the value - of the z slice selector. + Values of the slice selectors. """ - # Update min/max - for i, dim in enumerate(_dims_sel): - self.slice_selectors[dim].setRange(0, self.layer.data.shape[i]) + return {d: self.slice_selectors[d].value() for d in _dims_sel} - def get_xy(self) -> Tuple[np.ndarray, np.ndarray]: + def _get_xy(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any]]: """ Get data for plotting. """ - x = np.arange(self.layer.data.shape[self.current_dim_index]) + x = np.arange(self._layer.data.shape[self.current_dim_index]) - vals = self.selector_values + vals = self._selector_values vals.update({"z": self.current_z}) slices = [] @@ -102,19 +104,20 @@ def get_xy(self) -> Tuple[np.ndarray, np.ndarray]: # Reverse since z is the first axis in napari slices = slices[::-1] - y = self.layer.data[tuple(slices)].ravel() + y = self._layer.data[tuple(slices)].ravel() return x, y - def clear(self) -> None: - self.axes.cla() - def draw(self) -> None: """ Clear axes and draw a 1D plot. """ - x, y = self.get_xy() + x, y = self._get_xy() self.axes.plot(x, y) self.axes.set_xlabel(self.current_dim) - self.axes.set_title(self.layer.name) + self.axes.set_title(self._layer.name) + # Make sure all ticks lie on integer values + self.axes.xaxis.set_major_locator( + mticker.MaxNLocator(steps=[1, 2, 5, 10], integer=True) + ) diff --git a/src/napari_matplotlib/styles/README.md b/src/napari_matplotlib/styles/README.md new file mode 100644 index 00000000..79d3c417 --- /dev/null +++ b/src/napari_matplotlib/styles/README.md @@ -0,0 +1,3 @@ +This folder contains default built-in Matplotlib style sheets. +See https://matplotlib.org/stable/tutorials/introductory/customizing.html#defining-your-own-style +for more info on Matplotlib style sheets. diff --git a/src/napari_matplotlib/styles/dark.mplstyle b/src/napari_matplotlib/styles/dark.mplstyle new file mode 100644 index 00000000..1658f9b4 --- /dev/null +++ b/src/napari_matplotlib/styles/dark.mplstyle @@ -0,0 +1,12 @@ +# Dark-theme napari colour scheme for matplotlib plots + +# text (very light grey - almost white): #f0f1f2 +# foreground (mid grey): #414851 +# background (dark blue-gray): #262930 + +figure.facecolor : none +axes.labelcolor : f0f1f2 +axes.facecolor : none +axes.edgecolor : 414851 +xtick.color : f0f1f2 +ytick.color : f0f1f2 diff --git a/src/napari_matplotlib/styles/light.mplstyle b/src/napari_matplotlib/styles/light.mplstyle new file mode 100644 index 00000000..3b8d7d1d --- /dev/null +++ b/src/napari_matplotlib/styles/light.mplstyle @@ -0,0 +1,12 @@ +# Light-theme napari colour scheme for matplotlib plots + +# text (very dark grey - almost black): #3b3a39 +# foreground (mid grey): #d6d0ce +# background (brownish beige): #efebe9 + +figure.facecolor : none +axes.labelcolor : 3b3a39 +axes.facecolor : none +axes.edgecolor : d6d0ce +xtick.color : 3b3a39 +ytick.color : 3b3a39 diff --git a/src/napari_matplotlib/tests/baseline/test_custom_theme.png b/src/napari_matplotlib/tests/baseline/test_custom_theme.png new file mode 100644 index 00000000..65c43a49 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_custom_theme.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_histogram_2D.png b/src/napari_matplotlib/tests/baseline/test_histogram_2D.png new file mode 100644 index 00000000..b76d1e10 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_histogram_2D.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_histogram_3D.png b/src/napari_matplotlib/tests/baseline/test_histogram_3D.png new file mode 100644 index 00000000..2dffdcb2 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_histogram_3D.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_no_theme_side_effects.png b/src/napari_matplotlib/tests/baseline/test_no_theme_side_effects.png new file mode 100644 index 00000000..4888f305 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_no_theme_side_effects.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_slice_2D.png b/src/napari_matplotlib/tests/baseline/test_slice_2D.png new file mode 100644 index 00000000..5b73091c Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_slice_2D.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_slice_3D.png b/src/napari_matplotlib/tests/baseline/test_slice_3D.png new file mode 100644 index 00000000..43c8c3b6 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_slice_3D.png differ diff --git a/src/napari_matplotlib/tests/conftest.py b/src/napari_matplotlib/tests/conftest.py index 274f1226..6b2a813f 100644 --- a/src/napari_matplotlib/tests/conftest.py +++ b/src/napari_matplotlib/tests/conftest.py @@ -1,5 +1,9 @@ +import os +from pathlib import Path + import numpy as np import pytest +from skimage import data @pytest.fixture( @@ -11,3 +15,46 @@ ) def image_data(request): return np.ones(request.param[0]), request.param[1] + + +@pytest.fixture +def astronaut_data(): + return data.astronaut(), {"rgb": True} + + +@pytest.fixture +def brain_data(): + return data.brain(), {"rgb": False} + + +@pytest.fixture +def points_with_features_data(): + n_points = 100 + np.random.seed(10) + points_data = 100 * np.random.random((100, 2)) + points_features = { + "feature_0": np.random.random((n_points,)), + "feature_1": np.random.random((n_points,)), + "feature_2": np.random.random((n_points,)), + } + + return points_data, {"features": points_features} + + +@pytest.fixture(autouse=True, scope="session") +def set_strict_qt(): + env_var = "NAPARI_STRICT_QT" + old_val = os.environ.get(env_var) + os.environ[env_var] = "1" + # Run tests + yield + # Reset to original value + if old_val is not None: + os.environ[env_var] = old_val + else: + del os.environ[env_var] + + +@pytest.fixture +def theme_path(): + return Path(__file__).parent / "data" / "test_theme.mplstyle" diff --git a/src/napari_matplotlib/tests/data/test_theme.mplstyle b/src/napari_matplotlib/tests/data/test_theme.mplstyle new file mode 100644 index 00000000..2f94b31f --- /dev/null +++ b/src/napari_matplotlib/tests/data/test_theme.mplstyle @@ -0,0 +1,15 @@ +# Dark-theme napari colour scheme for matplotlib plots + +#f4b8b2 # light red +#b2e4f4 # light blue +#0aa3fc # dark blue +#008939 # dark green + +figure.facecolor : f4b8b2 # light red +axes.facecolor : b2e4f4 # light blue +axes.edgecolor : 0aa3fc # dark blue + +xtick.color : 008939 # dark green +xtick.labelcolor : 008939 # dark green +ytick.color : 008939 # dark green +ytick.labelcolor : 008939 # dark green diff --git a/src/napari_matplotlib/tests/helpers.py b/src/napari_matplotlib/tests/helpers.py new file mode 100644 index 00000000..b00c0af4 --- /dev/null +++ b/src/napari_matplotlib/tests/helpers.py @@ -0,0 +1,29 @@ +from io import BytesIO + +import numpy as np +import numpy.typing as npt +import pytest +from matplotlib.figure import Figure + + +def fig_to_array(fig: Figure) -> npt.NDArray[np.uint8]: + """ + Convert a figure to an RGB array. + """ + with BytesIO() as io_buf: + fig.savefig(io_buf, format="raw") + io_buf.seek(0) + img_arr: npt.NDArray[np.uint8] = np.reshape( + np.frombuffer(io_buf.getvalue(), dtype=np.uint8), + newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1), + ) + return img_arr + + +def assert_figures_equal(fig1: Figure, fig2: Figure) -> None: + np.testing.assert_equal(fig_to_array(fig1), fig_to_array(fig2)) + + +def assert_figures_not_equal(fig1: Figure, fig2: Figure) -> None: + with pytest.raises(AssertionError, match="Arrays are not equal"): + assert_figures_equal(fig1, fig2) diff --git a/src/napari_matplotlib/tests/scatter/__init__.py b/src/napari_matplotlib/tests/scatter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/napari_matplotlib/tests/scatter/baseline/test_features_scatter_widget_2D.png b/src/napari_matplotlib/tests/scatter/baseline/test_features_scatter_widget_2D.png new file mode 100644 index 00000000..269ebd01 Binary files /dev/null and b/src/napari_matplotlib/tests/scatter/baseline/test_features_scatter_widget_2D.png differ diff --git a/src/napari_matplotlib/tests/scatter/baseline/test_scatter_2D.png b/src/napari_matplotlib/tests/scatter/baseline/test_scatter_2D.png new file mode 100644 index 00000000..3b550666 Binary files /dev/null and b/src/napari_matplotlib/tests/scatter/baseline/test_scatter_2D.png differ diff --git a/src/napari_matplotlib/tests/scatter/baseline/test_scatter_3D.png b/src/napari_matplotlib/tests/scatter/baseline/test_scatter_3D.png new file mode 100644 index 00000000..27e7d673 Binary files /dev/null and b/src/napari_matplotlib/tests/scatter/baseline/test_scatter_3D.png differ diff --git a/src/napari_matplotlib/tests/scatter/test_scatter.py b/src/napari_matplotlib/tests/scatter/test_scatter.py new file mode 100644 index 00000000..a225863d --- /dev/null +++ b/src/napari_matplotlib/tests/scatter/test_scatter.py @@ -0,0 +1,57 @@ +from copy import deepcopy + +import pytest + +from napari_matplotlib import ScatterBaseWidget, ScatterWidget + + +@pytest.mark.mpl_image_compare +def test_scatter_2D(make_napari_viewer, astronaut_data): + viewer = make_napari_viewer() + viewer.theme = "light" + widget = ScatterWidget(viewer) + fig = widget.figure + + viewer.add_image(astronaut_data[0], **astronaut_data[1], name="astronaut") + + viewer.add_image( + astronaut_data[0] * -1, **astronaut_data[1], name="astronaut_reversed" + ) + # De-select existing selection + viewer.layers.selection.clear() + + # Select images + viewer.layers.selection.add(viewer.layers[0]) + viewer.layers.selection.add(viewer.layers[1]) + return deepcopy(fig) + + +@pytest.mark.mpl_image_compare +def test_scatter_3D(make_napari_viewer, brain_data): + viewer = make_napari_viewer() + viewer.theme = "light" + widget = ScatterWidget(viewer) + fig = widget.figure + + viewer.add_image(brain_data[0], **brain_data[1], name="brain") + + viewer.add_image( + brain_data[0] * -1, **brain_data[1], name="brain_reversed" + ) + # De-select existing selection + viewer.layers.selection.clear() + axis = viewer.dims.last_used + slice_no = brain_data[0].shape[0] - 1 + viewer.dims.set_current_step(axis, slice_no) + # Select images + viewer.layers.selection.add(viewer.layers[0]) + viewer.layers.selection.add(viewer.layers[1]) + + return deepcopy(fig) + + +def test_get_data_notimplemented_on_base(make_napari_viewer): + viewer = make_napari_viewer() + widget = ScatterBaseWidget(viewer) + with pytest.raises(NotImplementedError): + widget._get_data() diff --git a/src/napari_matplotlib/tests/test_scatter.py b/src/napari_matplotlib/tests/scatter/test_scatter_features.py similarity index 62% rename from src/napari_matplotlib/tests/test_scatter.py rename to src/napari_matplotlib/tests/scatter/test_scatter_features.py index 8103968e..b5a396fd 100644 --- a/src/napari_matplotlib/tests/test_scatter.py +++ b/src/napari_matplotlib/tests/scatter/test_scatter_features.py @@ -1,26 +1,42 @@ +from copy import deepcopy +from typing import Any, Dict, Tuple + import numpy as np +import numpy.typing as npt +import pytest -from napari_matplotlib import FeaturesScatterWidget, ScatterWidget +from napari_matplotlib import FeaturesScatterWidget -def test_scatter(make_napari_viewer): - # Smoke test adding a scatter widget +@pytest.mark.mpl_image_compare +def test_features_scatter_widget_2D( + make_napari_viewer, points_with_features_data +): viewer = make_napari_viewer() - viewer.add_image(np.random.random((100, 100))) - viewer.add_image(np.random.random((100, 100))) - ScatterWidget(viewer) + viewer.theme = "light" + widget = FeaturesScatterWidget(viewer) + viewer.add_points( + points_with_features_data[0], **points_with_features_data[1] + ) + assert len(viewer.layers) == 1 + # De-select existing selection + viewer.layers.selection.clear() -def test_features_scatter_widget(make_napari_viewer): - # Smoke test adding a features scatter widget - viewer = make_napari_viewer() - viewer.add_image(np.random.random((100, 100))) - viewer.add_labels(np.random.randint(0, 5, (100, 100))) - FeaturesScatterWidget(viewer) + # Select points data and chosen features + viewer.layers.selection.add(viewer.layers[0]) # images need to be selected + widget.x_axis_key = "feature_0" + widget.y_axis_key = "feature_1" + + fig = widget.figure + + return deepcopy(fig) -def make_labels_layer_with_features(): - label_image = np.zeros((100, 100), dtype=np.uint16) +def make_labels_layer_with_features() -> ( + Tuple[npt.NDArray[np.uint16], Dict[str, Any]] +): + label_image: npt.NDArray[np.uint16] = np.zeros((100, 100), dtype=np.uint16) for label_value, start_index in enumerate([10, 30, 50], start=1): end_index = start_index + 10 label_image[start_index:end_index, start_index:end_index] = label_value @@ -34,7 +50,9 @@ def make_labels_layer_with_features(): def test_features_scatter_get_data(make_napari_viewer): - """test the get data method""" + """ + Test the get data method. + """ # make the label image label_image, feature_table = make_labels_layer_with_features() @@ -50,17 +68,16 @@ def test_features_scatter_get_data(make_napari_viewer): y_column = "feature_2" scatter_widget.y_axis_key = y_column - data, x_axis_name, y_axis_name = scatter_widget._get_data() - np.testing.assert_allclose( - data, np.stack((feature_table[x_column], feature_table[y_column])) - ) - assert x_axis_name == x_column.replace("_", " ") - assert y_axis_name == y_column.replace("_", " ") + x, y, x_axis_name, y_axis_name = scatter_widget._get_data() + np.testing.assert_allclose(x, feature_table[x_column]) + np.testing.assert_allclose(y, np.stack(feature_table[y_column])) + assert x_axis_name == x_column + assert y_axis_name == y_column def test_get_valid_axis_keys(make_napari_viewer): - """test the values returned from - FeaturesScatterWidget._get_valid_keys() when there + """ + Test the values returned from _get_valid_keys() when there are valid keys. """ # make the label image @@ -76,7 +93,7 @@ def test_get_valid_axis_keys(make_napari_viewer): def test_get_valid_axis_keys_no_valid_keys(make_napari_viewer): - """test the values returned from + """Test the values returned from FeaturesScatterWidget._get_valid_keys() when there are not valid keys. """ diff --git a/src/napari_matplotlib/tests/test_histogram.py b/src/napari_matplotlib/tests/test_histogram.py index f497a1a9..4d170014 100644 --- a/src/napari_matplotlib/tests/test_histogram.py +++ b/src/napari_matplotlib/tests/test_histogram.py @@ -1,8 +1,30 @@ +from copy import deepcopy + +import pytest + from napari_matplotlib import HistogramWidget -def test_example_q_widget(make_napari_viewer, image_data): - # Smoke test adding a histogram widget +@pytest.mark.mpl_image_compare +def test_histogram_2D(make_napari_viewer, astronaut_data): + viewer = make_napari_viewer() + viewer.theme = "light" + viewer.add_image(astronaut_data[0], **astronaut_data[1]) + fig = HistogramWidget(viewer).figure + # Need to return a copy, as original figure is too eagerley garbage + # collected by the widget + return deepcopy(fig) + + +@pytest.mark.mpl_image_compare +def test_histogram_3D(make_napari_viewer, brain_data): viewer = make_napari_viewer() - viewer.add_image(image_data[0], **image_data[1]) - HistogramWidget(viewer) + viewer.theme = "light" + viewer.add_image(brain_data[0], **brain_data[1]) + axis = viewer.dims.last_used + slice_no = brain_data[0].shape[0] - 1 + viewer.dims.set_current_step(axis, slice_no) + fig = HistogramWidget(viewer).figure + # Need to return a copy, as original figure is too eagerley garbage + # collected by the widget + return deepcopy(fig) diff --git a/src/napari_matplotlib/tests/test_layer_changes.py b/src/napari_matplotlib/tests/test_layer_changes.py new file mode 100644 index 00000000..bdd6c600 --- /dev/null +++ b/src/napari_matplotlib/tests/test_layer_changes.py @@ -0,0 +1,109 @@ +from copy import deepcopy +from typing import Any, Dict, Tuple, Type + +import numpy as np +import numpy.typing as npt +import pytest +from napari.viewer import Viewer + +from napari_matplotlib import ( + FeaturesScatterWidget, + HistogramWidget, + ScatterWidget, + SliceWidget, +) +from napari_matplotlib.base import NapariMPLWidget +from napari_matplotlib.tests.helpers import ( + assert_figures_equal, + assert_figures_not_equal, +) + + +@pytest.mark.parametrize( + "widget_cls, n_layers", + [(HistogramWidget, 1), (SliceWidget, 1), (ScatterWidget, 2)], +) +def test_change_one_layer( + make_napari_viewer, + brain_data, + astronaut_data, + widget_cls, + n_layers, +): + """ + Test all widgets that take one layer as input to make sure the plot changes + when the napari layer selection changes. + """ + viewer = make_napari_viewer() + + widget = widget_cls(viewer) + # Add n copies of two different datasets + for _ in range(n_layers): + viewer.add_image(brain_data[0], **brain_data[1]) + for _ in range(n_layers): + viewer.add_image(astronaut_data[0], **astronaut_data[1]) + + assert len(viewer.layers) == 2 * n_layers + assert_plot_changes(viewer, widget, n_layers=n_layers) + + +@pytest.mark.parametrize("widget_cls", [FeaturesScatterWidget]) +def test_change_features_layer( + make_napari_viewer, points_with_features_data, widget_cls +): + """ + Test all widgets that take one layer with features as input to make sure the + plot changes when the napari layer selection changes. + """ + viewer = make_napari_viewer() + assert_features_plot_changes(viewer, widget_cls, points_with_features_data) + + +def assert_features_plot_changes( + viewer: Viewer, + widget_cls: Type[NapariMPLWidget], + data: Tuple[npt.NDArray[np.generic], Dict[str, Any]], +) -> None: + """ + When the selected layer is changed, make sure the plot generated + by `widget_cls` also changes. + """ + widget = widget_cls(viewer) + viewer.add_points(data[0], **data[1]) + # Change the features data for the second layer + data[1]["features"] = { + name: data + 1 for name, data in data[1]["features"].items() + } + viewer.add_points(data[0], **data[1]) + assert_plot_changes(viewer, widget, n_layers=1) + + +def assert_plot_changes( + viewer: Viewer, widget: NapariMPLWidget, *, n_layers: int +) -> None: + """ + Assert that a widget plot changes when the layer selection + is changed. The passed viewer must already have (2 * n_layers) layers + loaded. + """ + # Select first layer(s) + viewer.layers.selection.clear() + + for i in range(n_layers): + viewer.layers.selection.add(viewer.layers[i]) + assert len(viewer.layers.selection) == n_layers + fig1 = deepcopy(widget.figure) + + # Re-selecting first layer(s) should produce identical plot + viewer.layers.selection.clear() + for i in range(n_layers): + viewer.layers.selection.add(viewer.layers[i]) + assert len(viewer.layers.selection) == n_layers + assert_figures_equal(widget.figure, fig1) + + # Plotting the second layer(s) should produce a different plot + viewer.layers.selection.clear() + for i in range(n_layers): + viewer.layers.selection.add(viewer.layers[n_layers + i]) + assert len(viewer.layers.selection) == n_layers + assert_figures_not_equal(widget.figure, fig1) diff --git a/src/napari_matplotlib/tests/test_slice.py b/src/napari_matplotlib/tests/test_slice.py index d0be3cc1..412e71c3 100644 --- a/src/napari_matplotlib/tests/test_slice.py +++ b/src/napari_matplotlib/tests/test_slice.py @@ -1,10 +1,30 @@ -import numpy as np +from copy import deepcopy + +import pytest from napari_matplotlib import SliceWidget -def test_scatter(make_napari_viewer): - # Smoke test adding a histogram widget +@pytest.mark.mpl_image_compare +def test_slice_3D(make_napari_viewer, brain_data): + viewer = make_napari_viewer() + viewer.theme = "light" + viewer.add_image(brain_data[0], **brain_data[1]) + axis = viewer.dims.last_used + slice_no = brain_data[0].shape[0] - 1 + viewer.dims.set_current_step(axis, slice_no) + fig = SliceWidget(viewer).figure + # Need to return a copy, as original figure is too eagerley garbage + # collected by the widget + return deepcopy(fig) + + +@pytest.mark.mpl_image_compare +def test_slice_2D(make_napari_viewer, astronaut_data): viewer = make_napari_viewer() - viewer.add_image(np.random.random((100, 100, 100))) - SliceWidget(viewer) + viewer.theme = "light" + viewer.add_image(astronaut_data[0], **astronaut_data[1]) + fig = SliceWidget(viewer).figure + # Need to return a copy, as original figure is too eagerley garbage + # collected by the widget + return deepcopy(fig) diff --git a/src/napari_matplotlib/tests/test_theme.py b/src/napari_matplotlib/tests/test_theme.py new file mode 100644 index 00000000..a3642f8f --- /dev/null +++ b/src/napari_matplotlib/tests/test_theme.py @@ -0,0 +1,190 @@ +import os +import shutil +from copy import deepcopy +from pathlib import Path + +import matplotlib +import napari +import numpy as np +import pytest +from matplotlib.colors import to_rgba + +from napari_matplotlib import HistogramWidget, ScatterWidget +from napari_matplotlib.base import NapariMPLWidget + + +@pytest.mark.parametrize( + "theme_name, expected_icons", + [("dark", "white"), ("light", "black")], +) +def test_theme_mpl_toolbar_icons( + make_napari_viewer, theme_name, expected_icons +): + """Check that the icons are taken from the correct folder for each napari theme.""" + viewer = make_napari_viewer() + viewer.theme = theme_name + path_to_icons = NapariMPLWidget(viewer)._get_path_to_icon() + assert path_to_icons.exists(), "The theme points to non-existant icons." + assert ( + path_to_icons.stem == expected_icons + ), "The theme is selecting unexpected icons." + + +def _mock_up_theme() -> None: + """Mock up a new color theme based on dark mode but with a tasteful blue background. + + Based on: + https://napari.org/stable/gallery/new_theme.html + """ + blue_theme = napari.utils.theme.get_theme("dark", False) + blue_theme.name = "blue" + blue_theme.background = "#4169e1" # my favourite shade of blue + napari.utils.theme.register_theme("blue", blue_theme) + + +def test_theme_background_check(make_napari_viewer): + """ + Check that the hue saturation lightness can distinguish dark and light backgrounds. + """ + viewer = make_napari_viewer() + widget = NapariMPLWidget(viewer) + + viewer.theme = "dark" + assert widget._napari_theme_has_light_bg() is False + + viewer.theme = "light" + assert widget._napari_theme_has_light_bg() is True + + _mock_up_theme() + viewer.theme = "blue" + assert widget._napari_theme_has_light_bg() is True + + +@pytest.mark.parametrize( + "theme_name, expected_text_colour", + [ + ("dark", "#f0f1f2"), # #f0f1f2 is a light grey (almost white) + ("light", "#3b3a39"), # #3b3a39 is a brownish dark grey (almost black) + ], +) +def test_titles_respect_theme( + make_napari_viewer, theme_name, expected_text_colour +): + """ + Test that the axis labels and titles are the correct color for the napari theme. + """ + viewer = make_napari_viewer() + widget = ScatterWidget(viewer) + viewer.theme = theme_name + + # make a scatter plot of two random layers + viewer.add_image(np.random.random((10, 10)), name="first test image") + viewer.add_image(np.random.random((10, 10)), name="second test image") + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[0]) + viewer.layers.selection.add(viewer.layers[1]) + + ax = widget.figure.gca() + + # sanity test to make sure we've got the correct image names + assert ax.xaxis.label.get_text() == "first test image" + assert ax.yaxis.label.get_text() == "second test image" + + # print(dir(ax.yaxis.label)) + # TODO: put checks of the axis tick labels here + + assert ax.xaxis.label.get_color() == expected_text_colour + assert ax.yaxis.label.get_color() == expected_text_colour + + +@pytest.mark.mpl_image_compare +def test_no_theme_side_effects(make_napari_viewer): + """Ensure that napari-matplotlib doesn't pollute the globally set style. + + A MWE to guard aganst issue matplotlib/#64. Should always reproduce a plot + with the default matplotlib style. + """ + import matplotlib.pyplot as plt + + np.random.seed(12345) + + # should not affect global matplotlib plot style + viewer = make_napari_viewer() + viewer.theme = "dark" + NapariMPLWidget(viewer) + + # some plotting unrelated to napari-matplotlib + normal_dist = np.random.normal(size=1000) + unrelated_figure, ax = plt.subplots() + ax.hist(normal_dist, bins=100) + ax.set_xlabel("something unrelated to napari (x)") + ax.set_ylabel("something unrelated to napari (y)") + ax.set_title( + "this plot style should not change with napari styles or themes" + ) + unrelated_figure.tight_layout() + + return unrelated_figure + + +@pytest.mark.mpl_image_compare +def test_custom_theme(make_napari_viewer, theme_path, brain_data): + viewer = make_napari_viewer() + viewer.theme = "dark" + + widget = ScatterWidget(viewer) + widget.mpl_style_sheet_path = theme_path + + viewer.add_image(brain_data[0], **brain_data[1], name="brain") + viewer.add_image( + brain_data[0] * -1, **brain_data[1], name="brain_reversed" + ) + + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[0]) + viewer.layers.selection.add(viewer.layers[1]) + + return deepcopy(widget.figure) + + +def find_mpl_stylesheet(name: str) -> Path: + """Find the built-in matplotlib stylesheet.""" + return Path(matplotlib.__path__[0]) / f"mpl-data/stylelib/{name}.mplstyle" + + +def test_custom_stylesheet(make_napari_viewer, image_data): + """ + Test that a stylesheet in the current directory is given precidence. + + Do this by copying over a stylesheet from matplotlib's built in styles, + naming it correctly, and checking the colours are as expected. + """ + # Copy Solarize_Light2 as if it was a user-overriden stylesheet. + style_sheet_path = ( + Path(matplotlib.get_configdir()) / "napari-matplotlib.mplstyle" + ) + if style_sheet_path.exists(): + pytest.skip("Won't ovewrite existing custom style sheet.") + shutil.copy( + find_mpl_stylesheet("Solarize_Light2"), + style_sheet_path, + ) + + try: + viewer = make_napari_viewer() + viewer.add_image(image_data[0], **image_data[1]) + widget = HistogramWidget(viewer) + assert widget.mpl_style_sheet_path == style_sheet_path + ax = widget.figure.gca() + + # The axes should have a light brownish grey background: + assert ax.get_facecolor() == to_rgba("#eee8d5") + assert ax.patch.get_facecolor() == to_rgba("#eee8d5") + + # The figure background and axis gridlines are light yellow: + assert widget.figure.patch.get_facecolor() == to_rgba("#fdf6e3") + for gridline in ax.get_xgridlines() + ax.get_ygridlines(): + assert gridline.get_visible() is True + assert gridline.get_color() == "#fdf6e3" + finally: + os.remove(style_sheet_path) diff --git a/src/napari_matplotlib/tests/test_ui.py b/src/napari_matplotlib/tests/test_ui.py new file mode 100644 index 00000000..3b30ad3a --- /dev/null +++ b/src/napari_matplotlib/tests/test_ui.py @@ -0,0 +1,49 @@ +import pytest +from qtpy.QtCore import QSize +from qtpy.QtGui import QImage + +from napari_matplotlib import HistogramWidget, ScatterWidget, SliceWidget + + +def _are_different(a: QImage, b: QImage) -> bool: + """ + Check that a and b are identical, pixel by pixel. Via a stupid nested for loop. + """ + assert not a.isNull() + assert not b.isNull() + assert a.size() == b.size() + for x in range(a.width()): + for y in range(a.height()): + if a.pixel(x, y) != b.pixel(x, y): + return True # exit quickly + return False + + +@pytest.mark.parametrize( + "Widget", [HistogramWidget, ScatterWidget, SliceWidget] +) +def test_mpl_toolbar_buttons_checked(make_napari_viewer, Widget): + """Test that the icons for checkable actions change when when a tool is selected. + + A simple test of NapariNavigationToolbar._update_buttons_checked. Make sure the + checked and unchecked icons are not the same. + """ + checkable_actions = ["Zoom", "Pan"] + + viewer = make_napari_viewer() + widget = Widget(viewer) + + # search through all of the icons for the ones whose icons are expected to + # change when checked + for action in widget.toolbar.actions(): + if action.text() in checkable_actions: + assert action.isChecked() is False + assert action.isCheckable() is True + unchecked = action.icon().pixmap(QSize(48, 48)).toImage() + + # simulate the user click (QTest.mouseClick can't take a QAction) + action.trigger() + + assert action.isChecked() is True + checked = action.icon().pixmap(QSize(48, 48)).toImage() + assert _are_different(unchecked, checked) diff --git a/src/napari_matplotlib/tests/test_util.py b/src/napari_matplotlib/tests/test_util.py index 22d943c9..a8792d41 100644 --- a/src/napari_matplotlib/tests/test_util.py +++ b/src/napari_matplotlib/tests/test_util.py @@ -1,6 +1,20 @@ +import importlib +import sys + import pytest +from qtpy.QtCore import QSize + +from napari_matplotlib.util import Interval, from_napari_css_get_size_of + + +def test_version_fallback(mocker): + """Test the versioning fallback (in case setuptools_scm didn't work)""" + import napari_matplotlib # fmt: skip + assert napari_matplotlib.__version__ != "unknown" # type: ignore[attr-defined] -from napari_matplotlib.util import Interval + mocker.patch.dict(sys.modules, {"napari_matplotlib._version": None}) + importlib.reload(napari_matplotlib) + assert napari_matplotlib.__version__ == "unknown" # type: ignore[attr-defined] def test_interval(): @@ -12,4 +26,54 @@ def test_interval(): assert 10 not in interval with pytest.raises(ValueError, match="must be an integer"): - "string" in interval + "string" in interval # type: ignore + + with pytest.raises(ValueError, match="must be <= upper_bound"): + Interval(5, 3) + + +@pytest.mark.parametrize( + "lower, upper, text", + [ + (None, None, None), + (1, None, "Select at least 1 layer to generate plot"), + (4, None, "Select at least 4 layers to generate plot"), + (None, 1, "Select at most 1 layer to generate plot"), + (None, 5939, "Select at most 5939 layers to generate plot"), + (1, 1, "Select 1 layer to generate plot"), + (2, 2, "Select 2 layers to generate plot"), + (1, 2, "Select between 1 and 2 layers to generate plot"), + ], +) +def test_interval_helper_text(lower, upper, text): + assert Interval(lower, upper)._helper_text == text + + +def test_get_size_from_css(mocker): + """Test getting the max-width and max-height from something in css""" + test_css = """ + Flibble { + min-width : 0; + max-width : 123px; + min-height : 0px; + max-height : 456px; + padding: 0px; + } + """ + mocker.patch("napari.qt.get_current_stylesheet").return_value = test_css + assert from_napari_css_get_size_of("Flibble", (1, 2)) == QSize(123, 456) + + +def test_fallback_if_missing_dimensions(mocker): + """Test fallback if given something that doesn't have dimensions""" + test_css = " Flobble { background-color: rgb(0, 97, 163); } " + mocker.patch("napari.qt.get_current_stylesheet").return_value = test_css + with pytest.warns(RuntimeWarning, match="Unable to find DimensionToken"): + assert from_napari_css_get_size_of("Flobble", (1, 2)) == QSize(1, 2) + + +def test_fallback_if_prelude_not_in_css(): + """Test fallback if given something not in the css""" + doesntexist = "AQButtonThatDoesntExist" + with pytest.warns(RuntimeWarning, match=f"Unable to find {doesntexist}"): + assert from_napari_css_get_size_of(doesntexist, (1, 2)) == QSize(1, 2) diff --git a/src/napari_matplotlib/util.py b/src/napari_matplotlib/util.py index be6a79c8..2aa15ddd 100644 --- a/src/napari_matplotlib/util.py +++ b/src/napari_matplotlib/util.py @@ -1,7 +1,16 @@ -from typing import Optional +from typing import List, Optional, Tuple, Union +from warnings import warn + +import napari.qt +import tinycss2 +from qtpy.QtCore import QSize class Interval: + """ + An integer interval. + """ + def __init__(self, lower_bound: Optional[int], upper_bound: Optional[int]): """ Parameters @@ -19,7 +28,16 @@ def __init__(self, lower_bound: Optional[int], upper_bound: Optional[int]): self.lower = lower_bound self.upper = upper_bound - def __contains__(self, val): + def __repr__(self) -> str: + """ + Get string representation. + """ + return f"Interval({self.lower}, {self.upper})" + + def __contains__(self, val: int) -> bool: + """ + Return True if val is in the current interval. + """ if not isinstance(val, int): raise ValueError("variable must be an integer") if self.lower is not None and val < self.lower: @@ -27,3 +45,96 @@ def __contains__(self, val): if self.upper is not None and val > self.upper: return False return True + + @property + def _helper_text(self) -> Optional[str]: + """ + Helper text for widgets. + """ + if self.lower is None and self.upper is None: + helper_text = None + elif self.lower is not None and self.upper is None: + helper_text = ( + f"Select at least {self.lower} layers to generate plot" + ) + elif self.lower is None and self.upper is not None: + helper_text = ( + f"Select at most {self.upper} layers to generate plot" + ) + elif self.lower == self.upper: + helper_text = f"Select {self.lower} layers to generate plot" + + else: + helper_text = ( + f"Select between {self.lower} and " + f"{self.upper} layers to generate plot" + ) + + if helper_text is not None: + helper_text = helper_text.replace("1 layers", "1 layer") + + return helper_text + + +def _has_id(nodes: List[tinycss2.ast.Node], id_name: str) -> bool: + """ + Is `id_name` in IdentTokens in the list of CSS `nodes`? + """ + return any( + [node.type == "ident" and node.value == id_name for node in nodes] + ) + + +def _get_dimension( + nodes: List[tinycss2.ast.Node], id_name: str +) -> Union[int, None]: + """ + Get the value of the DimensionToken for the IdentToken `id_name`. + + Returns + ------- + None if no IdentToken is found. + """ + cleaned_nodes = [node for node in nodes if node.type != "whitespace"] + for name, _, value, _ in zip(*(iter(cleaned_nodes),) * 4): + if ( + name.type == "ident" + and value.type == "dimension" + and name.value == id_name + ): + return value.int_value + warn(f"Unable to find DimensionToken for {id_name}", RuntimeWarning) + return None + + +def from_napari_css_get_size_of( + qt_element_name: str, fallback: Tuple[int, int] +) -> QSize: + """ + Get the size of `qt_element_name` from napari's current stylesheet. + + TODO: Confirm that the napari.qt.get_current_stylesheet changes with napari + theme (docs seem to indicate it should) + + Returns + ------- + QSize of the element if it's found, the `fallback` if it's not found.. + """ + rules = tinycss2.parse_stylesheet( + napari.qt.get_current_stylesheet(), + skip_comments=True, + skip_whitespace=True, + ) + w, h = None, None + for rule in rules: + if _has_id(rule.prelude, qt_element_name): + w = _get_dimension(rule.content, "max-width") + h = _get_dimension(rule.content, "max-height") + if w and h: + return QSize(w, h) + warn( + f"Unable to find {qt_element_name} or unable to find its size in " + f"the current Napari stylesheet, falling back to {fallback}", + RuntimeWarning, + ) + return QSize(*fallback) diff --git a/tox.ini b/tox.ini index 0e084df6..298887e1 100644 --- a/tox.ini +++ b/tox.ini @@ -1,4 +1,3 @@ -# For more information about tox, see https://tox.readthedocs.io/en/latest/ [tox] envlist = py{38,39,310} isolated_build = true @@ -11,4 +10,10 @@ python = [testenv] extras = testing -commands = python -m pytest -v --color=yes --cov=napari_matplotlib --cov-report=xml +allowlist_externals = + cp + ls +commands = + cp -R {toxinidir}/src/napari_matplotlib/tests/baseline {envdir}/baseline + ls {toxinidir}/src/napari_matplotlib/tests/baseline + python -m pytest --mpl --mpl-generate-summary=html --mpl-results-path={toxinidir}/reports -v --color=yes --cov=napari_matplotlib --cov-report=xml