Skip to content

Commit dfcd52f

Browse files
timhoffmstory645
andcommitted
Apply suggestions from code review
Co-authored-by: hannah <story645@gmail.com>
1 parent bf2894a commit dfcd52f

File tree

2 files changed

+48
-51
lines changed

2 files changed

+48
-51
lines changed

doc/users/next_whats_new/grouped_bar.rst

+5-4
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@ Example:
1010

1111
.. plot::
1212
:include-source: true
13+
:alt: Diagram of a grouped bar chart of 3 datasets with 2 categories.
1314

1415
import matplotlib.pyplot as plt
1516

1617
categories = ['A', 'B']
1718
datasets = {
18-
'dataset 0': [1.0, 3.0],
19-
'dataset 1': [1.4, 3.4],
20-
'dataset 2': [1.8, 3.8],
19+
'dataset 0': [1, 11],
20+
'dataset 1': [3, 13],
21+
'dataset 2': [5, 15],
2122
}
2223

23-
fig, ax = plt.subplots(figsize=(4, 2.2))
24+
fig, ax = plt.subplots()
2425
ax.grouped_bar(datasets, tick_labels=categories)
2526
ax.legend()

lib/matplotlib/axes/_axes.py

+43-47
Original file line numberDiff line numberDiff line change
@@ -3073,22 +3073,20 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
30733073
"""
30743074
Make a grouped bar plot.
30753075
3076-
.. note::
3076+
.. versionadded:: 3.11
3077+
30773078
This function is new in v3.11, and the API is still provisional.
30783079
We may still fine-tune some aspects based on user-feedback.
30793080
3080-
This is a convenience function to plot bars for multiple datasets.
3081-
In particular, it simplifies positioning of the bars compared to individual
3082-
`~.Axes.bar` plots.
3083-
3084-
Bar plots present categorical data as a sequence of bars, one bar per category.
3085-
We call one set of such values a *dataset* and it's bars all share the same
3086-
color. Grouped bar plots show multiple such datasets, where the values per
3087-
category are grouped together. The category names are drawn as tick labels
3088-
below the bar groups. Each dataset has a distinct bar color, and can optionally
3089-
get a label that is used for the legend.
3081+
Grouped bar charts visualize a collection of multiple categorical datasets.
3082+
A categorical dataset is a mapping *name* -> *value*. The values of the
3083+
dataset are represented by a sequence of bars of the same color.
3084+
In a grouped bar chart, the bars of all datasets are grouped together by
3085+
category. The category names are drawn as tick labels next to the bar group.
3086+
Each dataset has a distinct bar color, and can optionally get a label that
3087+
is used for the legend.
30903088
3091-
Here is an example call structure and the corresponding plot:
3089+
Example:
30923090
30933091
.. code-block:: python
30943092
@@ -3121,25 +3119,20 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
31213119
- dict of array-like: A mapping from names to datasets. Each dataset
31223120
(dict value) must have the same number of elements.
31233121
3124-
This is similar to passing a list of array-like, with the addition that
3125-
each dataset gets a name.
3126-
31273122
Example call:
31283123
31293124
.. code-block:: python
31303125
3131-
grouped_bar({'ds0': dataset_0, 'ds1': dataset_1, 'ds2': dataset_2]})
3126+
data_dict = {'ds0': dataset_0, 'ds1': dataset_1, 'ds2': dataset_2}
3127+
grouped_bar(data_dict)
31323128
3133-
The names are used as *labels*, i.e. the following two calls are
3134-
equivalent:
3129+
The names are used as *labels*, i.e. this is equivalent to
31353130
31363131
.. code-block:: python
31373132
3138-
data_dict = {'ds0': dataset_0, 'ds1': dataset_1, 'ds2': dataset_2]}
3139-
grouped_bar(data_dict)
31403133
grouped_bar(data_dict.values(), labels=data_dict.keys())
31413134
3142-
When using a dict-like input, you must not pass *labels* explicitly.
3135+
When using a dict input, you must not pass *labels* explicitly.
31433136
31443137
- a 2D array: The rows are the categories, the columns are the different
31453138
datasets.
@@ -3154,30 +3147,31 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
31543147
31553148
.. code-block:: python
31563149
3157-
group_labels = ["group_A", "group_B"]
3150+
categories = ["A", "B"]
31583151
dataset_labels = ["dataset_0", "dataset_1", "dataset_2"]
31593152
array = np.random.random((2, 3))
3160-
3161-
Note that this is consistent with pandas. These two calls produce
3162-
the same bar plot structure:
3163-
3164-
.. code-block:: python
3165-
31663153
grouped_bar(array, tick_labels=categories, labels=dataset_labels)
3167-
df = pd.DataFrame(array, index=categories, columns=dataset_labels)
3168-
df.plot.bar()
31693154
31703155
- a `pandas.DataFrame`.
31713156
3157+
The index is used for the categories, the columns are used for the
3158+
datasets.
3159+
31723160
.. code-block:: python
31733161
31743162
df = pd.DataFrame(
3175-
np.random.random((2, 3))
3176-
index=["group_A", "group_B"],
3163+
np.random.random((2, 3)),
3164+
index=["A", "B"],
31773165
columns=["dataset_0", "dataset_1", "dataset_2"]
31783166
)
31793167
grouped_bar(df)
31803168
3169+
i.e. this is equivalent to
3170+
3171+
.. code-block::
3172+
3173+
grouped_bar(df.to_numpy(), tick_labels=df.index, labels=df.columns)
3174+
31813175
Note that ``grouped_bar(df)`` produces a structurally equivalent plot like
31823176
``df.plot.bar()``.
31833177
@@ -3187,22 +3181,21 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
31873181
31883182
tick_labels : list of str, optional
31893183
The category labels, which are placed on ticks at the center *positions*
3190-
of the bar groups.
3191-
3192-
If not set, the axis ticks (positions and labels) are left unchanged.
3184+
of the bar groups. If not set, the axis ticks (positions and labels) are
3185+
left unchanged.
31933186
31943187
labels : list of str, optional
31953188
The labels of the datasets, i.e. the bars within one group.
31963189
These will show up in the legend.
31973190
31983191
group_spacing : float, default: 1.5
3199-
The space between two bar groups in units of bar width.
3192+
The space between two bar groups as multiples of bar width.
32003193
32013194
The default value of 1.5 thus means that there's a gap of
32023195
1.5 bar widths between bar groups.
32033196
32043197
bar_spacing : float, default: 0
3205-
The space between bars in units of bar width.
3198+
The space between bars as multiples of bar width.
32063199
32073200
orientation : {"vertical", "horizontal"}, default: "vertical"
32083201
The direction of the bars.
@@ -3249,7 +3242,7 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
32493242
aspects. ``bar(x, y)`` is a lower-level API and places bars with height *y*
32503243
at explicit positions *x*. It also allows to specify individual bar widths
32513244
and colors. This kind of detailed control and flexibility is difficult to
3252-
manage and often not needed when plotting multiple datasets as grouped bar
3245+
manage and often not needed when plotting multiple datasets as a grouped bar
32533246
plot. Therefore, ``grouped_bar`` focusses on the abstraction of bar plots
32543247
as visualization of categorical data.
32553248
@@ -3309,8 +3302,18 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
33093302
heights = heights.T
33103303

33113304
num_datasets = len(heights)
3312-
dataset_0 = next(iter(heights))
3313-
num_groups = len(dataset_0)
3305+
num_groups = len(next(iter(heights))) # inferred from first dataset
3306+
3307+
# validate that all datasets have the same length, i.e. num_groups
3308+
# - can be skipped if heights is an array
3309+
if not hasattr(heights, 'shape'):
3310+
for i, dataset in enumerate(heights):
3311+
if len(dataset) != num_groups:
3312+
raise ValueError(
3313+
"'heights' contains datasets with different number of "
3314+
f"elements. dataset 0 has {num_groups} elements but "
3315+
f"dataset {i} has {len(dataset)} elements."
3316+
)
33143317

33153318
if positions is None:
33163319
group_centers = np.arange(num_groups)
@@ -3325,13 +3328,6 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
33253328
else:
33263329
group_distance = 1
33273330

3328-
for i, dataset in enumerate(heights):
3329-
if len(dataset) != num_groups:
3330-
raise ValueError(
3331-
f"'x' indicates {num_groups} groups, but dataset {i} "
3332-
f"has {len(dataset)} groups"
3333-
)
3334-
33353331
_api.check_in_list(["vertical", "horizontal"], orientation=orientation)
33363332

33373333
if colors is None:

0 commit comments

Comments
 (0)