@@ -92,6 +92,34 @@ def _add_pad(self, x_min, x_max, y_min, y_max):
92
92
return x_min - dx , x_max + dx , y_min - dy , y_max + dy
93
93
94
94
95
+ class _User2DTransform (Transform ):
96
+ """A transform defined by two user-set functions."""
97
+
98
+ input_dims = output_dims = 2
99
+
100
+ def __init__ (self , forward , backward ):
101
+ """
102
+ Parameters
103
+ ----------
104
+ forward, backward : callable
105
+ The forward and backward transforms, taking ``x`` and ``y`` as
106
+ separate arguments and returning ``(tr_x, tr_y)``.
107
+ """
108
+ # The normal Matplotlib convention would be to take and return an
109
+ # (N, 2) array but axisartist uses the transposed version.
110
+ super ().__init__ ()
111
+ self ._forward = forward
112
+ self ._backward = backward
113
+
114
+ def transform_non_affine (self , values ):
115
+ # docstring inherited
116
+ return np .transpose (self ._forward (* np .transpose (values )))
117
+
118
+ def inverted (self ):
119
+ # docstring inherited
120
+ return type (self )(self ._backward , self ._forward )
121
+
122
+
95
123
class GridFinder :
96
124
def __init__ (self ,
97
125
transform ,
@@ -123,7 +151,7 @@ def __init__(self,
123
151
self .grid_locator2 = grid_locator2
124
152
self .tick_formatter1 = tick_formatter1
125
153
self .tick_formatter2 = tick_formatter2
126
- self .update_transform (transform )
154
+ self .set_transform (transform )
127
155
128
156
def get_grid_info (self , x1 , y1 , x2 , y2 ):
129
157
"""
@@ -214,27 +242,26 @@ def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb):
214
242
215
243
return gi
216
244
217
- def update_transform (self , aux_trans ):
218
- if not isinstance (aux_trans , Transform ) and len (aux_trans ) != 2 :
219
- raise TypeError ("'aux_trans' must be either a Transform instance "
220
- "or a pair of callables" )
221
- self ._aux_transform = aux_trans
245
+ def set_transform (self , aux_trans ):
246
+ if isinstance (aux_trans , Transform ):
247
+ self ._aux_transform = aux_trans
248
+ elif len (aux_trans ) == 2 and all (map (callable , aux_trans )):
249
+ self ._aux_transform = _User2DTransform (* aux_trans )
250
+ else :
251
+ raise TypeError ("'aux_trans' must be either a Transform "
252
+ "instance or a pair of callables" )
253
+
254
+ def get_transform (self ):
255
+ return self ._aux_transform
256
+
257
+ update_transform = set_transform # backcompat alias.
222
258
223
259
def transform_xy (self , x , y ):
224
- aux_trf = self ._aux_transform
225
- if isinstance (aux_trf , Transform ):
226
- return aux_trf .transform (np .column_stack ([x , y ])).T
227
- else :
228
- transform_xy , inv_transform_xy = aux_trf
229
- return transform_xy (x , y )
260
+ return self ._aux_transform .transform (np .column_stack ([x , y ])).T
230
261
231
262
def inv_transform_xy (self , x , y ):
232
- aux_trf = self ._aux_transform
233
- if isinstance (aux_trf , Transform ):
234
- return aux_trf .inverted ().transform (np .column_stack ([x , y ])).T
235
- else :
236
- transform_xy , inv_transform_xy = aux_trf
237
- return inv_transform_xy (x , y )
263
+ return self ._aux_transform .inverted ().transform (
264
+ np .column_stack ([x , y ])).T
238
265
239
266
def update (self , ** kw ):
240
267
for k in kw :
0 commit comments