1
1
use crate :: common:: { hash:: PyHash , lock:: PyRwLock } ;
2
2
use crate :: {
3
3
builtins:: { PyInt , PyStrRef , PyType , PyTypeRef } ,
4
- function:: { FromArgs , FuncArgs , IntoPyResult , OptionalArg } ,
5
- protocol:: { PyBuffer , PyIterReturn , PyMappingMethods } ,
4
+ function:: { FromArgs , FuncArgs , IntoPyObject , IntoPyResult , OptionalArg } ,
5
+ protocol:: { PyBuffer , PyIterReturn , PyMappingMethods , PySequence , PySequenceMethods } ,
6
6
utils:: Either ,
7
- IdProtocol , PyComparisonValue , PyObjectRef , PyRef , PyResult , PyValue , TypeProtocol ,
8
- VirtualMachine ,
7
+ IdProtocol , PyArithmeticValue , PyComparisonValue , PyObjectRef , PyRef , PyResult , PyValue ,
8
+ TypeProtocol , VirtualMachine ,
9
9
} ;
10
10
use crossbeam_utils:: atomic:: AtomicCell ;
11
11
use num_traits:: ToPrimitive ;
12
+ use std:: borrow:: Cow ;
12
13
use std:: cmp:: Ordering ;
13
14
14
15
// The corresponding field in CPython is `tp_` prefixed.
@@ -22,7 +23,7 @@ pub struct PyTypeSlots {
22
23
23
24
// Method suites for standard classes
24
25
// tp_as_number
25
- // tp_as_sequence
26
+ pub as_sequence : AtomicCell < Option < AsSequenceFunc > > ,
26
27
pub as_mapping : AtomicCell < Option < AsMappingFunc > > ,
27
28
28
29
// More standard operations (here for binary compatibility)
@@ -149,17 +150,20 @@ pub(crate) type DescrSetFunc =
149
150
fn ( PyObjectRef , PyObjectRef , Option < PyObjectRef > , & VirtualMachine ) -> PyResult < ( ) > ;
150
151
pub ( crate ) type NewFunc = fn ( PyTypeRef , FuncArgs , & VirtualMachine ) -> PyResult ;
151
152
pub ( crate ) type DelFunc = fn ( & PyObjectRef , & VirtualMachine ) -> PyResult < ( ) > ;
153
+ pub ( crate ) type AsSequenceFunc =
154
+ fn ( & PyObjectRef , & VirtualMachine ) -> Cow < ' static , PySequenceMethods > ;
155
+
156
+ macro_rules! then_some_closure {
157
+ ( $cond: expr, $closure: expr) => {
158
+ if $cond {
159
+ Some ( $closure)
160
+ } else {
161
+ None
162
+ }
163
+ } ;
164
+ }
152
165
153
166
fn as_mapping_wrapper ( zelf : & PyObjectRef , _vm : & VirtualMachine ) -> PyMappingMethods {
154
- macro_rules! then_some_closure {
155
- ( $cond: expr, $closure: expr) => {
156
- if $cond {
157
- Some ( $closure)
158
- } else {
159
- None
160
- }
161
- } ;
162
- }
163
167
PyMappingMethods {
164
168
length : then_some_closure ! ( zelf. has_class_attr( "__len__" ) , |zelf, vm| {
165
169
vm. call_special_method( zelf, "__len__" , ( ) ) . map( |obj| {
@@ -192,6 +196,90 @@ fn as_mapping_wrapper(zelf: &PyObjectRef, _vm: &VirtualMachine) -> PyMappingMeth
192
196
}
193
197
}
194
198
199
+ fn as_sequence_wrapper (
200
+ zelf : & PyObjectRef ,
201
+ _vm : & VirtualMachine ,
202
+ ) -> Cow < ' static , PySequenceMethods > {
203
+ Cow :: Owned ( PySequenceMethods {
204
+ length : then_some_closure ! ( zelf. has_class_attr( "__len__" ) , |zelf, vm| {
205
+ vm. obj_len_opt( zelf) . unwrap( )
206
+ } ) ,
207
+ concat : then_some_closure ! ( zelf. has_class_attr( "__add__" ) , |zelf, other, vm| {
208
+ if PySequence :: check( zelf, vm) && PySequence :: check( other, vm) {
209
+ let ret = vm. call_special_method( zelf. clone( ) , "__add__" , ( other. clone( ) , ) ) ?;
210
+ if let PyArithmeticValue :: Implemented ( obj) = PyArithmeticValue :: from_object( vm, ret)
211
+ {
212
+ return Ok ( obj) ;
213
+ }
214
+ }
215
+ Err ( vm. new_type_error( format!( "'{}' object can't be concatenated" , zelf) ) )
216
+ } ) ,
217
+ repeat : then_some_closure ! ( zelf. has_class_attr( "__mul__" ) , |zelf, n, vm| {
218
+ if PySequence :: check( zelf, vm) {
219
+ let ret =
220
+ vm. call_special_method( zelf. clone( ) , "__mul__" , ( n. into_pyobject( vm) , ) ) ?;
221
+ if let PyArithmeticValue :: Implemented ( obj) = PyArithmeticValue :: from_object( vm, ret)
222
+ {
223
+ return Ok ( obj) ;
224
+ }
225
+ }
226
+ Err ( vm. new_type_error( format!( "'{}' object can't be repeated" , zelf) ) )
227
+ } ) ,
228
+ inplace_concat : then_some_closure ! (
229
+ zelf. has_class_attr( "__iadd__" ) || zelf. has_class_attr( "__add__" ) ,
230
+ |zelf, other, vm| {
231
+ if PySequence :: check( & zelf, vm) && PySequence :: check( other, vm) {
232
+ if let Ok ( f) = vm. get_special_method( zelf. clone( ) , "__iadd__" ) ? {
233
+ let ret = f. invoke( ( other. clone( ) , ) , vm) ?;
234
+ if let PyArithmeticValue :: Implemented ( obj) =
235
+ PyArithmeticValue :: from_object( vm, ret)
236
+ {
237
+ return Ok ( obj) ;
238
+ }
239
+ }
240
+ if let Ok ( f) = vm. get_special_method( zelf. clone( ) , "__add__" ) ? {
241
+ let ret = f. invoke( ( other. clone( ) , ) , vm) ?;
242
+ if let PyArithmeticValue :: Implemented ( obj) =
243
+ PyArithmeticValue :: from_object( vm, ret)
244
+ {
245
+ return Ok ( obj) ;
246
+ }
247
+ }
248
+ }
249
+ Err ( vm. new_type_error( format!( "'{}' object can't be concatenated" , zelf) ) )
250
+ }
251
+ ) ,
252
+ inplace_repeat : then_some_closure ! (
253
+ zelf. has_class_attr( "__imul__" ) || zelf. has_class_attr( "__mul__" ) ,
254
+ |zelf, n, vm| {
255
+ if PySequence :: check( & zelf, vm) {
256
+ if let Ok ( f) = vm. get_special_method( zelf. clone( ) , "__imul__" ) ? {
257
+ let ret = f. invoke( ( n. into_pyobject( vm) , ) , vm) ?;
258
+ if let PyArithmeticValue :: Implemented ( obj) =
259
+ PyArithmeticValue :: from_object( vm, ret)
260
+ {
261
+ return Ok ( obj) ;
262
+ }
263
+ }
264
+ if let Ok ( f) = vm. get_special_method( zelf. clone( ) , "__mul__" ) ? {
265
+ let ret = f. invoke( ( n. into_pyobject( vm) , ) , vm) ?;
266
+ if let PyArithmeticValue :: Implemented ( obj) =
267
+ PyArithmeticValue :: from_object( vm, ret)
268
+ {
269
+ return Ok ( obj) ;
270
+ }
271
+ }
272
+ }
273
+ Err ( vm. new_type_error( format!( "'{}' object can't be repeated" , zelf) ) )
274
+ }
275
+ ) ,
276
+ item : None ,
277
+ ass_item : None ,
278
+ // TODO: IterSearch
279
+ contains : None ,
280
+ } )
281
+ }
282
+
195
283
fn hash_wrapper ( zelf : & PyObjectRef , vm : & VirtualMachine ) -> PyResult < PyHash > {
196
284
let hash_obj = vm. call_special_method ( zelf. clone ( ) , "__hash__" , ( ) ) ?;
197
285
match hash_obj. payload_if_subclass :: < PyInt > ( vm) {
@@ -291,7 +379,10 @@ impl PyType {
291
379
match name {
292
380
"__len__" | "__getitem__" | "__setitem__" | "__delitem__" => {
293
381
update_slot ! ( as_mapping, as_mapping_wrapper) ;
294
- // TODO: need to update sequence protocol too
382
+ update_slot ! ( as_sequence, as_sequence_wrapper) ;
383
+ }
384
+ "__add__" | "__iadd__" | "__mul__" | "__imul__" => {
385
+ update_slot ! ( as_sequence, as_sequence_wrapper) ;
295
386
}
296
387
"__hash__" => {
297
388
update_slot ! ( hash, hash_wrapper) ;
@@ -804,6 +895,20 @@ pub trait AsMapping: PyValue {
804
895
) -> PyResult < ( ) > ;
805
896
}
806
897
898
+ #[ pyimpl]
899
+ pub trait AsSequence : PyValue {
900
+ #[ inline]
901
+ #[ pyslot]
902
+ fn slot_as_sequence (
903
+ zelf : & PyObjectRef ,
904
+ vm : & VirtualMachine ,
905
+ ) -> Cow < ' static , PySequenceMethods > {
906
+ let zelf = unsafe { zelf. downcast_unchecked_ref :: < Self > ( ) } ;
907
+ Self :: as_sequence ( zelf, vm)
908
+ }
909
+ fn as_sequence ( zelf : & PyRef < Self > , vm : & VirtualMachine ) -> Cow < ' static , PySequenceMethods > ;
910
+ }
911
+
807
912
#[ pyimpl]
808
913
pub trait Iterable : PyValue {
809
914
#[ pyslot]
0 commit comments