diff --git a/doc/users/whats_new/wireframe3d.rst b/doc/users/whats_new/wireframe3d.rst new file mode 100644 index 000000000000..d7bfe31a5391 --- /dev/null +++ b/doc/users/whats_new/wireframe3d.rst @@ -0,0 +1,14 @@ +Zero r/cstride support in plot_wireframe +---------------------------------------- + +Adam Hughes added support to mplot3d's plot_wireframe to draw only row or +column line plots. + + +Example:: + + from mpl_toolkits.mplot3d import Axes3D, axes3d + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + X, Y, Z = axes3d.get_test_data(0.05) + ax.plot_wireframe(X, Y, Z, rstride=10, cstride=0) diff --git a/examples/mplot3d/wire3d_zero_stride.py b/examples/mplot3d/wire3d_zero_stride.py new file mode 100644 index 000000000000..693a2b1676e4 --- /dev/null +++ b/examples/mplot3d/wire3d_zero_stride.py @@ -0,0 +1,12 @@ +from mpl_toolkits.mplot3d import axes3d +import matplotlib.pyplot as plt +import numpy as np + +fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(8, 12), subplot_kw={'projection': '3d'}) +X, Y, Z = axes3d.get_test_data(0.05) +ax1.plot_wireframe(X, Y, Z, rstride=10, cstride=0) +ax1.set_title("Column stride 0") +ax2.plot_wireframe(X, Y, Z, rstride=0, cstride=10) +ax2.set_title("Row stride 0") +plt.tight_layout() +plt.show() diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index ccb0dda5ad03..3542c85d6ddd 100755 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -1717,7 +1717,9 @@ def plot_wireframe(self, X, Y, Z, *args, **kwargs): Plot a 3D wireframe. The `rstride` and `cstride` kwargs set the stride used to - sample the input data to generate the graph. + sample the input data to generate the graph. If either is 0 + the input data in not sampled along this direction producing a + 3D line plot rather than a wireframe plot. ========== ================================================ Argument Description @@ -1748,14 +1750,23 @@ def plot_wireframe(self, X, Y, Z, *args, **kwargs): # This transpose will make it easy to obtain the columns. tX, tY, tZ = np.transpose(X), np.transpose(Y), np.transpose(Z) - rii = list(xrange(0, rows, rstride)) - cii = list(xrange(0, cols, cstride)) + if rstride: + rii = list(xrange(0, rows, rstride)) + # Add the last index only if needed + if rows > 0 and rii[-1] != (rows - 1) : + rii += [rows-1] + else: + rii = [] + if cstride: + cii = list(xrange(0, cols, cstride)) + # Add the last index only if needed + if cols > 0 and cii[-1] != (cols - 1) : + cii += [cols-1] + else: + cii = [] - # Add the last index only if needed - if rows > 0 and rii[-1] != (rows - 1) : - rii += [rows-1] - if cols > 0 and cii[-1] != (cols - 1) : - cii += [cols-1] + if rstride == 0 and cstride == 0: + raise ValueError("Either rstride or cstride must be non zero") # If the inputs were empty, then just # reset everything. diff --git a/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/wireframe3dzerocstride.png b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/wireframe3dzerocstride.png new file mode 100644 index 000000000000..ca39d6a5df82 Binary files /dev/null and b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/wireframe3dzerocstride.png differ diff --git a/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/wireframe3dzerorstride.png b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/wireframe3dzerorstride.png new file mode 100644 index 000000000000..8a8b814f8156 Binary files /dev/null and b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/wireframe3dzerorstride.png differ diff --git a/lib/mpl_toolkits/tests/test_mplot3d.py b/lib/mpl_toolkits/tests/test_mplot3d.py index c82d06c08d67..42718166d32d 100644 --- a/lib/mpl_toolkits/tests/test_mplot3d.py +++ b/lib/mpl_toolkits/tests/test_mplot3d.py @@ -1,6 +1,9 @@ +import sys +import nose +from nose.tools import assert_raises from mpl_toolkits.mplot3d import Axes3D, axes3d from matplotlib import cm -from matplotlib.testing.decorators import image_comparison +from matplotlib.testing.decorators import image_comparison, cleanup import matplotlib.pyplot as plt import numpy as np @@ -172,6 +175,34 @@ def test_wireframe3d(): ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10) +@image_comparison(baseline_images=['wireframe3dzerocstride'], remove_text=True, + extensions=['png']) +def test_wireframe3dzerocstride(): + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + X, Y, Z = axes3d.get_test_data(0.05) + ax.plot_wireframe(X, Y, Z, rstride=10, cstride=0) + + +@image_comparison(baseline_images=['wireframe3dzerorstride'], remove_text=True, + extensions=['png']) +def test_wireframe3dzerorstride(): + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + X, Y, Z = axes3d.get_test_data(0.05) + ax.plot_wireframe(X, Y, Z, rstride=0, cstride=10) + +@cleanup +def test_wireframe3dzerostrideraises(): + if sys.version_info[:2] < (2, 7): + raise nose.SkipTest("assert_raises as context manager " + "not supported with Python < 2.7") + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + X, Y, Z = axes3d.get_test_data(0.05) + with assert_raises(ValueError): + ax.plot_wireframe(X, Y, Z, rstride=0, cstride=0) + @image_comparison(baseline_images=['quiver3d'], remove_text=True) def test_quiver3d(): fig = plt.figure()