Skip to content

Commit 26253f7

Browse files
committed
Dedupe various method implementations using functools.partialmethod.
This is shorter and yields better signatures on the resulting methods (as partialmethod forwards to the signature of the underlying helper). For example, the signature of FigureCanvasCairo.print_ps is now `(self, fobj, *, orientation='portrait')` rather than `(self, fobj, *args, **kwargs)`. For the cairo/wx print_foos, also note that we can delete the `*args` in the various print methods with no deprecation as they were not supported to start with (the underlying call to `_save` and `_print_image` would have raised if such `*args` were passed in).
1 parent 71c9f09 commit 26253f7

File tree

5 files changed

+70
-197
lines changed

5 files changed

+70
-197
lines changed

lib/matplotlib/backends/backend_cairo.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
This backend depends on cairocffi or pycairo.
77
"""
88

9+
import functools
910
import gzip
1011
import math
1112

@@ -467,20 +468,8 @@ def _get_printed_image_surface(self):
467468
self.figure.draw(renderer)
468469
return surface
469470

470-
def print_pdf(self, fobj, *args, **kwargs):
471-
return self._save(fobj, 'pdf', *args, **kwargs)
472-
473-
def print_ps(self, fobj, *args, **kwargs):
474-
return self._save(fobj, 'ps', *args, **kwargs)
475-
476-
def print_svg(self, fobj, *args, **kwargs):
477-
return self._save(fobj, 'svg', *args, **kwargs)
478-
479-
def print_svgz(self, fobj, *args, **kwargs):
480-
return self._save(fobj, 'svgz', *args, **kwargs)
481-
482471
@_check_savefig_extra_args
483-
def _save(self, fo, fmt, *, orientation='portrait'):
472+
def _save(self, fmt, fobj, *, orientation='portrait'):
484473
# save PDF/PS/SVG
485474

486475
dpi = 72
@@ -496,22 +485,22 @@ def _save(self, fo, fmt, *, orientation='portrait'):
496485
if not hasattr(cairo, 'PSSurface'):
497486
raise RuntimeError('cairo has not been compiled with PS '
498487
'support enabled')
499-
surface = cairo.PSSurface(fo, width_in_points, height_in_points)
488+
surface = cairo.PSSurface(fobj, width_in_points, height_in_points)
500489
elif fmt == 'pdf':
501490
if not hasattr(cairo, 'PDFSurface'):
502491
raise RuntimeError('cairo has not been compiled with PDF '
503492
'support enabled')
504-
surface = cairo.PDFSurface(fo, width_in_points, height_in_points)
493+
surface = cairo.PDFSurface(fobj, width_in_points, height_in_points)
505494
elif fmt in ('svg', 'svgz'):
506495
if not hasattr(cairo, 'SVGSurface'):
507496
raise RuntimeError('cairo has not been compiled with SVG '
508497
'support enabled')
509498
if fmt == 'svgz':
510-
if isinstance(fo, str):
511-
fo = gzip.GzipFile(fo, 'wb')
499+
if isinstance(fobj, str):
500+
fobj = gzip.GzipFile(fobj, 'wb')
512501
else:
513-
fo = gzip.GzipFile(None, 'wb', fileobj=fo)
514-
surface = cairo.SVGSurface(fo, width_in_points, height_in_points)
502+
fobj = gzip.GzipFile(None, 'wb', fileobj=fobj)
503+
surface = cairo.SVGSurface(fobj, width_in_points, height_in_points)
515504
else:
516505
raise ValueError("Unknown format: {!r}".format(fmt))
517506

@@ -531,7 +520,12 @@ def _save(self, fo, fmt, *, orientation='portrait'):
531520
ctx.show_page()
532521
surface.finish()
533522
if fmt == 'svgz':
534-
fo.close()
523+
fobj.close()
524+
525+
print_pdf = functools.partialmethod(_save, "pdf")
526+
print_ps = functools.partialmethod(_save, "ps")
527+
print_svg = functools.partialmethod(_save, "svg")
528+
print_svgz = functools.partialmethod(_save, "svgz")
535529

536530

537531
@_Backend.export

lib/matplotlib/backends/backend_wx.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Copyright (C) Jeremy O'Donoghue & John Hunter, 2003-4.
88
"""
99

10+
import functools
1011
import logging
1112
import math
1213
import pathlib
@@ -829,29 +830,8 @@ def draw(self, drawDC=None):
829830
self._isDrawn = True
830831
self.gui_repaint(drawDC=drawDC)
831832

832-
def print_bmp(self, filename, *args, **kwargs):
833-
return self._print_image(filename, wx.BITMAP_TYPE_BMP, *args, **kwargs)
834-
835-
def print_jpeg(self, filename, *args, **kwargs):
836-
return self._print_image(filename, wx.BITMAP_TYPE_JPEG,
837-
*args, **kwargs)
838-
print_jpg = print_jpeg
839-
840-
def print_pcx(self, filename, *args, **kwargs):
841-
return self._print_image(filename, wx.BITMAP_TYPE_PCX, *args, **kwargs)
842-
843-
def print_png(self, filename, *args, **kwargs):
844-
return self._print_image(filename, wx.BITMAP_TYPE_PNG, *args, **kwargs)
845-
846-
def print_tiff(self, filename, *args, **kwargs):
847-
return self._print_image(filename, wx.BITMAP_TYPE_TIF, *args, **kwargs)
848-
print_tif = print_tiff
849-
850-
def print_xpm(self, filename, *args, **kwargs):
851-
return self._print_image(filename, wx.BITMAP_TYPE_XPM, *args, **kwargs)
852-
853833
@_check_savefig_extra_args
854-
def _print_image(self, filename, filetype, *, quality=None):
834+
def _print_image(self, filetype, filename, *, quality=None):
855835
origBitmap = self.bitmap
856836

857837
self.bitmap = wx.Bitmap(math.ceil(self.figure.bbox.width),
@@ -897,6 +877,19 @@ def _print_image(self, filename, filetype, *, quality=None):
897877
if self:
898878
self.Refresh()
899879

880+
print_bmp = functools.partialmethod(
881+
_print_image, wx.BITMAP_TYPE_BMP)
882+
print_jpeg = print_jpg = functools.partialmethod(
883+
_print_image, wx.BITMAP_TYPE_JPEG)
884+
print_pcx = functools.partialmethod(
885+
_print_image, wx.BITMAP_TYPE_PCX)
886+
print_png = functools.partialmethod(
887+
_print_image, wx.BITMAP_TYPE_PNG)
888+
print_tiff = print_tif = functools.partialmethod(
889+
_print_image, wx.BITMAP_TYPE_TIF)
890+
print_xpm = functools.partialmethod(
891+
_print_image, wx.BITMAP_TYPE_XPM)
892+
900893

901894
class FigureFrameWx(wx.Frame):
902895
def __init__(self, num, fig):

lib/matplotlib/testing/jpl_units/Duration.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Duration module."""
22

3+
import functools
34
import operator
45

56
from matplotlib import _api
@@ -44,38 +45,20 @@ def seconds(self):
4445
def __bool__(self):
4546
return self._seconds != 0
4647

47-
def __eq__(self, rhs):
48-
return self._cmp(rhs, operator.eq)
49-
50-
def __ne__(self, rhs):
51-
return self._cmp(rhs, operator.ne)
52-
53-
def __lt__(self, rhs):
54-
return self._cmp(rhs, operator.lt)
55-
56-
def __le__(self, rhs):
57-
return self._cmp(rhs, operator.le)
58-
59-
def __gt__(self, rhs):
60-
return self._cmp(rhs, operator.gt)
61-
62-
def __ge__(self, rhs):
63-
return self._cmp(rhs, operator.ge)
64-
65-
def _cmp(self, rhs, op):
48+
def _cmp(self, op, rhs):
6649
"""
67-
Compare two Durations.
68-
69-
= INPUT VARIABLES
70-
- rhs The Duration to compare against.
71-
- op The function to do the comparison
72-
73-
= RETURN VALUE
74-
- Returns op(self, rhs)
50+
Check that *self* and *rhs* share frames; compare them using *op*.
7551
"""
7652
self.checkSameFrame(rhs, "compare")
7753
return op(self._seconds, rhs._seconds)
7854

55+
__eq__ = functools.partialmethod(_cmp, operator.eq)
56+
__ne__ = functools.partialmethod(_cmp, operator.ne)
57+
__lt__ = functools.partialmethod(_cmp, operator.lt)
58+
__le__ = functools.partialmethod(_cmp, operator.le)
59+
__gt__ = functools.partialmethod(_cmp, operator.gt)
60+
__ge__ = functools.partialmethod(_cmp, operator.ge)
61+
7962
def __add__(self, rhs):
8063
"""
8164
Add two Durations.
@@ -126,17 +109,7 @@ def __mul__(self, rhs):
126109
"""
127110
return Duration(self._frame, self._seconds * float(rhs))
128111

129-
def __rmul__(self, lhs):
130-
"""
131-
Scale a Duration by a value.
132-
133-
= INPUT VARIABLES
134-
- lhs The scalar to multiply by.
135-
136-
= RETURN VALUE
137-
- Returns the scaled Duration.
138-
"""
139-
return Duration(self._frame, self._seconds * float(lhs))
112+
__rmul__ = __mul__
140113

141114
def __str__(self):
142115
"""Print the Duration."""

lib/matplotlib/testing/jpl_units/Epoch.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Epoch module."""
22

3+
import functools
34
import operator
45
import math
56
import datetime as DT
@@ -106,44 +107,22 @@ def secondsPast(self, frame, jd):
106107
delta = t._jd - jd
107108
return t._seconds + delta * 86400
108109

109-
def __eq__(self, rhs):
110-
return self._cmp(rhs, operator.eq)
111-
112-
def __ne__(self, rhs):
113-
return self._cmp(rhs, operator.ne)
114-
115-
def __lt__(self, rhs):
116-
return self._cmp(rhs, operator.lt)
117-
118-
def __le__(self, rhs):
119-
return self._cmp(rhs, operator.le)
120-
121-
def __gt__(self, rhs):
122-
return self._cmp(rhs, operator.gt)
123-
124-
def __ge__(self, rhs):
125-
return self._cmp(rhs, operator.ge)
126-
127-
def _cmp(self, rhs, op):
128-
"""
129-
Compare two Epoch's.
130-
131-
= INPUT VARIABLES
132-
- rhs The Epoch to compare against.
133-
- op The function to do the comparison
134-
135-
= RETURN VALUE
136-
- Returns op(self, rhs)
137-
"""
110+
def _cmp(self, op, rhs):
111+
"""Compare Epochs *self* and *rhs* using operator *op*."""
138112
t = self
139113
if self._frame != rhs._frame:
140114
t = self.convert(rhs._frame)
141-
142115
if t._jd != rhs._jd:
143116
return op(t._jd, rhs._jd)
144-
145117
return op(t._seconds, rhs._seconds)
146118

119+
__eq__ = functools.partialmethod(_cmp, operator.eq)
120+
__ne__ = functools.partialmethod(_cmp, operator.ne)
121+
__lt__ = functools.partialmethod(_cmp, operator.lt)
122+
__le__ = functools.partialmethod(_cmp, operator.le)
123+
__gt__ = functools.partialmethod(_cmp, operator.gt)
124+
__ge__ = functools.partialmethod(_cmp, operator.ge)
125+
147126
def __add__(self, rhs):
148127
"""
149128
Add a duration to an Epoch.

lib/matplotlib/testing/jpl_units/UnitDbl.py

Lines changed: 20 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""UnitDbl module."""
22

3+
import functools
34
import operator
45

56
from matplotlib import _api
@@ -88,99 +89,32 @@ def __bool__(self):
8889
"""Return the truth value of a UnitDbl."""
8990
return bool(self._value)
9091

91-
def __eq__(self, rhs):
92-
return self._cmp(rhs, operator.eq)
93-
94-
def __ne__(self, rhs):
95-
return self._cmp(rhs, operator.ne)
96-
97-
def __lt__(self, rhs):
98-
return self._cmp(rhs, operator.lt)
99-
100-
def __le__(self, rhs):
101-
return self._cmp(rhs, operator.le)
102-
103-
def __gt__(self, rhs):
104-
return self._cmp(rhs, operator.gt)
105-
106-
def __ge__(self, rhs):
107-
return self._cmp(rhs, operator.ge)
108-
109-
def _cmp(self, rhs, op):
110-
"""
111-
Compare two UnitDbl's.
112-
113-
= ERROR CONDITIONS
114-
- If the input rhs units are not the same as our units,
115-
an error is thrown.
116-
117-
= INPUT VARIABLES
118-
- rhs The UnitDbl to compare against.
119-
- op The function to do the comparison
120-
121-
= RETURN VALUE
122-
- Returns op(self, rhs)
123-
"""
92+
def _cmp(self, op, rhs):
93+
"""Check that *self* and *rhs* share units; compare them using *op*."""
12494
self.checkSameUnits(rhs, "compare")
12595
return op(self._value, rhs._value)
12696

127-
def __add__(self, rhs):
128-
"""
129-
Add two UnitDbl's.
130-
131-
= ERROR CONDITIONS
132-
- If the input rhs units are not the same as our units,
133-
an error is thrown.
134-
135-
= INPUT VARIABLES
136-
- rhs The UnitDbl to add.
137-
138-
= RETURN VALUE
139-
- Returns the sum of ourselves and the input UnitDbl.
140-
"""
141-
self.checkSameUnits(rhs, "add")
142-
return UnitDbl(self._value + rhs._value, self._units)
143-
144-
def __sub__(self, rhs):
145-
"""
146-
Subtract two UnitDbl's.
147-
148-
= ERROR CONDITIONS
149-
- If the input rhs units are not the same as our units,
150-
an error is thrown.
151-
152-
= INPUT VARIABLES
153-
- rhs The UnitDbl to subtract.
154-
155-
= RETURN VALUE
156-
- Returns the difference of ourselves and the input UnitDbl.
157-
"""
158-
self.checkSameUnits(rhs, "subtract")
159-
return UnitDbl(self._value - rhs._value, self._units)
97+
__eq__ = functools.partialmethod(_cmp, operator.eq)
98+
__ne__ = functools.partialmethod(_cmp, operator.ne)
99+
__lt__ = functools.partialmethod(_cmp, operator.lt)
100+
__le__ = functools.partialmethod(_cmp, operator.le)
101+
__gt__ = functools.partialmethod(_cmp, operator.gt)
102+
__ge__ = functools.partialmethod(_cmp, operator.ge)
160103

161-
def __mul__(self, rhs):
162-
"""
163-
Scale a UnitDbl by a value.
104+
def _binop_unit_unit(self, op, rhs):
105+
"""Check that *self* and *rhs* share units; combine them using *op*."""
106+
self.checkSameUnits(rhs, op.__name__)
107+
return UnitDbl(op(self._value, rhs._value), self._units)
164108

165-
= INPUT VARIABLES
166-
- rhs The scalar to multiply by.
109+
__add__ = functools.partialmethod(_binop_unit_unit, operator.add)
110+
__sub__ = functools.partialmethod(_binop_unit_unit, operator.sub)
167111

168-
= RETURN VALUE
169-
- Returns the scaled UnitDbl.
170-
"""
171-
return UnitDbl(self._value * rhs, self._units)
112+
def _binop_unit_scalar(self, op, scalar):
113+
"""Combine *self* and *scalar* using *op*."""
114+
return UnitDbl(op(self._value, scalar), self._units)
172115

173-
def __rmul__(self, lhs):
174-
"""
175-
Scale a UnitDbl by a value.
176-
177-
= INPUT VARIABLES
178-
- lhs The scalar to multiply by.
179-
180-
= RETURN VALUE
181-
- Returns the scaled UnitDbl.
182-
"""
183-
return UnitDbl(self._value * lhs, self._units)
116+
__mul__ = functools.partialmethod(_binop_unit_scalar, operator.mul)
117+
__rmul__ = functools.partialmethod(_binop_unit_scalar, operator.mul)
184118

185119
def __str__(self):
186120
"""Print the UnitDbl."""

0 commit comments

Comments
 (0)