From 68d43e5096ce154b5d3199709c2dee30524f2f8f Mon Sep 17 00:00:00 2001 From: Paul Ganssle Date: Sun, 13 Aug 2017 17:37:10 -0700 Subject: [PATCH] Modified rrulewraper to handle timezone-aware datetimes. --- lib/matplotlib/dates.py | 98 ++++++++++++++++++++++++++++-- lib/matplotlib/tests/test_dates.py | 18 ++++++ 2 files changed, 110 insertions(+), 6 deletions(-) diff --git a/lib/matplotlib/dates.py b/lib/matplotlib/dates.py index bd27bef3759a..ef536e9c1740 100644 --- a/lib/matplotlib/dates.py +++ b/lib/matplotlib/dates.py @@ -118,6 +118,7 @@ import time import math import datetime +import functools import warnings @@ -732,20 +733,105 @@ def __call__(self, x, pos=None): class rrulewrapper(object): + def __init__(self, freq, tzinfo=None, **kwargs): + kwargs['freq'] = freq + self._base_tzinfo = tzinfo - def __init__(self, freq, **kwargs): - self._construct = kwargs.copy() - self._construct["freq"] = freq - self._rrule = rrule(**self._construct) + self._update_rrule(**kwargs) def set(self, **kwargs): self._construct.update(kwargs) + + self._update_rrule(**self._construct) + + def _update_rrule(self, **kwargs): + tzinfo = self._base_tzinfo + + # rrule does not play nicely with time zones - especially pytz time + # zones, it's best to use naive zones and attach timezones once the + # datetimes are returned + if 'dtstart' in kwargs: + dtstart = kwargs['dtstart'] + if dtstart.tzinfo is not None: + if tzinfo is None: + tzinfo = dtstart.tzinfo + else: + dtstart = dtstart.astimezone(tzinfo) + + kwargs['dtstart'] = dtstart.replace(tzinfo=None) + + if 'until' in kwargs: + until = kwargs['until'] + if until.tzinfo is not None: + if tzinfo is not None: + until = until.astimezone(tzinfo) + else: + raise ValueError('until cannot be aware if dtstart ' + 'is naive and tzinfo is None') + + kwargs['until'] = until.replace(tzinfo=None) + + self._construct = kwargs.copy() + self._tzinfo = tzinfo self._rrule = rrule(**self._construct) + def _attach_tzinfo(self, dt, tzinfo): + # pytz zones are attached by "localizing" the datetime + if hasattr(tzinfo, 'localize'): + return tzinfo.localize(dt, is_dst=True) + + return dt.replace(tzinfo=tzinfo) + + def _aware_return_wrapper(self, f, returns_list=False): + """Decorator function that allows rrule methods to handle tzinfo.""" + # This is only necessary if we're actually attaching a tzinfo + if self._tzinfo is None: + return f + + # All datetime arguments must be naive. If they are not naive, they are + # converted to the _tzinfo zone before dropping the zone. + def normalize_arg(arg): + if isinstance(arg, datetime.datetime) and arg.tzinfo is not None: + if arg.tzinfo is not self._tzinfo: + arg = arg.astimezone(self._tzinfo) + + return arg.replace(tzinfo=None) + + return arg + + def normalize_args(args, kwargs): + args = tuple(normalize_arg(arg) for arg in args) + kwargs = {kw: normalize_arg(arg) for kw, arg in kwargs.items()} + + return args, kwargs + + # There are two kinds of functions we care about - ones that return + # dates and ones that return lists of dates. + if not returns_list: + def inner_func(*args, **kwargs): + args, kwargs = normalize_args(args, kwargs) + dt = f(*args, **kwargs) + return self._attach_tzinfo(dt, self._tzinfo) + else: + def inner_func(*args, **kwargs): + args, kwargs = normalize_args(args, kwargs) + dts = f(*args, **kwargs) + return [self._attach_tzinfo(dt, self._tzinfo) for dt in dts] + + return functools.wraps(f)(inner_func) + def __getattr__(self, name): if name in self.__dict__: return self.__dict__[name] - return getattr(self._rrule, name) + + f = getattr(self._rrule, name) + + if name in {'after', 'before'}: + return self._aware_return_wrapper(f) + elif name in {'xafter', 'xbefore', 'between'}: + return self._aware_return_wrapper(f, returns_list=True) + else: + return f def __setstate__(self, state): self.__dict__.update(state) @@ -1226,7 +1312,7 @@ def __init__(self, bymonth=None, bymonthday=1, interval=1, tz=None): bymonth = [x.item() for x in bymonth.astype(int)] rule = rrulewrapper(MONTHLY, bymonth=bymonth, bymonthday=bymonthday, - interval=interval, **self.hms0d) + interval=interval, **self.hms0d) RRuleLocator.__init__(self, rule, tz) diff --git a/lib/matplotlib/tests/test_dates.py b/lib/matplotlib/tests/test_dates.py index 5a25e6182b7e..792341ee1527 100644 --- a/lib/matplotlib/tests/test_dates.py +++ b/lib/matplotlib/tests/test_dates.py @@ -442,6 +442,24 @@ def tz_convert(*args): _test_date2num_dst(pd.date_range, tz_convert) +@pytest.mark.parametrize("attach_tz, get_tz", [ + (lambda dt, zi: zi.localize(dt), lambda n: pytz.timezone(n)), + (lambda dt, zi: dt.replace(tzinfo=zi), lambda n: dateutil.tz.gettz(n))]) +def test_rrulewrapper(attach_tz, get_tz): + SYD = get_tz('Australia/Sydney') + + dtstart = attach_tz(datetime.datetime(2017, 4, 1, 0), SYD) + dtend = attach_tz(datetime.datetime(2017, 4, 4, 0), SYD) + + rule = mdates.rrulewrapper(freq=dateutil.rrule.DAILY, dtstart=dtstart) + + act = rule.between(dtstart, dtend) + exp = [datetime.datetime(2017, 4, 1, 13, tzinfo=dateutil.tz.tzutc()), + datetime.datetime(2017, 4, 2, 14, tzinfo=dateutil.tz.tzutc())] + + assert act == exp + + def test_DayLocator(): with pytest.raises(ValueError): mdates.DayLocator(interval=-1)