diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index ed7f4cd6..f3da9081 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -1,5 +1,6 @@ import os from pathlib import Path +from typing import Tuple import matplotlib as mpl import napari @@ -10,6 +11,8 @@ from qtpy.QtGui import QIcon from qtpy.QtWidgets import QVBoxLayout, QWidget +from .util import Interval + mpl.rc("axes", edgecolor="white") mpl.rc("axes", facecolor="#262930") mpl.rc("axes", labelcolor="white") @@ -65,6 +68,11 @@ def __init__(self, napari_viewer: napari.viewer.Viewer): self.setup_callbacks() + # Accept any number of input layers by default + n_layers_input = Interval(None, None) + # Accept any type of input layer by default + input_layer_types: Tuple[napari.layers.Layer, ...] = (napari.layers.Layer,) + @property def n_selected_layers(self) -> int: """ @@ -104,10 +112,10 @@ def _draw(self) -> None: figure if so. """ self.clear() - if self.n_selected_layers != self.n_layers_input: - self.canvas.draw() - return - self.draw() + if self.n_selected_layers in self.n_layers_input and all( + isinstance(layer, self.input_layer_types) for layer in self.layers + ): + self.draw() self.canvas.draw() def clear(self) -> None: diff --git a/src/napari_matplotlib/histogram.py b/src/napari_matplotlib/histogram.py index 2180bf08..1f479ad7 100644 --- a/src/napari_matplotlib/histogram.py +++ b/src/napari_matplotlib/histogram.py @@ -6,6 +6,8 @@ import napari +from .util import Interval + _COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"} @@ -14,7 +16,8 @@ class HistogramWidget(NapariMPLWidget): Display a histogram of the currently selected layer. """ - n_layers_input = 1 + n_layers_input = Interval(1, 1) + input_layer_types = (napari.layers.Image,) def __init__(self, napari_viewer: napari.viewer.Viewer): super().__init__(napari_viewer) diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 324e9126..54689868 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -6,6 +6,7 @@ from magicgui import magicgui from .base import NapariMPLWidget +from .util import Interval __all__ = ["ScatterWidget", "FeaturesScatterWidget"] @@ -84,7 +85,8 @@ class ScatterWidget(ScatterBaseWidget): of a scatter plot, to avoid too many scatter points. """ - n_layers_input = 2 + n_layers_input = Interval(2, 2) + input_layer_types = (napari.layers.Image,) def __init__( self, @@ -114,7 +116,15 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]: class FeaturesScatterWidget(ScatterBaseWidget): - n_layers_input = 1 + n_layers_input = Interval(1, 1) + # All layers that have a .features attributes + input_layer_types = ( + napari.layers.Labels, + napari.layers.Points, + napari.layers.Shapes, + napari.layers.Tracks, + napari.layers.Vectors, + ) def __init__( self, @@ -162,7 +172,8 @@ def _set_axis_keys(self, x_axis_key: str, y_axis_key: str): self._draw() def _get_valid_axis_keys(self, combo_widget=None) -> List[str]: - """Get the valid axis keys from the layer FeatureTable. + """ + Get the valid axis keys from the layer FeatureTable. Returns ------- diff --git a/src/napari_matplotlib/slice.py b/src/napari_matplotlib/slice.py index 1d6407f5..5ab122c0 100644 --- a/src/napari_matplotlib/slice.py +++ b/src/napari_matplotlib/slice.py @@ -4,7 +4,8 @@ import numpy as np from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox -from napari_matplotlib.base import NapariMPLWidget +from .base import NapariMPLWidget +from .util import Interval __all__ = ["SliceWidget"] @@ -17,7 +18,8 @@ class SliceWidget(NapariMPLWidget): Plot a 1D slice along a given dimension. """ - n_layers_input = 1 + n_layers_input = Interval(1, 1) + input_layer_types = (napari.layers.Image,) def __init__(self, napari_viewer: napari.viewer.Viewer): # Setup figure/axes diff --git a/src/napari_matplotlib/tests/test_util.py b/src/napari_matplotlib/tests/test_util.py new file mode 100644 index 00000000..22d943c9 --- /dev/null +++ b/src/napari_matplotlib/tests/test_util.py @@ -0,0 +1,15 @@ +import pytest + +from napari_matplotlib.util import Interval + + +def test_interval(): + interval = Interval(4, 9) + for i in range(4, 10): + assert i in interval + + assert 3 not in interval + assert 10 not in interval + + with pytest.raises(ValueError, match="must be an integer"): + "string" in interval diff --git a/src/napari_matplotlib/util.py b/src/napari_matplotlib/util.py new file mode 100644 index 00000000..be6a79c8 --- /dev/null +++ b/src/napari_matplotlib/util.py @@ -0,0 +1,29 @@ +from typing import Optional + + +class Interval: + def __init__(self, lower_bound: Optional[int], upper_bound: Optional[int]): + """ + Parameters + ---------- + lower_bound, upper_bound: + Bounds. Use `None` to specify an open bound. + """ + if ( + lower_bound is not None + and upper_bound is not None + and lower_bound > upper_bound + ): + raise ValueError("lower_bound must be <= upper_bound") + + self.lower = lower_bound + self.upper = upper_bound + + def __contains__(self, val): + if not isinstance(val, int): + raise ValueError("variable must be an integer") + if self.lower is not None and val < self.lower: + return False + if self.upper is not None and val > self.upper: + return False + return True