Skip to content

Commit ea968ed

Browse files
committed
Broadcasting for fun and profit.
1 parent 5b1d671 commit ea968ed

20 files changed

+63
-150
lines changed

examples/pylab_examples/fancybox_demo2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
styles = mpatch.BoxStyle.get_styles()
55
spacing = 1.2
66

7-
figheight = (spacing * len(styles) + .5)
7+
figheight = spacing * len(styles) + .5
88
fig1 = plt.figure(1, (4/1.5, figheight/1.5))
99
fontsize = 0.3 * 72
1010

examples/pylab_examples/table_demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
bar_width = 0.4
2626

2727
# Initialize the vertical-offset for the stacked bar chart.
28-
y_offset = np.array([0.0] * len(columns))
28+
y_offset = np.zeros(len(columns))
2929

3030
# Plot bars and create text labels for the table
3131
cell_text = []

lib/matplotlib/axes/_axes.py

+12-29
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import matplotlib
1515
from matplotlib import _preprocess_data
1616

17+
from matplotlib._backports import numpy as _backports_np
1718
import matplotlib.cbook as cbook
1819
from matplotlib.cbook import (mplDeprecation, STEP_LOOKUP_MAP,
1920
iterable, is_string_like,
@@ -1997,20 +1998,14 @@ def bar(self, left, height, width=0.8, bottom=None, **kwargs):
19971998
label = kwargs.pop('label', '')
19981999
tick_labels = kwargs.pop('tick_label', None)
19992000

2000-
def make_iterable(x):
2001-
if not iterable(x):
2002-
return [x]
2003-
else:
2004-
return x
2005-
20062001
# make them safe to take len() of
20072002
_left = left
2008-
left = make_iterable(left)
2009-
height = make_iterable(height)
2010-
width = make_iterable(width)
2003+
left = np.atleast_1d(left)
2004+
height = np.atleast_1d(height)
2005+
width = np.atleast_1d(width)
20112006
_bottom = bottom
2012-
bottom = make_iterable(bottom)
2013-
linewidth = make_iterable(linewidth)
2007+
bottom = np.atleast_1d(bottom)
2008+
linewidth = np.atleast_1d(linewidth)
20142009

20152010
adjust_ylim = False
20162011
adjust_xlim = False
@@ -2025,10 +2020,8 @@ def make_iterable(x):
20252020
bottom = [0]
20262021

20272022
nbars = len(left)
2028-
if len(width) == 1:
2029-
width *= nbars
2030-
if len(bottom) == 1:
2031-
bottom *= nbars
2023+
width = _backports_np.broadcast_to(width, nbars)
2024+
bottom = _backports_np.broadcast_to(bottom, nbars)
20322025

20332026
tick_label_axis = self.xaxis
20342027
tick_label_position = left
@@ -2043,18 +2036,16 @@ def make_iterable(x):
20432036
left = [0]
20442037

20452038
nbars = len(bottom)
2046-
if len(left) == 1:
2047-
left *= nbars
2048-
if len(height) == 1:
2049-
height *= nbars
2039+
left = _backports_np.broadcast_to(left, nbars)
2040+
height = _backports_np.broadcast_to(height, nbars)
20502041

20512042
tick_label_axis = self.yaxis
20522043
tick_label_position = bottom
20532044
else:
20542045
raise ValueError('invalid orientation: %s' % orientation)
20552046

20562047
if len(linewidth) < nbars:
2057-
linewidth *= nbars
2048+
linewidth = np.tile(linewidth, nbars)
20582049

20592050
color = list(mcolors.to_rgba_array(color))
20602051
if len(color) == 0: # until to_rgba_array is changed
@@ -2181,15 +2172,7 @@ def make_iterable(x):
21812172
self.add_container(bar_container)
21822173

21832174
if tick_labels is not None:
2184-
tick_labels = make_iterable(tick_labels)
2185-
if isinstance(tick_labels, six.string_types):
2186-
tick_labels = [tick_labels]
2187-
if len(tick_labels) == 1:
2188-
tick_labels *= nbars
2189-
if len(tick_labels) != nbars:
2190-
raise ValueError("incompatible sizes: argument 'tick_label' "
2191-
"must be length %d or string" % nbars)
2192-
2175+
tick_labels = _backports_np.broadcast_to(tick_labels, nbars)
21932176
tick_label_axis.set_ticks(tick_label_position)
21942177
tick_label_axis.set_ticklabels(tick_labels)
21952178

lib/matplotlib/backends/backend_agg.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def draw_text(self, gc, x, y, s, prop, angle, ismath=False, mtext=None):
192192
flags = get_hinting_flag()
193193
font = self._get_agg_font(prop)
194194

195-
if font is None: return None
195+
if font is None:
196+
return None
196197
if len(s) == 1 and ord(s) > 127:
197198
font.load_char(ord(s), flags=flags)
198199
else:

lib/matplotlib/backends/backend_qt5.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,8 @@ def edit_parameters(self):
642642
QtWidgets.QMessageBox.warning(
643643
self.parent, "Error", "There are no axes to edit.")
644644
return
645-
if len(allaxes) == 1:
646-
axes = allaxes[0]
645+
elif len(allaxes) == 1:
646+
axes, = allaxes
647647
else:
648648
titles = []
649649
for axes in allaxes:

lib/matplotlib/cbook.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import collections
1616
import datetime
1717
import errno
18-
import functools
1918
import glob
2019
import gzip
2120
import io
@@ -1799,7 +1798,7 @@ def delete_masked_points(*args):
17991798
except: # Fixme: put in tuple of possible exceptions?
18001799
pass
18011800
if len(masks):
1802-
mask = functools.reduce(np.logical_and, masks)
1801+
mask = np.logical_and.reduce(masks)
18031802
igood = mask.nonzero()[0]
18041803
if len(igood) < nrecs:
18051804
for i, x in enumerate(margs):

lib/matplotlib/colors.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,12 @@
5858

5959
from __future__ import (absolute_import, division, print_function,
6060
unicode_literals)
61-
import re
61+
6262
import six
6363
from six.moves import zip
64+
65+
import itertools
66+
import re
6467
import warnings
6568

6669
import numpy as np
@@ -774,26 +777,23 @@ def __init__(self, colors, name='from_list', N=None):
774777
775778
the list will be extended by repetition.
776779
"""
777-
self.colors = colors
778780
self.monochrome = False # True only if all colors in map are
779781
# identical; needed for contouring.
780782
if N is None:
781-
N = len(self.colors)
783+
self.colors = colors
784+
N = len(colors)
782785
else:
783-
if (cbook.is_string_like(self.colors) and
784-
cbook.is_hashable(self.colors)):
785-
self.colors = [self.colors] * N
786+
if cbook.is_string_like(colors) and cbook.is_hashable(colors):
787+
self.colors = [colors] * N
786788
self.monochrome = True
787-
elif cbook.iterable(self.colors):
788-
self.colors = list(self.colors) # in case it was a tuple
789-
if len(self.colors) == 1:
789+
elif cbook.iterable(colors):
790+
if len(colors) == 1:
790791
self.monochrome = True
791-
if len(self.colors) < N:
792-
self.colors = list(self.colors) * N
793-
del(self.colors[N:])
792+
self.colors = list(
793+
itertools.islice(itertools.cycle(colors), N))
794794
else:
795795
try:
796-
gray = float(self.colors)
796+
gray = float(colors)
797797
except TypeError:
798798
pass
799799
else:

lib/matplotlib/gridspec.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -100,40 +100,32 @@ def get_grid_positions(self, fig):
100100
# calculate accumulated heights of columns
101101
cellH = totHeight/(nrows + hspace*(nrows-1))
102102
sepH = hspace*cellH
103-
104103
if self._row_height_ratios is not None:
105104
netHeight = cellH * nrows
106105
tr = float(sum(self._row_height_ratios))
107106
cellHeights = [netHeight*r/tr for r in self._row_height_ratios]
108107
else:
109108
cellHeights = [cellH] * nrows
110-
111109
sepHeights = [0] + ([sepH] * (nrows-1))
112-
cellHs = np.add.accumulate(np.ravel(list(zip(sepHeights, cellHeights))))
113-
110+
cellHs = np.cumsum(np.column_stack([sepHeights, cellHeights]))
114111

115112
# calculate accumulated widths of rows
116113
cellW = totWidth/(ncols + wspace*(ncols-1))
117114
sepW = wspace*cellW
118-
119115
if self._col_width_ratios is not None:
120116
netWidth = cellW * ncols
121117
tr = float(sum(self._col_width_ratios))
122118
cellWidths = [netWidth*r/tr for r in self._col_width_ratios]
123119
else:
124120
cellWidths = [cellW] * ncols
125-
126121
sepWidths = [0] + ([sepW] * (ncols-1))
127-
cellWs = np.add.accumulate(np.ravel(list(zip(sepWidths, cellWidths))))
128-
129-
122+
cellWs = np.cumsum(np.column_stack([sepWidths, cellWidths]))
130123

131124
figTops = [top - cellHs[2*rowNum] for rowNum in range(nrows)]
132125
figBottoms = [top - cellHs[2*rowNum+1] for rowNum in range(nrows)]
133126
figLefts = [left + cellWs[2*colNum] for colNum in range(ncols)]
134127
figRights = [left + cellWs[2*colNum+1] for colNum in range(ncols)]
135128

136-
137129
return figBottoms, figTops, figLefts, figRights
138130

139131
def __getitem__(self, key):

lib/matplotlib/image.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ def flush_images():
146146
if len(image_group) == 1:
147147
image_group[0].draw(renderer)
148148
elif len(image_group) > 1:
149-
data, l, b = composite_images(
150-
image_group, renderer, mag)
149+
data, l, b = composite_images(image_group, renderer, mag)
151150
if data.size != 0:
152151
gc = renderer.new_gc()
153152
gc.set_clip_rectangle(parent.bbox)

lib/matplotlib/legend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def __init__(self, parent, handles, labels,
294294
self._scatteryoffsets = np.array([3. / 8., 4. / 8., 2.5 / 8.])
295295
else:
296296
self._scatteryoffsets = np.asarray(scatteryoffsets)
297-
reps = int(self.scatterpoints / len(self._scatteryoffsets)) + 1
297+
reps = self.scatterpoints // len(self._scatteryoffsets) + 1
298298
self._scatteryoffsets = np.tile(self._scatteryoffsets,
299299
reps)[:self.scatterpoints]
300300

lib/matplotlib/lines.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -665,20 +665,8 @@ def recache(self, always=False):
665665
else:
666666
y = self._y
667667

668-
if len(x) == 1 and len(y) > 1:
669-
x = x * np.ones(y.shape, float)
670-
if len(y) == 1 and len(x) > 1:
671-
y = y * np.ones(x.shape, float)
672-
673-
if len(x) != len(y):
674-
raise RuntimeError('xdata and ydata must be the same length')
675-
676-
self._xy = np.empty((len(x), 2), dtype=float)
677-
self._xy[:, 0] = x
678-
self._xy[:, 1] = y
679-
680-
self._x = self._xy[:, 0] # just a view
681-
self._y = self._xy[:, 1] # just a view
668+
self._xy = np.column_stack(np.broadcast_arrays(x, y)).astype(float)
669+
self._x, self._y = self._xy.T # views
682670

683671
self._subslice = False
684672
if (self.axes and len(x) > 1000 and self._is_sorted(x) and

lib/matplotlib/offsetbox.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -71,33 +71,28 @@ def _get_packed_offsets(wd_list, total, sep, mode="fixed"):
7171
# d_list is currently not used.
7272

7373
if mode == "fixed":
74-
offsets_ = np.add.accumulate([0] + [w + sep for w in w_list])
74+
offsets_ = np.cumsum([0] + [w + sep for w in w_list])
7575
offsets = offsets_[:-1]
76-
7776
if total is None:
7877
total = offsets_[-1] - sep
79-
8078
return total, offsets
8179

8280
elif mode == "expand":
8381
if len(w_list) > 1:
8482
sep = (total - sum(w_list)) / (len(w_list) - 1.)
8583
else:
86-
sep = 0.
87-
offsets_ = np.add.accumulate([0] + [w + sep for w in w_list])
84+
sep = 0
85+
offsets_ = np.cumsum([0] + [w + sep for w in w_list])
8886
offsets = offsets_[:-1]
89-
9087
return total, offsets
9188

9289
elif mode == "equal":
9390
maxh = max(w_list)
9491
if total is None:
9592
total = (maxh + sep) * len(w_list)
9693
else:
97-
sep = float(total) / (len(w_list)) - maxh
98-
99-
offsets = np.array([(maxh + sep) * i for i in range(len(w_list))])
100-
94+
sep = total / len(w_list) - maxh
95+
offsets = (maxh + sep) * np.arange(len(w_list))
10196
return total, offsets
10297

10398
else:

lib/matplotlib/patches.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1601,8 +1601,7 @@ def draw(self, renderer):
16011601
# Get the width and height in pixels
16021602
width = self.convert_xunits(self.width)
16031603
height = self.convert_yunits(self.height)
1604-
width, height = self.get_transform().transform_point(
1605-
(width, height))
1604+
width, height = self.get_transform().transform_point((width, height))
16061605
inv_error = (1.0 / 1.89818e-6) * 0.5
16071606

16081607
if width < inv_error and height < inv_error:

lib/matplotlib/stackplot.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,7 @@ def stackplot(axes, x, *args, **kwargs):
5555
element in the stacked area plot.
5656
"""
5757

58-
if len(args) == 1:
59-
y = np.atleast_2d(*args)
60-
elif len(args) > 1:
61-
y = np.row_stack(args)
58+
y = np.row_stack(args)
6259

6360
labels = iter(kwargs.pop('labels', []))
6461

lib/matplotlib/tests/test_backend_bases.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,10 @@ def check(master_transform, paths, all_transforms,
3636
[], 'data')]
3737
uses = rb._iter_collection_uses_per_path(
3838
paths, all_transforms, offsets, facecolors, edgecolors)
39-
seen = [0] * len(raw_paths)
40-
for i in ids:
41-
seen[i] += 1
42-
for n in seen:
43-
assert n in (uses-1, uses)
39+
if raw_paths:
40+
seen = np.bincount(ids, minlength=len(raw_paths))
41+
for n in seen:
42+
assert n in [uses - 1, uses]
4443

4544
check(id, paths, tforms, offsets, facecolors, edgecolors)
4645
check(id, paths[0:1], tforms, offsets, facecolors, edgecolors)

lib/matplotlib/tests/test_bbox_tight.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_bbox_inches_tight():
3030
ind = np.arange(len(colLabels)) + 0.3 # the x locations for the groups
3131
cellText = []
3232
width = 0.4 # the width of the bars
33-
yoff = np.array([0.0] * len(colLabels))
33+
yoff = np.zeros(len(colLabels))
3434
# the bottom values for stacked bar chart
3535
fig, ax = plt.subplots(1, 1)
3636
for row in xrange(rows):

lib/matplotlib/tests/test_image.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,8 @@ def test_image_edges():
326326

327327
data = np.tile(np.arange(12), 15).reshape(20, 9)
328328

329-
im = ax.imshow(data, origin='upper',
330-
extent=[-10, 10, -10, 10], interpolation='none',
331-
cmap='gray'
332-
)
329+
im = ax.imshow(data, origin='upper', extent=[-10, 10, -10, 10],
330+
interpolation='none', cmap='gray')
333331

334332
x = y = 2
335333
ax.set_xlim([-x, x])

lib/matplotlib/tests/test_table.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_zorder():
3030
rowLabels = ['%d year' % x for x in (100, 50)]
3131

3232
cellText = []
33-
yoff = np.array([0.0] * len(colLabels))
33+
yoff = np.zeros(len(colLabels))
3434
for row in reversed(data):
3535
yoff += row
3636
cellText.append(['%1.1f' % (x/1000.0) for x in yoff])

lib/matplotlib/textpath.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def get_glyphs_with_font(self, font, s, glyph_map=None,
225225
rects = []
226226

227227
return (list(zip(glyph_ids, xpositions, ypositions, sizes)),
228-
glyph_map_new, rects)
228+
glyph_map_new, rects)
229229

230230
def get_glyphs_mathtext(self, prop, s, glyph_map=None,
231231
return_new_glyphs_only=False):

0 commit comments

Comments
 (0)