Skip to content

Commit eb13ba8

Browse files
Change draw order so that the 3D axis spines are not blocked by gridlines
1 parent 8b58763 commit eb13ba8

File tree

2 files changed

+43
-20
lines changed

2 files changed

+43
-20
lines changed

lib/mpl_toolkits/mplot3d/axes3d.py

+3
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,9 @@ def draw(self, renderer):
489489
# Draw panes first
490490
for axis in self._axis_map.values():
491491
axis.draw_pane(renderer)
492+
# Then gridlines
493+
for axis in self._axis_map.values():
494+
axis.draw_grid(renderer)
492495
# Then axes
493496
for axis in self._axis_map.values():
494497
axis.draw(renderer)

lib/mpl_toolkits/mplot3d/axis3d.py

+40-20
Original file line numberDiff line numberDiff line change
@@ -462,26 +462,6 @@ def draw(self, renderer):
462462
self.offsetText.set_ha(align)
463463
self.offsetText.draw(renderer)
464464

465-
if self.axes._draw_grid and len(ticks):
466-
# Grid points where the planes meet
467-
xyz0 = np.tile(minmax, (len(ticks), 1))
468-
xyz0[:, index] = [tick.get_loc() for tick in ticks]
469-
470-
# Grid lines go from the end of one plane through the plane
471-
# intersection (at xyz0) to the end of the other plane. The first
472-
# point (0) differs along dimension index-2 and the last (2) along
473-
# dimension index-1.
474-
lines = np.stack([xyz0, xyz0, xyz0], axis=1)
475-
lines[:, 0, index - 2] = maxmin[index - 2]
476-
lines[:, 2, index - 1] = maxmin[index - 1]
477-
self.gridlines.set_segments(lines)
478-
gridinfo = info['grid']
479-
self.gridlines.set_color(gridinfo['color'])
480-
self.gridlines.set_linewidth(gridinfo['linewidth'])
481-
self.gridlines.set_linestyle(gridinfo['linestyle'])
482-
self.gridlines.do_3d_projection()
483-
self.gridlines.draw(renderer)
484-
485465
# Draw ticks:
486466
tickdir = self._get_tickdir()
487467
tickdelta = deltas[tickdir] if highs[tickdir] else -deltas[tickdir]
@@ -519,6 +499,46 @@ def draw(self, renderer):
519499
renderer.close_group('axis3d')
520500
self.stale = False
521501

502+
@artist.allow_rasterization
503+
def draw_grid(self, renderer):
504+
if not self.axes._draw_grid:
505+
return
506+
507+
self.label._transform = self.axes.transData
508+
renderer.open_group("grid3d", gid=self.get_gid())
509+
510+
ticks = self._update_ticks()
511+
if len(ticks):
512+
# Get general axis information:
513+
info = self._axinfo
514+
index = info["i"]
515+
516+
mins, maxs, tc, highs = self._get_coord_info()
517+
518+
minmax = np.where(highs, maxs, mins)
519+
maxmin = np.where(~highs, maxs, mins)
520+
521+
# Grid points where the planes meet
522+
xyz0 = np.tile(minmax, (len(ticks), 1))
523+
xyz0[:, index] = [tick.get_loc() for tick in ticks]
524+
525+
# Grid lines go from the end of one plane through the plane
526+
# intersection (at xyz0) to the end of the other plane. The first
527+
# point (0) differs along dimension index-2 and the last (2) along
528+
# dimension index-1.
529+
lines = np.stack([xyz0, xyz0, xyz0], axis=1)
530+
lines[:, 0, index - 2] = maxmin[index - 2]
531+
lines[:, 2, index - 1] = maxmin[index - 1]
532+
self.gridlines.set_segments(lines)
533+
gridinfo = info['grid']
534+
self.gridlines.set_color(gridinfo['color'])
535+
self.gridlines.set_linewidth(gridinfo['linewidth'])
536+
self.gridlines.set_linestyle(gridinfo['linestyle'])
537+
self.gridlines.do_3d_projection()
538+
self.gridlines.draw(renderer)
539+
540+
renderer.close_group('grid3d')
541+
522542
# TODO: Get this to work (more) properly when mplot3d supports the
523543
# transforms framework.
524544
def get_tightbbox(self, renderer=None, *, for_layout_only=False):

0 commit comments

Comments
 (0)