Skip to content

Commit 952a322

Browse files
committed
Simplify and robustify ConnectionPatch coordinates conversion.
Dedupe a bit of code in ConnectionPatch._get_xy which now takes xy as a single argument rather than an unpacked one. By converting xy to an array, one additionally gets an error if xy is passed as a `set`, rather than nonsensical results with no error.
1 parent f2b6c66 commit 952a322

File tree

2 files changed

+61
-115
lines changed

2 files changed

+61
-115
lines changed

lib/matplotlib/patches.py

Lines changed: 40 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -4287,108 +4287,54 @@ def __init__(self, xyA, xyB, coordsA, coordsB=None,
42874287
# if True, draw annotation only if self.xy is inside the axes
42884288
self._annotation_clip = None
42894289

4290-
def _get_xy(self, x, y, s, axes=None):
4290+
def _get_xy(self, xy, s, axes=None):
42914291
"""Calculate the pixel position of given point."""
4292+
s0 = s # For the error message, if needed.
42924293
if axes is None:
42934294
axes = self.axes
4295+
xy = np.array(xy)
4296+
if s in ["figure points", "axes points"]:
4297+
xy *= self.figure.dpi / 72
4298+
s = s.replace("points", "pixels")
4299+
elif s == "figure fraction":
4300+
s = self.figure.transFigure
4301+
elif s == "axes fraction":
4302+
s = axes.transAxes
4303+
x, y = xy
42944304

42954305
if s == 'data':
42964306
trans = axes.transData
42974307
x = float(self.convert_xunits(x))
42984308
y = float(self.convert_yunits(y))
42994309
return trans.transform((x, y))
43004310
elif s == 'offset points':
4301-
# convert the data point
4302-
dx, dy = self.xy
4303-
4304-
# prevent recursion
4305-
if self.xycoords == 'offset points':
4306-
return self._get_xy(dx, dy, 'data')
4307-
4308-
dx, dy = self._get_xy(dx, dy, self.xycoords)
4309-
4310-
# convert the offset
4311-
dpi = self.figure.get_dpi()
4312-
x *= dpi / 72.
4313-
y *= dpi / 72.
4314-
4315-
# add the offset to the data point
4316-
x += dx
4317-
y += dy
4318-
4319-
return x, y
4311+
if self.xycoords == 'offset points': # prevent recursion
4312+
return self._get_xy(self.xy, 'data')
4313+
return (
4314+
self._get_xy(self.xy, self.xycoords) # converted data point
4315+
+ xy * self.figure.dpi / 72) # converted offset
43204316
elif s == 'polar':
43214317
theta, r = x, y
43224318
x = r * np.cos(theta)
43234319
y = r * np.sin(theta)
43244320
trans = axes.transData
43254321
return trans.transform((x, y))
4326-
elif s == 'figure points':
4327-
# points from the lower left corner of the figure
4328-
dpi = self.figure.dpi
4329-
l, b, w, h = self.figure.bbox.bounds
4330-
r = l + w
4331-
t = b + h
4332-
4333-
x *= dpi / 72.
4334-
y *= dpi / 72.
4335-
if x < 0:
4336-
x = r + x
4337-
if y < 0:
4338-
y = t + y
4339-
return x, y
43404322
elif s == 'figure pixels':
43414323
# pixels from the lower left corner of the figure
4342-
l, b, w, h = self.figure.bbox.bounds
4343-
r = l + w
4344-
t = b + h
4345-
if x < 0:
4346-
x = r + x
4347-
if y < 0:
4348-
y = t + y
4349-
return x, y
4350-
elif s == 'figure fraction':
4351-
# (0, 0) is lower left, (1, 1) is upper right of figure
4352-
trans = self.figure.transFigure
4353-
return trans.transform((x, y))
4354-
elif s == 'axes points':
4355-
# points from the lower left corner of the axes
4356-
dpi = self.figure.dpi
4357-
l, b, w, h = axes.bbox.bounds
4358-
r = l + w
4359-
t = b + h
4360-
if x < 0:
4361-
x = r + x * dpi / 72.
4362-
else:
4363-
x = l + x * dpi / 72.
4364-
if y < 0:
4365-
y = t + y * dpi / 72.
4366-
else:
4367-
y = b + y * dpi / 72.
4324+
bb = self.figure.bbox
4325+
x = bb.x0 + x if x >= 0 else bb.x1 + x
4326+
y = bb.y0 + y if y >= 0 else bb.y1 + y
43684327
return x, y
43694328
elif s == 'axes pixels':
43704329
# pixels from the lower left corner of the axes
4371-
l, b, w, h = axes.bbox.bounds
4372-
r = l + w
4373-
t = b + h
4374-
if x < 0:
4375-
x = r + x
4376-
else:
4377-
x = l + x
4378-
if y < 0:
4379-
y = t + y
4380-
else:
4381-
y = b + y
4330+
bb = axes.bbox
4331+
x = bb.x0 + x if x >= 0 else bb.x1 + x
4332+
y = bb.y0 + y if y >= 0 else bb.y1 + y
43824333
return x, y
4383-
elif s == 'axes fraction':
4384-
# (0, 0) is lower left, (1, 1) is upper right of axes
4385-
trans = axes.transAxes
4386-
return trans.transform((x, y))
43874334
elif isinstance(s, transforms.Transform):
4388-
return s.transform((x, y))
4335+
return s.transform(xy)
43894336
else:
4390-
raise ValueError("{} is not a valid coordinate "
4391-
"transformation.".format(s))
4337+
raise ValueError(f"{s0} is not a valid coordinate transformation")
43924338

43934339
def set_annotation_clip(self, b):
43944340
"""
@@ -4418,39 +4364,29 @@ def get_annotation_clip(self):
44184364

44194365
def get_path_in_displaycoord(self):
44204366
"""Return the mutated path of the arrow in display coordinates."""
4421-
44224367
dpi_cor = self.get_dpi_cor()
4423-
4424-
x, y = self.xy1
4425-
posA = self._get_xy(x, y, self.coords1, self.axesA)
4426-
4427-
x, y = self.xy2
4428-
posB = self._get_xy(x, y, self.coords2, self.axesB)
4429-
4430-
_path = self.get_connectionstyle()(posA, posB,
4431-
patchA=self.patchA,
4432-
patchB=self.patchB,
4433-
shrinkA=self.shrinkA * dpi_cor,
4434-
shrinkB=self.shrinkB * dpi_cor
4435-
)
4436-
4437-
_path, fillable = self.get_arrowstyle()(
4438-
_path,
4439-
self.get_mutation_scale() * dpi_cor,
4440-
self.get_linewidth() * dpi_cor,
4441-
self.get_mutation_aspect()
4442-
)
4443-
4444-
return _path, fillable
4368+
posA = self._get_xy(self.xy1, self.coords1, self.axesA)
4369+
posB = self._get_xy(self.xy2, self.coords2, self.axesB)
4370+
path = self.get_connectionstyle()(
4371+
posA, posB,
4372+
patchA=self.patchA, patchB=self.patchB,
4373+
shrinkA=self.shrinkA * dpi_cor, shrinkB=self.shrinkB * dpi_cor,
4374+
)
4375+
path, fillable = self.get_arrowstyle()(
4376+
path,
4377+
self.get_mutation_scale() * dpi_cor,
4378+
self.get_linewidth() * dpi_cor,
4379+
self.get_mutation_aspect()
4380+
)
4381+
return path, fillable
44454382

44464383
def _check_xy(self, renderer):
44474384
"""Check whether the annotation needs to be drawn."""
44484385

44494386
b = self.get_annotation_clip()
44504387

44514388
if b or (b is None and self.coords1 == "data"):
4452-
x, y = self.xy1
4453-
xy_pixel = self._get_xy(x, y, self.coords1, self.axesA)
4389+
xy_pixel = self._get_xy(self.xy1, self.coords1, self.axesA)
44544390
if self.axesA is None:
44554391
axes = self.axes
44564392
else:
@@ -4459,8 +4395,7 @@ def _check_xy(self, renderer):
44594395
return False
44604396

44614397
if b or (b is None and self.coords2 == "data"):
4462-
x, y = self.xy2
4463-
xy_pixel = self._get_xy(x, y, self.coords2, self.axesB)
4398+
xy_pixel = self._get_xy(self.xy2, self.coords2, self.axesB)
44644399
if self.axesB is None:
44654400
axes = self.axes
44664401
else:

lib/matplotlib/tests/test_patches.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -394,16 +394,27 @@ def test_connection_patch():
394394
ax2.add_artist(con)
395395

396396

397-
def test_connection_patch_fig():
398-
# Test that connection patch can be added as figure artist
399-
fig, (ax1, ax2) = plt.subplots(1, 2)
400-
xy = (0.3, 0.2)
401-
con = mpatches.ConnectionPatch(xyA=xy, xyB=xy,
402-
coordsA="data", coordsB="data",
403-
axesA=ax1, axesB=ax2,
404-
arrowstyle="->", shrinkB=5)
405-
fig.add_artist(con)
406-
fig.canvas.draw()
397+
@check_figures_equal(extensions=["png"])
398+
def test_connection_patch_fig(fig_test, fig_ref):
399+
# Test that connection patch can be added as figure artist, and that figure
400+
# pixels count negative values from the top right corner (this API may be
401+
# changed in the future).
402+
ax1, ax2 = fig_test.subplots(1, 2)
403+
con = mpatches.ConnectionPatch(
404+
xyA=(.3, .2), coordsA="data", axesA=ax1,
405+
xyB=(-30, -20), coordsB="figure pixels",
406+
arrowstyle="->", shrinkB=5)
407+
fig_test.add_artist(con)
408+
409+
ax1, ax2 = fig_ref.subplots(1, 2)
410+
bb = fig_ref.bbox
411+
# Necessary so that pixel counts match on both sides.
412+
plt.rcParams["savefig.dpi"] = plt.rcParams["figure.dpi"]
413+
con = mpatches.ConnectionPatch(
414+
xyA=(.3, .2), coordsA="data", axesA=ax1,
415+
xyB=(bb.width - 30, bb.height - 20), coordsB="figure pixels",
416+
arrowstyle="->", shrinkB=5)
417+
fig_ref.add_artist(con)
407418

408419

409420
def test_datetime_rectangle():

0 commit comments

Comments
 (0)