diff --git a/doc/users/next_whats_new/stem_orientation.rst b/doc/users/next_whats_new/stem_orientation.rst new file mode 100644 index 000000000000..727c9c6ec60f --- /dev/null +++ b/doc/users/next_whats_new/stem_orientation.rst @@ -0,0 +1,13 @@ +Added *orientation* parameter for stem plots +-------------------------------------------- + +By default, stem lines are vertical. They can be changed to horizontal using +the *orientation* parameter of `.Axes.stem` or `.pyplot.stem`: + +.. plot:: + + locs = np.linspace(0.1, 2 * np.pi, 25) + heads = np.cos(locs) + + fig, ax = plt.subplots() + ax.stem(locs, heads, orientation='horizontal') diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index e5846422b9f9..d0ecdcdca7a1 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -2723,27 +2723,32 @@ def broken_barh(self, xranges, yrange, **kwargs): @_preprocess_data() def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0, - label=None, use_line_collection=True): + label=None, use_line_collection=True, orientation='vertical'): """ Create a stem plot. - A stem plot plots vertical lines at each *x* location from the baseline - to *y*, and places a marker there. + A stem plot draws lines perpendicular to a baseline at each location + *locs* from the baseline to *heads*, and places a marker there. For + vertical stem plots (the default), the *locs* are *x* positions, and + the *heads* are *y* values. For horizontal stem plots, the *locs* are + *y* positions, and the *heads* are *x* values. Call signature:: - stem([x,] y, linefmt=None, markerfmt=None, basefmt=None) + stem([locs,] heads, linefmt=None, markerfmt=None, basefmt=None) - The x-positions are optional. The formats may be provided either as - positional or as keyword-arguments. + The *locs*-positions are optional. The formats may be provided either + as positional or as keyword-arguments. Parameters ---------- - x : array-like, optional - The x-positions of the stems. Default: (0, 1, ..., len(y) - 1). + locs : array-like, default: (0, 1, ..., len(heads) - 1) + For vertical stem plots, the x-positions of the stems. + For horizontal stem plots, the y-positions of the stems. - y : array-like - The y-values of the stem heads. + heads : array-like + For vertical stem plots, the y-values of the stem heads. + For horizontal stem plots, the x-values of the stem heads. linefmt : str, optional A string defining the properties of the vertical lines. Usually, @@ -2774,8 +2779,12 @@ def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0, basefmt : str, default: 'C3-' ('C2-' in classic mode) A format string defining the properties of the baseline. + orientation : str, default: 'vertical' + If 'vertical', will produce a plot with stems oriented vertically, + otherwise the stems will be oriented horizontally. + bottom : float, default: 0 - The y-position of the baseline. + The y/x-position of the baseline (depending on orientation). label : str, default: None The label to use for the stems in legends. @@ -2803,17 +2812,24 @@ def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0, if not 1 <= len(args) <= 5: raise TypeError('stem expected between 1 and 5 positional ' 'arguments, got {}'.format(args)) + cbook._check_in_list(['horizontal', 'vertical'], + orientation=orientation) if len(args) == 1: - y, = args - x = np.arange(len(y)) + heads, = args + locs = np.arange(len(heads)) args = () else: - x, y, *args = args + locs, heads, *args = args - self._process_unit_info(xdata=x, ydata=y) - x = self.convert_xunits(x) - y = self.convert_yunits(y) + if orientation == 'vertical': + self._process_unit_info(xdata=locs, ydata=heads) + locs = self.convert_xunits(locs) + heads = self.convert_yunits(heads) + else: + self._process_unit_info(xdata=heads, ydata=locs) + heads = self.convert_xunits(heads) + locs = self.convert_yunits(locs) # defaults for formats if linefmt is None: @@ -2864,7 +2880,14 @@ def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0, # New behaviour in 3.1 is to use a LineCollection for the stemlines if use_line_collection: - stemlines = [((xi, bottom), (xi, yi)) for xi, yi in zip(x, y)] + if orientation == 'horizontal': + stemlines = [ + ((bottom, loc), (head, loc)) + for loc, head in zip(locs, heads)] + else: + stemlines = [ + ((loc, bottom), (loc, head)) + for loc, head in zip(locs, heads)] if linestyle is None: linestyle = rcParams['lines.linestyle'] stemlines = mcoll.LineCollection(stemlines, linestyles=linestyle, @@ -2874,16 +2897,34 @@ def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0, # Old behaviour is to plot each of the lines individually else: stemlines = [] - for xi, yi in zip(x, y): - l, = self.plot([xi, xi], [bottom, yi], + for loc, head in zip(locs, heads): + if orientation == 'horizontal': + xs = [bottom, head] + ys = [loc, loc] + else: + xs = [loc, loc] + ys = [bottom, head] + l, = self.plot(xs, ys, color=linecolor, linestyle=linestyle, marker=linemarker, label="_nolegend_") stemlines.append(l) - markerline, = self.plot(x, y, color=markercolor, linestyle=markerstyle, + if orientation == 'horizontal': + marker_x = heads + marker_y = locs + baseline_x = [bottom, bottom] + baseline_y = [np.min(locs), np.max(locs)] + else: + marker_x = locs + marker_y = heads + baseline_x = [np.min(locs), np.max(locs)] + baseline_y = [bottom, bottom] + + markerline, = self.plot(marker_x, marker_y, + color=markercolor, linestyle=markerstyle, marker=markermarker, label="_nolegend_") - baseline, = self.plot([np.min(x), np.max(x)], [bottom, bottom], + baseline, = self.plot(baseline_x, baseline_y, color=basecolor, linestyle=basestyle, marker=basemarker, label="_nolegend_") diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 70389be97e39..62076a4a6e36 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -3018,11 +3018,13 @@ def stackplot( @_copy_docstring_and_deprecators(Axes.stem) def stem( *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0, - label=None, use_line_collection=True, data=None): + label=None, use_line_collection=True, orientation='vertical', + data=None): return gca().stem( *args, linefmt=linefmt, markerfmt=markerfmt, basefmt=basefmt, bottom=bottom, label=label, use_line_collection=use_line_collection, + orientation=orientation, **({"data": data} if data is not None else {})) diff --git a/lib/matplotlib/tests/baseline_images/test_axes/stem_orientation.png b/lib/matplotlib/tests/baseline_images/test_axes/stem_orientation.png new file mode 100644 index 000000000000..21614d974311 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_axes/stem_orientation.png differ diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index 9b739aabc1ae..84b13405af33 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -3245,14 +3245,13 @@ def test_hist_stacked_weighted(): @image_comparison(['stem.png'], style='mpl20', remove_text=True) def test_stem(use_line_collection): x = np.linspace(0.1, 2 * np.pi, 100) - args = (x, np.cos(x)) - # Label is a single space to force a legend to be drawn, but to avoid any - # text being drawn - kwargs = dict(linefmt='C2-.', markerfmt='k+', basefmt='C1-.', - label=' ', use_line_collection=use_line_collection) fig, ax = plt.subplots() - ax.stem(*args, **kwargs) + # Label is a single space to force a legend to be drawn, but to avoid any + # text being drawn + ax.stem(x, np.cos(x), + linefmt='C2-.', markerfmt='k+', basefmt='C1-.', label=' ', + use_line_collection=use_line_collection) ax.legend() @@ -3279,6 +3278,18 @@ def test_stem_dates(): ax.stem(xs, ys, "*-") +@pytest.mark.parametrize("use_line_collection", [True, False], + ids=['w/ line collection', 'w/o line collection']) +@image_comparison(['stem_orientation.png'], style='mpl20', remove_text=True) +def test_stem_orientation(use_line_collection): + x = np.linspace(0.1, 2*np.pi, 50) + + fig, ax = plt.subplots() + ax.stem(x, np.cos(x), + linefmt='C2-.', markerfmt='kx', basefmt='C1-.', + use_line_collection=use_line_collection, orientation='horizontal') + + @image_comparison(['hist_stacked_stepfilled_alpha']) def test_hist_stacked_stepfilled_alpha(): # make some data