Skip to content

Commit 6acdf0e

Browse files
timhoffmstory645
andcommitted
Apply suggestions from code review
Co-authored-by: hannah <story645@gmail.com>
1 parent bf2894a commit 6acdf0e

File tree

2 files changed

+32
-36
lines changed

2 files changed

+32
-36
lines changed

doc/users/next_whats_new/grouped_bar.rst

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Example:
1010

1111
.. plot::
1212
:include-source: true
13+
:alt: Diagram of a grouped bar chart of 3 datasets with 2 categories.
1314

1415
import matplotlib.pyplot as plt
1516

lib/matplotlib/axes/_axes.py

+31-36
Original file line numberDiff line numberDiff line change
@@ -3077,18 +3077,14 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
30773077
This function is new in v3.11, and the API is still provisional.
30783078
We may still fine-tune some aspects based on user-feedback.
30793079
3080-
This is a convenience function to plot bars for multiple datasets.
3081-
In particular, it simplifies positioning of the bars compared to individual
3082-
`~.Axes.bar` plots.
3083-
30843080
Bar plots present categorical data as a sequence of bars, one bar per category.
3085-
We call one set of such values a *dataset* and it's bars all share the same
3081+
We call one set of such values a *dataset* and its bars all share the same
30863082
color. Grouped bar plots show multiple such datasets, where the values per
30873083
category are grouped together. The category names are drawn as tick labels
30883084
below the bar groups. Each dataset has a distinct bar color, and can optionally
30893085
get a label that is used for the legend.
30903086
3091-
Here is an example call structure and the corresponding plot:
3087+
Example:
30923088
30933089
.. code-block:: python
30943090
@@ -3098,6 +3094,10 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
30983094
30993095
.. plot:: _embedded_plots/grouped_bar.py
31003096
3097+
``grouped_bar()`` is a high-level plotting function for grouped bar charts.
3098+
Use `~.Axes.bar` instead if you need finer grained control on individual bar
3099+
positions or widths.
3100+
31013101
Parameters
31023102
----------
31033103
heights : list of array-like or dict of array-like or 2D array \
@@ -3121,25 +3121,20 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
31213121
- dict of array-like: A mapping from names to datasets. Each dataset
31223122
(dict value) must have the same number of elements.
31233123
3124-
This is similar to passing a list of array-like, with the addition that
3125-
each dataset gets a name.
3126-
31273124
Example call:
31283125
31293126
.. code-block:: python
31303127
31313128
grouped_bar({'ds0': dataset_0, 'ds1': dataset_1, 'ds2': dataset_2]})
31323129
3133-
The names are used as *labels*, i.e. the following two calls are
3134-
equivalent:
3130+
The names are used as *labels*, which is equivalent to passing in a dict
3131+
of array-like:
31353132
31363133
.. code-block:: python
31373134
3138-
data_dict = {'ds0': dataset_0, 'ds1': dataset_1, 'ds2': dataset_2]}
3139-
grouped_bar(data_dict)
31403135
grouped_bar(data_dict.values(), labels=data_dict.keys())
31413136
3142-
When using a dict-like input, you must not pass *labels* explicitly.
3137+
When using a dict input, you must not pass *labels* explicitly.
31433138
31443139
- a 2D array: The rows are the categories, the columns are the different
31453140
datasets.
@@ -3154,21 +3149,16 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
31543149
31553150
.. code-block:: python
31563151
3157-
group_labels = ["group_A", "group_B"]
3152+
categories = ["A", "B"]
31583153
dataset_labels = ["dataset_0", "dataset_1", "dataset_2"]
31593154
array = np.random.random((2, 3))
3160-
3161-
Note that this is consistent with pandas. These two calls produce
3162-
the same bar plot structure:
3163-
3164-
.. code-block:: python
3165-
31663155
grouped_bar(array, tick_labels=categories, labels=dataset_labels)
3167-
df = pd.DataFrame(array, index=categories, columns=dataset_labels)
3168-
df.plot.bar()
31693156
31703157
- a `pandas.DataFrame`.
31713158
3159+
The index is used for the categories, the columns are used for the
3160+
datasets.
3161+
31723162
.. code-block:: python
31733163
31743164
df = pd.DataFrame(
@@ -3178,6 +3168,9 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
31783168
)
31793169
grouped_bar(df)
31803170
3171+
# is equivalent to
3172+
grouped_bar(df.to_numpy(), tick_labels=df.index, labels=df.columns)
3173+
31813174
Note that ``grouped_bar(df)`` produces a structurally equivalent plot like
31823175
``df.plot.bar()``.
31833176
@@ -3187,9 +3180,8 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
31873180
31883181
tick_labels : list of str, optional
31893182
The category labels, which are placed on ticks at the center *positions*
3190-
of the bar groups.
3191-
3192-
If not set, the axis ticks (positions and labels) are left unchanged.
3183+
of the bar groups. If not set, the axis ticks (positions and labels) are
3184+
left unchanged.
31933185
31943186
labels : list of str, optional
31953187
The labels of the datasets, i.e. the bars within one group.
@@ -3249,7 +3241,7 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
32493241
aspects. ``bar(x, y)`` is a lower-level API and places bars with height *y*
32503242
at explicit positions *x*. It also allows to specify individual bar widths
32513243
and colors. This kind of detailed control and flexibility is difficult to
3252-
manage and often not needed when plotting multiple datasets as grouped bar
3244+
manage and often not needed when plotting multiple datasets as a grouped bar
32533245
plot. Therefore, ``grouped_bar`` focusses on the abstraction of bar plots
32543246
as visualization of categorical data.
32553247
@@ -3309,8 +3301,18 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
33093301
heights = heights.T
33103302

33113303
num_datasets = len(heights)
3312-
dataset_0 = next(iter(heights))
3313-
num_groups = len(dataset_0)
3304+
num_groups = len(next(iter(heights))) # inferred from first dataset
3305+
3306+
# validate that all datasets have the same length, i.e. num_groups
3307+
# - can be skipped if heights is an array
3308+
if not hasattr(heights, 'shape'):
3309+
for i, dataset in enumerate(heights):
3310+
if len(dataset) != num_groups:
3311+
raise ValueError(
3312+
"'heights' contains datasets with different number of "
3313+
f"elements. dataset 0 has {num_groups} elements but "
3314+
f"dataset {i} has {len(dataset)} elements."
3315+
)
33143316

33153317
if positions is None:
33163318
group_centers = np.arange(num_groups)
@@ -3325,13 +3327,6 @@ def grouped_bar(self, heights, *, positions=None, group_spacing=1.5, bar_spacing
33253327
else:
33263328
group_distance = 1
33273329

3328-
for i, dataset in enumerate(heights):
3329-
if len(dataset) != num_groups:
3330-
raise ValueError(
3331-
f"'x' indicates {num_groups} groups, but dataset {i} "
3332-
f"has {len(dataset)} groups"
3333-
)
3334-
33353330
_api.check_in_list(["vertical", "horizontal"], orientation=orientation)
33363331

33373332
if colors is None:

0 commit comments

Comments
 (0)