From e3836c522941c00fdbdf5ef39954e55fc3c0ae99 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Tue, 5 Feb 2019 15:08:05 +0100 Subject: [PATCH] Numpyfy tick handling code in Axis3D. --- lib/mpl_toolkits/mplot3d/axis3d.py | 36 ++++++++++++------------------ 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/lib/mpl_toolkits/mplot3d/axis3d.py b/lib/mpl_toolkits/mplot3d/axis3d.py index 6f33d83c1849..c34ca25b98e4 100644 --- a/lib/mpl_toolkits/mplot3d/axis3d.py +++ b/lib/mpl_toolkits/mplot3d/axis3d.py @@ -247,11 +247,8 @@ def draw(self, renderer): self.line.draw(renderer) # Grid points where the planes meet - xyz0 = [] - for tick in ticks: - coord = minmax.copy() - coord[index] = tick.get_loc() - xyz0.append(coord) + xyz0 = np.tile(minmax, (len(ticks), 1)) + xyz0[:, index] = [tick.get_loc() for tick in ticks] # Draw labels peparray = np.asanyarray(pep) @@ -357,30 +354,25 @@ def draw(self, renderer): self.offsetText.draw(renderer) # Draw grid lines - if len(xyz0) > 0: + if self.axes._draw_grid and len(ticks): # Grid points at end of one plane - xyz1 = copy.deepcopy(xyz0) + xyz1 = xyz0.copy() newindex = (index + 1) % 3 newval = get_flip_min_max(xyz1[0], newindex, mins, maxs) - for i in range(len(ticks)): - xyz1[i][newindex] = newval + xyz1[:, newindex] = newval # Grid points at end of the other plane - xyz2 = copy.deepcopy(xyz0) + xyz2 = xyz0.copy() newindex = (index + 2) % 3 newval = get_flip_min_max(xyz2[0], newindex, mins, maxs) - for i in range(len(ticks)): - xyz2[i][newindex] = newval - - lines = list(zip(xyz1, xyz0, xyz2)) - if self.axes._draw_grid: - self.gridlines.set_segments(lines) - self.gridlines.set_color([info['grid']['color']] * len(lines)) - self.gridlines.set_linewidth( - [info['grid']['linewidth']] * len(lines)) - self.gridlines.set_linestyle( - [info['grid']['linestyle']] * len(lines)) - self.gridlines.draw(renderer, project=True) + xyz2[:, newindex] = newval + + lines = np.stack([xyz1, xyz0, xyz2], axis=1) + self.gridlines.set_segments(lines) + self.gridlines.set_color(info['grid']['color']) + self.gridlines.set_linewidth(info['grid']['linewidth']) + self.gridlines.set_linestyle(info['grid']['linestyle']) + self.gridlines.draw(renderer, project=True) # Draw ticks tickdir = info['tickdir']