From 0eff7d043f619ef4faa2f601f06700dea479ff77 Mon Sep 17 00:00:00 2001 From: Erik Rigtorp Date: Fri, 31 Dec 2010 23:03:13 -0500 Subject: [PATCH] ENH: Add stride trick to create rolling windows --- numpy/lib/stride_tricks.py | 39 +++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py index 7358be222687..8eeeee877112 100644 --- a/numpy/lib/stride_tricks.py +++ b/numpy/lib/stride_tricks.py @@ -7,7 +7,7 @@ """ import numpy as np -__all__ = ['broadcast_arrays'] +__all__ = ['broadcast_arrays', 'rolling_window'] class DummyArray(object): """ Dummy object that just exists to hang __array_interface__ dictionaries @@ -113,3 +113,40 @@ def broadcast_arrays(*args): broadcasted = [as_strided(x, shape=sh, strides=st) for (x,sh,st) in zip(args, shapes, strides)] return broadcasted + +def rolling_window(a, window): + """ + Make an ndarray with a rolling window of the last dimension + + Parameters + ---------- + a : array_like + Array to add rolling window to + window : int + Size of rolling window + + Returns + ------- + Array that is a view of the original array with a added dimension + of size w. + + Examples + -------- + >>> x=np.arange(10).reshape((2,5)) + >>> np.rolling_window(x, 3) + array([[[0, 1, 2], [1, 2, 3], [2, 3, 4]], + [[5, 6, 7], [6, 7, 8], [7, 8, 9]]]) + + Calculate rolling mean of last dimension: + >>> np.mean(np.rolling_window(x, 3), -1) + array([[ 1., 2., 3.], + [ 6., 7., 8.]]) + + """ + if window < 1: + raise ValueError, "`window` must be at least 1." + if window > a.shape[-1]: + raise ValueError, "`window` is too long." + shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) + strides = a.strides + (a.strides[-1],) + return as_strided(a, shape=shape, strides=strides)