diff --git a/lib/mpl_toolkits/mplot3d/axis3d.py b/lib/mpl_toolkits/mplot3d/axis3d.py index 7ea7bc08069c..ba04570c65ad 100644 --- a/lib/mpl_toolkits/mplot3d/axis3d.py +++ b/lib/mpl_toolkits/mplot3d/axis3d.py @@ -283,8 +283,7 @@ def draw(self, renderer): outeredgep = edgep2 outerindex = 1 - pos = outeredgep.copy() - pos = move_from_center(pos, centers, labeldeltas, axmask) + pos = move_from_center(outeredgep, centers, labeldeltas, axmask) olx, oly, olz = proj3d.proj_transform(*pos, renderer.M) self.offsetText.set_text(self.major.formatter.get_offset()) self.offsetText.set_position((olx, oly)) @@ -339,19 +338,14 @@ def draw(self, renderer): self.offsetText.set_ha(align) self.offsetText.draw(renderer) - # Draw grid lines if self.axes._draw_grid and len(ticks): - # Grid points at end of one plane - xyz1 = xyz0.copy() - newindex = (index + 1) % 3 - xyz1[:, newindex] = maxmin[newindex] - - # Grid points at end of the other plane - xyz2 = xyz0.copy() - newindex = (index + 2) % 3 - xyz2[:, newindex] = maxmin[newindex] - - lines = np.stack([xyz1, xyz0, xyz2], axis=1) + # Grid lines go from the end of one plane through the plane + # intersection (at xyz0) to the end of the other plane. The first + # point (0) differs along dimension index-2 and the last (2) along + # dimension index-1. + lines = np.stack([xyz0, xyz0, xyz0], axis=1) + lines[:, 0, index - 2] = maxmin[index - 2] + lines[:, 2, index - 1] = maxmin[index - 1] self.gridlines.set_segments(lines) self.gridlines.set_color(info['grid']['color']) self.gridlines.set_linewidth(info['grid']['linewidth'])