From caca5d0d1b59ab47b3fb0556d05cba1134a3bc1c Mon Sep 17 00:00:00 2001 From: David Stansby Date: Sun, 8 May 2022 10:45:03 +0100 Subject: [PATCH 1/3] Add layer number/type validation --- src/napari_matplotlib/base.py | 15 +++++++++++---- src/napari_matplotlib/histogram.py | 5 ++++- src/napari_matplotlib/scatter.py | 4 +++- src/napari_matplotlib/slice.py | 6 ++++-- src/napari_matplotlib/util.py | 29 +++++++++++++++++++++++++++++ 5 files changed, 51 insertions(+), 8 deletions(-) create mode 100644 src/napari_matplotlib/util.py diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index ed7f4cd6..4613edec 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -10,6 +10,8 @@ from qtpy.QtGui import QIcon from qtpy.QtWidgets import QVBoxLayout, QWidget +from .util import Interval + mpl.rc("axes", edgecolor="white") mpl.rc("axes", facecolor="#262930") mpl.rc("axes", labelcolor="white") @@ -65,6 +67,11 @@ def __init__(self, napari_viewer: napari.viewer.Viewer): self.setup_callbacks() + # Accept any number of input layers by default + n_layers_input = Interval(None, None) + # Accept any type of input layer by default + input_layer_types = (napari.layers.Layer,) + @property def n_selected_layers(self) -> int: """ @@ -104,10 +111,10 @@ def _draw(self) -> None: figure if so. """ self.clear() - if self.n_selected_layers != self.n_layers_input: - self.canvas.draw() - return - self.draw() + if self.n_selected_layers in self.n_layers_input and all( + isinstance(layer, self.input_layer_types) for layer in self.layers + ): + self.draw() self.canvas.draw() def clear(self) -> None: diff --git a/src/napari_matplotlib/histogram.py b/src/napari_matplotlib/histogram.py index 2180bf08..1f479ad7 100644 --- a/src/napari_matplotlib/histogram.py +++ b/src/napari_matplotlib/histogram.py @@ -6,6 +6,8 @@ import napari +from .util import Interval + _COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"} @@ -14,7 +16,8 @@ class HistogramWidget(NapariMPLWidget): Display a histogram of the currently selected layer. """ - n_layers_input = 1 + n_layers_input = Interval(1, 1) + input_layer_types = (napari.layers.Image,) def __init__(self, napari_viewer: napari.viewer.Viewer): super().__init__(napari_viewer) diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 324e9126..01878fc9 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -6,6 +6,7 @@ from magicgui import magicgui from .base import NapariMPLWidget +from .util import Interval __all__ = ["ScatterWidget", "FeaturesScatterWidget"] @@ -84,7 +85,8 @@ class ScatterWidget(ScatterBaseWidget): of a scatter plot, to avoid too many scatter points. """ - n_layers_input = 2 + n_layers_input = Interval(2, 2) + input_layer_types = (napari.layers.Image,) def __init__( self, diff --git a/src/napari_matplotlib/slice.py b/src/napari_matplotlib/slice.py index 1d6407f5..5ab122c0 100644 --- a/src/napari_matplotlib/slice.py +++ b/src/napari_matplotlib/slice.py @@ -4,7 +4,8 @@ import numpy as np from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox -from napari_matplotlib.base import NapariMPLWidget +from .base import NapariMPLWidget +from .util import Interval __all__ = ["SliceWidget"] @@ -17,7 +18,8 @@ class SliceWidget(NapariMPLWidget): Plot a 1D slice along a given dimension. """ - n_layers_input = 1 + n_layers_input = Interval(1, 1) + input_layer_types = (napari.layers.Image,) def __init__(self, napari_viewer: napari.viewer.Viewer): # Setup figure/axes diff --git a/src/napari_matplotlib/util.py b/src/napari_matplotlib/util.py new file mode 100644 index 00000000..9d3db2f3 --- /dev/null +++ b/src/napari_matplotlib/util.py @@ -0,0 +1,29 @@ +from typing import Optional + + +class Interval: + def __init__(self, lower_bound: Optional[int], upper_bound: Optional[int]): + """ + Parameters + ---------- + lower_bound, upper_bound: + Bounds. Use `None` to specify an open bound. + """ + if ( + lower_bound is not None + and upper_bound is not None + and lower_bound > upper_bound + ): + raise ValueError("lower_bound must be <= upper_bound") + + self.lower = lower_bound + self.upper = upper_bound + + def __contains__(self, val): + if not isinstance(val, int): + return NotImplemented + if self.lower is not None and val < self.lower: + return False + if self.upper is not None and val > self.upper: + return False + return True From 578703a7dd163c0fe61a9166197faafb0e8a7da4 Mon Sep 17 00:00:00 2001 From: David Stansby Date: Thu, 19 May 2022 18:21:51 +0100 Subject: [PATCH 2/3] Add typing to FeaturesScatterWidget --- src/napari_matplotlib/base.py | 3 ++- src/napari_matplotlib/scatter.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index 4613edec..f3da9081 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -1,5 +1,6 @@ import os from pathlib import Path +from typing import Tuple import matplotlib as mpl import napari @@ -70,7 +71,7 @@ def __init__(self, napari_viewer: napari.viewer.Viewer): # Accept any number of input layers by default n_layers_input = Interval(None, None) # Accept any type of input layer by default - input_layer_types = (napari.layers.Layer,) + input_layer_types: Tuple[napari.layers.Layer, ...] = (napari.layers.Layer,) @property def n_selected_layers(self) -> int: diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 01878fc9..54689868 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -116,7 +116,15 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]: class FeaturesScatterWidget(ScatterBaseWidget): - n_layers_input = 1 + n_layers_input = Interval(1, 1) + # All layers that have a .features attributes + input_layer_types = ( + napari.layers.Labels, + napari.layers.Points, + napari.layers.Shapes, + napari.layers.Tracks, + napari.layers.Vectors, + ) def __init__( self, @@ -164,7 +172,8 @@ def _set_axis_keys(self, x_axis_key: str, y_axis_key: str): self._draw() def _get_valid_axis_keys(self, combo_widget=None) -> List[str]: - """Get the valid axis keys from the layer FeatureTable. + """ + Get the valid axis keys from the layer FeatureTable. Returns ------- From 5dbde280b61b84abe59462e2fae86c4e8ee1f1f3 Mon Sep 17 00:00:00 2001 From: David Stansby Date: Thu, 19 May 2022 18:30:49 +0100 Subject: [PATCH 3/3] Add interval tests --- src/napari_matplotlib/tests/test_util.py | 15 +++++++++++++++ src/napari_matplotlib/util.py | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 src/napari_matplotlib/tests/test_util.py diff --git a/src/napari_matplotlib/tests/test_util.py b/src/napari_matplotlib/tests/test_util.py new file mode 100644 index 00000000..22d943c9 --- /dev/null +++ b/src/napari_matplotlib/tests/test_util.py @@ -0,0 +1,15 @@ +import pytest + +from napari_matplotlib.util import Interval + + +def test_interval(): + interval = Interval(4, 9) + for i in range(4, 10): + assert i in interval + + assert 3 not in interval + assert 10 not in interval + + with pytest.raises(ValueError, match="must be an integer"): + "string" in interval diff --git a/src/napari_matplotlib/util.py b/src/napari_matplotlib/util.py index 9d3db2f3..be6a79c8 100644 --- a/src/napari_matplotlib/util.py +++ b/src/napari_matplotlib/util.py @@ -21,7 +21,7 @@ def __init__(self, lower_bound: Optional[int], upper_bound: Optional[int]): def __contains__(self, val): if not isinstance(val, int): - return NotImplemented + raise ValueError("variable must be an integer") if self.lower is not None and val < self.lower: return False if self.upper is not None and val > self.upper: