@@ -1002,43 +1002,23 @@ def __init__(self, ax, labels, actives=None):
1002
1002
if actives is None :
1003
1003
actives = [False ] * len (labels )
1004
1004
1005
- if len (labels ) > 1 :
1006
- dy = 1. / (len (labels ) + 1 )
1007
- ys = np .linspace (1 - dy , dy , len (labels ))
1008
- else :
1009
- dy = 0.25
1010
- ys = [0.5 ]
1011
-
1012
- axcolor = ax .get_facecolor ()
1013
-
1014
- self .labels = []
1015
- self .lines = []
1016
- self .rectangles = []
1017
-
1018
- lineparams = {'color' : 'k' , 'linewidth' : 1.25 ,
1019
- 'transform' : ax .transAxes , 'solid_capstyle' : 'butt' }
1020
- for y , label , active in zip (ys , labels , actives ):
1021
- t = ax .text (0.25 , y , label , transform = ax .transAxes ,
1022
- horizontalalignment = 'left' ,
1023
- verticalalignment = 'center' )
1024
-
1025
- w , h = dy / 2 , dy / 2
1026
- x , y = 0.05 , y - h / 2
1027
-
1028
- p = Rectangle (xy = (x , y ), width = w , height = h , edgecolor = 'black' ,
1029
- facecolor = axcolor , transform = ax .transAxes )
1005
+ ys = np .linspace (1 , 0 , len (labels )+ 2 )[1 :- 1 ]
1006
+ text_size = mpl .rcParams ["font.size" ] / 2
1030
1007
1031
- l1 = Line2D ([x , x + w ], [y + h , y ], ** lineparams )
1032
- l2 = Line2D ([x , x + w ], [y , y + h ], ** lineparams )
1008
+ self .labels = [
1009
+ ax .text (0.25 , y , label , transform = ax .transAxes ,
1010
+ horizontalalignment = "left" , verticalalignment = "center" )
1011
+ for y , label in zip (ys , labels )]
1033
1012
1034
- l1 .set_visible (active )
1035
- l2 .set_visible (active )
1036
- self .labels .append (t )
1037
- self .rectangles .append (p )
1038
- self .lines .append ((l1 , l2 ))
1039
- ax .add_patch (p )
1040
- ax .add_line (l1 )
1041
- ax .add_line (l2 )
1013
+ self ._squares = ax .scatter (
1014
+ [0.15 ] * len (ys ), ys , marker = 's' , s = text_size ** 2 ,
1015
+ c = "none" , linewidth = 1 , transform = ax .transAxes , edgecolor = "k"
1016
+ )
1017
+ self ._crosses = ax .scatter (
1018
+ [0.15 ] * len (ys ), ys , marker = 'x' , linewidth = 1 , s = text_size ** 2 ,
1019
+ c = ["k" if active else "none" for active in actives ],
1020
+ transform = ax .transAxes
1021
+ )
1042
1022
1043
1023
self .connect_event ('button_press_event' , self ._clicked )
1044
1024
@@ -1047,11 +1027,27 @@ def __init__(self, ax, labels, actives=None):
1047
1027
def _clicked (self , event ):
1048
1028
if self .ignore (event ) or event .button != 1 or event .inaxes != self .ax :
1049
1029
return
1050
- for i , (p , t ) in enumerate (zip (self .rectangles , self .labels )):
1051
- if (t .get_window_extent ().contains (event .x , event .y ) or
1052
- p .get_window_extent ().contains (event .x , event .y )):
1053
- self .set_active (i )
1054
- break
1030
+ pclicked = self .ax .transAxes .inverted ().transform ((event .x , event .y ))
1031
+ distances = {}
1032
+ if hasattr (self , "_rectangles" ):
1033
+ for i , (p , t ) in enumerate (zip (self ._rectangles , self .labels )):
1034
+ x0 , y0 = p .get_xy ()
1035
+ if (t .get_window_extent ().contains (event .x , event .y )
1036
+ or (x0 <= pclicked [0 ] <= x0 + p .get_width ()
1037
+ and y0 <= pclicked [1 ] <= y0 + p .get_height ())):
1038
+ distances [i ] = np .linalg .norm (pclicked - p .get_center ())
1039
+ else :
1040
+ _ , square_inds = self ._squares .contains (event )
1041
+ coords = self ._squares .get_offset_transform ().transform (
1042
+ self ._squares .get_offsets ()
1043
+ )
1044
+ for i , t in enumerate (self .labels ):
1045
+ if (i in square_inds ["ind" ]
1046
+ or t .get_window_extent ().contains (event .x , event .y )):
1047
+ distances [i ] = np .linalg .norm (pclicked - coords [i ])
1048
+ if len (distances ) > 0 :
1049
+ closest = min (distances , key = distances .get )
1050
+ self .set_active (closest )
1055
1051
1056
1052
def set_active (self , index ):
1057
1053
"""
@@ -1072,9 +1068,20 @@ def set_active(self, index):
1072
1068
if index not in range (len (self .labels )):
1073
1069
raise ValueError (f'Invalid CheckButton index: { index } ' )
1074
1070
1075
- l1 , l2 = self .lines [index ]
1076
- l1 .set_visible (not l1 .get_visible ())
1077
- l2 .set_visible (not l2 .get_visible ())
1071
+ cross_facecolors = self ._crosses .get_facecolor ()
1072
+ cross_facecolors [index ] = colors .to_rgba (
1073
+ "black"
1074
+ if colors .same_color (
1075
+ cross_facecolors [index ], colors .to_rgba ("none" )
1076
+ )
1077
+ else "none"
1078
+ )
1079
+ self ._crosses .set_facecolor (cross_facecolors )
1080
+
1081
+ if hasattr (self , "_lines" ):
1082
+ l1 , l2 = self ._lines [index ]
1083
+ l1 .set_visible (not l1 .get_visible ())
1084
+ l2 .set_visible (not l2 .get_visible ())
1078
1085
1079
1086
if self .drawon :
1080
1087
self .ax .figure .canvas .draw ()
@@ -1086,7 +1093,8 @@ def get_status(self):
1086
1093
"""
1087
1094
Return a list of the status (True/False) of all of the check buttons.
1088
1095
"""
1089
- return [l1 .get_visible () for (l1 , l2 ) in self .lines ]
1096
+ return [not colors .same_color (color , colors .to_rgba ("none" ))
1097
+ for color in self ._crosses .get_facecolors ()]
1090
1098
1091
1099
def on_clicked (self , func ):
1092
1100
"""
@@ -1100,6 +1108,57 @@ def disconnect(self, cid):
1100
1108
"""Remove the observer with connection id *cid*."""
1101
1109
self ._observers .disconnect (cid )
1102
1110
1111
+ @_api .deprecated ("3.7" )
1112
+ @property
1113
+ def rectangles (self ):
1114
+ if not hasattr (self , "_rectangles" ):
1115
+ ys = np .linspace (1 , 0 , len (self .labels )+ 2 )[1 :- 1 ]
1116
+ dy = 1. / (len (self .labels ) + 1 )
1117
+ w , h = dy / 2 , dy / 2
1118
+ rectangles = self ._rectangles = [
1119
+ Rectangle (xy = (0.05 , ys [i ] - h / 2 ), width = w , height = h ,
1120
+ edgecolor = "black" ,
1121
+ facecolor = "none" ,
1122
+ transform = self .ax .transAxes
1123
+ )
1124
+ for i , y in enumerate (ys )
1125
+ ]
1126
+ self ._squares .set_visible (False )
1127
+ for rectangle in rectangles :
1128
+ self .ax .add_patch (rectangle )
1129
+ if not hasattr (self , "_lines" ):
1130
+ with _api .suppress_matplotlib_deprecation_warning ():
1131
+ _ = self .lines
1132
+ return self ._rectangles
1133
+
1134
+ @_api .deprecated ("3.7" )
1135
+ @property
1136
+ def lines (self ):
1137
+ if not hasattr (self , "_lines" ):
1138
+ ys = np .linspace (1 , 0 , len (self .labels )+ 2 )[1 :- 1 ]
1139
+ self ._crosses .set_visible (False )
1140
+ dy = 1. / (len (self .labels ) + 1 )
1141
+ w , h = dy / 2 , dy / 2
1142
+ self ._lines = []
1143
+ current_status = self .get_status ()
1144
+ lineparams = {'color' : 'k' , 'linewidth' : 1.25 ,
1145
+ 'transform' : self .ax .transAxes ,
1146
+ 'solid_capstyle' : 'butt' }
1147
+ for i , y in enumerate (ys ):
1148
+ x , y = 0.05 , y - h / 2
1149
+ l1 = Line2D ([x , x + w ], [y + h , y ], ** lineparams )
1150
+ l2 = Line2D ([x , x + w ], [y , y + h ], ** lineparams )
1151
+
1152
+ l1 .set_visible (current_status [i ])
1153
+ l2 .set_visible (current_status [i ])
1154
+ self ._lines .append ((l1 , l2 ))
1155
+ self .ax .add_patch (l1 )
1156
+ self .ax .add_patch (l2 )
1157
+ if not hasattr (self , "_rectangles" ):
1158
+ with _api .suppress_matplotlib_deprecation_warning ():
1159
+ _ = self .rectangles
1160
+ return self ._lines
1161
+
1103
1162
1104
1163
class TextBox (AxesWidget ):
1105
1164
"""
@@ -1457,8 +1516,10 @@ def set_active(self, index):
1457
1516
if index not in range (len (self .labels )):
1458
1517
raise ValueError (f'Invalid RadioButton index: { index } ' )
1459
1518
self .value_selected = self .labels [index ].get_text ()
1460
- self ._buttons .get_facecolor ()[:] = colors .to_rgba ("none" )
1461
- self ._buttons .get_facecolor ()[index ] = colors .to_rgba (self .activecolor )
1519
+ button_facecolors = self ._buttons .get_facecolor ()
1520
+ button_facecolors [:] = colors .to_rgba ("none" )
1521
+ button_facecolors [index ] = colors .to_rgba (self .activecolor )
1522
+ self ._buttons .set_facecolor (button_facecolors )
1462
1523
if hasattr (self , "_circles" ): # Remove once circles is removed.
1463
1524
for i , p in enumerate (self ._circles ):
1464
1525
p .set_facecolor (self .activecolor if i == index else "none" )
0 commit comments