Skip to content

Commit 5f1ba00

Browse files
committed
Vectorize handling of stacked/cumulative in hist().
1 parent c94d0ed commit 5f1ba00

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6656,36 +6656,28 @@ def hist(self, x, bins=None, range=None, density=None, weights=None,
66566656
hist_kwargs = dict(range=bin_range)
66576657

66586658
# List to store all the top coordinates of the histograms
6659-
tops = []
6660-
mlast = None
6659+
tops = [] # Will have shape (n_datasets, n_bins).
66616660
# Loop through datasets
66626661
for i in range(nx):
66636662
# this will automatically overwrite bins,
66646663
# so that each histogram uses the same bins
66656664
m, bins = np.histogram(x[i], bins, weights=w[i], **hist_kwargs)
6666-
m = m.astype(float) # causes problems later if it's an int
6667-
if mlast is None:
6668-
mlast = np.zeros(len(bins)-1, m.dtype)
6669-
if stacked:
6670-
m += mlast
6671-
mlast[:] = m
66726665
tops.append(m)
6673-
6674-
# If a stacked density plot, normalize so the area of all the stacked
6675-
# histograms together is 1
6676-
if stacked and density:
6677-
db = np.diff(bins)
6678-
for m in tops:
6679-
m[:] = (m / db) / tops[-1].sum()
6666+
tops = np.array(tops, float) # causes problems later if it's an int
6667+
if stacked:
6668+
tops = tops.cumsum(axis=0)
6669+
# If a stacked density plot, normalize so the area of all the
6670+
# stacked histograms together is 1
6671+
if density:
6672+
tops = (tops / np.diff(bins)) / tops[-1].sum()
66806673
if cumulative:
66816674
slc = slice(None)
66826675
if isinstance(cumulative, Number) and cumulative < 0:
66836676
slc = slice(None, None, -1)
6684-
66856677
if density:
6686-
tops = [(m * np.diff(bins))[slc].cumsum()[slc] for m in tops]
6678+
tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
66876679
else:
6688-
tops = [m[slc].cumsum()[slc] for m in tops]
6680+
tops = tops[:, slc].cumsum(axis=1)[:, slc]
66896681

66906682
patches = []
66916683

0 commit comments

Comments
 (0)