Skip to content

Picklable figures #1020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
First pass pickle support for figures, transforms & artists.
  • Loading branch information
pelson committed Aug 20, 2012
commit 0db9429b5a3fd57b19b62e6eb23dcff1f12b5d03
9 changes: 8 additions & 1 deletion lib/matplotlib/artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def __init__(self):
self.y_isdata = True # with y
self._snap = None

def __getstate__(self):
d = self.__dict__.copy()
# remove the unpicklable remove method, this will get re-added on load
d.pop('_remove_method')
# axes_artist_collections = ['lines', 'collections', 'tables', '']
return d

def remove(self):
"""
Remove the artist from the figure if possible. The effect
Expand All @@ -123,7 +130,7 @@ def remove(self):
# the _remove_method attribute directly. This would be a protected
# attribute if Python supported that sort of thing. The callback
# has one parameter, which is the child to be removed.
if self._remove_method != None:
if self._remove_method is not None:
self._remove_method(self)
else:
raise NotImplementedError('cannot remove artist')
Expand Down
80 changes: 66 additions & 14 deletions lib/matplotlib/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,13 @@ def __init__(self, axes, command='plot'):
self.command = command
self.set_color_cycle()

def __getinitargs__(self):
# means that the color cycle will be lost.
return (self.axes, self.command)

def __getstate__(self):
return False

def set_color_cycle(self, clist=None):
if clist is None:
clist = rcParams['axes.color_cycle']
Expand Down Expand Up @@ -332,7 +339,7 @@ def _grab_next_args(self, *args, **kwargs):
for seg in self._plot_args(remaining[:isplit], kwargs):
yield seg
remaining=remaining[isplit:]


class Axes(martist.Artist):
"""
Expand All @@ -352,9 +359,10 @@ class Axes(martist.Artist):

_shared_x_axes = cbook.Grouper()
_shared_y_axes = cbook.Grouper()

def __str__(self):
return "Axes(%g,%g;%gx%g)" % tuple(self._position.bounds)

def __init__(self, fig, rect,
axisbg = None, # defaults to rc axes.facecolor
frameon = True,
Expand Down Expand Up @@ -1423,7 +1431,9 @@ def add_artist(self, a):
self.artists.append(a)
self._set_artist_props(a)
a.set_clip_path(self.patch)
a._remove_method = lambda h: self.artists.remove(h)
def remove_fn(artist):
self.artists.remove(artist)
a._remove_method = remove_fn #lambda h: self.artists.remove(h)
return a

def add_collection(self, collection, autolim=True):
Expand All @@ -1445,7 +1455,11 @@ def add_collection(self, collection, autolim=True):
if collection._paths and len(collection._paths):
self.update_datalim(collection.get_datalim(self.transData))

collection._remove_method = lambda h: self.collections.remove(h)
# XXX back to start
def remove_fn(artist):
self.collections.remove(artist)

collection._remove_method = remove_fn #lambda h: self.collections.remove(h)
return collection

def add_line(self, line):
Expand All @@ -1463,7 +1477,10 @@ def add_line(self, line):
if not line.get_label():
line.set_label('_line%d'%len(self.lines))
self.lines.append(line)
line._remove_method = lambda h: self.lines.remove(h)
# def remove_fn(artist):
# self.lines.remove(artist)
# line._remove_method = remove_fn #lambda h: self.lines.remove(h)
line._remove_method = self.lines.remove
return line

def _update_line_limits(self, line):
Expand All @@ -1489,7 +1506,9 @@ def add_patch(self, p):
p.set_clip_path(self.patch)
self._update_patch_limits(p)
self.patches.append(p)
p._remove_method = lambda h: self.patches.remove(h)
def remove_fn(artist):
self.patches.remove(artist)
p._remove_method = remove_fn #lambda h: self.patches.remove(h)
return p

def _update_patch_limits(self, patch):
Expand Down Expand Up @@ -1524,7 +1543,9 @@ def add_table(self, tab):
self._set_artist_props(tab)
self.tables.append(tab)
tab.set_clip_path(self.patch)
tab._remove_method = lambda h: self.tables.remove(h)
def remove_fn(artist):
self.tables.remove(artist)
tab._remove_method = remove_fn #lambda h: self.tables.remove(h)
return tab

def add_container(self, container):
Expand All @@ -1538,7 +1559,9 @@ def add_container(self, container):
if not label:
container.set_label('_container%d'%len(self.containers))
self.containers.append(container)
container.set_remove_method(lambda h: self.containers.remove(container))
def remove_fn(artist):
self.containers.remove(artist)
container.set_remove_method(remove_fn)
return container


Expand Down Expand Up @@ -1599,13 +1622,13 @@ def _process_unit_info(self, xdata=None, ydata=None, kwargs=None):
if xdata is not None:
# we only need to update if there is nothing set yet.
if not self.xaxis.have_units():
self.xaxis.update_units(xdata)
self.xaxis.update_units(xdata)
#print '\tset from xdata', self.xaxis.units

if ydata is not None:
# we only need to update if there is nothing set yet.
if not self.yaxis.have_units():
self.yaxis.update_units(ydata)
self.yaxis.update_units(ydata)
#print '\tset from ydata', self.yaxis.units

# process kwargs 2nd since these will override default units
Expand Down Expand Up @@ -3330,7 +3353,9 @@ def text(self, x, y, s, fontdict=None,
if fontdict is not None: t.update(fontdict)
t.update(kwargs)
self.texts.append(t)
t._remove_method = lambda h: self.texts.remove(h)
def remove_fn(artist):
self.texts.remove(artist)
t._remove_method = remove_fn #lambda h: self.texts.remove(h)


#if t.get_clip_on(): t.set_clip_box(self.bbox)
Expand Down Expand Up @@ -3359,7 +3384,9 @@ def annotate(self, *args, **kwargs):
self._set_artist_props(a)
if kwargs.has_key('clip_on'): a.set_clip_path(self.patch)
self.texts.append(a)
a._remove_method = lambda h: self.texts.remove(h)
def remove_fn(artist):
self.texts.remove(artist)
a._remove_method = remove_fn #lambda h: self.texts.remove(h)
return a

#### Lines and spans
Expand Down Expand Up @@ -7022,7 +7049,9 @@ def imshow(self, X, cmap=None, norm=None, aspect=None,
im.set_extent(im.get_extent())

self.images.append(im)
im._remove_method = lambda h: self.images.remove(h)
def remove_fn(artist):
self.images.remove(artist)
im._remove_method = remove_fn #lambda h: self.images.remove(h)

return im

Expand Down Expand Up @@ -8770,7 +8799,15 @@ def __init__(self, fig, *args, **kwargs):
# _axes_class is set in the subplot_class_factory
self._axes_class.__init__(self, fig, self.figbox, **kwargs)


def __reduce__(self):
# get the first axes class which does not inherit from a subplotbase
axes_class = filter(lambda klass: (issubclass(klass, Axes) and
not issubclass(klass, SubplotBase)),
self.__class__.mro())[0]
r = [_PicklableSubplotClassConstructor(),
(axes_class,),
self.__getstate__()]
return tuple(r)

def get_geometry(self):
"""get the subplot geometry, eg 2,2,3"""
Expand Down Expand Up @@ -8852,6 +8889,21 @@ def subplot_class_factory(axes_class=None):
# This is provided for backward compatibility
Subplot = subplot_class_factory()


class _PicklableSubplotClassConstructor(object):
"""
This stub class exists to return the appropriate subplot
class when __call__-ed with an axes class. This is purely to
allow Picking of Axes and Subplots."""
def __call__(self, axes_class):
# create a dummy object instance
subplot_instance = _PicklableSubplotClassConstructor()
subplot_class = subplot_class_factory(axes_class)
# update the class to the desired subplot class
subplot_instance.__class__ = subplot_class
return subplot_instance


docstring.interpd.update(Axes=martist.kwdoc(Axes))
docstring.interpd.update(Subplot=martist.kwdoc(Axes))

Expand Down
1 change: 0 additions & 1 deletion lib/matplotlib/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,6 @@ class Ticker:
formatter = None



class Axis(artist.Artist):

"""
Expand Down
Loading