Skip to content

Commit 7c800f9

Browse files
committed
Fix hexbin marginals
1 parent c941e69 commit 7c800f9

File tree

3 files changed

+47
-42
lines changed

3 files changed

+47
-42
lines changed

lib/matplotlib/axes/_axes.py

+45-41
Original file line numberDiff line numberDiff line change
@@ -4646,6 +4646,11 @@ def reduce_C_function(C: array) -> float
46464646
# Count the number of data in each hexagon
46474647
x = np.array(x, float)
46484648
y = np.array(y, float)
4649+
4650+
if marginals:
4651+
xorig = x.copy()
4652+
yorig = y.copy()
4653+
46494654
if xscale == 'log':
46504655
if np.any(x <= 0.0):
46514656
raise ValueError("x contains non-positive values, so can not"
@@ -4674,10 +4679,6 @@ def reduce_C_function(C: array) -> float
46744679
sx = (xmax - xmin) / nx
46754680
sy = (ymax - ymin) / ny
46764681

4677-
if marginals:
4678-
xorig = x.copy()
4679-
yorig = y.copy()
4680-
46814682
x = (x - xmin) / sx
46824683
y = (y - ymin) / sy
46834684
ix1 = np.round(x).astype(int)
@@ -4833,40 +4834,42 @@ def reduce_C_function(C: array) -> float
48334834
if not marginals:
48344835
return collection
48354836

4837+
# Process marginals
48364838
if C is None:
48374839
C = np.ones(len(x))
48384840

4839-
def coarse_bin(x, y, coarse):
4840-
ind = coarse.searchsorted(x).clip(0, len(coarse) - 1)
4841-
mus = np.zeros(len(coarse))
4842-
for i in range(len(coarse)):
4843-
yi = y[ind == i]
4841+
def coarse_bin(x, y, bin_edges):
4842+
"""
4843+
Sort x-values into bins defined by *bin_edges*, then for all the
4844+
corresponding y-values in each bin use *reduce_c_function* to
4845+
compute the bin value.
4846+
"""
4847+
nbins = len(bin_edges) - 1
4848+
# Sort x-values into bins
4849+
bin_idxs = np.searchsorted(bin_edges, x) - 1
4850+
mus = np.zeros(nbins) * np.nan
4851+
for i in range(nbins):
4852+
# Get y-values for each bin
4853+
yi = y[bin_idxs == i]
48444854
if len(yi) > 0:
4845-
mu = reduce_C_function(yi)
4846-
else:
4847-
mu = np.nan
4848-
mus[i] = mu
4855+
mus[i] = reduce_C_function(yi)
48494856
return mus
48504857

4851-
coarse = np.linspace(xmin, xmax, gridsize)
4858+
if xscale == 'log':
4859+
bin_edges = np.geomspace(xmin, xmax, nx + 1)
4860+
else:
4861+
bin_edges = np.linspace(xmin, xmax, nx + 1)
4862+
xcoarse = coarse_bin(xorig, C, bin_edges)
48524863

4853-
xcoarse = coarse_bin(xorig, C, coarse)
4854-
valid = ~np.isnan(xcoarse)
48554864
verts, values = [], []
4856-
for i, val in enumerate(xcoarse):
4857-
thismin = coarse[i]
4858-
if i < len(coarse) - 1:
4859-
thismax = coarse[i + 1]
4860-
else:
4861-
thismax = thismin + np.diff(coarse)[-1]
4862-
4863-
if not valid[i]:
4865+
for bin_left, bin_right, val in zip(
4866+
bin_edges[:-1], bin_edges[1:], xcoarse):
4867+
if np.isnan(val):
48644868
continue
4865-
4866-
verts.append([(thismin, 0),
4867-
(thismin, 0.05),
4868-
(thismax, 0.05),
4869-
(thismax, 0)])
4869+
verts.append([(bin_left, 0),
4870+
(bin_left, 0.05),
4871+
(bin_right, 0.05),
4872+
(bin_right, 0)])
48704873
values.append(val)
48714874

48724875
values = np.array(values)
@@ -4881,20 +4884,21 @@ def coarse_bin(x, y, coarse):
48814884
hbar.update(kwargs)
48824885
self.add_collection(hbar, autolim=False)
48834886

4884-
coarse = np.linspace(ymin, ymax, gridsize)
4885-
ycoarse = coarse_bin(yorig, C, coarse)
4886-
valid = ~np.isnan(ycoarse)
4887+
if yscale == 'log':
4888+
bin_edges = np.geomspace(ymin, ymax, 2 * ny + 1)
4889+
else:
4890+
bin_edges = np.linspace(ymin, ymax, 2 * ny + 1)
4891+
ycoarse = coarse_bin(yorig, C, bin_edges)
4892+
48874893
verts, values = [], []
4888-
for i, val in enumerate(ycoarse):
4889-
thismin = coarse[i]
4890-
if i < len(coarse) - 1:
4891-
thismax = coarse[i + 1]
4892-
else:
4893-
thismax = thismin + np.diff(coarse)[-1]
4894-
if not valid[i]:
4894+
for bin_bottom, bin_top, val in zip(
4895+
bin_edges[:-1], bin_edges[1:], ycoarse):
4896+
if np.isnan(val):
48954897
continue
4896-
verts.append([(0, thismin), (0.0, thismax),
4897-
(0.05, thismax), (0.05, thismin)])
4898+
verts.append([(0, bin_bottom),
4899+
(0, bin_top),
4900+
(0.05, bin_top),
4901+
(0.05, bin_bottom)])
48984902
values.append(val)
48994903

49004904
values = np.array(values)

lib/matplotlib/tests/test_axes.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,8 @@ def test_hexbin_log():
781781
y = np.power(2, y * 0.5)
782782

783783
fig, ax = plt.subplots()
784-
h = ax.hexbin(x, y, yscale='log', bins='log')
784+
h = ax.hexbin(x, y, yscale='log', bins='log',
785+
marginals=True, reduce_C_function=np.sum)
785786
plt.colorbar(h)
786787

787788

0 commit comments

Comments
 (0)