diff --git a/doc/users/whats_new/2015-07-30_ensure_ax.rst b/doc/users/whats_new/2015-07-30_ensure_ax.rst new file mode 100644 index 000000000000..c830f8e81d48 --- /dev/null +++ b/doc/users/whats_new/2015-07-30_ensure_ax.rst @@ -0,0 +1,18 @@ +Decorators to ensure an Axes exists & ``plt.gna`` +------------------------------------------------- + +Added a top-level function to `pyplot` to create a single axes +figure and return the `Axes` object. + +Added decorators :: + + ensure_ax + ensure_ax_meth + ensure_new_ax + +which take a function or method that expects an `Axes` as the first +positional argument and returns a function which allows the axes to be +passed either as the first positional argument or as a kwarg. If the +`Axes` is not provided either gets the current axes (via `plt.gca()`) +for `ensure_ax` and `ensure_ax_meth` or creating a new single-axes figure +(via `plt.gna`) for `ensure_new_ax` diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 39d4c3858c61..bc28d2cf2055 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -25,6 +25,9 @@ import types from cycler import cycler + +from functools import wraps + import matplotlib import matplotlib.colorbar from matplotlib import style @@ -189,6 +192,146 @@ def uninstall_repl_displayhook(): draw_all = _pylab_helpers.Gcf.draw_all +_ENSURE_AX_DOC = """ + +This function has been decorated by pyplot to have +an implicit reference to the `plt.gca()` passed as the first argument. + +The wrapped function can be called as any of :: + + {obj}{func}(*args, **kwargs) + {obj}{func}(ax, *args, **kwargs) + {obj}{func}(.., ax=ax) + +""" + + +_ENSURE_AX_NEW_DOC = """ + +This function has been decorated by pyplot to create a new +axes if one is not explicitly passed. + +The wrapped function can be called as any of :: + + {obj}{func}(*args, **kwargs) + {obj}{func}(ax, *args, **kwargs) + {obj}{func}(.., ax=ax) + +The first will make a new figure and axes, the other two +will add to the axes passed in. + +""" + + +def ensure_ax(func): + """Decorator to ensure that the function gets an `Axes` object. + + + The intent of this decorator is to simplify the writing of helper + plotting functions that are useful for both interactive and + programmatic usage. + + The encouraged signature for third-party and user functions :: + + def my_function(ax, data, style) + + explicitly expects an Axes object as input rather than using + plt.gca() or creating axes with in the function body. This + allows for great flexibility, but some find it verbose for + interactive use. This decorator allows the Axes input to be + omitted in which case `plt.gca()` is passed into the function. + Thus :: + + wrapped = ensure_ax(my_function) + + can be called as any of :: + + wrapped(data, style) + wrapped(ax, data, style) + wrapped(data, style, ax=plt.gca()) + + + """ + @wraps(func) + def inner(*args, **kwargs): + if 'ax' in kwargs: + ax = kwargs.pop('ax') + elif len(args) > 0 and isinstance(args[0], Axes): + ax = args[0] + args = args[1:] + else: + ax = gca() + return func(ax, *args, **kwargs) + pre_doc = inner.__doc__ + if pre_doc is None: + pre_doc = '' + else: + pre_doc = dedent(pre_doc) + inner.__doc__ = pre_doc + _ENSURE_AX_DOC.format(func=func.__name__, obj='') + + return inner + + +def ensure_new_ax(func): + """Decorator to ensure that the function gets a new `Axes` object. + + Same as ensure_ax expect that a new figure and axes are created + if an Axes is not explicitly passed. + + """ + @wraps(func) + def inner(*args, **kwargs): + if 'ax' in kwargs: + ax = kwargs.pop('ax') + elif len(args) > 0 and isinstance(args[0], Axes): + ax = args[0] + args = args[1:] + else: + ax = gna() + return func(ax, *args, **kwargs) + pre_doc = inner.__doc__ + if pre_doc is None: + pre_doc = '' + else: + pre_doc = dedent(pre_doc) + inner.__doc__ = (pre_doc + + _ENSURE_AX_NEW_DOC.format(func=func.__name__, obj='')) + + return inner + + +def ensure_ax_meth(func): + """ + The same as ensure axes, but for class methods :: + + class foo(object): + @ensure_ax_meth + def my_function(self, ax, style): + + will allow you to call your objects plotting methods with + out explicitly passing in an `Axes` object. + """ + @wraps(func) + def inner(*args, **kwargs): + s = args[0] + args = args[1:] + if 'ax' in kwargs: + ax = kwargs.pop('ax') + elif len(args) > 1 and isinstance(args[0], Axes): + ax = args[0] + args = args[1:] + else: + ax = gca() + return func(s, ax, *args, **kwargs) + pre_doc = inner.__doc__ + if pre_doc is None: + pre_doc = '' + else: + pre_doc = dedent(pre_doc) + inner.__doc__ = pre_doc + _ENSURE_AX_DOC.format(func=func.__name__, + obj='obj.') + return inner + @docstring.copy_dedent(Artist.findobj) def findobj(o=None, match=None, include_self=True): @@ -1242,6 +1385,30 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, return ret +def gna(figsize=None, tight_layout=False): + """ + Create a single new axes in a new figure. + + This is a convenience function for working interactively + and should not be used in scripts. + + Parameters + ---------- + figsize : tuple, optional + Figure size in inches (w, h) + + tight_layout : bool, optional + If tight layout shoudl be used. + + Returns + ------- + ax : Axes + New axes + """ + _, ax = subplots(figsize=figsize, tight_layout=tight_layout) + return ax + + def subplot2grid(shape, loc, rowspan=1, colspan=1, **kwargs): """ Create a subplot in a grid. The grid is specified by *shape*, at