diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index caf38d16..b78e43e1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,11 +6,6 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/asottile/setup-cfg-fmt - rev: v2.4.0 - hooks: - - id: setup-cfg-fmt - - repo: https://github.com/psf/black rev: 23.7.0 hooks: @@ -22,14 +17,14 @@ repos: - id: napari-plugin-checks - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.4.1 + rev: v1.5.1 hooks: - id: mypy additional_dependencies: [numpy, matplotlib] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.0.280' + rev: 'v0.0.285' hooks: - id: ruff diff --git a/baseline/test_feature_histogram2.png b/baseline/test_feature_histogram2.png new file mode 100644 index 00000000..b7bb19e0 Binary files /dev/null and b/baseline/test_feature_histogram2.png differ diff --git a/docs/changelog.rst b/docs/changelog.rst index 103cd80e..6f77e0c3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,5 +1,23 @@ Changelog ========= +1.1.0 +----- +Additions +~~~~~~~~~ +- Added a widget to draw a histogram of features. + +Changes +~~~~~~~ +- The slice widget is now limited to slicing along the x/y dimensions. Support + for slicing along z has been removed for now to make the code simpler. +- The slice widget now uses a slider to select the slice value. + +Bug fixes +~~~~~~~~~ +- Fixed creating 1D slices of 2D images. +- Removed the limitation that only the first 99 indices could be sliced using + the slice widget. + 1.0.2 ----- Bug fixes diff --git a/docs/user_guide.rst b/docs/user_guide.rst index 0872e540..fbd48db1 100644 --- a/docs/user_guide.rst +++ b/docs/user_guide.rst @@ -30,6 +30,7 @@ These widgets plot the data stored in the ``.features`` attribute of individual Currently available are: - 2D scatter plots of two features against each other. +- Histograms of individual features. To use these: diff --git a/setup.cfg b/setup.cfg index 229b5777..e1fc9e73 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,7 +55,7 @@ docs = sphinx-automodapi sphinx-gallery testing = - napari[pyqt6-experimental]>=0.4.18 + napari[pyqt6_experimental]>=0.4.18 pooch pyqt6 pytest diff --git a/src/napari_matplotlib/features.py b/src/napari_matplotlib/features.py new file mode 100644 index 00000000..34abf104 --- /dev/null +++ b/src/napari_matplotlib/features.py @@ -0,0 +1,9 @@ +from napari.layers import Labels, Points, Shapes, Tracks, Vectors + +FEATURES_LAYER_TYPES = ( + Labels, + Points, + Shapes, + Tracks, + Vectors, +) diff --git a/src/napari_matplotlib/histogram.py b/src/napari_matplotlib/histogram.py index ab098f38..66aa7acc 100644 --- a/src/napari_matplotlib/histogram.py +++ b/src/napari_matplotlib/histogram.py @@ -1,13 +1,15 @@ -from typing import Optional +from typing import Any, List, Optional, Tuple import napari import numpy as np -from qtpy.QtWidgets import QWidget +import numpy.typing as npt +from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget from .base import SingleAxesWidget +from .features import FEATURES_LAYER_TYPES from .util import Interval -__all__ = ["HistogramWidget"] +__all__ = ["HistogramWidget", "FeaturesHistogramWidget"] _COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"} @@ -61,3 +63,112 @@ def draw(self) -> None: self.axes.hist(data.ravel(), bins=bins, label=layer.name) self.axes.legend() + + +class FeaturesHistogramWidget(SingleAxesWidget): + """ + Display a histogram of selected feature attached to selected layer. + """ + + n_layers_input = Interval(1, 1) + # All layers that have a .features attributes + input_layer_types = FEATURES_LAYER_TYPES + + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(napari_viewer, parent=parent) + + self.layout().addLayout(QVBoxLayout()) + self._key_selection_widget = QComboBox() + self.layout().addWidget(QLabel("Key:")) + self.layout().addWidget(self._key_selection_widget) + + self._key_selection_widget.currentTextChanged.connect( + self._set_axis_keys + ) + + self._update_layers(None) + + @property + def x_axis_key(self) -> Optional[str]: + """Key to access x axis data from the FeaturesTable""" + return self._x_axis_key + + @x_axis_key.setter + def x_axis_key(self, key: Optional[str]) -> None: + self._x_axis_key = key + self._draw() + + def _set_axis_keys(self, x_axis_key: str) -> None: + """Set both axis keys and then redraw the plot""" + self._x_axis_key = x_axis_key + self._draw() + + def _get_valid_axis_keys(self) -> List[str]: + """ + Get the valid axis keys from the layer FeatureTable. + + Returns + ------- + axis_keys : List[str] + The valid axis keys in the FeatureTable. If the table is empty + or there isn't a table, returns an empty list. + """ + if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")): + return [] + else: + return self.layers[0].features.keys() + + def _get_data(self) -> Tuple[Optional[npt.NDArray[Any]], str]: + """Get the plot data. + + Returns + ------- + data : List[np.ndarray] + List contains X and Y columns from the FeatureTable. Returns + an empty array if nothing to plot. + x_axis_name : str + The title to display on the x axis. Returns + an empty string if nothing to plot. + """ + if not hasattr(self.layers[0], "features"): + # if the selected layer doesn't have a featuretable, + # skip draw + return None, "" + + feature_table = self.layers[0].features + + if (len(feature_table) == 0) or (self.x_axis_key is None): + return None, "" + + data = feature_table[self.x_axis_key] + x_axis_name = self.x_axis_key.replace("_", " ") + + return data, x_axis_name + + def on_update_layers(self) -> None: + """ + Called when the layer selection changes by ``self.update_layers()``. + """ + # reset the axis keys + self._x_axis_key = None + + # Clear combobox + self._key_selection_widget.clear() + self._key_selection_widget.addItems(self._get_valid_axis_keys()) + + def draw(self) -> None: + """Clear the axes and histogram the currently selected layer/slice.""" + data, x_axis_name = self._get_data() + + if data is None: + return + + self.axes.hist(data, bins=50, edgecolor="white", linewidth=0.3) + + # set ax labels + self.axes.set_xlabel(x_axis_name) + self.axes.set_ylabel("Counts [#]") diff --git a/src/napari_matplotlib/napari.yaml b/src/napari_matplotlib/napari.yaml index b736592b..71af0ca6 100644 --- a/src/napari_matplotlib/napari.yaml +++ b/src/napari_matplotlib/napari.yaml @@ -14,6 +14,10 @@ contributions: python_name: napari_matplotlib:FeaturesScatterWidget title: Make a scatter plot of layer features + - id: napari-matplotlib.features_histogram + python_name: napari_matplotlib:FeaturesHistogramWidget + title: Plot feature histograms + - id: napari-matplotlib.slice python_name: napari_matplotlib:SliceWidget title: Plot a 1D slice @@ -28,5 +32,8 @@ contributions: - command: napari-matplotlib.features_scatter display_name: FeaturesScatter + - command: napari-matplotlib.features_histogram + display_name: FeaturesHistogram + - command: napari-matplotlib.slice display_name: 1D slice diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index db86c7f3..a4148bd2 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -5,6 +5,7 @@ from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget from .base import SingleAxesWidget +from .features import FEATURES_LAYER_TYPES from .util import Interval __all__ = ["ScatterBaseWidget", "ScatterWidget", "FeaturesScatterWidget"] @@ -94,13 +95,7 @@ class FeaturesScatterWidget(ScatterBaseWidget): n_layers_input = Interval(1, 1) # All layers that have a .features attributes - input_layer_types = ( - napari.layers.Labels, - napari.layers.Points, - napari.layers.Shapes, - napari.layers.Tracks, - napari.layers.Vectors, - ) + input_layer_types = FEATURES_LAYER_TYPES def __init__( self, diff --git a/src/napari_matplotlib/slice.py b/src/napari_matplotlib/slice.py index e3aa80b2..393f2e45 100644 --- a/src/napari_matplotlib/slice.py +++ b/src/napari_matplotlib/slice.py @@ -1,19 +1,23 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, List, Optional, Tuple import matplotlib.ticker as mticker import napari import numpy as np import numpy.typing as npt -from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox, QWidget +from qtpy.QtCore import Qt +from qtpy.QtWidgets import ( + QComboBox, + QLabel, + QSlider, + QVBoxLayout, + QWidget, +) from .base import SingleAxesWidget from .util import Interval __all__ = ["SliceWidget"] -_dims_sel = ["x", "y"] -_dims = ["x", "y", "z"] - class SliceWidget(SingleAxesWidget): """ @@ -31,28 +35,46 @@ def __init__( # Setup figure/axes super().__init__(napari_viewer, parent=parent) - button_layout = QHBoxLayout() - self.layout().addLayout(button_layout) - self.dim_selector = QComboBox() + self.dim_selector.addItems(["x", "y"]) + + self.slice_selector = QSlider(orientation=Qt.Orientation.Horizontal) + + # Create widget layout + button_layout = QVBoxLayout() button_layout.addWidget(QLabel("Slice axis:")) button_layout.addWidget(self.dim_selector) - self.dim_selector.addItems(_dims) - - self.slice_selectors = {} - for d in _dims_sel: - self.slice_selectors[d] = QSpinBox() - button_layout.addWidget(QLabel(f"{d}:")) - button_layout.addWidget(self.slice_selectors[d]) + button_layout.addWidget(self.slice_selector) + self.layout().addLayout(button_layout) # Setup callbacks - # Re-draw when any of the combon/spin boxes are updated + # Re-draw when any of the combo/slider is updated self.dim_selector.currentTextChanged.connect(self._draw) - for d in _dims_sel: - self.slice_selectors[d].textChanged.connect(self._draw) + self.slice_selector.valueChanged.connect(self._draw) self._update_layers(None) + def on_update_layers(self) -> None: + """ + Called when layer selection is updated. + """ + if not len(self.layers): + return + if self.current_dim_name == "x": + max = self._layer.data.shape[-2] + elif self.current_dim_name == "y": + max = self._layer.data.shape[-1] + else: + raise RuntimeError("dim name must be x or y") + self.slice_selector.setRange(0, max - 1) + + @property + def _slice_width(self) -> int: + """ + Width of the slice being plotted. + """ + return self._layer.data.shape[self.current_dim_index] + @property def _layer(self) -> napari.layers.Layer: """ @@ -61,7 +83,7 @@ def _layer(self) -> napari.layers.Layer: return self.layers[0] @property - def current_dim(self) -> str: + def current_dim_name(self) -> str: """ Currently selected slice dimension. """ @@ -74,36 +96,40 @@ def current_dim_index(self) -> int: """ # Note the reversed list because in napari the z-axis is the first # numpy axis - return _dims[::-1].index(self.current_dim) + return self._dim_names.index(self.current_dim_name) @property - def _selector_values(self) -> Dict[str, int]: + def _dim_names(self) -> List[str]: """ - Values of the slice selectors. + List of dimension names. This is a property as it varies depending on the + dimensionality of the currently selected data. """ - return {d: self.slice_selectors[d].value() for d in _dims_sel} + if self._layer.data.ndim == 2: + return ["y", "x"] + elif self._layer.data.ndim == 3: + return ["z", "y", "x"] + else: + raise RuntimeError("Don't know how to handle ndim != 2 or 3") def _get_xy(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any]]: """ Get data for plotting. """ - x = np.arange(self._layer.data.shape[self.current_dim_index]) - - vals = self._selector_values - vals.update({"z": self.current_z}) + val = self.slice_selector.value() slices = [] - for d in _dims: - if d == self.current_dim: + for dim_name in self._dim_names: + if dim_name == self.current_dim_name: # Select all data along this axis slices.append(slice(None)) + elif dim_name == "z": + # Only select the currently viewed z-index + slices.append(slice(self.current_z, self.current_z + 1)) else: # Select specific index - val = vals[d] slices.append(slice(val, val + 1)) - # Reverse since z is the first axis in napari - slices = slices[::-1] + x = np.arange(self._slice_width) y = self._layer.data[tuple(slices)].ravel() return x, y @@ -115,7 +141,7 @@ def draw(self) -> None: x, y = self._get_xy() self.axes.plot(x, y) - self.axes.set_xlabel(self.current_dim) + self.axes.set_xlabel(self.current_dim_name) self.axes.set_title(self._layer.name) # Make sure all ticks lie on integer values self.axes.xaxis.set_major_locator( diff --git a/src/napari_matplotlib/tests/baseline/test_feature_histogram.png b/src/napari_matplotlib/tests/baseline/test_feature_histogram.png new file mode 100644 index 00000000..1892af44 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_feature_histogram.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_feature_histogram2.png b/src/napari_matplotlib/tests/baseline/test_feature_histogram2.png new file mode 100644 index 00000000..b7bb19e0 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_feature_histogram2.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_slice_2D.png b/src/napari_matplotlib/tests/baseline/test_slice_2D.png index 639f25b8..ee3ce3b6 100644 Binary files a/src/napari_matplotlib/tests/baseline/test_slice_2D.png and b/src/napari_matplotlib/tests/baseline/test_slice_2D.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_slice_3D.png b/src/napari_matplotlib/tests/baseline/test_slice_3D.png index 43c8c3b6..c30211da 100644 Binary files a/src/napari_matplotlib/tests/baseline/test_slice_3D.png and b/src/napari_matplotlib/tests/baseline/test_slice_3D.png differ diff --git a/src/napari_matplotlib/tests/test_histogram.py b/src/napari_matplotlib/tests/test_histogram.py index 4d170014..006c042f 100644 --- a/src/napari_matplotlib/tests/test_histogram.py +++ b/src/napari_matplotlib/tests/test_histogram.py @@ -1,8 +1,13 @@ from copy import deepcopy +import numpy as np import pytest -from napari_matplotlib import HistogramWidget +from napari_matplotlib import FeaturesHistogramWidget, HistogramWidget +from napari_matplotlib.tests.helpers import ( + assert_figures_equal, + assert_figures_not_equal, +) @pytest.mark.mpl_image_compare @@ -28,3 +33,90 @@ def test_histogram_3D(make_napari_viewer, brain_data): # Need to return a copy, as original figure is too eagerley garbage # collected by the widget return deepcopy(fig) + + +def test_feature_histogram(make_napari_viewer): + n_points = 1000 + random_points = np.random.random((n_points, 3)) * 10 + feature1 = np.random.random(n_points) + feature2 = np.random.normal(size=n_points) + + viewer = make_napari_viewer() + viewer.add_points( + random_points, + properties={"feature1": feature1, "feature2": feature2}, + name="points1", + ) + viewer.add_points( + random_points, + properties={"feature1": feature1, "feature2": feature2}, + name="points2", + ) + + widget = FeaturesHistogramWidget(viewer) + viewer.window.add_dock_widget(widget) + + # Check whether changing the selected key changes the plot + widget._set_axis_keys("feature1") + fig1 = deepcopy(widget.figure) + + widget._set_axis_keys("feature2") + assert_figures_not_equal(widget.figure, fig1) + + # check whether selecting a different layer produces the same plot + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[1]) + assert_figures_equal(widget.figure, fig1) + + +@pytest.mark.mpl_image_compare +def test_feature_histogram2(make_napari_viewer): + import numpy as np + + np.random.seed(0) + n_points = 1000 + random_points = np.random.random((n_points, 3)) * 10 + feature1 = np.random.random(n_points) + feature2 = np.random.normal(size=n_points) + + viewer = make_napari_viewer() + viewer.add_points( + random_points, + properties={"feature1": feature1, "feature2": feature2}, + name="points1", + ) + viewer.add_points( + random_points, + properties={"feature1": feature1, "feature2": feature2}, + name="points2", + ) + + widget = FeaturesHistogramWidget(viewer) + viewer.window.add_dock_widget(widget) + widget._set_axis_keys("feature1") + + fig = FeaturesHistogramWidget(viewer).figure + return deepcopy(fig) + + +def test_change_layer(make_napari_viewer, brain_data, astronaut_data): + viewer = make_napari_viewer() + widget = HistogramWidget(viewer) + + viewer.add_image(brain_data[0], **brain_data[1]) + viewer.add_image(astronaut_data[0], **astronaut_data[1]) + + # Select first layer + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[0]) + fig1 = deepcopy(widget.figure) + + # Re-selecting first layer should produce identical plot + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[0]) + assert_figures_equal(widget.figure, fig1) + + # Plotting the second layer should produce a different plot + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[1]) + assert_figures_not_equal(widget.figure, fig1) diff --git a/src/napari_matplotlib/tests/test_slice.py b/src/napari_matplotlib/tests/test_slice.py index 412e71c3..368a7ded 100644 --- a/src/napari_matplotlib/tests/test_slice.py +++ b/src/napari_matplotlib/tests/test_slice.py @@ -9,9 +9,13 @@ def test_slice_3D(make_napari_viewer, brain_data): viewer = make_napari_viewer() viewer.theme = "light" - viewer.add_image(brain_data[0], **brain_data[1]) + + data = brain_data[0] + assert data.ndim == 3, data.shape + viewer.add_image(data, **brain_data[1]) + axis = viewer.dims.last_used - slice_no = brain_data[0].shape[0] - 1 + slice_no = data.shape[0] - 1 viewer.dims.set_current_step(axis, slice_no) fig = SliceWidget(viewer).figure # Need to return a copy, as original figure is too eagerley garbage @@ -23,8 +27,37 @@ def test_slice_3D(make_napari_viewer, brain_data): def test_slice_2D(make_napari_viewer, astronaut_data): viewer = make_napari_viewer() viewer.theme = "light" - viewer.add_image(astronaut_data[0], **astronaut_data[1]) + + # Take first RGB channel + data = astronaut_data[0][:, :, 0] + assert data.ndim == 2, data.shape + viewer.add_image(data) + fig = SliceWidget(viewer).figure # Need to return a copy, as original figure is too eagerley garbage # collected by the widget return deepcopy(fig) + + +def test_slice_axes(make_napari_viewer, astronaut_data): + viewer = make_napari_viewer() + viewer.theme = "light" + + # Take first RGB channel + data = astronaut_data[0][:256, :, 0] + # Shape: + # x: 0 > 512 + # y: 0 > 256 + assert data.ndim == 2, data.shape + # Make sure data isn't square for later tests + assert data.shape[0] != data.shape[1] + viewer.add_image(data) + + widget = SliceWidget(viewer) + assert widget._dim_names == ["y", "x"] + assert widget.current_dim_name == "x" + assert widget.slice_selector.value() == 0 + assert widget.slice_selector.minimum() == 0 + assert widget.slice_selector.maximum() == data.shape[0] - 1 + # x/y are flipped in napari + assert widget._slice_width == data.shape[1]