Skip to content

Commit 08f74bb

Browse files
committed
Make PyDict ThreadSafe
1 parent 38cb24d commit 08f74bb

File tree

1 file changed

+35
-43
lines changed

1 file changed

+35
-43
lines changed

vm/src/obj/objdict.rs

Lines changed: 35 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::cell::{Cell, RefCell};
1+
use std::cell::Cell;
22
use std::fmt;
33

44
use super::objiter;
@@ -9,7 +9,7 @@ use crate::exceptions::PyBaseExceptionRef;
99
use crate::function::{KwArgs, OptionalArg, PyFuncArgs};
1010
use crate::pyobject::{
1111
IdProtocol, IntoPyObject, ItemProtocol, PyAttributes, PyClassImpl, PyContext, PyIterable,
12-
PyObjectRef, PyRef, PyResult, PyValue,
12+
PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe,
1313
};
1414
use crate::vm::{ReprGuard, VirtualMachine};
1515

@@ -20,9 +20,10 @@ pub type DictContentType = dictdatatype::Dict;
2020
#[pyclass]
2121
#[derive(Default)]
2222
pub struct PyDict {
23-
entries: RefCell<DictContentType>,
23+
entries: DictContentType,
2424
}
2525
pub type PyDictRef = PyRef<PyDict>;
26+
impl ThreadSafe for PyDict {}
2627

2728
impl fmt::Debug for PyDict {
2829
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@@ -43,7 +44,7 @@ impl PyDictRef {
4344
#[pyslot]
4445
fn tp_new(class: PyClassRef, _args: PyFuncArgs, vm: &VirtualMachine) -> PyResult<PyDictRef> {
4546
PyDict {
46-
entries: RefCell::new(DictContentType::default()),
47+
entries: DictContentType::default(),
4748
}
4849
.into_ref_with_type(vm, class)
4950
}
@@ -59,7 +60,7 @@ impl PyDictRef {
5960
}
6061

6162
fn merge(
62-
dict: &RefCell<DictContentType>,
63+
dict: &DictContentType,
6364
dict_obj: OptionalArg<PyObjectRef>,
6465
kwargs: KwArgs,
6566
vm: &VirtualMachine,
@@ -68,13 +69,13 @@ impl PyDictRef {
6869
let dicted: Result<PyDictRef, _> = dict_obj.clone().downcast();
6970
if let Ok(dict_obj) = dicted {
7071
for (key, value) in dict_obj {
71-
dict.borrow_mut().insert(vm, &key, value)?;
72+
dict.insert(vm, &key, value)?;
7273
}
7374
} else if let Some(keys) = vm.get_method(dict_obj.clone(), "keys") {
7475
let keys = objiter::get_iter(vm, &vm.invoke(&keys?, vec![])?)?;
7576
while let Some(key) = objiter::get_next_object(vm, &keys)? {
7677
let val = dict_obj.get_item(&key, vm)?;
77-
dict.borrow_mut().insert(vm, &key, val)?;
78+
dict.insert(vm, &key, val)?;
7879
}
7980
} else {
8081
let iter = objiter::get_iter(vm, &dict_obj)?;
@@ -92,14 +93,13 @@ impl PyDictRef {
9293
if objiter::get_next_object(vm, &elem_iter)?.is_some() {
9394
return Err(err(vm));
9495
}
95-
dict.borrow_mut().insert(vm, &key, value)?;
96+
dict.insert(vm, &key, value)?;
9697
}
9798
}
9899
}
99100

100-
let mut dict_borrowed = dict.borrow_mut();
101101
for (key, value) in kwargs.into_iter() {
102-
dict_borrowed.insert(vm, &vm.new_str(key), value)?;
102+
dict.insert(vm, &vm.new_str(key), value)?;
103103
}
104104
Ok(())
105105
}
@@ -111,27 +111,26 @@ impl PyDictRef {
111111
value: OptionalArg<PyObjectRef>,
112112
vm: &VirtualMachine,
113113
) -> PyResult<PyDictRef> {
114-
let mut dict = DictContentType::default();
114+
let dict = DictContentType::default();
115115
let value = value.unwrap_or_else(|| vm.ctx.none());
116116
for elem in iterable.iter(vm)? {
117117
let elem = elem?;
118118
dict.insert(vm, &elem, value.clone())?;
119119
}
120-
let entries = RefCell::new(dict);
121-
PyDict { entries }.into_ref_with_type(vm, class)
120+
PyDict { entries: dict }.into_ref_with_type(vm, class)
122121
}
123122

124123
#[pymethod(magic)]
125124
fn bool(self) -> bool {
126-
!self.entries.borrow().is_empty()
125+
!self.entries.is_empty()
127126
}
128127

129128
fn inner_eq(self, other: &PyDict, vm: &VirtualMachine) -> PyResult<bool> {
130-
if other.entries.borrow().len() != self.entries.borrow().len() {
129+
if other.entries.len() != self.entries.len() {
131130
return Ok(false);
132131
}
133132
for (k, v1) in self {
134-
match other.entries.borrow().get(vm, &k)? {
133+
match other.entries.get(vm, &k)? {
135134
Some(v2) => {
136135
if v1.is(&v2) {
137136
continue;
@@ -170,12 +169,12 @@ impl PyDictRef {
170169

171170
#[pymethod(magic)]
172171
fn len(self) -> usize {
173-
self.entries.borrow().len()
172+
self.entries.len()
174173
}
175174

176175
#[pymethod(magic)]
177176
fn sizeof(self) -> usize {
178-
size_of::<Self>() + self.entries.borrow().sizeof()
177+
size_of::<Self>() + self.entries.sizeof()
179178
}
180179

181180
#[pymethod(magic)]
@@ -197,17 +196,17 @@ impl PyDictRef {
197196

198197
#[pymethod(magic)]
199198
fn contains(self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
200-
self.entries.borrow().contains(vm, &key)
199+
self.entries.contains(vm, &key)
201200
}
202201

203202
#[pymethod(magic)]
204203
fn delitem(self, key: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
205-
self.entries.borrow_mut().delete(vm, &key)
204+
self.entries.delete(vm, &key)
206205
}
207206

208207
#[pymethod]
209208
fn clear(self) {
210-
self.entries.borrow_mut().clear()
209+
self.entries.clear()
211210
}
212211

213212
#[pymethod(magic)]
@@ -243,7 +242,7 @@ impl PyDictRef {
243242
value: PyObjectRef,
244243
vm: &VirtualMachine,
245244
) -> PyResult<()> {
246-
self.entries.borrow_mut().insert(vm, key, value)
245+
self.entries.insert(vm, key, value)
247246
}
248247

249248
#[pymethod(magic)]
@@ -262,7 +261,7 @@ impl PyDictRef {
262261
key: K,
263262
vm: &VirtualMachine,
264263
) -> PyResult<Option<PyObjectRef>> {
265-
if let Some(value) = self.entries.borrow().get(vm, key)? {
264+
if let Some(value) = self.entries.get(vm, key)? {
266265
return Ok(Some(value));
267266
}
268267

@@ -281,7 +280,7 @@ impl PyDictRef {
281280
default: OptionalArg<PyObjectRef>,
282281
vm: &VirtualMachine,
283282
) -> PyResult {
284-
match self.entries.borrow().get(vm, &key)? {
283+
match self.entries.get(vm, &key)? {
285284
Some(value) => Ok(value),
286285
None => Ok(default.unwrap_or_else(|| vm.ctx.none())),
287286
}
@@ -294,12 +293,11 @@ impl PyDictRef {
294293
default: OptionalArg<PyObjectRef>,
295294
vm: &VirtualMachine,
296295
) -> PyResult {
297-
let mut entries = self.entries.borrow_mut();
298-
match entries.get(vm, &key)? {
296+
match self.entries.get(vm, &key)? {
299297
Some(value) => Ok(value),
300298
None => {
301299
let set_value = default.unwrap_or_else(|| vm.ctx.none());
302-
entries.insert(vm, &key, set_value.clone())?;
300+
self.entries.insert(vm, &key, set_value.clone())?;
303301
Ok(set_value)
304302
}
305303
}
@@ -329,7 +327,7 @@ impl PyDictRef {
329327
default: OptionalArg<PyObjectRef>,
330328
vm: &VirtualMachine,
331329
) -> PyResult {
332-
match self.entries.borrow_mut().pop(vm, &key)? {
330+
match self.entries.pop(vm, &key)? {
333331
Some(value) => Ok(value),
334332
None => match default {
335333
OptionalArg::Present(default) => Ok(default),
@@ -340,8 +338,7 @@ impl PyDictRef {
340338

341339
#[pymethod]
342340
fn popitem(self, vm: &VirtualMachine) -> PyResult {
343-
let mut entries = self.entries.borrow_mut();
344-
if let Some((key, value)) = entries.pop_front() {
341+
if let Some((key, value)) = self.entries.pop_front() {
345342
Ok(vm.ctx.new_tuple(vec![key, value]))
346343
} else {
347344
let err_msg = vm.new_str("popitem(): dictionary is empty".to_owned());
@@ -360,14 +357,13 @@ impl PyDictRef {
360357
}
361358

362359
pub fn from_attributes(attrs: PyAttributes, vm: &VirtualMachine) -> PyResult<Self> {
363-
let mut dict = DictContentType::default();
360+
let dict = DictContentType::default();
364361

365362
for (key, value) in attrs {
366363
dict.insert(vm, &vm.ctx.new_str(key), value)?;
367364
}
368365

369-
let entries = RefCell::new(dict);
370-
Ok(PyDict { entries }.into_ref(vm))
366+
Ok(PyDict { entries: dict }.into_ref(vm))
371367
}
372368

373369
#[pymethod(magic)]
@@ -377,11 +373,11 @@ impl PyDictRef {
377373

378374
pub fn contains_key<T: IntoPyObject>(&self, key: T, vm: &VirtualMachine) -> bool {
379375
let key = key.into_pyobject(vm).unwrap();
380-
self.entries.borrow().contains(vm, &key).unwrap()
376+
self.entries.contains(vm, &key).unwrap()
381377
}
382378

383379
pub fn size(&self) -> dictdatatype::DictSize {
384-
self.entries.borrow().size()
380+
self.entries.size()
385381
}
386382

387383
/// This function can be used to get an item without raising the
@@ -487,7 +483,7 @@ impl Iterator for DictIter {
487483
type Item = (PyObjectRef, PyObjectRef);
488484

489485
fn next(&mut self) -> Option<Self::Item> {
490-
match self.dict.entries.borrow().next_entry(&mut self.position) {
486+
match self.dict.entries.next_entry(&mut self.position) {
491487
Some((key, value)) => Some((key, value)),
492488
None => None,
493489
}
@@ -563,13 +559,12 @@ macro_rules! dict_iterator {
563559
#[allow(clippy::redundant_closure_call)]
564560
fn next(&self, vm: &VirtualMachine) -> PyResult {
565561
let mut position = self.position.get();
566-
let dict = self.dict.entries.borrow();
567-
if dict.has_changed_size(&self.size) {
562+
if self.dict.entries.has_changed_size(&self.size) {
568563
return Err(
569564
vm.new_runtime_error("dictionary changed size during iteration".to_owned())
570565
);
571566
}
572-
match dict.next_entry(&mut position) {
567+
match self.dict.entries.next_entry(&mut position) {
573568
Some((key, value)) => {
574569
self.position.set(position);
575570
Ok($result_fn(vm, key, value))
@@ -585,10 +580,7 @@ macro_rules! dict_iterator {
585580

586581
#[pymethod(name = "__length_hint__")]
587582
fn length_hint(&self) -> usize {
588-
self.dict
589-
.entries
590-
.borrow()
591-
.len_from_entry_index(self.position.get())
583+
self.dict.entries.len_from_entry_index(self.position.get())
592584
}
593585
}
594586

0 commit comments

Comments
 (0)