diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index 8c717d6a..f6c688df 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import List, Optional, Tuple +import matplotlib.style import napari from matplotlib.axes import Axes from matplotlib.backends.backend_qtagg import ( @@ -41,9 +42,11 @@ def __init__( super().__init__(parent=parent) self.viewer = napari_viewer + has_mpl_stylesheet = self._apply_user_stylesheet_if_present() self.canvas = FigureCanvas() - self.canvas.figure.patch.set_facecolor("none") + if not has_mpl_stylesheet: + self.canvas.figure.patch.set_facecolor("none") self.canvas.figure.set_layout_engine("constrained") self.toolbar = NapariNavigationToolbar( self.canvas, parent=self @@ -70,10 +73,16 @@ def add_single_axes(self) -> None: The Axes is saved on the ``.axes`` attribute for later access. """ self.axes = self.figure.subplots() - self.apply_napari_colorscheme(self.axes) + self.apply_style(self.axes) + + def apply_style(self, ax: Axes) -> None: + """ + Use the user-supplied stylesheet if present, otherwise apply the + napari-compatible colorscheme (theme-dependent) to an Axes. + """ + if self._apply_user_stylesheet_if_present(): + return - def apply_napari_colorscheme(self, ax: Axes) -> None: - """Apply napari-compatible colorscheme to an Axes.""" # get the foreground colours from current theme theme = napari.utils.theme.get_theme(self.viewer.theme, as_dict=False) fg_colour = theme.foreground.as_hex() # fg is a muted contrast to bg @@ -93,6 +102,20 @@ def apply_napari_colorscheme(self, ax: Axes) -> None: ax.tick_params(axis="x", colors=text_colour) ax.tick_params(axis="y", colors=text_colour) + def _apply_user_stylesheet_if_present(self) -> bool: + """ + Apply the user-supplied stylesheet if present. + + Returns + ------- + True if the stylesheet was present and applied. + False otherwise. + """ + if (Path.cwd() / "user.mplstyle").exists(): + matplotlib.style.use("./user.mplstyle") + return True + return False + def _on_theme_change(self) -> None: """Update MPL toolbar and axis styling when `napari.Viewer.theme` is changed. @@ -101,7 +124,7 @@ def _on_theme_change(self) -> None: """ self._replace_toolbar_icons() if self.figure.gca(): - self.apply_napari_colorscheme(self.figure.gca()) + self.apply_style(self.figure.gca()) def _theme_has_light_bg(self) -> bool: """ @@ -245,7 +268,7 @@ def _draw(self) -> None: isinstance(layer, self.input_layer_types) for layer in self.layers ): self.draw() - self.apply_napari_colorscheme(self.figure.gca()) + self.apply_style(self.figure.gca()) self.canvas.draw() def clear(self) -> None: diff --git a/src/napari_matplotlib/tests/test_theme.py b/src/napari_matplotlib/tests/test_theme.py index cf841d2b..988b6c35 100644 --- a/src/napari_matplotlib/tests/test_theme.py +++ b/src/napari_matplotlib/tests/test_theme.py @@ -1,8 +1,13 @@ +import shutil +from pathlib import Path + +import matplotlib import napari import numpy as np import pytest +from matplotlib.colors import to_rgba -from napari_matplotlib import ScatterWidget +from napari_matplotlib import HistogramWidget, ScatterWidget from napari_matplotlib.base import NapariMPLWidget @@ -91,7 +96,7 @@ def test_titles_respect_theme( @pytest.mark.mpl_image_compare -def test_no_theme_side_effects(make_napari_viewer): +def test_no_theme_side_effects(tmpdir, make_napari_viewer): """Ensure that napari-matplotlib doesn't pollute the globally set style. A MWE to guard aganst issue matplotlib/#64. Should always reproduce a plot @@ -102,9 +107,11 @@ def test_no_theme_side_effects(make_napari_viewer): np.random.seed(12345) # should not affect global matplotlib plot style - viewer = make_napari_viewer() - viewer.theme = "dark" - NapariMPLWidget(viewer) + with tmpdir.as_cwd(): + shutil.copy(find_mpl_stylesheet("Solarize_Light2"), "./user.mplstyle") + viewer = make_napari_viewer() + viewer.theme = "dark" + NapariMPLWidget(viewer) # some plotting unrelated to napari-matplotlib normal_dist = np.random.normal(size=1000) @@ -118,3 +125,34 @@ def test_no_theme_side_effects(make_napari_viewer): unrelated_figure.tight_layout() return unrelated_figure + + +def find_mpl_stylesheet(name: str) -> Path: + """Find the built-in matplotlib stylesheet.""" + return Path(matplotlib.__path__[0]) / f"mpl-data/stylelib/{name}.mplstyle" + + +def test_stylesheet_in_cwd(tmpdir, make_napari_viewer, image_data): + """ + Test that a stylesheet in the current directory is given precidence. + + Do this by copying over a stylesheet from matplotlib's built in styles, + naming it correctly, and checking the colours are as expected. + """ + with tmpdir.as_cwd(): + # Copy Solarize_Light2 to current dir as if it was a user-overriden stylesheet. + shutil.copy(find_mpl_stylesheet("Solarize_Light2"), "./user.mplstyle") + viewer = make_napari_viewer() + viewer.add_image(image_data[0], **image_data[1]) + widget = HistogramWidget(viewer) + ax = widget.figure.gca() + + # The axes should have a light brownish grey background: + assert ax.get_facecolor() == to_rgba("#eee8d5") + assert ax.patch.get_facecolor() == to_rgba("#eee8d5") + + # The figure background and axis gridlines are light yellow: + assert widget.figure.patch.get_facecolor() == to_rgba("#fdf6e3") + for gridline in ax.get_xgridlines() + ax.get_ygridlines(): + assert gridline.get_visible() is True + assert gridline.get_color() == "#fdf6e3"