@@ -2,10 +2,10 @@ use crate::function::{OptionalArg, PyFuncArgs};
2
2
use crate :: obj:: objbytes:: { PyBytes , PyBytesRef } ;
3
3
use crate :: obj:: objstr:: PyStringRef ;
4
4
use crate :: obj:: objtype:: PyClassRef ;
5
- use crate :: pyobject:: { PyClassImpl , PyObjectRef , PyResult , PyValue } ;
5
+ use crate :: pyobject:: { PyClassImpl , PyObjectRef , PyResult , PyValue , ThreadSafe } ;
6
6
use crate :: vm:: VirtualMachine ;
7
- use std:: cell:: RefCell ;
8
7
use std:: fmt;
8
+ use std:: sync:: { RwLock , RwLockReadGuard , RwLockWriteGuard } ;
9
9
10
10
use blake2:: { Blake2b , Blake2s } ;
11
11
use digest:: DynDigest ;
@@ -17,9 +17,11 @@ use sha3::{Sha3_224, Sha3_256, Sha3_384, Sha3_512}; // TODO: , Shake128, Shake25
17
17
#[ pyclass( name = "hasher" ) ]
18
18
struct PyHasher {
19
19
name : String ,
20
- buffer : RefCell < HashWrapper > ,
20
+ buffer : RwLock < HashWrapper > ,
21
21
}
22
22
23
+ impl ThreadSafe for PyHasher { }
24
+
23
25
impl fmt:: Debug for PyHasher {
24
26
fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
25
27
write ! ( f, "hasher {}" , self . name)
@@ -37,10 +39,18 @@ impl PyHasher {
37
39
fn new ( name : & str , d : HashWrapper ) -> Self {
38
40
PyHasher {
39
41
name : name. to_owned ( ) ,
40
- buffer : RefCell :: new ( d) ,
42
+ buffer : RwLock :: new ( d) ,
41
43
}
42
44
}
43
45
46
+ fn borrow_value ( & self ) -> RwLockReadGuard < ' _ , HashWrapper > {
47
+ self . buffer . read ( ) . unwrap ( )
48
+ }
49
+
50
+ fn borrow_value_mut ( & self ) -> RwLockWriteGuard < ' _ , HashWrapper > {
51
+ self . buffer . write ( ) . unwrap ( )
52
+ }
53
+
44
54
#[ pyslot]
45
55
fn tp_new ( _cls : PyClassRef , _args : PyFuncArgs , vm : & VirtualMachine ) -> PyResult {
46
56
Ok ( PyHasher :: new ( "md5" , HashWrapper :: md5 ( ) )
@@ -55,12 +65,12 @@ impl PyHasher {
55
65
56
66
#[ pyproperty( name = "digest_size" ) ]
57
67
fn digest_size ( & self , vm : & VirtualMachine ) -> PyResult {
58
- Ok ( vm. ctx . new_int ( self . buffer . borrow ( ) . digest_size ( ) ) )
68
+ Ok ( vm. ctx . new_int ( self . borrow_value ( ) . digest_size ( ) ) )
59
69
}
60
70
61
71
#[ pymethod( name = "update" ) ]
62
72
fn update ( & self , data : PyBytesRef , vm : & VirtualMachine ) -> PyResult {
63
- self . buffer . borrow_mut ( ) . input ( data. get_value ( ) ) ;
73
+ self . borrow_value_mut ( ) . input ( data. get_value ( ) ) ;
64
74
Ok ( vm. get_none ( ) )
65
75
}
66
76
@@ -77,7 +87,7 @@ impl PyHasher {
77
87
}
78
88
79
89
fn get_digest ( & self ) -> Vec < u8 > {
80
- self . buffer . borrow ( ) . get_digest ( )
90
+ self . borrow_value ( ) . get_digest ( )
81
91
}
82
92
}
83
93
@@ -200,15 +210,18 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
200
210
} )
201
211
}
202
212
213
+ trait ThreadSafeDynDigest : DynDigest + Sync + Send { }
214
+ impl < T > ThreadSafeDynDigest for T where T : DynDigest + Sync + Send { }
215
+
203
216
/// Generic wrapper patching around the hashing libraries.
204
217
struct HashWrapper {
205
- inner : Box < dyn DynDigest > ,
218
+ inner : Box < dyn ThreadSafeDynDigest > ,
206
219
}
207
220
208
221
impl HashWrapper {
209
222
fn new < D : ' static > ( d : D ) -> Self
210
223
where
211
- D : DynDigest + Sized ,
224
+ D : ThreadSafeDynDigest ,
212
225
{
213
226
HashWrapper { inner : Box :: new ( d) }
214
227
}
@@ -279,7 +292,7 @@ impl HashWrapper {
279
292
}
280
293
281
294
fn get_digest ( & self ) -> Vec < u8 > {
282
- let cloned = self . inner . clone ( ) ;
295
+ let cloned = self . inner . box_clone ( ) ;
283
296
cloned. result ( ) . to_vec ( )
284
297
}
285
298
}
0 commit comments