@@ -4287,108 +4287,54 @@ def __init__(self, xyA, xyB, coordsA, coordsB=None,
4287
4287
# if True, draw annotation only if self.xy is inside the axes
4288
4288
self ._annotation_clip = None
4289
4289
4290
- def _get_xy (self , x , y , s , axes = None ):
4290
+ def _get_xy (self , xy , s , axes = None ):
4291
4291
"""Calculate the pixel position of given point."""
4292
+ s0 = s # For the error message, if needed.
4292
4293
if axes is None :
4293
4294
axes = self .axes
4295
+ xy = np .array (xy )
4296
+ if s in ["figure points" , "axes points" ]:
4297
+ xy *= self .figure .dpi / 72
4298
+ s = s .replace ("points" , "pixels" )
4299
+ elif s == "figure fraction" :
4300
+ s = self .figure .transFigure
4301
+ elif s == "axes fraction" :
4302
+ s = axes .transAxes
4303
+ x , y = xy
4294
4304
4295
4305
if s == 'data' :
4296
4306
trans = axes .transData
4297
4307
x = float (self .convert_xunits (x ))
4298
4308
y = float (self .convert_yunits (y ))
4299
4309
return trans .transform ((x , y ))
4300
4310
elif s == 'offset points' :
4301
- # convert the data point
4302
- dx , dy = self .xy
4303
-
4304
- # prevent recursion
4305
- if self .xycoords == 'offset points' :
4306
- return self ._get_xy (dx , dy , 'data' )
4307
-
4308
- dx , dy = self ._get_xy (dx , dy , self .xycoords )
4309
-
4310
- # convert the offset
4311
- dpi = self .figure .get_dpi ()
4312
- x *= dpi / 72.
4313
- y *= dpi / 72.
4314
-
4315
- # add the offset to the data point
4316
- x += dx
4317
- y += dy
4318
-
4319
- return x , y
4311
+ if self .xycoords == 'offset points' : # prevent recursion
4312
+ return self ._get_xy (self .xy , 'data' )
4313
+ return (
4314
+ self ._get_xy (self .xy , self .xycoords ) # converted data point
4315
+ + xy * self .figure .dpi / 72 ) # converted offset
4320
4316
elif s == 'polar' :
4321
4317
theta , r = x , y
4322
4318
x = r * np .cos (theta )
4323
4319
y = r * np .sin (theta )
4324
4320
trans = axes .transData
4325
4321
return trans .transform ((x , y ))
4326
- elif s == 'figure points' :
4327
- # points from the lower left corner of the figure
4328
- dpi = self .figure .dpi
4329
- l , b , w , h = self .figure .bbox .bounds
4330
- r = l + w
4331
- t = b + h
4332
-
4333
- x *= dpi / 72.
4334
- y *= dpi / 72.
4335
- if x < 0 :
4336
- x = r + x
4337
- if y < 0 :
4338
- y = t + y
4339
- return x , y
4340
4322
elif s == 'figure pixels' :
4341
4323
# pixels from the lower left corner of the figure
4342
- l , b , w , h = self .figure .bbox .bounds
4343
- r = l + w
4344
- t = b + h
4345
- if x < 0 :
4346
- x = r + x
4347
- if y < 0 :
4348
- y = t + y
4349
- return x , y
4350
- elif s == 'figure fraction' :
4351
- # (0, 0) is lower left, (1, 1) is upper right of figure
4352
- trans = self .figure .transFigure
4353
- return trans .transform ((x , y ))
4354
- elif s == 'axes points' :
4355
- # points from the lower left corner of the axes
4356
- dpi = self .figure .dpi
4357
- l , b , w , h = axes .bbox .bounds
4358
- r = l + w
4359
- t = b + h
4360
- if x < 0 :
4361
- x = r + x * dpi / 72.
4362
- else :
4363
- x = l + x * dpi / 72.
4364
- if y < 0 :
4365
- y = t + y * dpi / 72.
4366
- else :
4367
- y = b + y * dpi / 72.
4324
+ bb = self .figure .bbox
4325
+ x = bb .x0 + x if x >= 0 else bb .x1 + x
4326
+ y = bb .y0 + y if y >= 0 else bb .y1 + y
4368
4327
return x , y
4369
4328
elif s == 'axes pixels' :
4370
4329
# pixels from the lower left corner of the axes
4371
- l , b , w , h = axes .bbox .bounds
4372
- r = l + w
4373
- t = b + h
4374
- if x < 0 :
4375
- x = r + x
4376
- else :
4377
- x = l + x
4378
- if y < 0 :
4379
- y = t + y
4380
- else :
4381
- y = b + y
4330
+ bb = axes .bbox
4331
+ x = bb .x0 + x if x >= 0 else bb .x1 + x
4332
+ y = bb .y0 + y if y >= 0 else bb .y1 + y
4382
4333
return x , y
4383
- elif s == 'axes fraction' :
4384
- # (0, 0) is lower left, (1, 1) is upper right of axes
4385
- trans = axes .transAxes
4386
- return trans .transform ((x , y ))
4387
4334
elif isinstance (s , transforms .Transform ):
4388
- return s .transform (( x , y ) )
4335
+ return s .transform (xy )
4389
4336
else :
4390
- raise ValueError ("{} is not a valid coordinate "
4391
- "transformation." .format (s ))
4337
+ raise ValueError (f"{ s0 } is not a valid coordinate transformation" )
4392
4338
4393
4339
def set_annotation_clip (self , b ):
4394
4340
"""
@@ -4418,39 +4364,29 @@ def get_annotation_clip(self):
4418
4364
4419
4365
def get_path_in_displaycoord (self ):
4420
4366
"""Return the mutated path of the arrow in display coordinates."""
4421
-
4422
4367
dpi_cor = self .get_dpi_cor ()
4423
-
4424
- x , y = self .xy1
4425
- posA = self ._get_xy (x , y , self .coords1 , self .axesA )
4426
-
4427
- x , y = self .xy2
4428
- posB = self ._get_xy (x , y , self .coords2 , self .axesB )
4429
-
4430
- _path = self .get_connectionstyle ()(posA , posB ,
4431
- patchA = self .patchA ,
4432
- patchB = self .patchB ,
4433
- shrinkA = self .shrinkA * dpi_cor ,
4434
- shrinkB = self .shrinkB * dpi_cor
4435
- )
4436
-
4437
- _path , fillable = self .get_arrowstyle ()(
4438
- _path ,
4439
- self .get_mutation_scale () * dpi_cor ,
4440
- self .get_linewidth () * dpi_cor ,
4441
- self .get_mutation_aspect ()
4442
- )
4443
-
4444
- return _path , fillable
4368
+ posA = self ._get_xy (self .xy1 , self .coords1 , self .axesA )
4369
+ posB = self ._get_xy (self .xy2 , self .coords2 , self .axesB )
4370
+ path = self .get_connectionstyle ()(
4371
+ posA , posB ,
4372
+ patchA = self .patchA , patchB = self .patchB ,
4373
+ shrinkA = self .shrinkA * dpi_cor , shrinkB = self .shrinkB * dpi_cor ,
4374
+ )
4375
+ path , fillable = self .get_arrowstyle ()(
4376
+ path ,
4377
+ self .get_mutation_scale () * dpi_cor ,
4378
+ self .get_linewidth () * dpi_cor ,
4379
+ self .get_mutation_aspect ()
4380
+ )
4381
+ return path , fillable
4445
4382
4446
4383
def _check_xy (self , renderer ):
4447
4384
"""Check whether the annotation needs to be drawn."""
4448
4385
4449
4386
b = self .get_annotation_clip ()
4450
4387
4451
4388
if b or (b is None and self .coords1 == "data" ):
4452
- x , y = self .xy1
4453
- xy_pixel = self ._get_xy (x , y , self .coords1 , self .axesA )
4389
+ xy_pixel = self ._get_xy (self .xy1 , self .coords1 , self .axesA )
4454
4390
if self .axesA is None :
4455
4391
axes = self .axes
4456
4392
else :
@@ -4459,8 +4395,7 @@ def _check_xy(self, renderer):
4459
4395
return False
4460
4396
4461
4397
if b or (b is None and self .coords2 == "data" ):
4462
- x , y = self .xy2
4463
- xy_pixel = self ._get_xy (x , y , self .coords2 , self .axesB )
4398
+ xy_pixel = self ._get_xy (self .xy2 , self .coords2 , self .axesB )
4464
4399
if self .axesB is None :
4465
4400
axes = self .axes
4466
4401
else :
0 commit comments