Skip to content

Commit 72fa9c8

Browse files
authored
Merge pull request #11872 from anntzer/picklable-cmaps
Make all builtin cmaps picklable.
2 parents 9af32e6 + 3711718 commit 72fa9c8

File tree

2 files changed

+72
-86
lines changed

2 files changed

+72
-86
lines changed

lib/matplotlib/_cm.py

Lines changed: 64 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
with the purpose and type of your colormap if you add data for one here.
77
"""
88

9+
from functools import partial
10+
911
import numpy as np
1012

1113
_binary_data = {
@@ -41,22 +43,30 @@
4143
'blue': ((0., 0., 0.),
4244
(1.0, 0.4975, 0.4975))}
4345

44-
4546
def _flag_red(x): return 0.75 * np.sin((x * 31.5 + 0.25) * np.pi) + 0.5
4647
def _flag_green(x): return np.sin(x * 31.5 * np.pi)
4748
def _flag_blue(x): return 0.75 * np.sin((x * 31.5 - 0.25) * np.pi) + 0.5
49+
_flag_data = {'red': _flag_red, 'green': _flag_green, 'blue': _flag_blue}
50+
4851
def _prism_red(x): return 0.75 * np.sin((x * 20.9 + 0.25) * np.pi) + 0.67
4952
def _prism_green(x): return 0.75 * np.sin((x * 20.9 - 0.25) * np.pi) + 0.33
5053
def _prism_blue(x): return -1.1 * np.sin((x * 20.9) * np.pi)
51-
52-
53-
_flag_data = {'red': _flag_red, 'green': _flag_green, 'blue': _flag_blue}
5454
_prism_data = {'red': _prism_red, 'green': _prism_green, 'blue': _prism_blue}
5555

56+
def _ch_helper(gamma, s, r, h, p0, p1, x):
57+
"""Helper function for generating picklable cubehelix color maps."""
58+
# Apply gamma factor to emphasise low or high intensity values
59+
xg = x ** gamma
60+
# Calculate amplitude and angle of deviation from the black to white
61+
# diagonal in the plane of constant perceived intensity.
62+
a = h * xg * (1 - xg) / 2
63+
phi = 2 * np.pi * (s / 3 + r * x)
64+
return xg + a * (p0 * np.cos(phi) + p1 * np.sin(phi))
5665

5766
def cubehelix(gamma=1.0, s=0.5, r=-1.5, h=1.0):
58-
"""Return custom data dictionary of (r,g,b) conversion functions, which
59-
can be used with :func:`register_cmap`, for the cubehelix color scheme.
67+
"""
68+
Return custom data dictionary of (r,g,b) conversion functions, which can be
69+
used with :func:`register_cmap`, for the cubehelix color scheme.
6070
6171
Unlike most other color schemes cubehelix was designed by D.A. Green to
6272
be monotonically increasing in terms of perceived brightness.
@@ -90,79 +100,50 @@ def cubehelix(gamma=1.0, s=0.5, r=-1.5, h=1.0):
90100
colors are. If this parameter is zero then the color
91101
scheme is purely a greyscale; defaults to 1.0.
92102
========= =======================================================
93-
94103
"""
95-
96-
def get_color_function(p0, p1):
97-
98-
def color(x):
99-
# Apply gamma factor to emphasise low or high intensity values
100-
xg = x ** gamma
101-
102-
# Calculate amplitude and angle of deviation from the black
103-
# to white diagonal in the plane of constant
104-
# perceived intensity.
105-
a = h * xg * (1 - xg) / 2
106-
107-
phi = 2 * np.pi * (s / 3 + r * x)
108-
109-
return xg + a * (p0 * np.cos(phi) + p1 * np.sin(phi))
110-
return color
111-
112-
return {
113-
'red': get_color_function(-0.14861, 1.78277),
114-
'green': get_color_function(-0.29227, -0.90649),
115-
'blue': get_color_function(1.97294, 0.0),
116-
}
104+
return {'red': partial(_ch_helper, gamma, s, r, h, -0.14861, 1.78277),
105+
'green': partial(_ch_helper, gamma, s, r, h, -0.29227, -0.90649),
106+
'blue': partial(_ch_helper, gamma, s, r, h, 1.97294, 0.0)}
117107

118108
_cubehelix_data = cubehelix()
119109

120110
_bwr_data = ((0.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 0.0, 0.0))
121111
_brg_data = ((0.0, 0.0, 1.0), (1.0, 0.0, 0.0), (0.0, 1.0, 0.0))
122112

123113
# Gnuplot palette functions
124-
gfunc = {
125-
0: lambda x: 0,
126-
1: lambda x: 0.5,
127-
2: lambda x: 1,
128-
3: lambda x: x,
129-
4: lambda x: x ** 2,
130-
5: lambda x: x ** 3,
131-
6: lambda x: x ** 4,
132-
7: lambda x: np.sqrt(x),
133-
8: lambda x: np.sqrt(np.sqrt(x)),
134-
9: lambda x: np.sin(x * np.pi / 2),
135-
10: lambda x: np.cos(x * np.pi / 2),
136-
11: lambda x: np.abs(x - 0.5),
137-
12: lambda x: (2 * x - 1) ** 2,
138-
13: lambda x: np.sin(x * np.pi),
139-
14: lambda x: np.abs(np.cos(x * np.pi)),
140-
15: lambda x: np.sin(x * 2 * np.pi),
141-
16: lambda x: np.cos(x * 2 * np.pi),
142-
17: lambda x: np.abs(np.sin(x * 2 * np.pi)),
143-
18: lambda x: np.abs(np.cos(x * 2 * np.pi)),
144-
19: lambda x: np.abs(np.sin(x * 4 * np.pi)),
145-
20: lambda x: np.abs(np.cos(x * 4 * np.pi)),
146-
21: lambda x: 3 * x,
147-
22: lambda x: 3 * x - 1,
148-
23: lambda x: 3 * x - 2,
149-
24: lambda x: np.abs(3 * x - 1),
150-
25: lambda x: np.abs(3 * x - 2),
151-
26: lambda x: (3 * x - 1) / 2,
152-
27: lambda x: (3 * x - 2) / 2,
153-
28: lambda x: np.abs((3 * x - 1) / 2),
154-
29: lambda x: np.abs((3 * x - 2) / 2),
155-
30: lambda x: x / 0.32 - 0.78125,
156-
31: lambda x: 2 * x - 0.84,
157-
32: lambda x: gfunc32(x),
158-
33: lambda x: np.abs(2 * x - 0.5),
159-
34: lambda x: 2 * x,
160-
35: lambda x: 2 * x - 0.5,
161-
36: lambda x: 2 * x - 1.
162-
}
163-
164-
165-
def gfunc32(x):
114+
def _g0(x): return 0
115+
def _g1(x): return 0.5
116+
def _g2(x): return 1
117+
def _g3(x): return x
118+
def _g4(x): return x ** 2
119+
def _g5(x): return x ** 3
120+
def _g6(x): return x ** 4
121+
def _g7(x): return np.sqrt(x)
122+
def _g8(x): return np.sqrt(np.sqrt(x))
123+
def _g9(x): return np.sin(x * np.pi / 2)
124+
def _g10(x): return np.cos(x * np.pi / 2)
125+
def _g11(x): return np.abs(x - 0.5)
126+
def _g12(x): return (2 * x - 1) ** 2
127+
def _g13(x): return np.sin(x * np.pi)
128+
def _g14(x): return np.abs(np.cos(x * np.pi))
129+
def _g15(x): return np.sin(x * 2 * np.pi)
130+
def _g16(x): return np.cos(x * 2 * np.pi)
131+
def _g17(x): return np.abs(np.sin(x * 2 * np.pi))
132+
def _g18(x): return np.abs(np.cos(x * 2 * np.pi))
133+
def _g19(x): return np.abs(np.sin(x * 4 * np.pi))
134+
def _g20(x): return np.abs(np.cos(x * 4 * np.pi))
135+
def _g21(x): return 3 * x
136+
def _g22(x): return 3 * x - 1
137+
def _g23(x): return 3 * x - 2
138+
def _g24(x): return np.abs(3 * x - 1)
139+
def _g25(x): return np.abs(3 * x - 2)
140+
def _g26(x): return (3 * x - 1) / 2
141+
def _g27(x): return (3 * x - 2) / 2
142+
def _g28(x): return np.abs((3 * x - 1) / 2)
143+
def _g29(x): return np.abs((3 * x - 2) / 2)
144+
def _g30(x): return x / 0.32 - 0.78125
145+
def _g31(x): return 2 * x - 0.84
146+
def _g32(x):
166147
ret = np.zeros(len(x))
167148
m = (x < 0.25)
168149
ret[m] = 4 * x[m]
@@ -171,6 +152,12 @@ def gfunc32(x):
171152
m = (x >= 0.92)
172153
ret[m] = x[m] / 0.08 - 11.5
173154
return ret
155+
def _g33(x): return np.abs(2 * x - 0.5)
156+
def _g34(x): return 2 * x
157+
def _g35(x): return 2 * x - 0.5
158+
def _g36(x): return 2 * x - 1
159+
160+
gfunc = {i: globals()["_g{}".format(i)] for i in range(37)}
174161

175162
_gnuplot_data = {
176163
'red': gfunc[7],
@@ -1017,11 +1004,11 @@ def gfunc32(x):
10171004
'blue': gfunc[3],
10181005
}
10191006

1007+
def _gist_heat_red(x): return 1.5 * x
1008+
def _gist_heat_green(x): return 2 * x - 1
1009+
def _gist_heat_blue(x): return 4 * x - 3
10201010
_gist_heat_data = {
1021-
'red': lambda x: 1.5 * x,
1022-
'green': lambda x: 2 * x - 1,
1023-
'blue': lambda x: 4 * x - 3,
1024-
}
1011+
'red': _gist_heat_red, 'green': _gist_heat_green, 'blue': _gist_heat_blue}
10251012

10261013
_gist_ncar_data = \
10271014
{'red': (
@@ -1098,11 +1085,8 @@ def gfunc32(x):
10981085
(0.735, 0.000, 0.000), (1.000, 1.000, 1.000))
10991086
}
11001087

1101-
_gist_yarg_data = {
1102-
'red': lambda x: 1 - x,
1103-
'green': lambda x: 1 - x,
1104-
'blue': lambda x: 1 - x,
1105-
}
1088+
def _gist_yarg(x): return 1 - x
1089+
_gist_yarg_data = {'red': _gist_yarg, 'green': _gist_yarg, 'blue': _gist_yarg}
11061090

11071091
# This bipolar color map was generated from CoolWarmFloat33.csv of
11081092
# "Diverging Color Maps for Scientific Visualization" by Kenneth Moreland.

lib/matplotlib/tests/test_pickle.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1+
from io import BytesIO
12
import pickle
23
import platform
3-
from io import BytesIO
44

55
import numpy as np
6+
import pytest
67

8+
from matplotlib import cm
79
from matplotlib.testing.decorators import image_comparison
810
from matplotlib.dates import rrulewrapper
911
import matplotlib.pyplot as plt
1012
import matplotlib.transforms as mtransforms
1113

12-
try: # https://docs.python.org/3/library/exceptions.html#RecursionError
13-
RecursionError # Python 3.5+
14-
except NameError:
15-
RecursionError = RuntimeError # Python < 3.5
16-
1714

1815
def test_simple():
1916
fig = plt.figure()
@@ -194,3 +191,8 @@ def test_shared():
194191
fig = pickle.loads(pickle.dumps(fig))
195192
fig.axes[0].set_xlim(10, 20)
196193
assert fig.axes[1].get_xlim() == (10, 20)
194+
195+
196+
@pytest.mark.parametrize("cmap", cm.cmap_d.values())
197+
def test_cmap(cmap):
198+
pickle.dumps(cmap)

0 commit comments

Comments
 (0)