Skip to content

Commit d0956f3

Browse files
committed
Merge pull request mwaskom#551 from mwaskom/default_linewidths
Some changes to help heatmap and clustermap plot large matrices
2 parents 325bd8d + 09060db commit d0956f3

File tree

3 files changed

+93
-24
lines changed

3 files changed

+93
-24
lines changed

doc/releases/v0.6.0.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,14 @@ Other additions
6363

6464
- Added a ``savefig`` method to :class:`JointGrid` that defaults to a tight bounding box to make it easier to save figures using this class.
6565

66+
- Changed the default ``linewidths`` in :func:`heatmap` and :`clustermap` to 0 so that larger matrices plot correctly. This parameter still exists and can be used to get the old effect of lines demarcating each cell in the heatmap (the old default ``linewidths`` was 0.5).
67+
68+
- You can now pass an integer to the ``xticklabels`` and ``yticklabels`` parameter of :func:`heatmap` (and, by extension, :func:`clustermap`). This will make the plot use the ticklabels inferred from the data, but only plot every ``n`` label, where ``n`` is the number you pass. This can help when visualizing larger matrices with some sensible ordering to the rows or columns of the dataframe.
69+
6670
Bug fixes
6771
~~~~~~~~~
6872

69-
- Fixed a bug in :func:`clustermap` where the mask was not being reorganized using the dendrograms.
73+
- Fixed bugs in :func:`clustermap` where the mask and specified ticklabels were not being reorganized using the dendrograms.
7074

7175
- Fixed a bug in :class:`FacetGrid` and :class:`PairGrid` that lead to incorrect legend labels when levels of the ``hue`` variable appeared in ``hue_order`` but not in the data.
7276

seaborn/matrix.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -107,25 +107,38 @@ def __init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt,
107107
plot_data = np.ma.masked_where(np.asarray(mask), plot_data)
108108

109109
# Get good names for the rows and columns
110-
if isinstance(xticklabels, bool) and xticklabels:
111-
self.xticklabels = _index_to_ticklabels(data.columns)
110+
xtickevery = 1
111+
if isinstance(xticklabels, int) and xticklabels > 1:
112+
xtickevery = xticklabels
113+
xticklabels = _index_to_ticklabels(data.columns)
114+
elif isinstance(xticklabels, bool) and xticklabels:
115+
xticklabels = _index_to_ticklabels(data.columns)
112116
elif isinstance(xticklabels, bool) and not xticklabels:
113-
self.xticklabels = ['' for _ in range(data.shape[1])]
114-
else:
115-
self.xticklabels = xticklabels
116-
117-
xlabel = _index_to_label(data.columns)
118-
119-
if isinstance(yticklabels, bool) and yticklabels:
120-
self.yticklabels = _index_to_ticklabels(data.index)
117+
xticklabels = ['' for _ in range(data.shape[1])]
118+
119+
ytickevery = 1
120+
if isinstance(yticklabels, int) and yticklabels > 1:
121+
ytickevery = yticklabels
122+
yticklabels = _index_to_ticklabels(data.index)
123+
elif isinstance(yticklabels, bool) and yticklabels:
124+
yticklabels = _index_to_ticklabels(data.index)
121125
elif isinstance(yticklabels, bool) and not yticklabels:
122-
self.yticklabels = ['' for _ in range(data.shape[0])]
126+
yticklabels = ['' for _ in range(data.shape[0])]
123127
else:
124-
self.yticklabels = yticklabels[::-1]
128+
yticklabels = yticklabels[::-1]
125129

126-
ylabel = _index_to_label(data.index)
130+
# Get the positions and used label for the ticks
131+
nx, ny = data.T.shape
132+
xstart, xend, xstep = 0, nx, xtickevery
133+
self.xticks = np.arange(xstart, xend, xstep) + .5
134+
self.xticklabels = xticklabels[xstart:xend:xstep]
135+
ystart, yend, ystep = (ny - 1) % ytickevery, ny, ytickevery
136+
self.yticks = np.arange(ystart, yend, ystep) + .5
137+
self.yticklabels = yticklabels[ystart:yend:ystep]
127138

128139
# Get good names for the axis labels
140+
xlabel = _index_to_label(data.columns)
141+
ylabel = _index_to_label(data.index)
129142
self.xlabel = xlabel if xlabel is not None else ""
130143
self.ylabel = ylabel if ylabel is not None else ""
131144

@@ -204,8 +217,7 @@ def plot(self, ax, cax, kws):
204217
ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0]))
205218

206219
# Add row and column labels
207-
nx, ny = self.data.T.shape
208-
ax.set(xticks=np.arange(nx) + .5, yticks=np.arange(ny) + .5)
220+
ax.set(xticks=self.xticks, yticks=self.yticks)
209221
xtl = ax.set_xticklabels(self.xticklabels)
210222
ytl = ax.set_yticklabels(self.yticklabels, rotation="vertical")
211223

@@ -233,7 +245,7 @@ def plot(self, ax, cax, kws):
233245

234246
def heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False,
235247
annot=False, fmt=".2g", annot_kws=None,
236-
linewidths=.5, linecolor="white",
248+
linewidths=0, linecolor="white",
237249
cbar=True, cbar_kws=None, cbar_ax=None,
238250
square=False, ax=None, xticklabels=True, yticklabels=True,
239251
mask=None,
@@ -276,9 +288,9 @@ def heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False,
276288
annot_kws : dict of key, value mappings, optional
277289
Keyword arguments for ``ax.text`` when ``annot`` is True.
278290
linewidths : float, optional
279-
Width of the lines that divide each cell.
291+
Width of the lines that will divide each cell.
280292
linecolor : color, optional
281-
Color of the lines that divide each cell.
293+
Color of the lines that will divide each cell.
282294
cbar : boolean, optional
283295
Whether to draw a colorbar.
284296
cbar_kws : dict of key, value mappings, optional
@@ -292,14 +304,16 @@ def heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False,
292304
ax : matplotlib Axes, optional
293305
Axes in which to draw the plot, otherwise use the currently-active
294306
Axes.
295-
xticklabels : list-like or bool, optional
307+
xticklabels : list-like, int, or bool, optional
296308
If True, plot the column names of the dataframe. If False, don't plot
297309
the column names. If list-like, plot these alternate labels as the
298-
xticklabels
299-
yticklabels : list-like or bool, optional
310+
xticklabels. If an integer, use the column names but plot only every
311+
n label.
312+
yticklabels : list-like, int, or bool, optional
300313
If True, plot the row names of the dataframe. If False, don't plot
301314
the row names. If list-like, plot these alternate labels as the
302-
yticklabels
315+
yticklabels. If an integer, use the index names but plot only every
316+
n label.
303317
mask : boolean array or DataFrame, optional
304318
If passed, data will not be shown in cells where ``mask`` is True.
305319
Cells with missing values are automatically masked.
@@ -835,8 +849,22 @@ def plot_colors(self, xind, yind, **kws):
835849
def plot_matrix(self, colorbar_kws, xind, yind, **kws):
836850
self.data2d = self.data2d.iloc[yind, xind]
837851
self.mask = self.mask.iloc[yind, xind]
852+
853+
# Try to reorganize specified tick labels, if provided
854+
xtl = kws.pop("xticklabels", True)
855+
try:
856+
xtl = np.asarray(xtl)[xind]
857+
except (TypeError, IndexError):
858+
pass
859+
ytl = kws.pop("yticklabels", True)
860+
try:
861+
ytl = np.asarray(ytl)[yind]
862+
except (TypeError, IndexError):
863+
pass
864+
838865
heatmap(self.data2d, ax=self.ax_heatmap, cbar_ax=self.cax,
839-
cbar_kws=colorbar_kws, mask=self.mask, **kws)
866+
cbar_kws=colorbar_kws, mask=self.mask,
867+
xticklabels=xtl, yticklabels=ytl, **kws)
840868
self.ax_heatmap.yaxis.set_ticks_position('right')
841869
self.ax_heatmap.yaxis.set_label_position('right')
842870

seaborn/tests/test_matrix.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,22 @@ def test_custom_ticklabels(self):
200200
nt.assert_equal(p.xticklabels, xticklabels)
201201
nt.assert_equal(p.yticklabels, yticklabels[::-1])
202202

203+
def test_custom_ticklabel_interval(self):
204+
205+
kws = self.default_kws.copy()
206+
kws['xticklabels'] = 2
207+
kws['yticklabels'] = 3
208+
p = mat._HeatMapper(self.df_norm, **kws)
209+
210+
nx, ny = self.df_norm.T.shape
211+
ystart = (ny - 1) % 3
212+
npt.assert_array_equal(p.xticks, np.arange(0, nx, 2) + .5)
213+
npt.assert_array_equal(p.yticks, np.arange(ystart, ny, 3) + .5)
214+
npt.assert_array_equal(p.xticklabels,
215+
self.df_norm.columns[::2])
216+
npt.assert_array_equal(p.yticklabels,
217+
self.df_norm.index[::-1][ystart:ny:3])
218+
203219
def test_heatmap_annotation(self):
204220

205221
ax = mat.heatmap(self.df_norm, annot=True, fmt=".1f",
@@ -879,3 +895,24 @@ def test_mask_reorganization(self):
879895
g.dendrogram_col.reordered_ind])
880896

881897
plt.close("all")
898+
899+
def test_ticklabel_reorganization(self):
900+
901+
kws = self.default_kws.copy()
902+
xtl = np.arange(self.df_norm.shape[1])
903+
kws["xticklabels"] = list(xtl)
904+
ytl = self.letters.ix[:self.df_norm.shape[0]]
905+
kws["yticklabels"] = ytl
906+
907+
g = mat.clustermap(self.df_norm, **kws)
908+
909+
xtl_actual = [t.get_text() for t in g.ax_heatmap.get_xticklabels()]
910+
ytl_actual = [t.get_text() for t in g.ax_heatmap.get_yticklabels()]
911+
912+
xtl_want = xtl[g.dendrogram_col.reordered_ind].astype("<U1")
913+
ytl_want = ytl[g.dendrogram_row.reordered_ind].astype("<U1")[::-1]
914+
915+
npt.assert_array_equal(xtl_actual, xtl_want)
916+
npt.assert_array_equal(ytl_actual, ytl_want)
917+
918+
plt.close("all")

0 commit comments

Comments
 (0)