Skip to content

feat(3d): improve plot_surface shading logic #30424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,9 +2198,12 @@
vmin, vmax : float, optional
Bounds for the normalization.

shade : bool, default: True
Whether to shade the facecolors. Shading is always disabled when
*cmap* is specified.
shade : bool or "auto", default: "auto"
Whether to shade the facecolors. "auto" will shade only if the facecolor is uniform,

Check warning on line 2202 in lib/mpl_toolkits/mplot3d/axes3d.py

View workflow job for this annotation

GitHub Actions / ruff

[rdjson] reported by reviewdog 🐶 Line too long (97 > 88) Raw Output: message:"Line too long (97 > 88)" location:{path:"/home/runner/work/matplotlib/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py" range:{start:{line:2202 column:89} end:{line:2202 column:98}}} severity:WARNING source:{name:"ruff" url:"https://docs.astral.sh/ruff"} code:{value:"E501" url:"https://docs.astral.sh/ruff/rules/line-too-long"}
i.e. neither *cmap* nor *facecolors* is given.

Furthermore, shading is generally not compatible with colormapping and
``shade=True, cmap=...`` will raise an error.

lightsource : `~matplotlib.colors.LightSource`, optional
The lightsource to use when *shade* is True.
Expand Down Expand Up @@ -2251,8 +2254,10 @@
fcolors = kwargs.pop('facecolors', None)

cmap = kwargs.get('cmap', None)
shade = kwargs.pop('shade', cmap is None)
if shade is None:
shade = kwargs.pop('shade', "auto")
if shade == "auto":
shade = cmap is None and fcolors is None
elif shade is None:
raise ValueError("shade cannot be None.")

colset = [] # the sampled facecolor
Expand Down
2 changes: 1 addition & 1 deletion lib/mpl_toolkits/mplot3d/tests/test_axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def test_surface3d_masked():
z = np.ma.masked_less(matrix, 0)
norm = mcolors.Normalize(vmax=z.max(), vmin=z.min())
colors = mpl.colormaps["plasma"](norm(z))
ax.plot_surface(x, y, z, facecolors=colors)
ax.plot_surface(x, y, z, facecolors=colors, shade=True)
ax.view_init(30, -80, 0)


Expand Down
132 changes: 132 additions & 0 deletions lib/mpl_toolkits/mplot3d/tests/test_plot_surface_shade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
Tests for plot_surface shade parameter behavior.
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize


def test_plot_surface_auto_shade_with_facecolors():
"""Test that plot_surface with facecolors uses shade=False by default."""
X = np.linspace(0, 1, 10)
Y = np.linspace(0, 1, 10)
X_mesh, Y_mesh = np.meshgrid(X, Y)
Z = np.cos((1-X_mesh) * np.pi) * np.cos((1-Y_mesh) * np.pi) * 1e+14 + 1.4e+15
Z_colors = np.cos(X_mesh * np.pi)

norm = Normalize(vmin=np.min(Z_colors), vmax=np.max(Z_colors))
colors = cm.viridis(norm(Z_colors))[:-1, :-1]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Test that when facecolors is provided, shade defaults to False
surf = ax.plot_surface(X_mesh, Y_mesh, Z, facecolors=colors, edgecolor='none')

# We can't directly check shade attribute, but we can verify the plot works
# and doesn't crash, which indicates our logic is working
assert surf is not None
plt.close(fig)


def test_plot_surface_auto_shade_without_facecolors():
"""Test that plot_surface without facecolors uses shade=True by default."""
X = np.linspace(0, 1, 10)
Y = np.linspace(0, 1, 10)
X_mesh, Y_mesh = np.meshgrid(X, Y)
Z = np.cos((1-X_mesh) * np.pi) * np.cos((1-Y_mesh) * np.pi) * 1e+14 + 1.4e+15

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Test that when no facecolors or cmap is provided, shade defaults to True
surf = ax.plot_surface(X_mesh, Y_mesh, Z)

assert surf is not None
plt.close(fig)


def test_plot_surface_auto_shade_with_cmap():
"""Test that plot_surface with cmap uses shade=False by default."""
X = np.linspace(0, 1, 10)
Y = np.linspace(0, 1, 10)
X_mesh, Y_mesh = np.meshgrid(X, Y)
Z = np.cos((1-X_mesh) * np.pi) * np.cos((1-Y_mesh) * np.pi) * 1e+14 + 1.4e+15

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Test that when cmap is provided, shade defaults to False
surf = ax.plot_surface(X_mesh, Y_mesh, Z, cmap=cm.viridis)

assert surf is not None
plt.close(fig)


def test_plot_surface_explicit_shade_with_facecolors():
"""Test that explicit shade parameter overrides auto behavior with facecolors."""
X = np.linspace(0, 1, 10)
Y = np.linspace(0, 1, 10)
X_mesh, Y_mesh = np.meshgrid(X, Y)
Z = np.cos((1-X_mesh) * np.pi) * np.cos((1-Y_mesh) * np.pi) * 1e+14 + 1.4e+15
Z_colors = np.cos(X_mesh * np.pi)

norm = Normalize(vmin=np.min(Z_colors), vmax=np.max(Z_colors))
colors = cm.viridis(norm(Z_colors))[:-1, :-1]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Test that explicit shade=True works with facecolors
surf = ax.plot_surface(X_mesh, Y_mesh, Z, facecolors=colors, shade=True)

assert surf is not None
plt.close(fig)


def test_plot_surface_explicit_shade_false_without_facecolors():
"""Test that explicit shade=False overrides auto behavior without facecolors."""
X = np.linspace(0, 1, 10)
Y = np.linspace(0, 1, 10)
X_mesh, Y_mesh = np.meshgrid(X, Y)
Z = np.cos((1-X_mesh) * np.pi) * np.cos((1-Y_mesh) * np.pi) * 1e+14 + 1.4e+15

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Test that explicit shade=False works without facecolors
surf = ax.plot_surface(X_mesh, Y_mesh, Z, shade=False)

assert surf is not None
plt.close(fig)


def test_plot_surface_shade_auto_behavior_comprehensive():
"""Test the auto behavior logic comprehensively."""
X = np.linspace(0, 1, 5)
Y = np.linspace(0, 1, 5)
X_mesh, Y_mesh = np.meshgrid(X, Y)
Z = np.ones_like(X_mesh)
Z_colors = np.ones_like(X_mesh)
colors = cm.viridis(Z_colors)[:-1, :-1]

test_cases = [
# (kwargs, description)
({}, "no facecolors, no cmap -> shade=True"),
({'facecolors': colors}, "facecolors provided -> shade=False"),
({'cmap': cm.viridis}, "cmap provided -> shade=False"),
({'facecolors': colors, 'cmap': cm.viridis}, "both facecolors and cmap -> shade=False"),

Check warning on line 119 in lib/mpl_toolkits/mplot3d/tests/test_plot_surface_shade.py

View workflow job for this annotation

GitHub Actions / ruff

[rdjson] reported by reviewdog 🐶 Line too long (96 > 88) Raw Output: message:"Line too long (96 > 88)" location:{path:"/home/runner/work/matplotlib/matplotlib/lib/mpl_toolkits/mplot3d/tests/test_plot_surface_shade.py" range:{start:{line:119 column:89} end:{line:119 column:97}}} severity:WARNING source:{name:"ruff" url:"https://docs.astral.sh/ruff"} code:{value:"E501" url:"https://docs.astral.sh/ruff/rules/line-too-long"}
({'facecolors': colors, 'shade': True}, "explicit shade=True overrides auto"),
({'facecolors': colors, 'shade': False}, "explicit shade=False overrides auto"),
({}, "no parameters -> shade=True"),
]

for kwargs, description in test_cases:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# All these should work without crashing
surf = ax.plot_surface(X_mesh, Y_mesh, Z, **kwargs)
assert surf is not None, f"Failed: {description}"
plt.close(fig)
Loading