diff --git a/tests/snippets/stdlib_os.py b/tests/snippets/stdlib_os.py index 4d64fc7930..0de5b52b61 100644 --- a/tests/snippets/stdlib_os.py +++ b/tests/snippets/stdlib_os.py @@ -56,6 +56,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): FILE_NAME = "test1" +FILE_NAME2 = "test2" +FOLDER = "dir1" CONTENT = b"testing" CONTENT2 = b"rustpython" CONTENT3 = b"BOYA" @@ -73,3 +75,26 @@ def __exit__(self, exc_type, exc_val, exc_tb): assert os.read(fd, len(CONTENT2)) == CONTENT2 assert os.read(fd, len(CONTENT3)) == CONTENT3 os.close(fd) + + fname2 = tmpdir + os.sep + FILE_NAME2 + with open(fname2, "wb"): + pass + folder = tmpdir + os.sep + FOLDER + os.mkdir(folder) + + names = set() + paths = set() + dirs = set() + files = set() + for dir_entry in os.scandir(tmpdir): + names.add(dir_entry.name) + paths.add(dir_entry.path) + if dir_entry.is_dir(): + dirs.add(dir_entry.name) + if dir_entry.is_file(): + files.add(dir_entry.name) + + assert names == set([FILE_NAME, FILE_NAME2, FOLDER]) + assert paths == set([fname, fname2, folder]) + assert dirs == set([FOLDER]) + assert files == set([FILE_NAME, FILE_NAME2]) diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index fae9d6adb7..9717c8f22e 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::fs::File; use std::fs::OpenOptions; use std::io::{ErrorKind, Read, Write}; @@ -10,9 +11,11 @@ use crate::obj::objbytes::PyBytesRef; use crate::obj::objdict::PyDictRef; use crate::obj::objint; use crate::obj::objint::PyIntRef; +use crate::obj::objiter; use crate::obj::objstr; use crate::obj::objstr::PyStringRef; -use crate::pyobject::{ItemProtocol, PyObjectRef, PyResult}; +use crate::obj::objtype::PyClassRef; +use crate::pyobject::{ItemProtocol, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; #[cfg(unix)] @@ -190,6 +193,85 @@ fn _os_environ(vm: &VirtualMachine) -> PyDictRef { environ } +#[derive(Debug)] +struct DirEntry { + entry: fs::DirEntry, +} + +type DirEntryRef = PyRef; + +impl PyValue for DirEntry { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("os", "DirEntry") + } +} + +impl DirEntryRef { + fn name(self, _vm: &VirtualMachine) -> String { + self.entry.file_name().into_string().unwrap() + } + + fn path(self, _vm: &VirtualMachine) -> String { + self.entry.path().to_str().unwrap().to_string() + } + + fn is_dir(self, vm: &VirtualMachine) -> PyResult { + Ok(self + .entry + .file_type() + .map_err(|s| vm.new_os_error(s.to_string()))? + .is_dir()) + } + + fn is_file(self, vm: &VirtualMachine) -> PyResult { + Ok(self + .entry + .file_type() + .map_err(|s| vm.new_os_error(s.to_string()))? + .is_file()) + } +} + +#[derive(Debug)] +pub struct ScandirIterator { + entries: RefCell, +} + +impl PyValue for ScandirIterator { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("os", "ScandirIter") + } +} + +type ScandirIteratorRef = PyRef; + +impl ScandirIteratorRef { + fn next(self, vm: &VirtualMachine) -> PyResult { + match self.entries.borrow_mut().next() { + Some(entry) => match entry { + Ok(entry) => Ok(DirEntry { entry }.into_ref(vm).into_object()), + Err(s) => Err(vm.new_os_error(s.to_string())), + }, + None => Err(objiter::new_stop_iteration(vm)), + } + } + + fn iter(self, _vm: &VirtualMachine) -> Self { + self + } +} + +fn os_scandir(path: PyStringRef, vm: &VirtualMachine) -> PyResult { + match fs::read_dir(&path.value) { + Ok(iter) => Ok(ScandirIterator { + entries: RefCell::new(iter), + } + .into_ref(vm) + .into_object()), + Err(s) => Err(vm.new_os_error(s.to_string())), + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -201,6 +283,18 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let environ = _os_environ(vm); + let scandir_iter = py_class!(ctx, "ScandirIter", ctx.object(), { + "__iter__" => ctx.new_rustfunc(ScandirIteratorRef::iter), + "__next__" => ctx.new_rustfunc(ScandirIteratorRef::next), + }); + + let dir_entry = py_class!(ctx, "DirEntry", ctx.object(), { + "name" => ctx.new_property(DirEntryRef::name), + "path" => ctx.new_property(DirEntryRef::path), + "is_dir" => ctx.new_rustfunc(DirEntryRef::is_dir), + "is_file" => ctx.new_rustfunc(DirEntryRef::is_file), + }); + py_module!(vm, "_os", { "open" => ctx.new_rustfunc(os_open), "close" => ctx.new_rustfunc(os_close), @@ -217,6 +311,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "unsetenv" => ctx.new_rustfunc(os_unsetenv), "environ" => environ, "name" => ctx.new_str(os_name), + "scandir" => ctx.new_rustfunc(os_scandir), + "ScandirIter" => scandir_iter, + "DirEntry" => dir_entry, "O_RDONLY" => ctx.new_int(0), "O_WRONLY" => ctx.new_int(1), "O_RDWR" => ctx.new_int(2),