Skip to content

Commit cd66c66

Browse files
nils-braunMaxBenChrist
authored andcommitted
Maximum rolling (blue-yonder#184)
* Added a maximum number of shifts parameter * Added a test for this * Make documention more precise. * Added more tests
1 parent 5bc7b83 commit cd66c66

File tree

2 files changed

+82
-18
lines changed

2 files changed

+82
-18
lines changed

tests/utilities/test_dataframe_functions.py

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -213,22 +213,53 @@ def test_positive_rolling(self):
213213

214214
df_full = pd.concat([first_class, second_class], ignore_index=True)
215215

216-
df = dataframe_functions.roll_time_series(df_full, column_id="id", column_sort="time",
217-
column_kind=None, rolling_direction=1)
218-
219216
correct_indices = (["id=1, shift=3"] * 1 +
220217
["id=1, shift=2"] * 2 +
221218
["id=1, shift=1"] * 3 +
222219
["id=2, shift=1"] * 1 +
223220
["id=1, shift=0"] * 4 +
224221
["id=2, shift=0"] * 2)
222+
correct_values_a = [1, 1, 2, 1, 2, 3, 10, 1, 2, 3, 4, 10, 11]
223+
correct_values_b = [5, 5, 6, 5, 6, 7, 12, 5, 6, 7, 8, 12, 13]
224+
225+
df = dataframe_functions.roll_time_series(df_full, column_id="id", column_sort="time",
226+
column_kind=None, rolling_direction=1)
227+
228+
self.assertListEqual(list(df["id"]), correct_indices)
229+
self.assertListEqual(list(df["a"].values), correct_values_a)
230+
self.assertListEqual(list(df["b"].values), correct_values_b)
231+
232+
df = dataframe_functions.roll_time_series(df_full, column_id="id", column_sort="time",
233+
column_kind=None, rolling_direction=1,
234+
maximum_number_of_timeshifts=None)
225235

226236
self.assertListEqual(list(df["id"]), correct_indices)
237+
self.assertListEqual(list(df["a"].values), correct_values_a)
238+
self.assertListEqual(list(df["b"].values), correct_values_b)
227239

228-
self.assertListEqual(list(df["a"].values),
229-
[1, 1, 2, 1, 2, 3, 10, 1, 2, 3, 4, 10, 11])
230-
self.assertListEqual(list(df["b"].values),
231-
[5, 5, 6, 5, 6, 7, 12, 5, 6, 7, 8, 12, 13])
240+
df = dataframe_functions.roll_time_series(df_full, column_id="id", column_sort="time",
241+
column_kind=None, rolling_direction=1,
242+
maximum_number_of_timeshifts=1)
243+
244+
self.assertListEqual(list(df["id"]), correct_indices[3:])
245+
self.assertListEqual(list(df["a"].values), correct_values_a[3:])
246+
self.assertListEqual(list(df["b"].values), correct_values_b[3:])
247+
248+
df = dataframe_functions.roll_time_series(df_full, column_id="id", column_sort="time",
249+
column_kind=None, rolling_direction=1,
250+
maximum_number_of_timeshifts=2)
251+
252+
self.assertListEqual(list(df["id"]), correct_indices[1:])
253+
self.assertListEqual(list(df["a"].values), correct_values_a[1:])
254+
self.assertListEqual(list(df["b"].values), correct_values_b[1:])
255+
256+
df = dataframe_functions.roll_time_series(df_full, column_id="id", column_sort="time",
257+
column_kind=None, rolling_direction=1,
258+
maximum_number_of_timeshifts=4)
259+
260+
self.assertListEqual(list(df["id"]), correct_indices[:])
261+
self.assertListEqual(list(df["a"].values), correct_values_a[:])
262+
self.assertListEqual(list(df["b"].values), correct_values_b[:])
232263

233264
def test_negative_rolling(self):
234265
first_class = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "time": range(4)})
@@ -239,22 +270,53 @@ def test_negative_rolling(self):
239270

240271
df_full = pd.concat([first_class, second_class], ignore_index=True)
241272

242-
df = dataframe_functions.roll_time_series(df_full, column_id="id", column_sort="time",
243-
column_kind=None, rolling_direction=-1)
244-
245273
correct_indices = (["id=1, shift=-3"] * 1 +
246274
["id=1, shift=-2"] * 2 +
247275
["id=1, shift=-1"] * 3 +
248276
["id=2, shift=-1"] * 1 +
249277
["id=1, shift=0"] * 4 +
250278
["id=2, shift=0"] * 2)
279+
correct_values_a = [4, 3, 4, 2, 3, 4, 11, 1, 2, 3, 4, 10, 11]
280+
correct_values_b = [8, 7, 8, 6, 7, 8, 13, 5, 6, 7, 8, 12, 13]
281+
282+
df = dataframe_functions.roll_time_series(df_full, column_id="id", column_sort="time",
283+
column_kind=None, rolling_direction=-1)
251284

252285
self.assertListEqual(list(df["id"].values), correct_indices)
286+
self.assertListEqual(list(df["a"].values), correct_values_a)
287+
self.assertListEqual(list(df["b"].values), correct_values_b)
253288

254-
self.assertListEqual(list(df["a"].values),
255-
[4, 3, 4, 2, 3, 4, 11, 1, 2, 3, 4, 10, 11])
256-
self.assertListEqual(list(df["b"].values),
257-
[8, 7, 8, 6, 7, 8, 13, 5, 6, 7, 8, 12, 13])
289+
df = dataframe_functions.roll_time_series(df_full, column_id="id", column_sort="time",
290+
column_kind=None, rolling_direction=-1,
291+
maximum_number_of_timeshifts=None)
292+
293+
self.assertListEqual(list(df["id"].values), correct_indices)
294+
self.assertListEqual(list(df["a"].values), correct_values_a)
295+
self.assertListEqual(list(df["b"].values), correct_values_b)
296+
297+
df = dataframe_functions.roll_time_series(df_full, column_id="id", column_sort="time",
298+
column_kind=None, rolling_direction=-1,
299+
maximum_number_of_timeshifts=1)
300+
301+
self.assertListEqual(list(df["id"].values), correct_indices[3:])
302+
self.assertListEqual(list(df["a"].values), correct_values_a[3:])
303+
self.assertListEqual(list(df["b"].values), correct_values_b[3:])
304+
305+
df = dataframe_functions.roll_time_series(df_full, column_id="id", column_sort="time",
306+
column_kind=None, rolling_direction=-1,
307+
maximum_number_of_timeshifts=2)
308+
309+
self.assertListEqual(list(df["id"].values), correct_indices[1:])
310+
self.assertListEqual(list(df["a"].values), correct_values_a[1:])
311+
self.assertListEqual(list(df["b"].values), correct_values_b[1:])
312+
313+
df = dataframe_functions.roll_time_series(df_full, column_id="id", column_sort="time",
314+
column_kind=None, rolling_direction=-1,
315+
maximum_number_of_timeshifts=4)
316+
317+
self.assertListEqual(list(df["id"].values), correct_indices[:])
318+
self.assertListEqual(list(df["a"].values), correct_values_a[:])
319+
self.assertListEqual(list(df["b"].values), correct_values_b[:])
258320

259321
def test_stacked_rolling(self):
260322
first_class = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "time": range(4)})
@@ -309,8 +371,6 @@ def test_dict_rolling(self):
309371
self.assertListEqual(list(df["b"]["_value"].values),
310372
[8, 7, 8, 6, 7, 8, 13, 5, 6, 7, 8, 12, 13])
311373

312-
313-
314374
def test_warning_on_non_uniform_time_steps(self):
315375
with warnings.catch_warnings(record=True) as w:
316376
first_class = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8], "time": [1, 2, 4, 5]})

tsfresh/utilities/dataframe_functions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ def normalize_input_to_internal_representation(df_or_dict, column_id, column_sor
328328
return kind_to_df_map, column_id, column_value
329329

330330

331-
def roll_time_series(df_or_dict, column_id, column_sort, column_kind, rolling_direction):
331+
def roll_time_series(df_or_dict, column_id, column_sort, column_kind, rolling_direction,
332+
maximum_number_of_timeshifts=None):
332333
"""
333334
Roll the (sorted) data frames for each kind and each id separately in the "time" domain
334335
(which is represented by the sort order of the sort column given by `column_sort`).
@@ -360,6 +361,9 @@ def roll_time_series(df_or_dict, column_id, column_sort, column_kind, rolling_di
360361
:type column_kind: basestring or None
361362
:param rolling_direction: The sign decides, if to roll backwards or forwards in "time"
362363
:type rolling_direction: int
364+
:param maximum_number_of_timeshifts: If not None, shift only up to maximum_number_of_timeshifts.
365+
If None, shift as often as possible.
366+
:type maximum_number_of_timeshifts: int
363367
364368
:return: The rolled data frame or dictionary of data frames
365369
:rtype: the one from df_or_dict
@@ -416,7 +420,7 @@ def roll_time_series(df_or_dict, column_id, column_sort, column_kind, rolling_di
416420
rolling_direction = np.sign(rolling_direction)
417421

418422
grouped_data = df.groupby(grouper)
419-
maximum_number_of_timeshifts = grouped_data.count().max().max()
423+
maximum_number_of_timeshifts = maximum_number_of_timeshifts or grouped_data.count().max().max()
420424

421425
if np.isnan(maximum_number_of_timeshifts):
422426
raise ValueError("Somehow the maximum length of your time series is NaN (Does your time series container have "

0 commit comments

Comments
 (0)