Skip to content

Commit 57a38c7

Browse files
committed
ENH: Add grouped_bar() method
This is a WIP to implement #24313. It will be updated incrementally. As a first step, I've designed the data and label input API. Feedback is welcome.
1 parent c56027b commit 57a38c7

File tree

2 files changed

+191
-0
lines changed

2 files changed

+191
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
=================
3+
Grouped bar chart
4+
=================
5+
6+
This is an ex
7+
8+
9+
10+
Case 1: multiple separate datasets
11+
----------------------------------
12+
13+
"""
14+
import pandas as pd
15+
16+
import matplotlib.pyplot as plt
17+
import numpy as np
18+
19+
x = ['A', 'B']
20+
data1 = [1, 1.2]
21+
data2 = [2, 2.4]
22+
data3 = [3, 3.6]
23+
24+
25+
fig, axs = plt.subplots(1, 2)
26+
27+
# current solution: manual positioning with multiple bar)= calls
28+
label_pos = np.array([0, 1])
29+
bar_width = 0.8 / 3
30+
data_shift = -1*bar_width + np.array([0, bar_width, 2*bar_width])
31+
axs[0].bar(label_pos + data_shift[0], data1, width=bar_width, label="data1")
32+
axs[0].bar(label_pos + data_shift[1], data2, width=bar_width, label="data2")
33+
axs[0].bar(label_pos + data_shift[2], data3, width=bar_width, label="data3")
34+
axs[0].set_xticks(label_pos, x)
35+
axs[0].legend()
36+
37+
# grouped_bar() with list of datasets
38+
# note also that this is a straight-forward generalization of the single-dataset case:
39+
# bar(x, data1, label="data1")
40+
axs[1].grouped_bar(x, [data1, data2, data3], dataset_labels=["data1", "data2", "data3"])
41+
42+
43+
# %%
44+
# Case 1b: multiple datasets as dict
45+
# ----------------------------------
46+
# instead of carrying a list of datasets and a list of dataset labels, users may
47+
# want to organized their datasets in a dict.
48+
49+
datasets = {
50+
'data1': data1,
51+
'data2': data2,
52+
'data3': data3,
53+
}
54+
55+
# %%
56+
# While you can feed keys and values into the above API, it may be convenient to pass
57+
# the whole dict as "data" and automatically extract the labels from the keys:
58+
59+
fig, axs = plt.subplots(1, 2)
60+
61+
# explicitly extract values and labels from a dict and feed to grouped_bar():
62+
axs[0].grouped_bar(x, datasets.values(), dataset_labels=datasets.keys())
63+
# accepting a dict as input
64+
axs[1].grouped_bar(x, datasets)
65+
66+
# %%
67+
# Case 2: 2D array data
68+
# ---------------------
69+
# When receiving a 2D array, we interpret the data as
70+
#
71+
# .. code-block:: none
72+
#
73+
# dataset_0 dataset_1 dataset_2
74+
# x[0]='A' ds0_a ds1_a ds2_a
75+
# x[1]='B' ds0_b ds1_b ds2_b
76+
#
77+
# This is consistent with the standard data science interpretation of instances
78+
# on the vertical and features on the horizontal. And also matches how pandas is
79+
# interpreting rows and columns.
80+
#
81+
# Note that a list of individual datasets and a 2D array behave structurally different,
82+
# i.e. hen turning a list into a numpy array, you have to transpose that array to get
83+
# the correct representation. Those two behave the same::
84+
#
85+
# grouped_bar(x, [data1, data2])
86+
# grouped_bar(x, np.array([data1, data2]).T)
87+
#
88+
# This is a conscious decision, because the commonly understood dimension ordering
89+
# semantics of "list of datasets" and 2D array of datasets is different.
90+
91+
x = ['A', 'B']
92+
data = np.array([
93+
[1, 2, 3],
94+
[1.2, 2.4, 3.6],
95+
])
96+
columns = ["data1", "data2", "data3"]
97+
98+
fig, axs = plt.subplots(1, 2)
99+
100+
axs[0].grouped_bar(x, data, dataset_labels=columns)
101+
102+
df = pd.DataFrame(data, index=x, columns=columns)
103+
df.plot.bar(ax=axs[1])

lib/matplotlib/axes/_axes.py

+88
Original file line numberDiff line numberDiff line change
@@ -3000,6 +3000,94 @@ def broken_barh(self, xranges, yrange, **kwargs):
30003000

30013001
return col
30023002

3003+
def grouped_bar(self, x, heights, dataset_labels=None):
3004+
"""
3005+
Parameters
3006+
-----------
3007+
x : array-like of str
3008+
The labels.
3009+
heights : list of array-like or dict of array-like or 2D array
3010+
The heights for all x and groups. One of:
3011+
3012+
- list of array-like: A list of datasets, each dataset must have
3013+
``len(x)`` elements.
3014+
3015+
.. code-block:: none
3016+
3017+
x = ['a', 'b']
3018+
group_labels = ['ds0', 'ds1', 'ds2']
3019+
3020+
# group_labels: ds0 ds1 dw2
3021+
heights = [dataset_0, dataset_1, dataset_2]
3022+
3023+
# x[0] x[1]
3024+
dataset_0 = [ds0_a, ds0_b]
3025+
3026+
# x[0] x[1]
3027+
heights = [[ds0_a, ds0_b], # dataset_0
3028+
[ds1_a, ds1_b], # dataset_1
3029+
[ds2_a, ds2_b], # dataset_2
3030+
]
3031+
3032+
- dict of array-like: A names to datasets, each dataset (dict value)
3033+
must have ``len(x)`` elements.
3034+
3035+
group_labels = heights.keys()
3036+
heights = heights.values()
3037+
3038+
- a 2D array: columns map to *x*, columns are the different datasets.
3039+
3040+
.. code-block:: none
3041+
3042+
dataset_0 dataset_1 dataset_2
3043+
x[0]='a' ds0_a ds1_a ds2_a
3044+
x[1]='b' ds0_b ds1_b ds2_b
3045+
3046+
Note that this is consistent with pandas. These two calls produce
3047+
the same bar plot structure::
3048+
3049+
grouped_bar(x, array, group_labels=group_labels)
3050+
pd.DataFrame(array, index=x, columns=group_labels).plot.bar()
3051+
3052+
3053+
An iterable of array-like: The iteration runs over the groups.
3054+
Each individual array-like is the list of label values for that group.
3055+
dataset_labels : array-like of str, optional
3056+
The labels of the datasets.
3057+
"""
3058+
if hasattr(heights, 'keys'):
3059+
if dataset_labels is not None:
3060+
raise ValueError(
3061+
"'dataset_labels' cannot be used if 'heights' are a mapping")
3062+
dataset_labels = heights.keys()
3063+
heights = heights.values()
3064+
elif hasattr(heights, 'shape'):
3065+
heights = heights.T
3066+
3067+
num_labels = len(x)
3068+
num_datasets = len(heights)
3069+
3070+
for dataset in heights:
3071+
assert len(dataset) == num_labels
3072+
3073+
margin = 0.1
3074+
bar_width = (1 - 2 * margin) / num_datasets
3075+
block_centers = np.arange(num_labels)
3076+
3077+
if dataset_labels is None:
3078+
dataset_labels = [None] * num_datasets
3079+
else:
3080+
assert len(dataset_labels) == num_datasets
3081+
3082+
for i, (hs, dataset_label) in enumerate(zip(heights, dataset_labels)):
3083+
lefts = block_centers - 0.5 + margin + i * bar_width
3084+
print(i, x, lefts, hs, dataset_label)
3085+
self.bar(lefts, hs, width=bar_width, align="edge", label=dataset_label)
3086+
3087+
self.xaxis.set_ticks(block_centers, labels=x)
3088+
3089+
# TODO: does not return anything for now
3090+
30033091
@_preprocess_data()
30043092
def stem(self, *args, linefmt=None, markerfmt=None, basefmt=None, bottom=0,
30053093
label=None, orientation='vertical'):

0 commit comments

Comments
 (0)