Skip to content

Use super() instead of manually fetching supermethods for parasite axes. #11678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ per-file-ignores =
mpl_toolkits/axes_grid1/colorbar.py: E225, E231, E261, E262, E302, E303, E501, E701
mpl_toolkits/axes_grid1/inset_locator.py: E501
mpl_toolkits/axes_grid1/mpl_axes.py: E303, E501
mpl_toolkits/axes_grid1/parasite_axes.py: E225, E231, E302, E303, E501
mpl_toolkits/axisartist/angle_helper.py: E201, E203, E221, E222, E225, E231, E251, E261, E262, E302, E303, E501
mpl_toolkits/axisartist/axis_artist.py: E201, E202, E221, E225, E228, E231, E251, E261, E262, E302, E303, E402, E501, E701, E711
mpl_toolkits/axisartist/axisline_style.py: E231, E261, E262, E302, E303
Expand Down
117 changes: 40 additions & 77 deletions lib/mpl_toolkits/axes_grid1/parasite_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np


class ParasiteAxesBase(object):
class ParasiteAxesBase:

def get_images_artists(self):
artists = {a for a in self.get_children() if a.get_visible()}
Expand All @@ -21,11 +21,10 @@ def get_images_artists(self):
def __init__(self, parent_axes, **kwargs):
self._parent_axes = parent_axes
kwargs["frameon"] = False
self._get_base_axes_attr("__init__")(
self, parent_axes.figure, parent_axes._position, **kwargs)
super().__init__(parent_axes.figure, parent_axes._position, **kwargs)

def cla(self):
self._get_base_axes_attr("cla")(self)
super().cla()

martist.setp(self.get_children(), visible=False)
self._get_lines = self._parent_axes._get_lines
Expand All @@ -45,18 +44,14 @@ def parasite_axes_class_factory(axes_class=None):
if axes_class is None:
axes_class = Axes

def _get_base_axes_attr(self, attrname):
return getattr(axes_class, attrname)

return type("%sParasite" % axes_class.__name__,
(ParasiteAxesBase, axes_class),
{'_get_base_axes_attr': _get_base_axes_attr})
(ParasiteAxesBase, axes_class), {})


ParasiteAxes = parasite_axes_class_factory()


class ParasiteAxesAuxTransBase(object):
class ParasiteAxesAuxTransBase:
def __init__(self, parent_axes, aux_transform, viewlim_mode=None,
**kwargs):

Expand All @@ -80,14 +75,13 @@ def _set_lim_and_transforms(self):

def set_viewlim_mode(self, mode):
if mode not in [None, "equal", "transform"]:
raise ValueError("Unknown mode : %s" % (mode,))
raise ValueError("Unknown mode: %s" % (mode,))
else:
self._viewlim_mode = mode

def get_viewlim_mode(self):
return self._viewlim_mode


def update_viewlim(self):
viewlim = self._parent_axes.viewLim.frozen()
mode = self.get_viewlim_mode()
Expand All @@ -96,86 +90,80 @@ def update_viewlim(self):
elif mode == "equal":
self.axes.viewLim.set(viewlim)
elif mode == "transform":
self.axes.viewLim.set(viewlim.transformed(self.transAux.inverted()))
self.axes.viewLim.set(
viewlim.transformed(self.transAux.inverted()))
else:
raise ValueError("Unknown mode : %s" % (self._viewlim_mode,))

raise ValueError("Unknown mode: %s" % (self._viewlim_mode,))

def _pcolor(self, method_name, *XYC, **kwargs):
def _pcolor(self, super_pcolor, *XYC, **kwargs):
if len(XYC) == 1:
C = XYC[0]
ny, nx = C.shape

gx = np.arange(-0.5, nx, 1.)
gy = np.arange(-0.5, ny, 1.)
gx = np.arange(-0.5, nx)
gy = np.arange(-0.5, ny)

X, Y = np.meshgrid(gx, gy)
else:
X, Y, C = XYC

pcolor_routine = self._get_base_axes_attr(method_name)

if "transform" in kwargs:
mesh = pcolor_routine(self, X, Y, C, **kwargs)
mesh = super_pcolor(self, X, Y, C, **kwargs)
else:
orig_shape = X.shape
xy = np.vstack([X.flat, Y.flat])
xyt=xy.transpose()
xyt = np.column_stack([X.flat, Y.flat])
wxy = self.transAux.transform(xyt)
gx, gy = wxy[:,0].reshape(orig_shape), wxy[:,1].reshape(orig_shape)
mesh = pcolor_routine(self, gx, gy, C, **kwargs)
gx = wxy[:, 0].reshape(orig_shape)
gy = wxy[:, 1].reshape(orig_shape)
mesh = super_pcolor(self, gx, gy, C, **kwargs)
mesh.set_transform(self._parent_axes.transData)

return mesh

def pcolormesh(self, *XYC, **kwargs):
return self._pcolor("pcolormesh", *XYC, **kwargs)
return self._pcolor(super().pcolormesh, *XYC, **kwargs)

def pcolor(self, *XYC, **kwargs):
return self._pcolor("pcolor", *XYC, **kwargs)

return self._pcolor(super().pcolor, *XYC, **kwargs)

def _contour(self, method_name, *XYCL, **kwargs):
def _contour(self, super_contour, *XYCL, **kwargs):

if len(XYCL) <= 2:
C = XYCL[0]
ny, nx = C.shape

gx = np.arange(0., nx, 1.)
gy = np.arange(0., ny, 1.)
gx = np.arange(0., nx)
gy = np.arange(0., ny)

X,Y = np.meshgrid(gx, gy)
X, Y = np.meshgrid(gx, gy)
CL = XYCL
else:
X, Y = XYCL[:2]
CL = XYCL[2:]

contour_routine = self._get_base_axes_attr(method_name)

if "transform" in kwargs:
cont = contour_routine(self, X, Y, *CL, **kwargs)
cont = super_contour(self, X, Y, *CL, **kwargs)
else:
orig_shape = X.shape
xy = np.vstack([X.flat, Y.flat])
xyt=xy.transpose()
xyt = np.column_stack([X.flat, Y.flat])
wxy = self.transAux.transform(xyt)
gx, gy = wxy[:,0].reshape(orig_shape), wxy[:,1].reshape(orig_shape)
cont = contour_routine(self, gx, gy, *CL, **kwargs)
gx = wxy[:, 0].reshape(orig_shape)
gy = wxy[:, 1].reshape(orig_shape)
cont = super_contour(self, gx, gy, *CL, **kwargs)
for c in cont.collections:
c.set_transform(self._parent_axes.transData)

return cont

def contour(self, *XYCL, **kwargs):
return self._contour("contour", *XYCL, **kwargs)
return self._contour(super().contour, *XYCL, **kwargs)

def contourf(self, *XYCL, **kwargs):
return self._contour("contourf", *XYCL, **kwargs)
return self._contour(super().contourf, *XYCL, **kwargs)

def apply_aspect(self, position=None):
self.update_viewlim()
self._get_base_axes_attr("apply_aspect")(self)
#ParasiteAxes.apply_aspect()
super().apply_aspect()


@functools.lru_cache(None)
Expand All @@ -196,23 +184,10 @@ def parasite_axes_auxtrans_class_factory(axes_class=None):
axes_class=ParasiteAxes)


def _get_handles(ax):
handles = ax.lines[:]
handles.extend(ax.patches)
handles.extend([c for c in ax.collections
if isinstance(c, mcoll.LineCollection)])
handles.extend([c for c in ax.collections
if isinstance(c, mcoll.RegularPolyCollection)])
handles.extend([c for c in ax.collections
if isinstance(c, mcoll.CircleCollection)])

return handles


class HostAxesBase(object):
class HostAxesBase:
def __init__(self, *args, **kwargs):
self.parasites = []
self._get_base_axes_attr("__init__")(self, *args, **kwargs)
super().__init__(*args, **kwargs)

def get_aux_axes(self, tr, viewlim_mode="equal", axes_class=None):
parasite_axes_class = parasite_axes_auxtrans_class_factory(axes_class)
Expand All @@ -224,13 +199,9 @@ def get_aux_axes(self, tr, viewlim_mode="equal", axes_class=None):
return ax2

def _get_legend_handles(self, legend_handler_map=None):
# don't use this!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jklymak You put this coment in last year. Why? Should it be removed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#9324 (comment) I think the whole method should be removed, but maybe I'll put that in as a separate PR and not burden @anntzer with sorting out this stuff as part of an otherwise straightforward PR

Axes_get_legend_handles = self._get_base_axes_attr("_get_legend_handles")
all_handles = list(Axes_get_legend_handles(self, legend_handler_map))

all_handles = super()._get_legend_handles()
for ax in self.parasites:
all_handles.extend(ax._get_legend_handles(legend_handler_map))

return all_handles

def draw(self, renderer):
Expand All @@ -257,14 +228,14 @@ def draw(self, renderer):
self.images.extend(images)
self.artists.extend(artists)

self._get_base_axes_attr("draw")(self, renderer)
super().draw(renderer)
self.artists = orig_artists
self.images = orig_images

def cla(self):
for ax in self.parasites:
ax.cla()
self._get_base_axes_attr("cla")(self)
super().cla()

def twinx(self, axes_class=None):
"""
Expand Down Expand Up @@ -361,15 +332,10 @@ def _remove_method(h):
return ax2

def get_tightbbox(self, renderer, call_axes_locator=True):

bbs = [ax.get_tightbbox(renderer, call_axes_locator)
for ax in self.parasites]
get_tightbbox = self._get_base_axes_attr("get_tightbbox")
bbs.append(get_tightbbox(self, renderer, call_axes_locator))

_bbox = Bbox.union([b for b in bbs if b.width!=0 or b.height!=0])

return _bbox
bbs.append(super().get_tightbbox(renderer, call_axes_locator))
return Bbox.union([b for b in bbs if b.width != 0 or b.height != 0])


@functools.lru_cache(None)
Expand All @@ -380,13 +346,9 @@ def host_axes_class_factory(axes_class=None):
def _get_base_axes(self):
return axes_class

def _get_base_axes_attr(self, attrname):
return getattr(axes_class, attrname)

return type("%sHostAxes" % axes_class.__name__,
(HostAxesBase, axes_class),
{'_get_base_axes_attr': _get_base_axes_attr,
'_get_base_axes': _get_base_axes})
{'_get_base_axes': _get_base_axes})


def host_subplot_class_factory(axes_class):
Expand Down Expand Up @@ -421,6 +383,7 @@ def host_axes(*args, axes_class=None, figure=None, **kwargs):
plt.draw_if_interactive()
return ax


def host_subplot(*args, axes_class=None, figure=None, **kwargs):
"""
Create a subplot that can act as a host to parasitic axes.
Expand Down