Skip to content

Rcparam validation fix #3564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Oct 14, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions lib/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
import six
import sys
import distutils.version
from itertools import chain

__version__ = '1.4.x'
__version__numpy__ = '1.6' # minimum required numpy version
Expand Down Expand Up @@ -244,6 +245,7 @@ def _is_writable_dir(p):

return True


class Verbose:
"""
A class to handle reporting. Set the fileo attribute to any file
Expand Down Expand Up @@ -803,6 +805,18 @@ def matplotlib_fname():
_deprecated_ignore_map = {
}

_obsolete_set = set(['tk.pythoninspect', ])
_all_deprecated = set(chain(_deprecated_ignore_map,
_deprecated_map, _obsolete_set))

_rcparam_warn_str = ("Trying to set {key} to {value} via the {func} "
"method of RcParams which does not validate cleanly. "
"This warning will turn into an Exception in 1.5. "
"If you think {value} should validate correctly for "
"rcParams[{key}] "
"please create an issue on github."
)


class RcParams(dict):

Expand All @@ -814,14 +828,27 @@ class RcParams(dict):
"""

validate = dict((key, converter) for key, (default, converter) in
six.iteritems(defaultParams))
six.iteritems(defaultParams)
if key not in _all_deprecated)
msg_depr = "%s is deprecated and replaced with %s; please use the latter."
msg_depr_ignore = "%s is deprecated and ignored. Use %s"

# validate values on the way in
def __init__(self, *args, **kwargs):
for k, v in six.iteritems(dict(*args, **kwargs)):
try:
self[k] = v
except (ValueError, RuntimeError):
# force the issue
warnings.warn(_rcparam_warn_str.format(key=repr(k),
value=repr(v),
func='__init__'))
dict.__setitem__(self, k, v)

def __setitem__(self, key, val):
try:
if key in _deprecated_map:
alt_key, alt_val = _deprecated_map[key]
alt_key, alt_val = _deprecated_map[key]
warnings.warn(self.msg_depr % (key, alt_key))
key = alt_key
val = alt_val(val)
Expand All @@ -840,7 +867,7 @@ def __setitem__(self, key, val):

def __getitem__(self, key):
if key in _deprecated_map:
alt_key, alt_val = _deprecated_map[key]
alt_key, alt_val = _deprecated_map[key]
warnings.warn(self.msg_depr % (key, alt_key))
key = alt_key
elif key in _deprecated_ignore_map:
Expand All @@ -849,6 +876,22 @@ def __getitem__(self, key):
key = alt
return dict.__getitem__(self, key)

# http://stackoverflow.com/questions/2390827/how-to-properly-subclass-dict-and-override-get-set
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the comment is necessary (at all), but a test would be very welcome.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment could be shortened to one or two lines, but noting why update needs to be overridden is helpful; it's not obvious.

# the default dict `update` does not use __setitem__
# so rcParams.update(...) (such as in seaborn) side-steps
# all of the validation over-ride update to force
# through __setitem__
def update(self, *args, **kwargs):
for k, v in six.iteritems(dict(*args, **kwargs)):
try:
self[k] = v
except (ValueError, RuntimeError):
# force the issue
warnings.warn(_rcparam_warn_str.format(key=repr(k),
value=repr(v),
func='update'))
dict.__setitem__(self, k, v)

def __repr__(self):
import pprint
class_name = self.__class__.__name__
Expand Down Expand Up @@ -902,8 +945,9 @@ def rc_params(fail_on_error=False):
if not os.path.exists(fname):
# this should never happen, default in mpl-data should always be found
message = 'could not find rc file; returning defaults'
ret = RcParams([(key, default) for key, (default, _) in \
six.iteritems(defaultParams)])
ret = RcParams([(key, default) for key, (default, _) in
six.iteritems(defaultParams)
if key not in _all_deprecated])
warnings.warn(message)
return ret

Expand Down Expand Up @@ -1025,7 +1069,8 @@ def rc_params_from_file(fname, fail_on_error=False, use_default_template=True):
return config_from_file

iter_params = six.iteritems(defaultParams)
config = RcParams([(key, default) for key, (default, _) in iter_params])
config = RcParams([(key, default) for key, (default, _) in iter_params
if key not in _all_deprecated])
config.update(config_from_file)

verbose.set_level(config['verbose.level'])
Expand Down Expand Up @@ -1067,16 +1112,20 @@ def rc_params_from_file(fname, fail_on_error=False, use_default_template=True):

rcParamsOrig = rcParams.copy()

rcParamsDefault = RcParams([(key, default) for key, (default, converter) in \
six.iteritems(defaultParams)])
rcParamsDefault = RcParams([(key, default) for key, (default, converter) in
six.iteritems(defaultParams)
if key not in _all_deprecated])

rcParams['ps.usedistiller'] = checkdep_ps_distiller(rcParams['ps.usedistiller'])

rcParams['ps.usedistiller'] = checkdep_ps_distiller(
rcParams['ps.usedistiller'])
rcParams['text.usetex'] = checkdep_usetex(rcParams['text.usetex'])

if rcParams['axes.formatter.use_locale']:
import locale
locale.setlocale(locale.LC_ALL, '')


def rc(group, **kwargs):
"""
Set the current rc params. Group is the grouping for the rc, e.g.,
Expand Down
88 changes: 47 additions & 41 deletions lib/matplotlib/rcsetup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def validate_any(s):

def validate_path_exists(s):
"""If s is a path, return s, else False"""
if s is None:
return None
if os.path.exists(s):
return s
else:
Expand Down Expand Up @@ -172,50 +174,54 @@ def validate_maskedarray(v):
' please delete it from your matplotlibrc file')


class validate_nseq_float:
_seq_err_msg = ('You must supply exactly {n:d} values, you provided '
'{num:d} values: {s}')

_str_err_msg = ('You must supply exactly {n:d} comma-separated values, '
'you provided '
'{num:d} comma-separated values: {s}')


class validate_nseq_float(object):
def __init__(self, n):
self.n = n

def __call__(self, s):
"""return a seq of n floats or raise"""
if isinstance(s, six.string_types):
ss = s.split(',')
if len(ss) != self.n:
raise ValueError(
'You must supply exactly %d comma separated values' %
self.n)
try:
return [float(val) for val in ss]
except ValueError:
raise ValueError('Could not convert all entries to floats')
s = s.split(',')
err_msg = _str_err_msg
else:
assert type(s) in (list, tuple)
if len(s) != self.n:
raise ValueError('You must supply exactly %d values' % self.n)
err_msg = _seq_err_msg

if len(s) != self.n:
raise ValueError(err_msg.format(n=self.n, num=len(s), s=s))

try:
return [float(val) for val in s]
except ValueError:
raise ValueError('Could not convert all entries to floats')


class validate_nseq_int:
class validate_nseq_int(object):
def __init__(self, n):
self.n = n

def __call__(self, s):
"""return a seq of n ints or raise"""
if isinstance(s, six.string_types):
ss = s.split(',')
if len(ss) != self.n:
raise ValueError(
'You must supply exactly %d comma separated values' %
self.n)
try:
return [int(val) for val in ss]
except ValueError:
raise ValueError('Could not convert all entries to ints')
s = s.split(',')
err_msg = _str_err_msg
else:
assert type(s) in (list, tuple)
if len(s) != self.n:
raise ValueError('You must supply exactly %d values' % self.n)
err_msg = _seq_err_msg

if len(s) != self.n:
raise ValueError(err_msg.format(n=self.n, num=len(s), s=s))

try:
return [int(val) for val in s]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Time for some factorisation? This is class is exactly the same as the float version, except for the float/int right?

except ValueError:
raise ValueError('Could not convert all entries to ints')


def validate_color(s):
Expand Down Expand Up @@ -263,10 +269,10 @@ def validate_colorlist(s):
def validate_stringlist(s):
'return a list'
if isinstance(s, six.string_types):
return [v.strip() for v in s.split(',')]
return [six.text_type(v.strip()) for v in s.split(',') if v.strip()]
else:
assert type(s) in [list, tuple]
return [six.text_type(v) for v in s]
return [six.text_type(v) for v in s if v]


validate_orientation = ValidateInStrings(
Expand Down Expand Up @@ -517,7 +523,7 @@ def __call__(self, s):


## font props
'font.family': ['sans-serif', validate_stringlist], # used by text object
'font.family': [['sans-serif'], validate_stringlist], # used by text object
'font.style': ['normal', six.text_type],
'font.variant': ['normal', six.text_type],
'font.stretch': ['normal', six.text_type],
Expand Down Expand Up @@ -776,14 +782,14 @@ def __call__(self, s):
'keymap.home': [['h', 'r', 'home'], validate_stringlist],
'keymap.back': [['left', 'c', 'backspace'], validate_stringlist],
'keymap.forward': [['right', 'v'], validate_stringlist],
'keymap.pan': ['p', validate_stringlist],
'keymap.zoom': ['o', validate_stringlist],
'keymap.save': [('s', 'ctrl+s'), validate_stringlist],
'keymap.quit': [('ctrl+w', 'cmd+w'), validate_stringlist],
'keymap.grid': ['g', validate_stringlist],
'keymap.yscale': ['l', validate_stringlist],
'keymap.pan': [['p'], validate_stringlist],
'keymap.zoom': [['o'], validate_stringlist],
'keymap.save': [['s', 'ctrl+s'], validate_stringlist],
'keymap.quit': [['ctrl+w', 'cmd+w'], validate_stringlist],
'keymap.grid': [['g'], validate_stringlist],
'keymap.yscale': [['l'], validate_stringlist],
'keymap.xscale': [['k', 'L'], validate_stringlist],
'keymap.all_axes': ['a', validate_stringlist],
'keymap.all_axes': [['a'], validate_stringlist],

# sample data
'examples.directory': ['', six.text_type],
Expand All @@ -797,21 +803,21 @@ def __call__(self, s):
# Path to FFMPEG binary. If just binary name, subprocess uses $PATH.
'animation.ffmpeg_path': ['ffmpeg', six.text_type],

## Additional arguments for ffmpeg movie writer (using pipes)
'animation.ffmpeg_args': ['', validate_stringlist],
# Additional arguments for ffmpeg movie writer (using pipes)
'animation.ffmpeg_args': [[], validate_stringlist],
# Path to AVConv binary. If just binary name, subprocess uses $PATH.
'animation.avconv_path': ['avconv', six.text_type],
# Additional arguments for avconv movie writer (using pipes)
'animation.avconv_args': ['', validate_stringlist],
'animation.avconv_args': [[], validate_stringlist],
# Path to MENCODER binary. If just binary name, subprocess uses $PATH.
'animation.mencoder_path': ['mencoder', six.text_type],
# Additional arguments for mencoder movie writer (using pipes)
'animation.mencoder_args': ['', validate_stringlist],
'animation.mencoder_args': [[], validate_stringlist],
# Path to convert binary. If just binary name, subprocess uses $PATH
'animation.convert_path': ['convert', six.text_type],
# Additional arguments for mencoder movie writer (using pipes)

'animation.convert_args': ['', validate_stringlist]}
'animation.convert_args': [[], validate_stringlist]}


if __name__ == '__main__':
Expand Down
5 changes: 4 additions & 1 deletion lib/matplotlib/tests/test_animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_save_animation_smoketest():
yield check_save_animation, writer, extension


@with_setup(CleanupTest.setup_class, CleanupTest.teardown_class)
@cleanup
def check_save_animation(writer, extension='mp4'):
if not animation.writers.is_available(writer):
raise KnownFailureTest("writer '%s' not available on this system"
Expand All @@ -39,6 +39,9 @@ def check_save_animation(writer, extension='mp4'):
fig, ax = plt.subplots()
line, = ax.plot([], [])

ax.set_xlim(0, 10)
ax.set_ylim(-1, 1)

def init():
line.set_data([], [])
return line,
Expand Down
Loading