diff --git a/lib/mpl_toolkits/mplot3d/axis3d.py b/lib/mpl_toolkits/mplot3d/axis3d.py index d08576904d29..fff4a4f71701 100644 --- a/lib/mpl_toolkits/mplot3d/axis3d.py +++ b/lib/mpl_toolkits/mplot3d/axis3d.py @@ -322,10 +322,6 @@ def draw(self, renderer): self.line.set_data(pep[0], pep[1]) self.line.draw(renderer) - # Grid points where the planes meet - xyz0 = np.tile(minmax, (len(ticks), 1)) - xyz0[:, index] = [tick.get_loc() for tick in ticks] - # Draw labels # The transAxes transform is used because the Text object # rotates the text relative to the display coordinate system. @@ -414,10 +410,7 @@ def draw(self, renderer): if (centpt[index] > pep[index, outerindex] and np.count_nonzero(highs) % 2 == 0): # Usually mean align left, except if it is axis 2 - if index == 2: - align = 'right' - else: - align = 'left' + align = 'right' if index == 2 else 'left' else: # The TT case align = 'right' @@ -427,6 +420,10 @@ def draw(self, renderer): self.offsetText.draw(renderer) if self.axes._draw_grid and len(ticks): + # Grid points where the planes meet + xyz0 = np.tile(minmax, (len(ticks), 1)) + xyz0[:, index] = [tick.get_loc() for tick in ticks] + # Grid lines go from the end of one plane through the plane # intersection (at xyz0) to the end of the other plane. The first # point (0) differs along dimension index-2 and the last (2) along @@ -435,47 +432,45 @@ def draw(self, renderer): lines[:, 0, index - 2] = maxmin[index - 2] lines[:, 2, index - 1] = maxmin[index - 1] self.gridlines.set_segments(lines) - self.gridlines.set_color(info['grid']['color']) - self.gridlines.set_linewidth(info['grid']['linewidth']) - self.gridlines.set_linestyle(info['grid']['linestyle']) + gridinfo = info['grid'] + self.gridlines.set_color(gridinfo['color']) + self.gridlines.set_linewidth(gridinfo['linewidth']) + self.gridlines.set_linestyle(gridinfo['linestyle']) self.gridlines.do_3d_projection() self.gridlines.draw(renderer) # Draw ticks: tickdir = self._get_tickdir() - tickdelta = deltas[tickdir] - if highs[tickdir]: - ticksign = 1 - else: - ticksign = -1 - + tickdelta = deltas[tickdir] if highs[tickdir] else -deltas[tickdir] + + tick_info = info['tick'] + tick_out = tick_info['outward_factor'] * tickdelta + tick_in = tick_info['inward_factor'] * tickdelta + tick_lw = tick_info['linewidth'] + edgep1_tickdir = edgep1[tickdir] + out_tickdir = edgep1_tickdir + tick_out + in_tickdir = edgep1_tickdir - tick_in + + default_label_offset = 8. # A rough estimate + points = deltas_per_point * deltas for tick in ticks: # Get tick line positions pos = edgep1.copy() pos[index] = tick.get_loc() - pos[tickdir] = ( - edgep1[tickdir] - + info['tick']['outward_factor'] * ticksign * tickdelta) + pos[tickdir] = out_tickdir x1, y1, z1 = proj3d.proj_transform(*pos, self.axes.M) - pos[tickdir] = ( - edgep1[tickdir] - - info['tick']['inward_factor'] * ticksign * tickdelta) + pos[tickdir] = in_tickdir x2, y2, z2 = proj3d.proj_transform(*pos, self.axes.M) # Get position of label - default_offset = 8. # A rough estimate - labeldeltas = ( - (tick.get_pad() + default_offset) * deltas_per_point * deltas) + labeldeltas = (tick.get_pad() + default_label_offset) * points - axmask = [True, True, True] - axmask[index] = False - pos[tickdir] = edgep1[tickdir] + pos[tickdir] = edgep1_tickdir pos = move_from_center(pos, centers, labeldeltas, axmask) lx, ly, lz = proj3d.proj_transform(*pos, self.axes.M) tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly)) - tick.tick1line.set_linewidth( - info['tick']['linewidth'][tick._major]) + tick.tick1line.set_linewidth(tick_lw[tick._major]) tick.draw(renderer) renderer.close_group('axis3d')