Skip to content

axes_grid1: ImageGrid respect the aspect ratio of axes. #2248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 9, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions examples/axes_grid/demo_imagegrid_aspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import matplotlib.pyplot as plt

from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure(1)

grid1 = ImageGrid(fig, 121, (2,2), axes_pad=0.1,
aspect=True, share_all=True)

for i in [0, 1]:
grid1[i].set_aspect(2)


grid2 = ImageGrid(fig, 122, (2,2), axes_pad=0.1,
aspect=True, share_all=True)


for i in [1, 3]:
grid2[i].set_aspect(2)

plt.show()
42 changes: 25 additions & 17 deletions lib/mpl_toolkits/axes_grid1/axes_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,9 @@ def __init__(self, fig,
else:
axes_class, axes_class_args = axes_class

adjustable = axes_class_args.setdefault("adjustable", "box-forced")
if adjustable != "box-forced":
raise RuntimeError("adjustable parameter must not be set, or set to box-forced")


self.axes_all = []
Expand Down Expand Up @@ -582,28 +585,31 @@ def __init__(self, fig,
col, row = self._get_col_row(i)

if share_all:
sharex = self._refax
sharey = self._refax
if self.axes_all:
sharex = self.axes_all[0]
sharey = self.axes_all[0]
else:
sharex = None
sharey = None
else:
sharex = self._column_refax[col]
sharey = self._row_refax[row]

ax = axes_class(fig, rect, sharex=sharex, sharey=sharey,
**axes_class_args)

if share_all:
if self._refax is None:
self._refax = ax
else:
if sharex is None:
self._column_refax[col] = ax
if sharey is None:
self._row_refax[row] = ax

self.axes_all.append(ax)
self.axes_column[col].append(ax)
self.axes_row[row].append(ax)

if share_all:
if self._refax is None:
self._refax = ax
if sharex is None:
self._column_refax[col] = ax
if sharey is None:
self._row_refax[row] = ax

cax = self._defaultCbarAxesClass(fig, rect,
orientation=self._colorbar_location)
self.cbar_axes.append(cax)
Expand Down Expand Up @@ -653,13 +659,14 @@ def _update_locators(self):
self.cbar_axes[0].set_axes_locator(locator)
self.cbar_axes[0].set_visible(True)

for col,ax in enumerate(self._column_refax):
for col,ax in enumerate(self.axes_row[0]):
if h: h.append(self._horiz_pad_size) #Size.Fixed(self._axes_pad))

if ax:
sz = Size.AxesX(ax)
sz = Size.AxesX(ax, aspect="axes", ref_ax=self.axes_all[0])
else:
sz = Size.AxesX(self.axes_llc)
sz = Size.AxesX(self.axes_all[0],
aspect="axes", ref_ax=self.axes_all[0])

if (self._colorbar_mode == "each" or
(self._colorbar_mode == 'edge' and
Expand All @@ -682,13 +689,14 @@ def _update_locators(self):

v_ax_pos = []
v_cb_pos = []
for row,ax in enumerate(self._row_refax[::-1]):
for row,ax in enumerate(self.axes_column[0][::-1]):
if v: v.append(self._horiz_pad_size) #Size.Fixed(self._axes_pad))

if ax:
sz = Size.AxesY(ax)
sz = Size.AxesY(ax, aspect="axes", ref_ax=self.axes_all[0])
else:
sz = Size.AxesY(self.axes_llc)
sz = Size.AxesY(self.axes_all[0],
aspect="axes", ref_ax=self.axes_all[0])

if (self._colorbar_mode == "each" or
(self._colorbar_mode == 'edge' and
Expand Down
38 changes: 34 additions & 4 deletions lib/mpl_toolkits/axes_grid1/axes_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,39 @@ def get_size(self, renderer):

Scalable=Scaled

def _get_axes_aspect(ax):
aspect = ax.get_aspect()
# when aspec is "auto", consider it as 1.
if aspect in ('normal', 'auto'):
aspect = 1.
elif aspect == "equal":
aspect = 1
else:
aspect = float(aspect)

return aspect

class AxesX(_Base):
"""
Scaled size whose relative part corresponds to the data width
of the *axes* multiplied by the *aspect*.
"""
def __init__(self, axes, aspect=1.):
def __init__(self, axes, aspect=1., ref_ax=None):
self._axes = axes
self._aspect = aspect
if aspect == "axes" and ref_ax is None:
raise ValueError("ref_ax must be set when aspect='axes'")
self._ref_ax = ref_ax

def get_size(self, renderer):
l1, l2 = self._axes.get_xlim()
rel_size = abs(l2-l1)*self._aspect
if self._aspect == "axes":
ref_aspect = _get_axes_aspect(self._ref_ax)
aspect = ref_aspect/_get_axes_aspect(self._axes)
else:
aspect = self._aspect

rel_size = abs(l2-l1)*aspect
abs_size = 0.
return rel_size, abs_size

Expand All @@ -94,13 +114,23 @@ class AxesY(_Base):
Scaled size whose relative part corresponds to the data height
of the *axes* multiplied by the *aspect*.
"""
def __init__(self, axes, aspect=1.):
def __init__(self, axes, aspect=1., ref_ax=None):
self._axes = axes
self._aspect = aspect
if aspect == "axes" and ref_ax is None:
raise ValueError("ref_ax must be set when aspect='axes'")
self._ref_ax = ref_ax

def get_size(self, renderer):
l1, l2 = self._axes.get_ylim()
rel_size = abs(l2-l1)*self._aspect

if self._aspect == "axes":
ref_aspect = _get_axes_aspect(self._ref_ax)
aspect = _get_axes_aspect(self._axes)
else:
aspect = self._aspect

rel_size = abs(l2-l1)*aspect
abs_size = 0.
return rel_size, abs_size

Expand Down