Skip to content

Commit 16e3b9a

Browse files
authored
Merge pull request #7164 from WeatherGod/mplot3d/surface_strides
ENH: Added rcount/ccount to plot_surface()
2 parents f367003 + 58d3de2 commit 16e3b9a

File tree

7 files changed

+117
-17
lines changed

7 files changed

+117
-17
lines changed

doc/users/whats_new/plot_surface.rst

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
`rcount` and `ccount` for `plot_surface()`
2+
------------------------------------------
3+
4+
As of v2.0, mplot3d's :func:`~mpl_toolkits.mplot3d.axes3d.plot_surface` now
5+
accepts `rcount` and `ccount` arguments for controlling the sampling of the
6+
input data for plotting. These arguments specify the maximum number of
7+
evenly spaced samples to take from the input data. These arguments are
8+
also the new default sampling method for the function, and is
9+
considered a style change.
10+
11+
The old `rstride` and `cstride` arguments, which specified the size of the
12+
evenly spaced samples, become the default when 'classic' mode is invoked,
13+
and are still available for use. There are no plans for deprecating these
14+
arguments.
15+

examples/mplot3d/surface3d_demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
Z = np.sin(R)
2929

3030
# Plot the surface.
31-
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm,
31+
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
3232
linewidth=0, antialiased=False)
3333

3434
# Customize the z axis.

examples/mplot3d/surface3d_demo2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@
2222
z = 10 * np.outer(np.ones(np.size(u)), np.cos(v))
2323

2424
# Plot the surface
25-
ax.plot_surface(x, y, z, rstride=4, cstride=4, color='b')
25+
ax.plot_surface(x, y, z, color='b')
2626

2727
plt.show()

examples/mplot3d/surface3d_demo3.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@
3434
colors[x, y] = colortuple[(x + y) % len(colortuple)]
3535

3636
# Plot the surface with face colors taken from the array we made.
37-
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=colors,
38-
linewidth=0)
37+
surf = ax.plot_surface(X, Y, Z, facecolors=colors, linewidth=0)
3938

4039
# Customize the z axis.
4140
ax.set_zlim(-1, 1)

examples/mplot3d/surface3d_radial_demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
X, Y = R*np.cos(P), R*np.sin(P)
2929

3030
# Plot the surface.
31-
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=plt.cm.YlGnBu_r)
31+
ax.plot_surface(X, Y, Z, cmap=plt.cm.YlGnBu_r)
3232

3333
# Tweak the limits and add latex math labels.
3434
ax.set_zlim(0, 1)

lib/mpl_toolkits/mplot3d/axes3d.py

+81-8
Original file line numberDiff line numberDiff line change
@@ -1553,15 +1553,28 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
15531553
15541554
The `rstride` and `cstride` kwargs set the stride used to
15551555
sample the input data to generate the graph. If 1k by 1k
1556-
arrays are passed in the default values for the strides will
1557-
result in a 100x100 grid being plotted.
1556+
arrays are passed in, the default values for the strides will
1557+
result in a 100x100 grid being plotted. Defaults to 10.
1558+
Raises a ValueError if both stride and count kwargs
1559+
(see next section) are provided.
1560+
1561+
The `rcount` and `ccount` kwargs supersedes `rstride` and
1562+
`cstride` for default sampling method for surface plotting.
1563+
These arguments will determine at most how many evenly spaced
1564+
samples will be taken from the input data to generate the graph.
1565+
This is the default sampling method unless using the 'classic'
1566+
style. Will raise ValueError if both stride and count are
1567+
specified.
1568+
Added in v2.0.0.
15581569
15591570
============= ================================================
15601571
Argument Description
15611572
============= ================================================
15621573
*X*, *Y*, *Z* Data values as 2D arrays
1563-
*rstride* Array row stride (step size), defaults to 10
1564-
*cstride* Array column stride (step size), defaults to 10
1574+
*rstride* Array row stride (step size)
1575+
*cstride* Array column stride (step size)
1576+
*rcount* Use at most this many rows, defaults to 50
1577+
*ccount* Use at most this many columns, defaults to 50
15651578
*color* Color of the surface patches
15661579
*cmap* A colormap for the surface patches.
15671580
*facecolors* Face colors for the individual patches
@@ -1582,8 +1595,30 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
15821595
X, Y, Z = np.broadcast_arrays(X, Y, Z)
15831596
rows, cols = Z.shape
15841597

1598+
has_stride = 'rstride' in kwargs or 'cstride' in kwargs
1599+
has_count = 'rcount' in kwargs or 'ccount' in kwargs
1600+
1601+
if has_stride and has_count:
1602+
raise ValueError("Cannot specify both stride and count arguments")
1603+
15851604
rstride = kwargs.pop('rstride', 10)
15861605
cstride = kwargs.pop('cstride', 10)
1606+
rcount = kwargs.pop('rcount', 50)
1607+
ccount = kwargs.pop('ccount', 50)
1608+
1609+
if rcParams['_internal.classic_mode']:
1610+
# Strides have priority over counts in classic mode.
1611+
# So, only compute strides from counts
1612+
# if counts were explicitly given
1613+
if has_count:
1614+
rstride = int(np.ceil(rows / rcount))
1615+
cstride = int(np.ceil(cols / ccount))
1616+
else:
1617+
# If the strides are provided then it has priority.
1618+
# Otherwise, compute the strides from the counts.
1619+
if not has_stride:
1620+
rstride = int(np.ceil(rows / rcount))
1621+
cstride = int(np.ceil(cols / ccount))
15871622

15881623
if 'facecolors' in kwargs:
15891624
fcolors = kwargs.pop('facecolors')
@@ -1733,7 +1768,21 @@ def plot_wireframe(self, X, Y, Z, *args, **kwargs):
17331768
The `rstride` and `cstride` kwargs set the stride used to
17341769
sample the input data to generate the graph. If either is 0
17351770
the input data in not sampled along this direction producing a
1736-
3D line plot rather than a wireframe plot.
1771+
3D line plot rather than a wireframe plot. The stride arguments
1772+
are only used by default if in the 'classic' mode. They are
1773+
now superseded by `rcount` and `ccount`. Will raise ValueError
1774+
if both stride and count are used.
1775+
1776+
` The `rcount` and `ccount` kwargs supersedes `rstride` and
1777+
`cstride` for default sampling method for wireframe plotting.
1778+
These arguments will determine at most how many evenly spaced
1779+
samples will be taken from the input data to generate the graph.
1780+
This is the default sampling method unless using the 'classic'
1781+
style. Will raise ValueError if both stride and count are
1782+
specified. If either is zero, then the input data is not sampled
1783+
along this direction, producing a 3D line plot rather than a
1784+
wireframe plot.
1785+
Added in v2.0.0.
17371786
17381787
========== ================================================
17391788
Argument Description
@@ -1742,6 +1791,8 @@ def plot_wireframe(self, X, Y, Z, *args, **kwargs):
17421791
*Z*
17431792
*rstride* Array row stride (step size), defaults to 1
17441793
*cstride* Array column stride (step size), defaults to 1
1794+
*rcount* Use at most this many rows, defaults to 50
1795+
*ccount* Use at most this many columns, defaults to 50
17451796
========== ================================================
17461797
17471798
Keyword arguments are passed on to
@@ -1750,15 +1801,37 @@ def plot_wireframe(self, X, Y, Z, *args, **kwargs):
17501801
Returns a :class:`~mpl_toolkits.mplot3d.art3d.Line3DCollection`
17511802
'''
17521803

1753-
rstride = kwargs.pop("rstride", 1)
1754-
cstride = kwargs.pop("cstride", 1)
1755-
17561804
had_data = self.has_data()
17571805
Z = np.atleast_2d(Z)
17581806
# FIXME: Support masked arrays
17591807
X, Y, Z = np.broadcast_arrays(X, Y, Z)
17601808
rows, cols = Z.shape
17611809

1810+
has_stride = 'rstride' in kwargs or 'cstride' in kwargs
1811+
has_count = 'rcount' in kwargs or 'ccount' in kwargs
1812+
1813+
if has_stride and has_count:
1814+
raise ValueError("Cannot specify both stride and count arguments")
1815+
1816+
rstride = kwargs.pop('rstride', 1)
1817+
cstride = kwargs.pop('cstride', 1)
1818+
rcount = kwargs.pop('rcount', 50)
1819+
ccount = kwargs.pop('ccount', 50)
1820+
1821+
if rcParams['_internal.classic_mode']:
1822+
# Strides have priority over counts in classic mode.
1823+
# So, only compute strides from counts
1824+
# if counts were explicitly given
1825+
if has_count:
1826+
rstride = int(np.ceil(rows / rcount)) if rcount else 0
1827+
cstride = int(np.ceil(cols / ccount)) if ccount else 0
1828+
else:
1829+
# If the strides are provided then it has priority.
1830+
# Otherwise, compute the strides from the counts.
1831+
if not has_stride:
1832+
rstride = int(np.ceil(rows / rcount)) if rcount else 0
1833+
cstride = int(np.ceil(cols / ccount)) if ccount else 0
1834+
17621835
# We want two sets of lines, one running along the "rows" of
17631836
# Z and another set of lines running along the "columns" of Z.
17641837
# This transpose will make it easy to obtain the columns.

lib/mpl_toolkits/tests/test_mplot3d.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def f(t):
105105
R = np.sqrt(X ** 2 + Y ** 2)
106106
Z = np.sin(R)
107107

108-
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
108+
surf = ax.plot_surface(X, Y, Z, rcount=40, ccount=40,
109109
linewidth=0, antialiased=False)
110110

111111
ax.set_zlim3d(-1, 1)
@@ -141,7 +141,7 @@ def test_surface3d():
141141
X, Y = np.meshgrid(X, Y)
142142
R = np.sqrt(X ** 2 + Y ** 2)
143143
Z = np.sin(R)
144-
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm,
144+
surf = ax.plot_surface(X, Y, Z, rcount=40, ccount=40, cmap=cm.coolwarm,
145145
lw=0, antialiased=False)
146146
ax.set_zlim(-1.01, 1.01)
147147
fig.colorbar(surf, shrink=0.5, aspect=5)
@@ -194,7 +194,7 @@ def test_wireframe3d():
194194
fig = plt.figure()
195195
ax = fig.add_subplot(111, projection='3d')
196196
X, Y, Z = axes3d.get_test_data(0.05)
197-
ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10)
197+
ax.plot_wireframe(X, Y, Z, rcount=13, ccount=13)
198198

199199

200200
@image_comparison(baseline_images=['wireframe3dzerocstride'], remove_text=True,
@@ -203,7 +203,7 @@ def test_wireframe3dzerocstride():
203203
fig = plt.figure()
204204
ax = fig.add_subplot(111, projection='3d')
205205
X, Y, Z = axes3d.get_test_data(0.05)
206-
ax.plot_wireframe(X, Y, Z, rstride=10, cstride=0)
206+
ax.plot_wireframe(X, Y, Z, rcount=13, ccount=0)
207207

208208

209209
@image_comparison(baseline_images=['wireframe3dzerorstride'], remove_text=True,
@@ -214,6 +214,7 @@ def test_wireframe3dzerorstride():
214214
X, Y, Z = axes3d.get_test_data(0.05)
215215
ax.plot_wireframe(X, Y, Z, rstride=0, cstride=10)
216216

217+
217218
@cleanup
218219
def test_wireframe3dzerostrideraises():
219220
fig = plt.figure()
@@ -222,6 +223,18 @@ def test_wireframe3dzerostrideraises():
222223
with assert_raises(ValueError):
223224
ax.plot_wireframe(X, Y, Z, rstride=0, cstride=0)
224225

226+
227+
@cleanup
228+
def test_mixedsamplesraises():
229+
fig = plt.figure()
230+
ax = fig.add_subplot(111, projection='3d')
231+
X, Y, Z = axes3d.get_test_data(0.05)
232+
with assert_raises(ValueError):
233+
ax.plot_wireframe(X, Y, Z, rstride=10, ccount=50)
234+
with assert_raises(ValueError):
235+
ax.plot_surface(X, Y, Z, cstride=50, rcount=10)
236+
237+
225238
@image_comparison(baseline_images=['quiver3d'], remove_text=True)
226239
def test_quiver3d():
227240
fig = plt.figure()

0 commit comments

Comments
 (0)