@@ -462,26 +462,6 @@ def draw(self, renderer):
462
462
self .offsetText .set_ha (align )
463
463
self .offsetText .draw (renderer )
464
464
465
- if self .axes ._draw_grid and len (ticks ):
466
- # Grid points where the planes meet
467
- xyz0 = np .tile (minmax , (len (ticks ), 1 ))
468
- xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
469
-
470
- # Grid lines go from the end of one plane through the plane
471
- # intersection (at xyz0) to the end of the other plane. The first
472
- # point (0) differs along dimension index-2 and the last (2) along
473
- # dimension index-1.
474
- lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
475
- lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
476
- lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
477
- self .gridlines .set_segments (lines )
478
- gridinfo = info ['grid' ]
479
- self .gridlines .set_color (gridinfo ['color' ])
480
- self .gridlines .set_linewidth (gridinfo ['linewidth' ])
481
- self .gridlines .set_linestyle (gridinfo ['linestyle' ])
482
- self .gridlines .do_3d_projection ()
483
- self .gridlines .draw (renderer )
484
-
485
465
# Draw ticks:
486
466
tickdir = self ._get_tickdir ()
487
467
tickdelta = deltas [tickdir ] if highs [tickdir ] else - deltas [tickdir ]
@@ -519,6 +499,46 @@ def draw(self, renderer):
519
499
renderer .close_group ('axis3d' )
520
500
self .stale = False
521
501
502
+ @artist .allow_rasterization
503
+ def draw_grid (self , renderer ):
504
+ if not self .axes ._draw_grid :
505
+ return
506
+
507
+ self .label ._transform = self .axes .transData
508
+ renderer .open_group ("grid3d" , gid = self .get_gid ())
509
+
510
+ ticks = self ._update_ticks ()
511
+ if len (ticks ):
512
+ # Get general axis information:
513
+ info = self ._axinfo
514
+ index = info ["i" ]
515
+
516
+ mins , maxs , tc , highs = self ._get_coord_info ()
517
+
518
+ minmax = np .where (highs , maxs , mins )
519
+ maxmin = np .where (~ highs , maxs , mins )
520
+
521
+ # Grid points where the planes meet
522
+ xyz0 = np .tile (minmax , (len (ticks ), 1 ))
523
+ xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
524
+
525
+ # Grid lines go from the end of one plane through the plane
526
+ # intersection (at xyz0) to the end of the other plane. The first
527
+ # point (0) differs along dimension index-2 and the last (2) along
528
+ # dimension index-1.
529
+ lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
530
+ lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
531
+ lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
532
+ self .gridlines .set_segments (lines )
533
+ gridinfo = info ['grid' ]
534
+ self .gridlines .set_color (gridinfo ['color' ])
535
+ self .gridlines .set_linewidth (gridinfo ['linewidth' ])
536
+ self .gridlines .set_linestyle (gridinfo ['linestyle' ])
537
+ self .gridlines .do_3d_projection ()
538
+ self .gridlines .draw (renderer )
539
+
540
+ renderer .close_group ('grid3d' )
541
+
522
542
# TODO: Get this to work (more) properly when mplot3d supports the
523
543
# transforms framework.
524
544
def get_tightbbox (self , renderer = None , * , for_layout_only = False ):
0 commit comments