10
10
import numpy as np
11
11
12
12
13
- class ParasiteAxesBase ( object ) :
13
+ class ParasiteAxesBase :
14
14
15
15
def get_images_artists (self ):
16
16
artists = {a for a in self .get_children () if a .get_visible ()}
@@ -21,11 +21,10 @@ def get_images_artists(self):
21
21
def __init__ (self , parent_axes , ** kwargs ):
22
22
self ._parent_axes = parent_axes
23
23
kwargs ["frameon" ] = False
24
- self ._get_base_axes_attr ("__init__" )(
25
- self , parent_axes .figure , parent_axes ._position , ** kwargs )
24
+ super ().__init__ (parent_axes .figure , parent_axes ._position , ** kwargs )
26
25
27
26
def cla (self ):
28
- self . _get_base_axes_attr ( "cla" )( self )
27
+ super (). cla ( )
29
28
30
29
martist .setp (self .get_children (), visible = False )
31
30
self ._get_lines = self ._parent_axes ._get_lines
@@ -45,18 +44,14 @@ def parasite_axes_class_factory(axes_class=None):
45
44
if axes_class is None :
46
45
axes_class = Axes
47
46
48
- def _get_base_axes_attr (self , attrname ):
49
- return getattr (axes_class , attrname )
50
-
51
47
return type ("%sParasite" % axes_class .__name__ ,
52
- (ParasiteAxesBase , axes_class ),
53
- {'_get_base_axes_attr' : _get_base_axes_attr })
48
+ (ParasiteAxesBase , axes_class ), {})
54
49
55
50
56
51
ParasiteAxes = parasite_axes_class_factory ()
57
52
58
53
59
- class ParasiteAxesAuxTransBase ( object ) :
54
+ class ParasiteAxesAuxTransBase :
60
55
def __init__ (self , parent_axes , aux_transform , viewlim_mode = None ,
61
56
** kwargs ):
62
57
@@ -80,14 +75,13 @@ def _set_lim_and_transforms(self):
80
75
81
76
def set_viewlim_mode (self , mode ):
82
77
if mode not in [None , "equal" , "transform" ]:
83
- raise ValueError ("Unknown mode : %s" % (mode ,))
78
+ raise ValueError ("Unknown mode: %s" % (mode ,))
84
79
else :
85
80
self ._viewlim_mode = mode
86
81
87
82
def get_viewlim_mode (self ):
88
83
return self ._viewlim_mode
89
84
90
-
91
85
def update_viewlim (self ):
92
86
viewlim = self ._parent_axes .viewLim .frozen ()
93
87
mode = self .get_viewlim_mode ()
@@ -96,86 +90,80 @@ def update_viewlim(self):
96
90
elif mode == "equal" :
97
91
self .axes .viewLim .set (viewlim )
98
92
elif mode == "transform" :
99
- self .axes .viewLim .set (viewlim .transformed (self .transAux .inverted ()))
93
+ self .axes .viewLim .set (
94
+ viewlim .transformed (self .transAux .inverted ()))
100
95
else :
101
- raise ValueError ("Unknown mode : %s" % (self ._viewlim_mode ,))
102
-
96
+ raise ValueError ("Unknown mode: %s" % (self ._viewlim_mode ,))
103
97
104
- def _pcolor (self , method_name , * XYC , ** kwargs ):
98
+ def _pcolor (self , super_pcolor , * XYC , ** kwargs ):
105
99
if len (XYC ) == 1 :
106
100
C = XYC [0 ]
107
101
ny , nx = C .shape
108
102
109
- gx = np .arange (- 0.5 , nx , 1. )
110
- gy = np .arange (- 0.5 , ny , 1. )
103
+ gx = np .arange (- 0.5 , nx )
104
+ gy = np .arange (- 0.5 , ny )
111
105
112
106
X , Y = np .meshgrid (gx , gy )
113
107
else :
114
108
X , Y , C = XYC
115
109
116
- pcolor_routine = self ._get_base_axes_attr (method_name )
117
-
118
110
if "transform" in kwargs :
119
- mesh = pcolor_routine (self , X , Y , C , ** kwargs )
111
+ mesh = super_pcolor (self , X , Y , C , ** kwargs )
120
112
else :
121
113
orig_shape = X .shape
122
- xy = np .vstack ([X .flat , Y .flat ])
123
- xyt = xy .transpose ()
114
+ xyt = np .column_stack ([X .flat , Y .flat ])
124
115
wxy = self .transAux .transform (xyt )
125
- gx , gy = wxy [:,0 ].reshape (orig_shape ), wxy [:,1 ].reshape (orig_shape )
126
- mesh = pcolor_routine (self , gx , gy , C , ** kwargs )
116
+ gx = wxy [:, 0 ].reshape (orig_shape )
117
+ gy = wxy [:, 1 ].reshape (orig_shape )
118
+ mesh = super_pcolor (self , gx , gy , C , ** kwargs )
127
119
mesh .set_transform (self ._parent_axes .transData )
128
120
129
121
return mesh
130
122
131
123
def pcolormesh (self , * XYC , ** kwargs ):
132
- return self ._pcolor (" pcolormesh" , * XYC , ** kwargs )
124
+ return self ._pcolor (super (). pcolormesh , * XYC , ** kwargs )
133
125
134
126
def pcolor (self , * XYC , ** kwargs ):
135
- return self ._pcolor ("pcolor" , * XYC , ** kwargs )
136
-
127
+ return self ._pcolor (super ().pcolor , * XYC , ** kwargs )
137
128
138
- def _contour (self , method_name , * XYCL , ** kwargs ):
129
+ def _contour (self , super_contour , * XYCL , ** kwargs ):
139
130
140
131
if len (XYCL ) <= 2 :
141
132
C = XYCL [0 ]
142
133
ny , nx = C .shape
143
134
144
- gx = np .arange (0. , nx , 1. )
145
- gy = np .arange (0. , ny , 1. )
135
+ gx = np .arange (0. , nx )
136
+ gy = np .arange (0. , ny )
146
137
147
- X ,Y = np .meshgrid (gx , gy )
138
+ X , Y = np .meshgrid (gx , gy )
148
139
CL = XYCL
149
140
else :
150
141
X , Y = XYCL [:2 ]
151
142
CL = XYCL [2 :]
152
143
153
- contour_routine = self ._get_base_axes_attr (method_name )
154
-
155
144
if "transform" in kwargs :
156
- cont = contour_routine (self , X , Y , * CL , ** kwargs )
145
+ cont = super_contour (self , X , Y , * CL , ** kwargs )
157
146
else :
158
147
orig_shape = X .shape
159
- xy = np .vstack ([X .flat , Y .flat ])
160
- xyt = xy .transpose ()
148
+ xyt = np .column_stack ([X .flat , Y .flat ])
161
149
wxy = self .transAux .transform (xyt )
162
- gx , gy = wxy [:,0 ].reshape (orig_shape ), wxy [:,1 ].reshape (orig_shape )
163
- cont = contour_routine (self , gx , gy , * CL , ** kwargs )
150
+ gx = wxy [:, 0 ].reshape (orig_shape )
151
+ gy = wxy [:, 1 ].reshape (orig_shape )
152
+ cont = super_contour (self , gx , gy , * CL , ** kwargs )
164
153
for c in cont .collections :
165
154
c .set_transform (self ._parent_axes .transData )
166
155
167
156
return cont
168
157
169
158
def contour (self , * XYCL , ** kwargs ):
170
- return self ._contour (" contour" , * XYCL , ** kwargs )
159
+ return self ._contour (super (). contour , * XYCL , ** kwargs )
171
160
172
161
def contourf (self , * XYCL , ** kwargs ):
173
- return self ._contour (" contourf" , * XYCL , ** kwargs )
162
+ return self ._contour (super (). contourf , * XYCL , ** kwargs )
174
163
175
164
def apply_aspect (self , position = None ):
176
165
self .update_viewlim ()
177
- self ._get_base_axes_attr ("apply_aspect" )(self )
178
- #ParasiteAxes.apply_aspect()
166
+ super ().apply_aspect ()
179
167
180
168
181
169
@functools .lru_cache (None )
@@ -196,23 +184,10 @@ def parasite_axes_auxtrans_class_factory(axes_class=None):
196
184
axes_class = ParasiteAxes )
197
185
198
186
199
- def _get_handles (ax ):
200
- handles = ax .lines [:]
201
- handles .extend (ax .patches )
202
- handles .extend ([c for c in ax .collections
203
- if isinstance (c , mcoll .LineCollection )])
204
- handles .extend ([c for c in ax .collections
205
- if isinstance (c , mcoll .RegularPolyCollection )])
206
- handles .extend ([c for c in ax .collections
207
- if isinstance (c , mcoll .CircleCollection )])
208
-
209
- return handles
210
-
211
-
212
- class HostAxesBase (object ):
187
+ class HostAxesBase :
213
188
def __init__ (self , * args , ** kwargs ):
214
189
self .parasites = []
215
- self . _get_base_axes_attr ( "__init__" )( self , * args , ** kwargs )
190
+ super (). __init__ ( * args , ** kwargs )
216
191
217
192
def get_aux_axes (self , tr , viewlim_mode = "equal" , axes_class = None ):
218
193
parasite_axes_class = parasite_axes_auxtrans_class_factory (axes_class )
@@ -224,13 +199,9 @@ def get_aux_axes(self, tr, viewlim_mode="equal", axes_class=None):
224
199
return ax2
225
200
226
201
def _get_legend_handles (self , legend_handler_map = None ):
227
- # don't use this!
228
- Axes_get_legend_handles = self ._get_base_axes_attr ("_get_legend_handles" )
229
- all_handles = list (Axes_get_legend_handles (self , legend_handler_map ))
230
-
202
+ all_handles = super ()._get_legend_handles ()
231
203
for ax in self .parasites :
232
204
all_handles .extend (ax ._get_legend_handles (legend_handler_map ))
233
-
234
205
return all_handles
235
206
236
207
def draw (self , renderer ):
@@ -257,14 +228,14 @@ def draw(self, renderer):
257
228
self .images .extend (images )
258
229
self .artists .extend (artists )
259
230
260
- self . _get_base_axes_attr ( "draw" )( self , renderer )
231
+ super (). draw ( renderer )
261
232
self .artists = orig_artists
262
233
self .images = orig_images
263
234
264
235
def cla (self ):
265
236
for ax in self .parasites :
266
237
ax .cla ()
267
- self . _get_base_axes_attr ( "cla" )( self )
238
+ super (). cla ( )
268
239
269
240
def twinx (self , axes_class = None ):
270
241
"""
@@ -361,15 +332,10 @@ def _remove_method(h):
361
332
return ax2
362
333
363
334
def get_tightbbox (self , renderer , call_axes_locator = True ):
364
-
365
335
bbs = [ax .get_tightbbox (renderer , call_axes_locator )
366
336
for ax in self .parasites ]
367
- get_tightbbox = self ._get_base_axes_attr ("get_tightbbox" )
368
- bbs .append (get_tightbbox (self , renderer , call_axes_locator ))
369
-
370
- _bbox = Bbox .union ([b for b in bbs if b .width != 0 or b .height != 0 ])
371
-
372
- return _bbox
337
+ bbs .append (super ().get_tightbbox (renderer , call_axes_locator ))
338
+ return Bbox .union ([b for b in bbs if b .width != 0 or b .height != 0 ])
373
339
374
340
375
341
@functools .lru_cache (None )
@@ -380,13 +346,9 @@ def host_axes_class_factory(axes_class=None):
380
346
def _get_base_axes (self ):
381
347
return axes_class
382
348
383
- def _get_base_axes_attr (self , attrname ):
384
- return getattr (axes_class , attrname )
385
-
386
349
return type ("%sHostAxes" % axes_class .__name__ ,
387
350
(HostAxesBase , axes_class ),
388
- {'_get_base_axes_attr' : _get_base_axes_attr ,
389
- '_get_base_axes' : _get_base_axes })
351
+ {'_get_base_axes' : _get_base_axes })
390
352
391
353
392
354
def host_subplot_class_factory (axes_class ):
@@ -421,6 +383,7 @@ def host_axes(*args, axes_class=None, figure=None, **kwargs):
421
383
plt .draw_if_interactive ()
422
384
return ax
423
385
386
+
424
387
def host_subplot (* args , axes_class = None , figure = None , ** kwargs ):
425
388
"""
426
389
Create a subplot that can act as a host to parasitic axes.
0 commit comments