Skip to content

Commit 6aec225

Browse files
authored
Merge pull request #13832 from anntzer/histvec
Vectorize handling of stacked/cumulative in hist().
2 parents 6c4bc56 + 5f1ba00 commit 6aec225

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6650,36 +6650,28 @@ def hist(self, x, bins=None, range=None, density=False, weights=None,
66506650
hist_kwargs['density'] = density
66516651

66526652
# List to store all the top coordinates of the histograms
6653-
tops = []
6654-
mlast = None
6653+
tops = [] # Will have shape (n_datasets, n_bins).
66556654
# Loop through datasets
66566655
for i in range(nx):
66576656
# this will automatically overwrite bins,
66586657
# so that each histogram uses the same bins
66596658
m, bins = np.histogram(x[i], bins, weights=w[i], **hist_kwargs)
6660-
m = m.astype(float) # causes problems later if it's an int
6661-
if mlast is None:
6662-
mlast = np.zeros(len(bins)-1, m.dtype)
6663-
if stacked:
6664-
m += mlast
6665-
mlast[:] = m
66666659
tops.append(m)
6667-
6668-
# If a stacked density plot, normalize so the area of all the stacked
6669-
# histograms together is 1
6670-
if stacked and density:
6671-
db = np.diff(bins)
6672-
for m in tops:
6673-
m[:] = (m / db) / tops[-1].sum()
6660+
tops = np.array(tops, float) # causes problems later if it's an int
6661+
if stacked:
6662+
tops = tops.cumsum(axis=0)
6663+
# If a stacked density plot, normalize so the area of all the
6664+
# stacked histograms together is 1
6665+
if density:
6666+
tops = (tops / np.diff(bins)) / tops[-1].sum()
66746667
if cumulative:
66756668
slc = slice(None)
66766669
if isinstance(cumulative, Number) and cumulative < 0:
66776670
slc = slice(None, None, -1)
6678-
66796671
if density:
6680-
tops = [(m * np.diff(bins))[slc].cumsum()[slc] for m in tops]
6672+
tops = (tops * np.diff(bins))[:, slc].cumsum(axis=1)[:, slc]
66816673
else:
6682-
tops = [m[slc].cumsum()[slc] for m in tops]
6674+
tops = tops[:, slc].cumsum(axis=1)[:, slc]
66836675

66846676
patches = []
66856677

0 commit comments

Comments
 (0)