@@ -285,21 +285,17 @@ def test_plot_kernel():
285
285
plt .plot (kernel )
286
286
287
287
288
- @check_figures_equal (extensions = ['png' ])
289
- def test_unit_axis_label (fig_test , fig_ref ):
288
+ @pytest .mark .parametrize ('plot_meth_name' , ['scatter' , 'plot' ])
289
+ def test_unit_axis_label (plot_meth_name ):
290
+ # Check that the correct Axis labels are set on plots with units
290
291
import matplotlib .testing .jpl_units as units
291
292
units .register ()
292
293
293
- data = [0 * units .km , 1 * units .km , 2 * units .km ]
294
-
295
- ax_test = fig_test .subplots ()
296
- ax_ref = fig_ref .subplots ()
297
- axs = [ax_test , ax_ref ]
298
-
299
- for ax in axs :
300
- ax .yaxis .set_units ('km' )
301
- ax .set_xlim (10 , 20 )
302
- ax .set_ylim (10 , 20 )
303
-
304
- ax_test .scatter ([1 , 2 , 3 ], data , edgecolors = 'none' )
305
- ax_ref .plot ([1 , 2 , 3 ], data , marker = 'o' , linewidth = 0 )
294
+ fig , ax = plt .subplots ()
295
+ ax .xaxis .set_units ('m' )
296
+ ax .yaxis .set_units ('sec' )
297
+ plot_method = getattr (ax , plot_meth_name )
298
+ plot_method (np .arange (3 ) * units .m , np .arange (3 ) * units .sec )
299
+ assert ax .get_xlabel () == 'm'
300
+ assert ax .get_ylabel () == 'sec'
301
+ plt .close ('all' )
0 commit comments