Skip to content

Commit 43fd6be

Browse files
committed
BUG: ensure integer shape, strides for as_strided
1 parent 300ee75 commit 43fd6be

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

lib/matplotlib/mlab.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,11 @@ def stride_windows(x, n, noverlap=None, axis=0):
590590
if n > x.size:
591591
raise ValueError('n cannot be greater than the length of x')
592592

593+
# np.lib.stride_tricks.as_strided easily leads to memory corruption for
594+
# non integer shape and strides, i.e. noverlap or n. See #3845.
595+
noverlap = int(noverlap)
596+
n = int(n)
597+
593598
step = n - noverlap
594599
if axis == 0:
595600
shape = (n, (x.shape[-1]-noverlap)//step)
@@ -642,6 +647,10 @@ def stride_repeat(x, n, axis=0):
642647
if n < 1:
643648
raise ValueError('n cannot be less than 1')
644649

650+
# np.lib.stride_tricks.as_strided easily leads to memory corruption for
651+
# non integer shape and strides, i.e. n. See #3845.
652+
n = int(n)
653+
645654
if axis == 0:
646655
shape = (n, x.size)
647656
strides = (0, x.strides[0])

lib/matplotlib/tests/test_mlab.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,25 @@ def test_stride_windows_n32_noverlap0_axis1_unflatten(self):
283283
assert_equal(y.shape, x1.shape)
284284
assert_array_equal(y, x1)
285285

286+
def test_stride_ensure_integer_type(self):
287+
N = 1000
288+
n_sin_periods = 10
289+
x = np.empty(3*N, dtype='>f4')
290+
x.fill(np.NaN)
291+
y = x[N:2*N]
292+
y.fill(0.3)
293+
# previous to #3845 lead to corrupt access
294+
y_strided = mlab.stride_windows(y, n=330, noverlap=0.6)
295+
assert_array_equal(y_strided, 0.3)
296+
# previous to #3845 lead to corrupt access
297+
y_strided = mlab.stride_windows(y, n=333.3, noverlap=0)
298+
assert_array_equal(y_strided, 0.3)
299+
# even previous to #3845 could not find any problematic
300+
# configuration however, let's be sure it's not accidentally
301+
# introduced
302+
y_strided = mlab.stride_repeat(y, n=33.815)
303+
assert_array_equal(y_strided, 0.3)
304+
286305

287306
class csv_testcase(CleanupTestCase):
288307
def setUp(self):

0 commit comments

Comments
 (0)