Skip to content

Commit ea968ed

Browse files
committed
Broadcasting for fun and profit.
1 parent 5b1d671 commit ea968ed

File tree

20 files changed

+63
-150
lines changed

20 files changed

+63
-150
lines changed

examples/pylab_examples/fancybox_demo2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
styles = mpatch.BoxStyle.get_styles()
55
spacing = 1.2
66

7-
figheight = (spacing * len(styles) + .5)
7+
figheight = spacing * len(styles) + .5
88
fig1 = plt.figure(1, (4/1.5, figheight/1.5))
99
fontsize = 0.3 * 72
1010

examples/pylab_examples/table_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
bar_width = 0.4
2626

2727
# Initialize the vertical-offset for the stacked bar chart.
28-
y_offset = np.array([0.0] * len(columns))
28+
y_offset = np.zeros(len(columns))
2929

3030
# Plot bars and create text labels for the table
3131
cell_text = []

lib/matplotlib/axes/_axes.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import matplotlib
1515
from matplotlib import _preprocess_data
1616

17+
from matplotlib._backports import numpy as _backports_np
1718
import matplotlib.cbook as cbook
1819
from matplotlib.cbook import (mplDeprecation, STEP_LOOKUP_MAP,
1920
iterable, is_string_like,
@@ -1997,20 +1998,14 @@ def bar(self, left, height, width=0.8, bottom=None, **kwargs):
19971998
label = kwargs.pop('label', '')
19981999
tick_labels = kwargs.pop('tick_label', None)
19992000

2000-
def make_iterable(x):
2001-
if not iterable(x):
2002-
return [x]
2003-
else:
2004-
return x
2005-
20062001
# make them safe to take len() of
20072002
_left = left
2008-
left = make_iterable(left)
2009-
height = make_iterable(height)
2010-
width = make_iterable(width)
2003+
left = np.atleast_1d(left)
2004+
height = np.atleast_1d(height)
2005+
width = np.atleast_1d(width)
20112006
_bottom = bottom
2012-
bottom = make_iterable(bottom)
2013-
linewidth = make_iterable(linewidth)
2007+
bottom = np.atleast_1d(bottom)
2008+
linewidth = np.atleast_1d(linewidth)
20142009

20152010
adjust_ylim = False
20162011
adjust_xlim = False
@@ -2025,10 +2020,8 @@ def make_iterable(x):
20252020
bottom = [0]
20262021

20272022
nbars = len(left)
2028-
if len(width) == 1:
2029-
width *= nbars
2030-
if len(bottom) == 1:
2031-
bottom *= nbars
2023+
width = _backports_np.broadcast_to(width, nbars)
2024+
bottom = _backports_np.broadcast_to(bottom, nbars)
20322025

20332026
tick_label_axis = self.xaxis
20342027
tick_label_position = left
@@ -2043,18 +2036,16 @@ def make_iterable(x):
20432036
left = [0]
20442037

20452038
nbars = len(bottom)
2046-
if len(left) == 1:
2047-
left *= nbars
2048-
if len(height) == 1:
2049-
height *= nbars
2039+
left = _backports_np.broadcast_to(left, nbars)
2040+
height = _backports_np.broadcast_to(height, nbars)
20502041

20512042
tick_label_axis = self.yaxis
20522043
tick_label_position = bottom
20532044
else:
20542045
raise ValueError('invalid orientation: %s' % orientation)
20552046

20562047
if len(linewidth) < nbars:
2057-
linewidth *= nbars
2048+
linewidth = np.tile(linewidth, nbars)
20582049

20592050
color = list(mcolors.to_rgba_array(color))
20602051
if len(color) == 0: # until to_rgba_array is changed
@@ -2181,15 +2172,7 @@ def make_iterable(x):
21812172
self.add_container(bar_container)
21822173

21832174
if tick_labels is not None:
2184-
tick_labels = make_iterable(tick_labels)
2185-
if isinstance(tick_labels, six.string_types):
2186-
tick_labels = [tick_labels]
2187-
if len(tick_labels) == 1:
2188-
tick_labels *= nbars
2189-
if len(tick_labels) != nbars:
2190-
raise ValueError("incompatible sizes: argument 'tick_label' "
2191-
"must be length %d or string" % nbars)
2192-
2175+
tick_labels = _backports_np.broadcast_to(tick_labels, nbars)
21932176
tick_label_axis.set_ticks(tick_label_position)
21942177
tick_label_axis.set_ticklabels(tick_labels)
21952178

lib/matplotlib/backends/backend_agg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def draw_text(self, gc, x, y, s, prop, angle, ismath=False, mtext=None):
192192
flags = get_hinting_flag()
193193
font = self._get_agg_font(prop)
194194

195-
if font is None: return None
195+
if font is None:
196+
return None
196197
if len(s) == 1 and ord(s) > 127:
197198
font.load_char(ord(s), flags=flags)
198199
else:

lib/matplotlib/backends/backend_qt5.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,8 @@ def edit_parameters(self):
642642
QtWidgets.QMessageBox.warning(
643643
self.parent, "Error", "There are no axes to edit.")
644644
return
645-
if len(allaxes) == 1:
646-
axes = allaxes[0]
645+
elif len(allaxes) == 1:
646+
axes, = allaxes
647647
else:
648648
titles = []
649649
for axes in allaxes:

lib/matplotlib/cbook.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import collections
1616
import datetime
1717
import errno
18-
import functools
1918
import glob
2019
import gzip
2120
import io
@@ -1799,7 +1798,7 @@ def delete_masked_points(*args):
17991798
except: # Fixme: put in tuple of possible exceptions?
18001799
pass
18011800
if len(masks):
1802-
mask = functools.reduce(np.logical_and, masks)
1801+
mask = np.logical_and.reduce(masks)
18031802
igood = mask.nonzero()[0]
18041803
if len(igood) < nrecs:
18051804
for i, x in enumerate(margs):

lib/matplotlib/colors.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,12 @@
5858

5959
from __future__ import (absolute_import, division, print_function,
6060
unicode_literals)
61-
import re
61+
6262
import six
6363
from six.moves import zip
64+
65+
import itertools
66+
import re
6467
import warnings
6568

6669
import numpy as np
@@ -774,26 +777,23 @@ def __init__(self, colors, name='from_list', N=None):
774777
775778
the list will be extended by repetition.
776779
"""
777-
self.colors = colors
778780
self.monochrome = False # True only if all colors in map are
779781
# identical; needed for contouring.
780782
if N is None:
781-
N = len(self.colors)
783+
self.colors = colors
784+
N = len(colors)
782785
else:
783-
if (cbook.is_string_like(self.colors) and
784-
cbook.is_hashable(self.colors)):
785-
self.colors = [self.colors] * N
786+
if cbook.is_string_like(colors) and cbook.is_hashable(colors):
787+
self.colors = [colors] * N
786788
self.monochrome = True
787-
elif cbook.iterable(self.colors):
788-
self.colors = list(self.colors) # in case it was a tuple
789-
if len(self.colors) == 1:
789+
elif cbook.iterable(colors):
790+
if len(colors) == 1:
790791
self.monochrome = True
791-
if len(self.colors) < N:
792-
self.colors = list(self.colors) * N
793-
del(self.colors[N:])
792+
self.colors = list(
793+
itertools.islice(itertools.cycle(colors), N))
794794
else:
795795
try:
796-
gray = float(self.colors)
796+
gray = float(colors)
797797
except TypeError:
798798
pass
799799
else:

lib/matplotlib/gridspec.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -100,40 +100,32 @@ def get_grid_positions(self, fig):
100100
# calculate accumulated heights of columns
101101
cellH = totHeight/(nrows + hspace*(nrows-1))
102102
sepH = hspace*cellH
103-
104103
if self._row_height_ratios is not None:
105104
netHeight = cellH * nrows
106105
tr = float(sum(self._row_height_ratios))
107106
cellHeights = [netHeight*r/tr for r in self._row_height_ratios]
108107
else:
109108
cellHeights = [cellH] * nrows
110-
111109
sepHeights = [0] + ([sepH] * (nrows-1))
112-
cellHs = np.add.accumulate(np.ravel(list(zip(sepHeights, cellHeights))))
113-
110+
cellHs = np.cumsum(np.column_stack([sepHeights, cellHeights]))
114111

115112
# calculate accumulated widths of rows
116113
cellW = totWidth/(ncols + wspace*(ncols-1))
117114
sepW = wspace*cellW
118-
119115
if self._col_width_ratios is not None:
120116
netWidth = cellW * ncols
121117
tr = float(sum(self._col_width_ratios))
122118
cellWidths = [netWidth*r/tr for r in self._col_width_ratios]
123119
else:
124120
cellWidths = [cellW] * ncols
125-
126121
sepWidths = [0] + ([sepW] * (ncols-1))
127-
cellWs = np.add.accumulate(np.ravel(list(zip(sepWidths, cellWidths))))
128-
129-
122+
cellWs = np.cumsum(np.column_stack([sepWidths, cellWidths]))
130123

131124
figTops = [top - cellHs[2*rowNum] for rowNum in range(nrows)]
132125
figBottoms = [top - cellHs[2*rowNum+1] for rowNum in range(nrows)]
133126
figLefts = [left + cellWs[2*colNum] for colNum in range(ncols)]
134127
figRights = [left + cellWs[2*colNum+1] for colNum in range(ncols)]
135128

136-
137129
return figBottoms, figTops, figLefts, figRights
138130

139131
def __getitem__(self, key):

lib/matplotlib/image.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ def flush_images():
146146
if len(image_group) == 1:
147147
image_group[0].draw(renderer)
148148
elif len(image_group) > 1:
149-
data, l, b = composite_images(
150-
image_group, renderer, mag)
149+
data, l, b = composite_images(image_group, renderer, mag)
151150
if data.size != 0:
152151
gc = renderer.new_gc()
153152
gc.set_clip_rectangle(parent.bbox)

lib/matplotlib/legend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def __init__(self, parent, handles, labels,
294294
self._scatteryoffsets = np.array([3. / 8., 4. / 8., 2.5 / 8.])
295295
else:
296296
self._scatteryoffsets = np.asarray(scatteryoffsets)
297-
reps = int(self.scatterpoints / len(self._scatteryoffsets)) + 1
297+
reps = self.scatterpoints // len(self._scatteryoffsets) + 1
298298
self._scatteryoffsets = np.tile(self._scatteryoffsets,
299299
reps)[:self.scatterpoints]
300300

0 commit comments

Comments
 (0)