diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index 444150748208..1b0445a039a7 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -1,5 +1,6 @@ from collections import OrderedDict from contextlib import ExitStack +import functools import inspect import itertools import logging @@ -562,13 +563,10 @@ def __init__(self, fig, rect, self.update(kwargs) - if self.xaxis is not None: - self._xcid = self.xaxis.callbacks.connect( - 'units finalize', lambda: self._on_units_changed(scalex=True)) - - if self.yaxis is not None: - self._ycid = self.yaxis.callbacks.connect( - 'units finalize', lambda: self._on_units_changed(scaley=True)) + for name, axis in self._get_axis_map().items(): + axis.callbacks._pickled_cids.add( + axis.callbacks.connect( + 'units finalize', self._unit_change_handler(name))) rcParams = mpl.rcParams self.tick_params( @@ -2150,14 +2148,17 @@ def add_container(self, container): container._remove_method = self.containers.remove return container - def _on_units_changed(self, scalex=False, scaley=False): + def _unit_change_handler(self, axis_name, event=None): """ - Callback for processing changes to axis units. - - Currently requests updates of data limits and view limits. + Process axis units changes: requests updates to data and view limits. """ + if event is None: # Allow connecting `self._unit_change_handler(name)` + return functools.partial( + self._unit_change_handler, axis_name, event=object()) + _api.check_in_list(self._get_axis_map(), axis_name=axis_name) self.relim() - self._request_autoscale_view(scalex=scalex, scaley=scaley) + self._request_autoscale_view(scalex=(axis_name == "x"), + scaley=(axis_name == "y")) def relim(self, visible_only=False): """ diff --git a/lib/matplotlib/lines.py b/lib/matplotlib/lines.py index 479ae162112b..af92fadf79d9 100644 --- a/lib/matplotlib/lines.py +++ b/lib/matplotlib/lines.py @@ -627,13 +627,9 @@ def axes(self, ax): # call the set method from the base-class property Artist.axes.fset(self, ax) if ax is not None: - # connect unit-related callbacks - if ax.xaxis is not None: - self._xcid = ax.xaxis.callbacks.connect('units', - self.recache_always) - if ax.yaxis is not None: - self._ycid = ax.yaxis.callbacks.connect('units', - self.recache_always) + for axis in ax._get_axis_map().values(): + axis.callbacks._pickled_cids.add( + axis.callbacks.connect('units', self.recache_always)) def set_data(self, *args): """ diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 106a12b9f2bf..605bf8325291 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -11,6 +11,7 @@ """ from collections import defaultdict +import functools from functools import reduce from itertools import compress import math @@ -111,12 +112,6 @@ def __init__( # func used to format z -- fall back on major formatters self.fmt_zdata = None - if self.zaxis is not None: - self._zcid = self.zaxis.callbacks.connect( - 'units finalize', lambda: self._on_units_changed(scalez=True)) - else: - self._zcid = None - self.mouse_init() self.figure.canvas.callbacks._pickled_cids.update({ self.figure.canvas.mpl_connect( @@ -475,14 +470,16 @@ def get_axis_position(self): zhigh = tc[0][2] > tc[2][2] return xhigh, yhigh, zhigh - def _on_units_changed(self, scalex=False, scaley=False, scalez=False): - """ - Callback for processing changes to axis units. - - Currently forces updates of data limits and view limits. - """ + def _unit_change_handler(self, axis_name, event=None): + # docstring inherited + if event is None: # Allow connecting `self._unit_change_handler(name)` + return functools.partial( + self._unit_change_handler, axis_name, event=object()) + _api.check_in_list(self._get_axis_map(), axis_name=axis_name) self.relim() - self.autoscale_view(scalex=scalex, scaley=scaley, scalez=scalez) + self.autoscale_view(scalex=(axis_name == "x"), + scaley=(axis_name == "y"), + scalez=(axis_name == "z")) def update_datalim(self, xys, **kwargs): pass