diff --git a/examples/axes_grid/demo_imagegrid_aspect.py b/examples/axes_grid/demo_imagegrid_aspect.py new file mode 100644 index 000000000000..5a4af10cd458 --- /dev/null +++ b/examples/axes_grid/demo_imagegrid_aspect.py @@ -0,0 +1,20 @@ +import matplotlib.pyplot as plt + +from mpl_toolkits.axes_grid1 import ImageGrid +fig = plt.figure(1) + +grid1 = ImageGrid(fig, 121, (2,2), axes_pad=0.1, + aspect=True, share_all=True) + +for i in [0, 1]: + grid1[i].set_aspect(2) + + +grid2 = ImageGrid(fig, 122, (2,2), axes_pad=0.1, + aspect=True, share_all=True) + + +for i in [1, 3]: + grid2[i].set_aspect(2) + +plt.show() diff --git a/lib/mpl_toolkits/axes_grid1/axes_grid.py b/lib/mpl_toolkits/axes_grid1/axes_grid.py index f8a4b70a5cb0..f706ff78471a 100644 --- a/lib/mpl_toolkits/axes_grid1/axes_grid.py +++ b/lib/mpl_toolkits/axes_grid1/axes_grid.py @@ -544,6 +544,9 @@ def __init__(self, fig, else: axes_class, axes_class_args = axes_class + adjustable = axes_class_args.setdefault("adjustable", "box-forced") + if adjustable != "box-forced": + raise RuntimeError("adjustable parameter must not be set, or set to box-forced") self.axes_all = [] @@ -582,8 +585,12 @@ def __init__(self, fig, col, row = self._get_col_row(i) if share_all: - sharex = self._refax - sharey = self._refax + if self.axes_all: + sharex = self.axes_all[0] + sharey = self.axes_all[0] + else: + sharex = None + sharey = None else: sharex = self._column_refax[col] sharey = self._row_refax[row] @@ -591,19 +598,18 @@ def __init__(self, fig, ax = axes_class(fig, rect, sharex=sharex, sharey=sharey, **axes_class_args) - if share_all: - if self._refax is None: - self._refax = ax - else: - if sharex is None: - self._column_refax[col] = ax - if sharey is None: - self._row_refax[row] = ax - self.axes_all.append(ax) self.axes_column[col].append(ax) self.axes_row[row].append(ax) + if share_all: + if self._refax is None: + self._refax = ax + if sharex is None: + self._column_refax[col] = ax + if sharey is None: + self._row_refax[row] = ax + cax = self._defaultCbarAxesClass(fig, rect, orientation=self._colorbar_location) self.cbar_axes.append(cax) @@ -653,13 +659,14 @@ def _update_locators(self): self.cbar_axes[0].set_axes_locator(locator) self.cbar_axes[0].set_visible(True) - for col,ax in enumerate(self._column_refax): + for col,ax in enumerate(self.axes_row[0]): if h: h.append(self._horiz_pad_size) #Size.Fixed(self._axes_pad)) if ax: - sz = Size.AxesX(ax) + sz = Size.AxesX(ax, aspect="axes", ref_ax=self.axes_all[0]) else: - sz = Size.AxesX(self.axes_llc) + sz = Size.AxesX(self.axes_all[0], + aspect="axes", ref_ax=self.axes_all[0]) if (self._colorbar_mode == "each" or (self._colorbar_mode == 'edge' and @@ -682,13 +689,14 @@ def _update_locators(self): v_ax_pos = [] v_cb_pos = [] - for row,ax in enumerate(self._row_refax[::-1]): + for row,ax in enumerate(self.axes_column[0][::-1]): if v: v.append(self._horiz_pad_size) #Size.Fixed(self._axes_pad)) if ax: - sz = Size.AxesY(ax) + sz = Size.AxesY(ax, aspect="axes", ref_ax=self.axes_all[0]) else: - sz = Size.AxesY(self.axes_llc) + sz = Size.AxesY(self.axes_all[0], + aspect="axes", ref_ax=self.axes_all[0]) if (self._colorbar_mode == "each" or (self._colorbar_mode == 'edge' and diff --git a/lib/mpl_toolkits/axes_grid1/axes_size.py b/lib/mpl_toolkits/axes_grid1/axes_size.py index 33f40ea92e28..86bf928e1ed5 100644 --- a/lib/mpl_toolkits/axes_grid1/axes_size.py +++ b/lib/mpl_toolkits/axes_grid1/axes_size.py @@ -73,19 +73,39 @@ def get_size(self, renderer): Scalable=Scaled +def _get_axes_aspect(ax): + aspect = ax.get_aspect() + # when aspec is "auto", consider it as 1. + if aspect in ('normal', 'auto'): + aspect = 1. + elif aspect == "equal": + aspect = 1 + else: + aspect = float(aspect) + + return aspect class AxesX(_Base): """ Scaled size whose relative part corresponds to the data width of the *axes* multiplied by the *aspect*. """ - def __init__(self, axes, aspect=1.): + def __init__(self, axes, aspect=1., ref_ax=None): self._axes = axes self._aspect = aspect + if aspect == "axes" and ref_ax is None: + raise ValueError("ref_ax must be set when aspect='axes'") + self._ref_ax = ref_ax def get_size(self, renderer): l1, l2 = self._axes.get_xlim() - rel_size = abs(l2-l1)*self._aspect + if self._aspect == "axes": + ref_aspect = _get_axes_aspect(self._ref_ax) + aspect = ref_aspect/_get_axes_aspect(self._axes) + else: + aspect = self._aspect + + rel_size = abs(l2-l1)*aspect abs_size = 0. return rel_size, abs_size @@ -94,13 +114,23 @@ class AxesY(_Base): Scaled size whose relative part corresponds to the data height of the *axes* multiplied by the *aspect*. """ - def __init__(self, axes, aspect=1.): + def __init__(self, axes, aspect=1., ref_ax=None): self._axes = axes self._aspect = aspect + if aspect == "axes" and ref_ax is None: + raise ValueError("ref_ax must be set when aspect='axes'") + self._ref_ax = ref_ax def get_size(self, renderer): l1, l2 = self._axes.get_ylim() - rel_size = abs(l2-l1)*self._aspect + + if self._aspect == "axes": + ref_aspect = _get_axes_aspect(self._ref_ax) + aspect = _get_axes_aspect(self._axes) + else: + aspect = self._aspect + + rel_size = abs(l2-l1)*aspect abs_size = 0. return rel_size, abs_size