Skip to content

Commit f34af16

Browse files
committed
DEV: add HexBar3DCollection for plotting hexagonal prisms
1 parent 670e527 commit f34af16

17 files changed

+625
-278
lines changed

lib/matplotlib/axes/_axes.py

+6-104
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_AxesBase, _TransformedBoundsLocator, _process_plot_format)
3636
from matplotlib.axes._secondary_axes import SecondaryAxis
3737
from matplotlib.container import BarContainer, ErrorbarContainer, StemContainer
38+
from matplotlib.hexbin import hexbin
3839

3940
_log = logging.getLogger(__name__)
4041

@@ -4931,110 +4932,11 @@ def reduce_C_function(C: array) -> float
49314932

49324933
x, y, C = cbook.delete_masked_points(x, y, C)
49334934

4934-
# Set the size of the hexagon grid
4935-
if np.iterable(gridsize):
4936-
nx, ny = gridsize
4937-
else:
4938-
nx = gridsize
4939-
ny = int(nx / math.sqrt(3))
4940-
# Count the number of data in each hexagon
4941-
x = np.asarray(x, float)
4942-
y = np.asarray(y, float)
4943-
4944-
# Will be log()'d if necessary, and then rescaled.
4945-
tx = x
4946-
ty = y
4947-
4948-
if xscale == 'log':
4949-
if np.any(x <= 0.0):
4950-
raise ValueError(
4951-
"x contains non-positive values, so cannot be log-scaled")
4952-
tx = np.log10(tx)
4953-
if yscale == 'log':
4954-
if np.any(y <= 0.0):
4955-
raise ValueError(
4956-
"y contains non-positive values, so cannot be log-scaled")
4957-
ty = np.log10(ty)
4958-
if extent is not None:
4959-
xmin, xmax, ymin, ymax = extent
4960-
else:
4961-
xmin, xmax = (tx.min(), tx.max()) if len(x) else (0, 1)
4962-
ymin, ymax = (ty.min(), ty.max()) if len(y) else (0, 1)
4963-
4964-
# to avoid issues with singular data, expand the min/max pairs
4965-
xmin, xmax = mtransforms.nonsingular(xmin, xmax, expander=0.1)
4966-
ymin, ymax = mtransforms.nonsingular(ymin, ymax, expander=0.1)
4967-
4968-
nx1 = nx + 1
4969-
ny1 = ny + 1
4970-
nx2 = nx
4971-
ny2 = ny
4972-
n = nx1 * ny1 + nx2 * ny2
4973-
4974-
# In the x-direction, the hexagons exactly cover the region from
4975-
# xmin to xmax. Need some padding to avoid roundoff errors.
4976-
padding = 1.e-9 * (xmax - xmin)
4977-
xmin -= padding
4978-
xmax += padding
4979-
sx = (xmax - xmin) / nx
4980-
sy = (ymax - ymin) / ny
4981-
# Positions in hexagon index coordinates.
4982-
ix = (tx - xmin) / sx
4983-
iy = (ty - ymin) / sy
4984-
ix1 = np.round(ix).astype(int)
4985-
iy1 = np.round(iy).astype(int)
4986-
ix2 = np.floor(ix).astype(int)
4987-
iy2 = np.floor(iy).astype(int)
4988-
# flat indices, plus one so that out-of-range points go to position 0.
4989-
i1 = np.where((0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1),
4990-
ix1 * ny1 + iy1 + 1, 0)
4991-
i2 = np.where((0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2),
4992-
ix2 * ny2 + iy2 + 1, 0)
4993-
4994-
d1 = (ix - ix1) ** 2 + 3.0 * (iy - iy1) ** 2
4995-
d2 = (ix - ix2 - 0.5) ** 2 + 3.0 * (iy - iy2 - 0.5) ** 2
4996-
bdist = (d1 < d2)
4997-
4998-
if C is None: # [1:] drops out-of-range points.
4999-
counts1 = np.bincount(i1[bdist], minlength=1 + nx1 * ny1)[1:]
5000-
counts2 = np.bincount(i2[~bdist], minlength=1 + nx2 * ny2)[1:]
5001-
accum = np.concatenate([counts1, counts2]).astype(float)
5002-
if mincnt is not None:
5003-
accum[accum < mincnt] = np.nan
5004-
C = np.ones(len(x))
5005-
else:
5006-
# store the C values in a list per hexagon index
5007-
Cs_at_i1 = [[] for _ in range(1 + nx1 * ny1)]
5008-
Cs_at_i2 = [[] for _ in range(1 + nx2 * ny2)]
5009-
for i in range(len(x)):
5010-
if bdist[i]:
5011-
Cs_at_i1[i1[i]].append(C[i])
5012-
else:
5013-
Cs_at_i2[i2[i]].append(C[i])
5014-
if mincnt is None:
5015-
mincnt = 0
5016-
accum = np.array(
5017-
[reduce_C_function(acc) if len(acc) >= mincnt else np.nan
5018-
for Cs_at_i in [Cs_at_i1, Cs_at_i2]
5019-
for acc in Cs_at_i[1:]], # [1:] drops out-of-range points.
5020-
float)
5021-
5022-
good_idxs = ~np.isnan(accum)
5023-
5024-
offsets = np.zeros((n, 2), float)
5025-
offsets[:nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1)
5026-
offsets[:nx1 * ny1, 1] = np.tile(np.arange(ny1), nx1)
5027-
offsets[nx1 * ny1:, 0] = np.repeat(np.arange(nx2) + 0.5, ny2)
5028-
offsets[nx1 * ny1:, 1] = np.tile(np.arange(ny2), nx2) + 0.5
5029-
offsets[:, 0] *= sx
5030-
offsets[:, 1] *= sy
5031-
offsets[:, 0] += xmin
5032-
offsets[:, 1] += ymin
5033-
# remove accumulation bins with no data
5034-
offsets = offsets[good_idxs, :]
5035-
accum = accum[good_idxs]
5036-
5037-
polygon = [sx, sy / 3] * np.array(
4935+
offsets, accum, (sx, sy) = hexbin(x, y, C, gridsize,
4936+
xscale, yscale, extent,
4937+
reduce_C_function, mincnt)
4938+
4939+
polygon = [sx, sy] * np.array(
50384940
[[.5, -.5], [.5, .5], [0., 1.], [-.5, .5], [-.5, -.5], [0., -1.]])
50394941

50404942
if linewidths is None:

lib/matplotlib/cbook.py

+52
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,41 @@ def is_scalar_or_string(val):
495495
return isinstance(val, str) or not np.iterable(val)
496496

497497

498+
def duplicate_if_scalar(obj, n=2, raises=True):
499+
"""
500+
Ensure object size or duplicate if necessary.
501+
502+
Parameters
503+
----------
504+
obj : scalar, str or Sized
505+
506+
Returns
507+
-------
508+
509+
"""
510+
511+
if is_scalar_or_string(obj):
512+
return [obj] * n
513+
514+
size = len(obj)
515+
if size == 0:
516+
if raises:
517+
raise ValueError(f'Cannot duplicate empty {type(obj)}.')
518+
return [obj] * n
519+
520+
if size == 1:
521+
return list(obj) * n
522+
523+
if (size != n) and raises:
524+
raise ValueError(
525+
f'Input object of type {type(obj)} has incorrect size. Expected '
526+
f'either a scalar type object, or a Container with length in {{1, '
527+
f'{n}}}.'
528+
)
529+
530+
return obj
531+
532+
498533
@_api.delete_parameter(
499534
"3.8", "np_load", alternative="open(get_sample_data(..., asfileobj=False))")
500535
def get_sample_data(fname, asfileobj=True, *, np_load=True):
@@ -559,6 +594,23 @@ def flatten(seq, scalarp=is_scalar_or_string):
559594
yield from flatten(item, scalarp)
560595

561596

597+
def pairwise(iterable):
598+
"""
599+
Returns an iterator of paired items, overlapping, from the original
600+
601+
take(4, pairwise(count()))
602+
[(0, 1), (1, 2), (2, 3), (3, 4)]
603+
604+
From more_itertools:
605+
https://more-itertools.readthedocs.io/en/stable/_modules/more_itertools/recipes.html#pairwise
606+
607+
Can be removed on python >3.10 in favour of itertools.pairwise
608+
"""
609+
a, b = itertools.tee(iterable)
610+
next(b, None)
611+
return zip(a, b)
612+
613+
562614
class Stack:
563615
"""
564616
Stack of elements with a movable cursor.

lib/matplotlib/hexbin.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
Function to support histogramming over hexagonal tesselations.
3+
"""
4+
5+
import math
6+
7+
import numpy as np
8+
import matplotlib.transforms as mtransforms
9+
10+
11+
def hexbin(x, y, C=None, gridsize=100,
12+
xscale='linear', yscale='linear', extent=None,
13+
reduce_C_function=np.mean, mincnt=None):
14+
15+
# Set the size of the hexagon grid
16+
if np.iterable(gridsize):
17+
nx, ny = gridsize
18+
else:
19+
nx = gridsize
20+
ny = int(nx / math.sqrt(3))
21+
22+
# Count the number of data in each hexagon
23+
x = np.asarray(x, float)
24+
y = np.asarray(y, float)
25+
26+
# Will be log()'d if necessary, and then rescaled.
27+
tx = x
28+
ty = y
29+
30+
if xscale == 'log':
31+
if np.any(x <= 0.0):
32+
raise ValueError(
33+
"x contains non-positive values, so cannot be log-scaled")
34+
tx = np.log10(tx)
35+
if yscale == 'log':
36+
if np.any(y <= 0.0):
37+
raise ValueError(
38+
"y contains non-positive values, so cannot be log-scaled")
39+
ty = np.log10(ty)
40+
if extent is not None:
41+
xmin, xmax, ymin, ymax = extent
42+
else:
43+
xmin, xmax = (tx.min(), tx.max()) if len(x) else (0, 1)
44+
ymin, ymax = (ty.min(), ty.max()) if len(y) else (0, 1)
45+
46+
# to avoid issues with singular data, expand the min/max pairs
47+
xmin, xmax = mtransforms.nonsingular(xmin, xmax, expander=0.1)
48+
ymin, ymax = mtransforms.nonsingular(ymin, ymax, expander=0.1)
49+
50+
nx1 = nx + 1
51+
ny1 = ny + 1
52+
nx2 = nx
53+
ny2 = ny
54+
n = nx1 * ny1 + nx2 * ny2
55+
56+
# In the x-direction, the hexagons exactly cover the region from
57+
# xmin to xmax. Need some padding to avoid roundoff errors.
58+
padding = 1.e-9 * (xmax - xmin)
59+
xmin -= padding
60+
xmax += padding
61+
sx = (xmax - xmin) / nx
62+
sy = (ymax - ymin) / ny
63+
# Positions in hexagon index coordinates.
64+
ix = (tx - xmin) / sx
65+
iy = (ty - ymin) / sy
66+
ix1 = np.round(ix).astype(int)
67+
iy1 = np.round(iy).astype(int)
68+
ix2 = np.floor(ix).astype(int)
69+
iy2 = np.floor(iy).astype(int)
70+
# flat indices, plus one so that out-of-range points go to position 0.
71+
i1 = np.where((0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1),
72+
ix1 * ny1 + iy1 + 1, 0)
73+
i2 = np.where((0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2),
74+
ix2 * ny2 + iy2 + 1, 0)
75+
76+
d1 = (ix - ix1) ** 2 + 3.0 * (iy - iy1) ** 2
77+
d2 = (ix - ix2 - 0.5) ** 2 + 3.0 * (iy - iy2 - 0.5) ** 2
78+
bdist = (d1 < d2)
79+
80+
if C is None: # [1:] drops out-of-range points.
81+
counts1 = np.bincount(i1[bdist], minlength=1 + nx1 * ny1)[1:]
82+
counts2 = np.bincount(i2[~bdist], minlength=1 + nx2 * ny2)[1:]
83+
accum = np.concatenate([counts1, counts2]).astype(float)
84+
if mincnt is not None:
85+
accum[accum < mincnt] = np.nan
86+
C = np.ones(len(x))
87+
else:
88+
# store the C values in a list per hexagon index
89+
Cs_at_i1 = [[] for _ in range(1 + nx1 * ny1)]
90+
Cs_at_i2 = [[] for _ in range(1 + nx2 * ny2)]
91+
for i in range(len(x)):
92+
if bdist[i]:
93+
Cs_at_i1[i1[i]].append(C[i])
94+
else:
95+
Cs_at_i2[i2[i]].append(C[i])
96+
if mincnt is None:
97+
mincnt = 0
98+
accum = np.array(
99+
[reduce_C_function(acc) if len(acc) > mincnt else np.nan
100+
for Cs_at_i in [Cs_at_i1, Cs_at_i2]
101+
for acc in Cs_at_i[1:]], # [1:] drops out-of-range points.
102+
float)
103+
104+
good_idxs = ~np.isnan(accum)
105+
106+
offsets = np.zeros((n, 2), float)
107+
offsets[:nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1)
108+
offsets[:nx1 * ny1, 1] = np.tile(np.arange(ny1), nx1)
109+
offsets[nx1 * ny1:, 0] = np.repeat(np.arange(nx2) + 0.5, ny2)
110+
offsets[nx1 * ny1:, 1] = np.tile(np.arange(ny2), nx2) + 0.5
111+
offsets[:, 0] *= sx
112+
offsets[:, 1] *= sy
113+
offsets[:, 0] += xmin
114+
offsets[:, 1] += ymin
115+
# remove accumulation bins with no data
116+
offsets = offsets[good_idxs, :]
117+
accum = accum[good_idxs]
118+
119+
return (*offsets.T, accum), (sx, sy / 3)

0 commit comments

Comments
 (0)