1
1
use std:: cell:: Cell ;
2
+ use std:: collections:: hash_map:: DefaultHasher ;
2
3
use std:: hash:: { Hash , Hasher } ;
3
4
use std:: ops:: Deref ;
4
5
5
6
use num_traits:: ToPrimitive ;
6
7
7
- use crate :: function:: { OptionalArg , PyFuncArgs } ;
8
- use crate :: pyobject:: {
9
- PyContext , PyIteratorValue , PyObjectRef , PyRef , PyResult , PyValue , TypeProtocol ,
10
- } ;
8
+ use crate :: function:: OptionalArg ;
9
+ use crate :: pyobject:: { PyContext , PyIteratorValue , PyObjectRef , PyRef , PyResult , PyValue } ;
11
10
use crate :: vm:: VirtualMachine ;
12
11
13
12
use super :: objint;
14
- use super :: objtype:: { self , PyClassRef } ;
13
+ use super :: objtype:: PyClassRef ;
15
14
16
15
#[ derive( Debug ) ]
17
16
pub struct PyBytes {
@@ -57,16 +56,16 @@ pub fn init(context: &PyContext) {
57
56
- an integer";
58
57
59
58
extend_class ! ( context, bytes_type, {
60
- "__eq__" => context. new_rustfunc( bytes_eq) ,
61
- "__lt__" => context. new_rustfunc( bytes_lt) ,
62
- "__le__" => context. new_rustfunc( bytes_le) ,
63
- "__gt__" => context. new_rustfunc( bytes_gt) ,
64
- "__ge__" => context. new_rustfunc( bytes_ge) ,
65
- "__hash__" => context. new_rustfunc( bytes_hash) ,
66
59
"__new__" => context. new_rustfunc( bytes_new) ,
67
- "__repr__" => context. new_rustfunc( bytes_repr) ,
68
- "__len__" => context. new_rustfunc( bytes_len) ,
69
- "__iter__" => context. new_rustfunc( bytes_iter) ,
60
+ "__eq__" => context. new_rustfunc( PyBytesRef :: eq) ,
61
+ "__lt__" => context. new_rustfunc( PyBytesRef :: lt) ,
62
+ "__le__" => context. new_rustfunc( PyBytesRef :: le) ,
63
+ "__gt__" => context. new_rustfunc( PyBytesRef :: gt) ,
64
+ "__ge__" => context. new_rustfunc( PyBytesRef :: ge) ,
65
+ "__hash__" => context. new_rustfunc( PyBytesRef :: hash) ,
66
+ "__repr__" => context. new_rustfunc( PyBytesRef :: repr) ,
67
+ "__len__" => context. new_rustfunc( PyBytesRef :: len) ,
68
+ "__iter__" => context. new_rustfunc( PyBytesRef :: iter) ,
70
69
"__doc__" => context. new_str( bytes_doc. to_string( ) )
71
70
} ) ;
72
71
}
@@ -93,111 +92,71 @@ fn bytes_new(
93
92
PyBytes :: new ( value) . into_ref_with_type ( vm, cls)
94
93
}
95
94
96
- fn bytes_eq ( vm : & VirtualMachine , args : PyFuncArgs ) -> PyResult {
97
- arg_check ! (
98
- vm,
99
- args,
100
- required = [ ( a, Some ( vm. ctx. bytes_type( ) ) ) , ( b, None ) ]
101
- ) ;
102
-
103
- let result = if objtype:: isinstance ( b, & vm. ctx . bytes_type ( ) ) {
104
- get_value ( a) . to_vec ( ) == get_value ( b) . to_vec ( )
105
- } else {
106
- false
107
- } ;
108
- Ok ( vm. ctx . new_bool ( result) )
109
- }
110
-
111
- fn bytes_ge ( vm : & VirtualMachine , args : PyFuncArgs ) -> PyResult {
112
- arg_check ! (
113
- vm,
114
- args,
115
- required = [ ( a, Some ( vm. ctx. bytes_type( ) ) ) , ( b, None ) ]
116
- ) ;
117
-
118
- let result = if objtype:: isinstance ( b, & vm. ctx . bytes_type ( ) ) {
119
- get_value ( a) . to_vec ( ) >= get_value ( b) . to_vec ( )
120
- } else {
121
- return Err ( vm. new_type_error ( format ! ( "Cannot compare {} and {} using '>'" , a, b) ) ) ;
122
- } ;
123
- Ok ( vm. ctx . new_bool ( result) )
124
- }
125
-
126
- fn bytes_gt ( vm : & VirtualMachine , args : PyFuncArgs ) -> PyResult {
127
- arg_check ! (
128
- vm,
129
- args,
130
- required = [ ( a, Some ( vm. ctx. bytes_type( ) ) ) , ( b, None ) ]
131
- ) ;
95
+ impl PyBytesRef {
96
+ fn eq ( self , other : PyObjectRef , vm : & VirtualMachine ) -> PyObjectRef {
97
+ if let Ok ( other) = other. downcast :: < PyBytes > ( ) {
98
+ vm. ctx . new_bool ( self . value == other. value )
99
+ } else {
100
+ vm. ctx . not_implemented ( )
101
+ }
102
+ }
132
103
133
- let result = if objtype :: isinstance ( b , & vm . ctx . bytes_type ( ) ) {
134
- get_value ( a ) . to_vec ( ) > get_value ( b ) . to_vec ( )
135
- } else {
136
- return Err ( vm . new_type_error ( format ! ( "Cannot compare {} and {} using '>='" , a , b ) ) ) ;
137
- } ;
138
- Ok ( vm . ctx . new_bool ( result ) )
139
- }
104
+ fn ge ( self , other : PyObjectRef , vm : & VirtualMachine ) -> PyObjectRef {
105
+ if let Ok ( other ) = other . downcast :: < PyBytes > ( ) {
106
+ vm . ctx . new_bool ( self . value >= other . value )
107
+ } else {
108
+ vm . ctx . not_implemented ( )
109
+ }
110
+ }
140
111
141
- fn bytes_le ( vm : & VirtualMachine , args : PyFuncArgs ) -> PyResult {
142
- arg_check ! (
143
- vm,
144
- args,
145
- required = [ ( a, Some ( vm. ctx. bytes_type( ) ) ) , ( b, None ) ]
146
- ) ;
112
+ fn gt ( self , other : PyObjectRef , vm : & VirtualMachine ) -> PyObjectRef {
113
+ if let Ok ( other) = other. downcast :: < PyBytes > ( ) {
114
+ vm. ctx . new_bool ( self . value > other. value )
115
+ } else {
116
+ vm. ctx . not_implemented ( )
117
+ }
118
+ }
147
119
148
- let result = if objtype :: isinstance ( b , & vm . ctx . bytes_type ( ) ) {
149
- get_value ( a ) . to_vec ( ) <= get_value ( b ) . to_vec ( )
150
- } else {
151
- return Err ( vm . new_type_error ( format ! ( "Cannot compare {} and {} using '<'" , a , b ) ) ) ;
152
- } ;
153
- Ok ( vm . ctx . new_bool ( result ) )
154
- }
120
+ fn le ( self , other : PyObjectRef , vm : & VirtualMachine ) -> PyObjectRef {
121
+ if let Ok ( other ) = other . downcast :: < PyBytes > ( ) {
122
+ vm . ctx . new_bool ( self . value <= other . value )
123
+ } else {
124
+ vm . ctx . not_implemented ( )
125
+ }
126
+ }
155
127
156
- fn bytes_lt ( vm : & VirtualMachine , args : PyFuncArgs ) -> PyResult {
157
- arg_check ! (
158
- vm,
159
- args,
160
- required = [ ( a, Some ( vm. ctx. bytes_type( ) ) ) , ( b, None ) ]
161
- ) ;
128
+ fn lt ( self , other : PyObjectRef , vm : & VirtualMachine ) -> PyObjectRef {
129
+ if let Ok ( other) = other. downcast :: < PyBytes > ( ) {
130
+ vm. ctx . new_bool ( self . value < other. value )
131
+ } else {
132
+ vm. ctx . not_implemented ( )
133
+ }
134
+ }
162
135
163
- let result = if objtype:: isinstance ( b, & vm. ctx . bytes_type ( ) ) {
164
- get_value ( a) . to_vec ( ) < get_value ( b) . to_vec ( )
165
- } else {
166
- return Err ( vm. new_type_error ( format ! ( "Cannot compare {} and {} using '<='" , a, b) ) ) ;
167
- } ;
168
- Ok ( vm. ctx . new_bool ( result) )
169
- }
136
+ fn len ( self , _vm : & VirtualMachine ) -> usize {
137
+ self . value . len ( )
138
+ }
170
139
171
- fn bytes_len ( vm : & VirtualMachine , args : PyFuncArgs ) -> PyResult {
172
- arg_check ! ( vm, args, required = [ ( a, Some ( vm. ctx. bytes_type( ) ) ) ] ) ;
140
+ fn hash ( self , _vm : & VirtualMachine ) -> u64 {
141
+ let mut hasher = DefaultHasher :: new ( ) ;
142
+ self . value . hash ( & mut hasher) ;
143
+ hasher. finish ( )
144
+ }
173
145
174
- let byte_vec = get_value ( a) . to_vec ( ) ;
175
- Ok ( vm. ctx . new_int ( byte_vec. len ( ) ) )
176
- }
146
+ fn repr ( self , _vm : & VirtualMachine ) -> String {
147
+ // TODO: don't just unwrap
148
+ let data = String :: from_utf8 ( self . value . clone ( ) ) . unwrap ( ) ;
149
+ format ! ( "b'{}'" , data)
150
+ }
177
151
178
- fn bytes_hash ( vm : & VirtualMachine , args : PyFuncArgs ) -> PyResult {
179
- arg_check ! ( vm, args, required = [ ( zelf, Some ( vm. ctx. bytes_type( ) ) ) ] ) ;
180
- let data = get_value ( zelf) ;
181
- let mut hasher = std:: collections:: hash_map:: DefaultHasher :: new ( ) ;
182
- data. hash ( & mut hasher) ;
183
- let hash = hasher. finish ( ) ;
184
- Ok ( vm. ctx . new_int ( hash) )
152
+ fn iter ( obj : PyBytesRef , _vm : & VirtualMachine ) -> PyIteratorValue {
153
+ PyIteratorValue {
154
+ position : Cell :: new ( 0 ) ,
155
+ iterated_obj : obj. into_object ( ) ,
156
+ }
157
+ }
185
158
}
186
159
187
160
pub fn get_value < ' a > ( obj : & ' a PyObjectRef ) -> impl Deref < Target = Vec < u8 > > + ' a {
188
161
& obj. payload :: < PyBytes > ( ) . unwrap ( ) . value
189
162
}
190
-
191
- fn bytes_repr ( vm : & VirtualMachine , args : PyFuncArgs ) -> PyResult {
192
- arg_check ! ( vm, args, required = [ ( obj, Some ( vm. ctx. bytes_type( ) ) ) ] ) ;
193
- let value = get_value ( obj) ;
194
- let data = String :: from_utf8 ( value. to_vec ( ) ) . unwrap ( ) ;
195
- Ok ( vm. new_str ( format ! ( "b'{}'" , data) ) )
196
- }
197
-
198
- fn bytes_iter ( obj : PyBytesRef , _vm : & VirtualMachine ) -> PyIteratorValue {
199
- PyIteratorValue {
200
- position : Cell :: new ( 0 ) ,
201
- iterated_obj : obj. into_object ( ) ,
202
- }
203
- }
0 commit comments