Skip to content

Commit 8938125

Browse files
committed
Make DirEntry, ScandirIterator ThreadSafe
1 parent 59ce997 commit 8938125

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

vm/src/stdlib/os.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::cell::{Cell, RefCell};
21
use std::ffi;
32
use std::fs::File;
43
use std::fs::OpenOptions;
@@ -7,10 +6,12 @@ use std::io::{self, ErrorKind, Read, Write};
76
use std::os::unix::fs::OpenOptionsExt;
87
#[cfg(windows)]
98
use std::os::windows::fs::OpenOptionsExt;
9+
use std::sync::RwLock;
1010
use std::time::{Duration, SystemTime};
1111
use std::{env, fs};
1212

1313
use bitflags::bitflags;
14+
use crossbeam_utils::atomic::AtomicCell;
1415
#[cfg(unix)]
1516
use nix::errno::Errno;
1617
#[cfg(all(unix, not(target_os = "redox")))]
@@ -35,8 +36,8 @@ use crate::obj::objstr::{PyString, PyStringRef};
3536
use crate::obj::objtuple::PyTupleRef;
3637
use crate::obj::objtype::PyClassRef;
3738
use crate::pyobject::{
38-
Either, ItemProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
39-
TypeProtocol,
39+
Either, ItemProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe,
40+
TryFromObject, TypeProtocol,
4041
};
4142
use crate::vm::VirtualMachine;
4243

@@ -597,6 +598,8 @@ struct DirEntry {
597598
mode: OutputMode,
598599
}
599600

601+
impl ThreadSafe for DirEntry {}
602+
600603
type DirEntryRef = PyRef<DirEntry>;
601604

602605
impl PyValue for DirEntry {
@@ -683,11 +686,13 @@ impl DirEntryRef {
683686
#[pyclass]
684687
#[derive(Debug)]
685688
struct ScandirIterator {
686-
entries: RefCell<fs::ReadDir>,
687-
exhausted: Cell<bool>,
689+
entries: RwLock<fs::ReadDir>,
690+
exhausted: AtomicCell<bool>,
688691
mode: OutputMode,
689692
}
690693

694+
impl ThreadSafe for ScandirIterator {}
695+
691696
impl PyValue for ScandirIterator {
692697
fn class(vm: &VirtualMachine) -> PyClassRef {
693698
vm.class(MODULE_NAME, "ScandirIter")
@@ -698,11 +703,11 @@ impl PyValue for ScandirIterator {
698703
impl ScandirIterator {
699704
#[pymethod(name = "__next__")]
700705
fn next(&self, vm: &VirtualMachine) -> PyResult {
701-
if self.exhausted.get() {
706+
if self.exhausted.load() {
702707
return Err(objiter::new_stop_iteration(vm));
703708
}
704709

705-
match self.entries.borrow_mut().next() {
710+
match self.entries.write().unwrap().next() {
706711
Some(entry) => match entry {
707712
Ok(entry) => Ok(DirEntry {
708713
entry,
@@ -713,15 +718,15 @@ impl ScandirIterator {
713718
Err(s) => Err(convert_io_error(vm, s)),
714719
},
715720
None => {
716-
self.exhausted.set(true);
721+
self.exhausted.store(true);
717722
Err(objiter::new_stop_iteration(vm))
718723
}
719724
}
720725
}
721726

722727
#[pymethod]
723728
fn close(&self) {
724-
self.exhausted.set(true);
729+
self.exhausted.store(true);
725730
}
726731

727732
#[pymethod(name = "__iter__")]
@@ -748,8 +753,8 @@ fn os_scandir(path: OptionalArg<PyPathLike>, vm: &VirtualMachine) -> PyResult {
748753

749754
match fs::read_dir(path.path) {
750755
Ok(iter) => Ok(ScandirIterator {
751-
entries: RefCell::new(iter),
752-
exhausted: Cell::new(false),
756+
entries: RwLock::new(iter),
757+
exhausted: AtomicCell::new(false),
753758
mode: path.mode,
754759
}
755760
.into_ref(vm)

0 commit comments

Comments
 (0)