Skip to content

Commit 1574b5e

Browse files
authored
Merge pull request #19654 from anntzer/pm
Dedupe various method implementations using functools.partialmethod.
2 parents 6fb2f7b + 26253f7 commit 1574b5e

File tree

5 files changed

+70
-197
lines changed

5 files changed

+70
-197
lines changed

lib/matplotlib/backends/backend_cairo.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
This backend depends on cairocffi or pycairo.
77
"""
88

9+
import functools
910
import gzip
1011
import math
1112

@@ -467,20 +468,8 @@ def _get_printed_image_surface(self):
467468
self.figure.draw(renderer)
468469
return surface
469470

470-
def print_pdf(self, fobj, *args, **kwargs):
471-
return self._save(fobj, 'pdf', *args, **kwargs)
472-
473-
def print_ps(self, fobj, *args, **kwargs):
474-
return self._save(fobj, 'ps', *args, **kwargs)
475-
476-
def print_svg(self, fobj, *args, **kwargs):
477-
return self._save(fobj, 'svg', *args, **kwargs)
478-
479-
def print_svgz(self, fobj, *args, **kwargs):
480-
return self._save(fobj, 'svgz', *args, **kwargs)
481-
482471
@_check_savefig_extra_args
483-
def _save(self, fo, fmt, *, orientation='portrait'):
472+
def _save(self, fmt, fobj, *, orientation='portrait'):
484473
# save PDF/PS/SVG
485474

486475
dpi = 72
@@ -496,22 +485,22 @@ def _save(self, fo, fmt, *, orientation='portrait'):
496485
if not hasattr(cairo, 'PSSurface'):
497486
raise RuntimeError('cairo has not been compiled with PS '
498487
'support enabled')
499-
surface = cairo.PSSurface(fo, width_in_points, height_in_points)
488+
surface = cairo.PSSurface(fobj, width_in_points, height_in_points)
500489
elif fmt == 'pdf':
501490
if not hasattr(cairo, 'PDFSurface'):
502491
raise RuntimeError('cairo has not been compiled with PDF '
503492
'support enabled')
504-
surface = cairo.PDFSurface(fo, width_in_points, height_in_points)
493+
surface = cairo.PDFSurface(fobj, width_in_points, height_in_points)
505494
elif fmt in ('svg', 'svgz'):
506495
if not hasattr(cairo, 'SVGSurface'):
507496
raise RuntimeError('cairo has not been compiled with SVG '
508497
'support enabled')
509498
if fmt == 'svgz':
510-
if isinstance(fo, str):
511-
fo = gzip.GzipFile(fo, 'wb')
499+
if isinstance(fobj, str):
500+
fobj = gzip.GzipFile(fobj, 'wb')
512501
else:
513-
fo = gzip.GzipFile(None, 'wb', fileobj=fo)
514-
surface = cairo.SVGSurface(fo, width_in_points, height_in_points)
502+
fobj = gzip.GzipFile(None, 'wb', fileobj=fobj)
503+
surface = cairo.SVGSurface(fobj, width_in_points, height_in_points)
515504
else:
516505
raise ValueError("Unknown format: {!r}".format(fmt))
517506

@@ -531,7 +520,12 @@ def _save(self, fo, fmt, *, orientation='portrait'):
531520
ctx.show_page()
532521
surface.finish()
533522
if fmt == 'svgz':
534-
fo.close()
523+
fobj.close()
524+
525+
print_pdf = functools.partialmethod(_save, "pdf")
526+
print_ps = functools.partialmethod(_save, "ps")
527+
print_svg = functools.partialmethod(_save, "svg")
528+
print_svgz = functools.partialmethod(_save, "svgz")
535529

536530

537531
@_Backend.export

lib/matplotlib/backends/backend_wx.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Copyright (C) Jeremy O'Donoghue & John Hunter, 2003-4.
88
"""
99

10+
import functools
1011
import logging
1112
import math
1213
import pathlib
@@ -829,29 +830,8 @@ def draw(self, drawDC=None):
829830
self._isDrawn = True
830831
self.gui_repaint(drawDC=drawDC)
831832

832-
def print_bmp(self, filename, *args, **kwargs):
833-
return self._print_image(filename, wx.BITMAP_TYPE_BMP, *args, **kwargs)
834-
835-
def print_jpeg(self, filename, *args, **kwargs):
836-
return self._print_image(filename, wx.BITMAP_TYPE_JPEG,
837-
*args, **kwargs)
838-
print_jpg = print_jpeg
839-
840-
def print_pcx(self, filename, *args, **kwargs):
841-
return self._print_image(filename, wx.BITMAP_TYPE_PCX, *args, **kwargs)
842-
843-
def print_png(self, filename, *args, **kwargs):
844-
return self._print_image(filename, wx.BITMAP_TYPE_PNG, *args, **kwargs)
845-
846-
def print_tiff(self, filename, *args, **kwargs):
847-
return self._print_image(filename, wx.BITMAP_TYPE_TIF, *args, **kwargs)
848-
print_tif = print_tiff
849-
850-
def print_xpm(self, filename, *args, **kwargs):
851-
return self._print_image(filename, wx.BITMAP_TYPE_XPM, *args, **kwargs)
852-
853833
@_check_savefig_extra_args
854-
def _print_image(self, filename, filetype, *, quality=None):
834+
def _print_image(self, filetype, filename, *, quality=None):
855835
origBitmap = self.bitmap
856836

857837
self.bitmap = wx.Bitmap(math.ceil(self.figure.bbox.width),
@@ -897,6 +877,19 @@ def _print_image(self, filename, filetype, *, quality=None):
897877
if self:
898878
self.Refresh()
899879

880+
print_bmp = functools.partialmethod(
881+
_print_image, wx.BITMAP_TYPE_BMP)
882+
print_jpeg = print_jpg = functools.partialmethod(
883+
_print_image, wx.BITMAP_TYPE_JPEG)
884+
print_pcx = functools.partialmethod(
885+
_print_image, wx.BITMAP_TYPE_PCX)
886+
print_png = functools.partialmethod(
887+
_print_image, wx.BITMAP_TYPE_PNG)
888+
print_tiff = print_tif = functools.partialmethod(
889+
_print_image, wx.BITMAP_TYPE_TIF)
890+
print_xpm = functools.partialmethod(
891+
_print_image, wx.BITMAP_TYPE_XPM)
892+
900893

901894
class FigureFrameWx(wx.Frame):
902895
def __init__(self, num, fig):

lib/matplotlib/testing/jpl_units/Duration.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Duration module."""
22

3+
import functools
34
import operator
45

56
from matplotlib import _api
@@ -44,38 +45,20 @@ def seconds(self):
4445
def __bool__(self):
4546
return self._seconds != 0
4647

47-
def __eq__(self, rhs):
48-
return self._cmp(rhs, operator.eq)
49-
50-
def __ne__(self, rhs):
51-
return self._cmp(rhs, operator.ne)
52-
53-
def __lt__(self, rhs):
54-
return self._cmp(rhs, operator.lt)
55-
56-
def __le__(self, rhs):
57-
return self._cmp(rhs, operator.le)
58-
59-
def __gt__(self, rhs):
60-
return self._cmp(rhs, operator.gt)
61-
62-
def __ge__(self, rhs):
63-
return self._cmp(rhs, operator.ge)
64-
65-
def _cmp(self, rhs, op):
48+
def _cmp(self, op, rhs):
6649
"""
67-
Compare two Durations.
68-
69-
= INPUT VARIABLES
70-
- rhs The Duration to compare against.
71-
- op The function to do the comparison
72-
73-
= RETURN VALUE
74-
- Returns op(self, rhs)
50+
Check that *self* and *rhs* share frames; compare them using *op*.
7551
"""
7652
self.checkSameFrame(rhs, "compare")
7753
return op(self._seconds, rhs._seconds)
7854

55+
__eq__ = functools.partialmethod(_cmp, operator.eq)
56+
__ne__ = functools.partialmethod(_cmp, operator.ne)
57+
__lt__ = functools.partialmethod(_cmp, operator.lt)
58+
__le__ = functools.partialmethod(_cmp, operator.le)
59+
__gt__ = functools.partialmethod(_cmp, operator.gt)
60+
__ge__ = functools.partialmethod(_cmp, operator.ge)
61+
7962
def __add__(self, rhs):
8063
"""
8164
Add two Durations.
@@ -126,17 +109,7 @@ def __mul__(self, rhs):
126109
"""
127110
return Duration(self._frame, self._seconds * float(rhs))
128111

129-
def __rmul__(self, lhs):
130-
"""
131-
Scale a Duration by a value.
132-
133-
= INPUT VARIABLES
134-
- lhs The scalar to multiply by.
135-
136-
= RETURN VALUE
137-
- Returns the scaled Duration.
138-
"""
139-
return Duration(self._frame, self._seconds * float(lhs))
112+
__rmul__ = __mul__
140113

141114
def __str__(self):
142115
"""Print the Duration."""

lib/matplotlib/testing/jpl_units/Epoch.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Epoch module."""
22

3+
import functools
34
import operator
45
import math
56
import datetime as DT
@@ -106,44 +107,22 @@ def secondsPast(self, frame, jd):
106107
delta = t._jd - jd
107108
return t._seconds + delta * 86400
108109

109-
def __eq__(self, rhs):
110-
return self._cmp(rhs, operator.eq)
111-
112-
def __ne__(self, rhs):
113-
return self._cmp(rhs, operator.ne)
114-
115-
def __lt__(self, rhs):
116-
return self._cmp(rhs, operator.lt)
117-
118-
def __le__(self, rhs):
119-
return self._cmp(rhs, operator.le)
120-
121-
def __gt__(self, rhs):
122-
return self._cmp(rhs, operator.gt)
123-
124-
def __ge__(self, rhs):
125-
return self._cmp(rhs, operator.ge)
126-
127-
def _cmp(self, rhs, op):
128-
"""
129-
Compare two Epoch's.
130-
131-
= INPUT VARIABLES
132-
- rhs The Epoch to compare against.
133-
- op The function to do the comparison
134-
135-
= RETURN VALUE
136-
- Returns op(self, rhs)
137-
"""
110+
def _cmp(self, op, rhs):
111+
"""Compare Epochs *self* and *rhs* using operator *op*."""
138112
t = self
139113
if self._frame != rhs._frame:
140114
t = self.convert(rhs._frame)
141-
142115
if t._jd != rhs._jd:
143116
return op(t._jd, rhs._jd)
144-
145117
return op(t._seconds, rhs._seconds)
146118

119+
__eq__ = functools.partialmethod(_cmp, operator.eq)
120+
__ne__ = functools.partialmethod(_cmp, operator.ne)
121+
__lt__ = functools.partialmethod(_cmp, operator.lt)
122+
__le__ = functools.partialmethod(_cmp, operator.le)
123+
__gt__ = functools.partialmethod(_cmp, operator.gt)
124+
__ge__ = functools.partialmethod(_cmp, operator.ge)
125+
147126
def __add__(self, rhs):
148127
"""
149128
Add a duration to an Epoch.

lib/matplotlib/testing/jpl_units/UnitDbl.py

Lines changed: 20 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""UnitDbl module."""
22

3+
import functools
34
import operator
45

56
from matplotlib import _api
@@ -88,99 +89,32 @@ def __bool__(self):
8889
"""Return the truth value of a UnitDbl."""
8990
return bool(self._value)
9091

91-
def __eq__(self, rhs):
92-
return self._cmp(rhs, operator.eq)
93-
94-
def __ne__(self, rhs):
95-
return self._cmp(rhs, operator.ne)
96-
97-
def __lt__(self, rhs):
98-
return self._cmp(rhs, operator.lt)
99-
100-
def __le__(self, rhs):
101-
return self._cmp(rhs, operator.le)
102-
103-
def __gt__(self, rhs):
104-
return self._cmp(rhs, operator.gt)
105-
106-
def __ge__(self, rhs):
107-
return self._cmp(rhs, operator.ge)
108-
109-
def _cmp(self, rhs, op):
110-
"""
111-
Compare two UnitDbl's.
112-
113-
= ERROR CONDITIONS
114-
- If the input rhs units are not the same as our units,
115-
an error is thrown.
116-
117-
= INPUT VARIABLES
118-
- rhs The UnitDbl to compare against.
119-
- op The function to do the comparison
120-
121-
= RETURN VALUE
122-
- Returns op(self, rhs)
123-
"""
92+
def _cmp(self, op, rhs):
93+
"""Check that *self* and *rhs* share units; compare them using *op*."""
12494
self.checkSameUnits(rhs, "compare")
12595
return op(self._value, rhs._value)
12696

127-
def __add__(self, rhs):
128-
"""
129-
Add two UnitDbl's.
130-
131-
= ERROR CONDITIONS
132-
- If the input rhs units are not the same as our units,
133-
an error is thrown.
134-
135-
= INPUT VARIABLES
136-
- rhs The UnitDbl to add.
137-
138-
= RETURN VALUE
139-
- Returns the sum of ourselves and the input UnitDbl.
140-
"""
141-
self.checkSameUnits(rhs, "add")
142-
return UnitDbl(self._value + rhs._value, self._units)
143-
144-
def __sub__(self, rhs):
145-
"""
146-
Subtract two UnitDbl's.
147-
148-
= ERROR CONDITIONS
149-
- If the input rhs units are not the same as our units,
150-
an error is thrown.
151-
152-
= INPUT VARIABLES
153-
- rhs The UnitDbl to subtract.
154-
155-
= RETURN VALUE
156-
- Returns the difference of ourselves and the input UnitDbl.
157-
"""
158-
self.checkSameUnits(rhs, "subtract")
159-
return UnitDbl(self._value - rhs._value, self._units)
97+
__eq__ = functools.partialmethod(_cmp, operator.eq)
98+
__ne__ = functools.partialmethod(_cmp, operator.ne)
99+
__lt__ = functools.partialmethod(_cmp, operator.lt)
100+
__le__ = functools.partialmethod(_cmp, operator.le)
101+
__gt__ = functools.partialmethod(_cmp, operator.gt)
102+
__ge__ = functools.partialmethod(_cmp, operator.ge)
160103

161-
def __mul__(self, rhs):
162-
"""
163-
Scale a UnitDbl by a value.
104+
def _binop_unit_unit(self, op, rhs):
105+
"""Check that *self* and *rhs* share units; combine them using *op*."""
106+
self.checkSameUnits(rhs, op.__name__)
107+
return UnitDbl(op(self._value, rhs._value), self._units)
164108

165-
= INPUT VARIABLES
166-
- rhs The scalar to multiply by.
109+
__add__ = functools.partialmethod(_binop_unit_unit, operator.add)
110+
__sub__ = functools.partialmethod(_binop_unit_unit, operator.sub)
167111

168-
= RETURN VALUE
169-
- Returns the scaled UnitDbl.
170-
"""
171-
return UnitDbl(self._value * rhs, self._units)
112+
def _binop_unit_scalar(self, op, scalar):
113+
"""Combine *self* and *scalar* using *op*."""
114+
return UnitDbl(op(self._value, scalar), self._units)
172115

173-
def __rmul__(self, lhs):
174-
"""
175-
Scale a UnitDbl by a value.
176-
177-
= INPUT VARIABLES
178-
- lhs The scalar to multiply by.
179-
180-
= RETURN VALUE
181-
- Returns the scaled UnitDbl.
182-
"""
183-
return UnitDbl(self._value * lhs, self._units)
116+
__mul__ = functools.partialmethod(_binop_unit_scalar, operator.mul)
117+
__rmul__ = functools.partialmethod(_binop_unit_scalar, operator.mul)
184118

185119
def __str__(self):
186120
"""Print the UnitDbl."""

0 commit comments

Comments
 (0)