11
11
12
12
import numpy as np
13
13
import matplotlib .pyplot as plt
14
+ import matplotlib .colors as mcolors
14
15
15
16
# Fixing random state for reproducibility
16
17
np .random .seed (19680801 )
@@ -26,15 +27,19 @@ def plot_scatter(ax, prng, nb_samples=100):
26
27
return ax
27
28
28
29
29
- def plot_colored_sinusoidal_lines (ax ):
30
- """Plot sinusoidal lines with colors following the style color cycle."""
31
- L = 2 * np .pi
32
- x = np .linspace (0 , L )
30
+ def plot_colored_lines (ax ):
31
+ """Plot lines with colors following the style color cycle."""
32
+ t = np .linspace (- 10 , 10 , 100 )
33
+
34
+ def sigmoid (t , t0 ):
35
+ return 1 / (1 + np .exp (- (t - t0 )))
36
+
33
37
nb_colors = len (plt .rcParams ['axes.prop_cycle' ])
34
- shift = np .linspace (0 , L , nb_colors , endpoint = False )
35
- for s in shift :
36
- ax .plot (x , np .sin (x + s ), '-' )
37
- ax .set_xlim ([x [0 ], x [- 1 ]])
38
+ shifts = np .linspace (- 5 , 5 , nb_colors )
39
+ amplitudes = np .linspace (1 , 1.5 , nb_colors )
40
+ for t0 , a in zip (shifts , amplitudes ):
41
+ ax .plot (t , a * sigmoid (t , t0 ), '-' )
42
+ ax .set_xlim (- 10 , 10 )
38
43
return ax
39
44
40
45
@@ -108,23 +113,30 @@ def plot_figure(style_label=""):
108
113
# double the width and halve the height. NB: use relative changes because
109
114
# some styles may have a figure size different from the default one.
110
115
(fig_width , fig_height ) = plt .rcParams ['figure.figsize' ]
111
- fig_size = [fig_width * 2 , fig_height / 2 ]
116
+ fig_size = [fig_width * 2 , fig_height / 1.75 ]
112
117
113
118
fig , axs = plt .subplots (ncols = 6 , nrows = 1 , num = style_label ,
114
- figsize = fig_size , squeeze = True )
115
- axs [0 ].set_ylabel (style_label , fontsize = 13 , fontweight = 'bold' )
119
+ figsize = fig_size , constrained_layout = True )
120
+
121
+ # make a suptitle, in the same style for all subfigures,
122
+ # except those with dark backgrounds, which get a lighter
123
+ # color:
124
+ col = np .array ([19 , 6 , 84 ])/ 256
125
+ back = mcolors .rgb_to_hsv (
126
+ mcolors .to_rgb (plt .rcParams ['figure.facecolor' ]))[2 ]
127
+ if back < 0.5 :
128
+ col = [0.8 , 0.8 , 1 ]
129
+ fig .suptitle (style_label , x = 0.01 , fontsize = 14 , ha = 'left' ,
130
+ color = col , fontfamily = 'DejaVu Sans' ,
131
+ fontweight = 'normal' )
116
132
117
133
plot_scatter (axs [0 ], prng )
118
134
plot_image_and_patch (axs [1 ], prng )
119
135
plot_bar_graphs (axs [2 ], prng )
120
136
plot_colored_circles (axs [3 ], prng )
121
- plot_colored_sinusoidal_lines (axs [4 ])
137
+ plot_colored_lines (axs [4 ])
122
138
plot_histograms (axs [5 ], prng )
123
139
124
- fig .tight_layout ()
125
-
126
- return fig
127
-
128
140
129
141
if __name__ == "__main__" :
130
142
@@ -141,6 +153,6 @@ def plot_figure(style_label=""):
141
153
for style_label in style_list :
142
154
with plt .rc_context ({"figure.max_open_warning" : len (style_list )}):
143
155
with plt .style .context (style_label ):
144
- fig = plot_figure (style_label = style_label )
156
+ plot_figure (style_label = style_label )
145
157
146
158
plt .show ()
0 commit comments