diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000..f848263ced --- /dev/null +++ b/.flake8 @@ -0,0 +1,3 @@ +[flake8] +# black's line length +max-line-length = 88 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9097dc8acf..df88d45d44 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -6,7 +6,7 @@ on: name: CI env: - CARGO_ARGS: --features "ssl jit" + CARGO_ARGS: --features ssl,jit NON_WASM_PACKAGES: > -p rustpython-bytecode -p rustpython-common diff --git a/Cargo.toml b/Cargo.toml index 85fe251bea..390d16e35a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,11 +15,12 @@ members = [ ] [features] -default = ["threading", "pylib"] +default = ["threading", "pylib", "zlib"] flame-it = ["rustpython-vm/flame-it", "flame", "flamescope"] freeze-stdlib = ["rustpython-vm/freeze-stdlib"] jit = ["rustpython-vm/jit"] threading = ["rustpython-vm/threading"] +zlib = ["rustpython-vm/zlib"] ssl = ["rustpython-vm/ssl"] diff --git a/Lib/_dummy_os.py b/Lib/_dummy_os.py new file mode 100644 index 0000000000..5bd5ec0a13 --- /dev/null +++ b/Lib/_dummy_os.py @@ -0,0 +1,66 @@ +""" +A shim of the os module containing only simple path-related utilities +""" + +try: + from os import * +except ImportError: + import abc + + def __getattr__(name): + raise OSError("no os specific module found") + + def _shim(): + import _dummy_os, sys + sys.modules['os'] = _dummy_os + sys.modules['os.path'] = _dummy_os.path + + import posixpath as path + import sys + sys.modules['os.path'] = path + del sys + + sep = path.sep + + + def fspath(path): + """Return the path representation of a path-like object. + + If str or bytes is passed in, it is returned unchanged. Otherwise the + os.PathLike interface is used to get the path representation. If the + path representation is not str or bytes, TypeError is raised. If the + provided path is not str, bytes, or os.PathLike, TypeError is raised. + """ + if isinstance(path, (str, bytes)): + return path + + # Work from the object's type to match method resolution of other magic + # methods. + path_type = type(path) + try: + path_repr = path_type.__fspath__(path) + except AttributeError: + if hasattr(path_type, '__fspath__'): + raise + else: + raise TypeError("expected str, bytes or os.PathLike object, " + "not " + path_type.__name__) + if isinstance(path_repr, (str, bytes)): + return path_repr + else: + raise TypeError("expected {}.__fspath__() to return str or bytes, " + "not {}".format(path_type.__name__, + type(path_repr).__name__)) + + class PathLike(abc.ABC): + + """Abstract base class for implementing the file system path protocol.""" + + @abc.abstractmethod + def __fspath__(self): + """Return the file system path representation of the object.""" + raise NotImplementedError + + @classmethod + def __subclasshook__(cls, subclass): + return hasattr(subclass, '__fspath__') diff --git a/Lib/fnmatch.py b/Lib/fnmatch.py index b98e641329..af0dbcd092 100644 --- a/Lib/fnmatch.py +++ b/Lib/fnmatch.py @@ -9,7 +9,10 @@ The function translate(PATTERN) returns a regular expression corresponding to PATTERN. (It does not compile it.) """ -import os +try: + import os +except ImportError: + import _dummy_os as os import posixpath import re import functools diff --git a/Lib/genericpath.py b/Lib/genericpath.py index 5dd703d736..e790d74681 100644 --- a/Lib/genericpath.py +++ b/Lib/genericpath.py @@ -3,7 +3,10 @@ Do not use directly. The OS specific modules import the appropriate functions from this module themselves. """ -import os +try: + import os +except ImportError: + import _dummy_os as os import stat __all__ = ['commonprefix', 'exists', 'getatime', 'getctime', 'getmtime', diff --git a/Lib/io.py b/Lib/io.py index ee701d2c20..3f497b3116 100644 --- a/Lib/io.py +++ b/Lib/io.py @@ -52,12 +52,17 @@ import abc from _io import (DEFAULT_BUFFER_SIZE, BlockingIOError, UnsupportedOperation, - open, open_code, FileIO, BytesIO, StringIO, BufferedReader, + open, open_code, BytesIO, StringIO, BufferedReader, BufferedWriter, BufferedRWPair, BufferedRandom, # XXX RUSTPYTHON TODO: IncrementalNewlineDecoder # IncrementalNewlineDecoder, TextIOWrapper) TextIOWrapper) +try: + from _io import FileIO +except ImportError: + pass + OpenWrapper = _io.open # for compatibility with _pyio # Pretend this exception was created here. diff --git a/Lib/linecache.py b/Lib/linecache.py index 47885bfd54..bb09280d63 100644 --- a/Lib/linecache.py +++ b/Lib/linecache.py @@ -7,7 +7,10 @@ import functools import sys -import os +try: + import os +except ImportError: + import _dummy_os as os import tokenize __all__ = ["getline", "clearcache", "checkcache"] diff --git a/Lib/posixpath.py b/Lib/posixpath.py index ecb4e5a8f7..8bd078c2b9 100644 --- a/Lib/posixpath.py +++ b/Lib/posixpath.py @@ -22,7 +22,10 @@ altsep = None devnull = '/dev/null' -import os +try: + import os +except ImportError: + import _dummy_os as os import sys import stat import genericpath diff --git a/Lib/zipfile.py b/Lib/zipfile.py index 5dc6516cc4..cd309da01b 100644 --- a/Lib/zipfile.py +++ b/Lib/zipfile.py @@ -8,13 +8,22 @@ import importlib.util import io import itertools -import os +try: + import os +except ImportError: + import _dummy_os as os import posixpath -import shutil +try: + import shutil +except ImportError: + pass import stat import struct import sys -import threading +try: + import threading +except ImportError: + import _dummy_thread as threading import time import contextlib from collections import OrderedDict diff --git a/common/src/lock/cell_lock.rs b/common/src/lock/cell_lock.rs index 32c4116cb4..edf968b3ff 100644 --- a/common/src/lock/cell_lock.rs +++ b/common/src/lock/cell_lock.rs @@ -121,7 +121,7 @@ unsafe impl RawRwLock for RawCellRwLock { unsafe impl RawRwLockDowngrade for RawCellRwLock { unsafe fn downgrade(&self) { - // no-op -- we're always exclusively locked for this thread + self.state.set(ONE_READER); } } @@ -170,7 +170,7 @@ unsafe impl RawRwLockUpgradeDowngrade for RawCellRwLock { #[inline] unsafe fn downgrade_to_upgradable(&self) { - // no-op -- we're always exclusively locked for this thread + self.state.set(ONE_READER); } } diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 29e06964f1..66472be5d0 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -10,6 +10,8 @@ include = ["src/**/*.rs", "Cargo.toml", "build.rs", "Lib/**/*.py"] [features] default = ["compile-parse", "threading"] +# TODO: use resolver = "2" instead of features +zlib = ["libz-sys", "flate2/zlib"] vm-tracing-logging = [] flame-it = ["flame", "flamer"] freeze-stdlib = ["rustpython-pylib"] @@ -84,6 +86,10 @@ atty = "0.2" static_assertions = "1.1" half = "1.6" memchr = "2" +crc32fast = "1.2.0" +adler32 = "1.0.3" +flate2 = "1.0.20" +libz-sys = { version = "1.0", optional = true } # RustPython crates implementing functionality based on CPython mt19937 = "2.0" @@ -114,8 +120,6 @@ exitcode = "1.1.2" uname = "0.1.1" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -crc32fast = "1.2.0" -adler32 = "1.0.3" gethostname = "0.2.0" socket2 = "0.3.19" rustyline = "6.0" @@ -129,8 +133,6 @@ num_cpus = "1" [target.'cfg(not(any(target_arch = "wasm32", target_os = "redox")))'.dependencies] dns-lookup = "1.0" -flate2 = { version = "1.0.20", features = ["zlib"], default-features = false } -libz-sys = "1.0" [target.'cfg(windows)'.dependencies] winreg = "0.7" diff --git a/vm/src/lib.rs b/vm/src/lib.rs index 9764a8ce53..c511b73b1e 100644 --- a/vm/src/lib.rs +++ b/vm/src/lib.rs @@ -54,7 +54,7 @@ pub mod frame; mod frozen; pub mod function; pub mod import; -mod iterator; +pub mod iterator; mod py_io; pub mod py_serde; pub mod pyobject; diff --git a/vm/src/pyobjectrc.rs b/vm/src/pyobjectrc.rs index bc973a2844..e7222eaaed 100644 --- a/vm/src/pyobjectrc.rs +++ b/vm/src/pyobjectrc.rs @@ -307,7 +307,7 @@ impl Drop for PyObjectRef { // CPython-compatible drop implementation let zelf = self.clone(); if let Some(del_slot) = self.class().mro_find_map(|cls| cls.slots.del.load()) { - crate::vm::thread::with_vm(&zelf, |vm| { + let ret = crate::vm::thread::with_vm(&zelf, |vm| { if let Err(e) = del_slot(&zelf, vm) { // exception in del will be ignored but printed print!("Exception ignored in: ",); @@ -327,6 +327,9 @@ impl Drop for PyObjectRef { } } }); + if ret.is_none() { + warn!("couldn't run __del__ method for object") + } } // __del__ might have resurrected the object at this point, but that's fine, diff --git a/vm/src/stdlib/mod.rs b/vm/src/stdlib/mod.rs index 37e6de4e1f..8aa50a3833 100644 --- a/vm/src/stdlib/mod.rs +++ b/vm/src/stdlib/mod.rs @@ -42,6 +42,7 @@ mod tokenize; mod unicodedata; mod warnings; mod weakref; +mod zlib; #[cfg(any(not(target_arch = "wasm32"), target_os = "wasi"))] #[macro_use] @@ -67,8 +68,6 @@ mod ssl; mod winapi; #[cfg(windows)] mod winreg; -#[cfg(not(any(target_arch = "wasm32", target_os = "redox")))] -mod zlib; pub type StdlibInitFunc = Box PyObjectRef)>; @@ -103,6 +102,7 @@ pub fn get_module_inits() -> HashMap "_imp".to_owned() => Box::new(imp::make_module), "unicodedata".to_owned() => Box::new(unicodedata::make_module), "_warnings".to_owned() => Box::new(warnings::make_module), + "zlib".to_owned() => Box::new(zlib::make_module), crate::sysmodule::sysconfigdata_name() => Box::new(sysconfigdata::make_module), }; @@ -144,8 +144,6 @@ pub fn get_module_inits() -> HashMap modules.insert("_ssl".to_owned(), Box::new(ssl::make_module)); #[cfg(feature = "threading")] modules.insert("_thread".to_owned(), Box::new(thread::make_module)); - #[cfg(not(target_os = "redox"))] - modules.insert("zlib".to_owned(), Box::new(zlib::make_module)); modules.insert( "faulthandler".to_owned(), Box::new(faulthandler::make_module), diff --git a/vm/src/stdlib/zlib.rs b/vm/src/stdlib/zlib.rs index 031a932f69..ffa1b9a63b 100644 --- a/vm/src/stdlib/zlib.rs +++ b/vm/src/stdlib/zlib.rs @@ -20,14 +20,35 @@ mod decl { write::ZlibEncoder, Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status, }; - use libz_sys as libz; use std::io::Write; + #[cfg(not(feature = "zlib"))] + mod constants { + pub const Z_NO_COMPRESSION: i32 = 0; + pub const Z_BEST_COMPRESSION: i32 = 9; + pub const Z_BEST_SPEED: i32 = 1; + pub const Z_DEFAULT_COMPRESSION: i32 = -1; + pub const Z_NO_FLUSH: i32 = 0; + pub const Z_PARTIAL_FLUSH: i32 = 1; + pub const Z_SYNC_FLUSH: i32 = 2; + pub const Z_FULL_FLUSH: i32 = 3; + // not sure what the value here means, but it's the only compression method zlibmodule + // supports, so it doesn't really matter + pub const Z_DEFLATED: i32 = 8; + } + #[cfg(feature = "zlib")] + use libz_sys as constants; + + #[pyattr] + use constants::{ + Z_BEST_COMPRESSION, Z_BEST_SPEED, Z_DEFAULT_COMPRESSION, Z_DEFLATED as DEFLATED, + Z_FULL_FLUSH, Z_NO_COMPRESSION, Z_NO_FLUSH, Z_PARTIAL_FLUSH, Z_SYNC_FLUSH, + }; + + #[cfg(feature = "zlib")] #[pyattr] - use libz::{ - Z_BEST_COMPRESSION, Z_BEST_SPEED, Z_BLOCK, Z_DEFAULT_COMPRESSION, Z_DEFAULT_STRATEGY, - Z_DEFLATED as DEFLATED, Z_FILTERED, Z_FINISH, Z_FIXED, Z_FULL_FLUSH, Z_HUFFMAN_ONLY, - Z_NO_COMPRESSION, Z_NO_FLUSH, Z_PARTIAL_FLUSH, Z_RLE, Z_SYNC_FLUSH, Z_TREES, + use libz_sys::{ + Z_BLOCK, Z_DEFAULT_STRATEGY, Z_FILTERED, Z_FINISH, Z_FIXED, Z_HUFFMAN_ONLY, Z_RLE, Z_TREES, }; // copied from zlibmodule.c (commit 530f506ac91338) @@ -69,18 +90,21 @@ mod decl { }) } + fn compression_from_int(level: Option) -> Option { + match level.unwrap_or(Z_DEFAULT_COMPRESSION) { + Z_DEFAULT_COMPRESSION => Some(Compression::default()), + valid_level @ Z_NO_COMPRESSION..=Z_BEST_COMPRESSION => { + Some(Compression::new(valid_level as u32)) + } + _ => None, + } + } + /// Returns a bytes object containing compressed data. #[pyfunction] fn compress(data: PyBytesLike, level: OptionalArg, vm: &VirtualMachine) -> PyResult { - let level = level.unwrap_or(libz::Z_DEFAULT_COMPRESSION); - - let compression = match level { - valid_level @ libz::Z_NO_COMPRESSION..=libz::Z_BEST_COMPRESSION => { - Compression::new(valid_level as u32) - } - libz::Z_DEFAULT_COMPRESSION => Compression::default(), - _ => return Err(new_zlib_error("Bad compression level", vm)), - }; + let compression = compression_from_int(level.into_option()) + .ok_or_else(|| new_zlib_error("Bad compression level", vm))?; let mut encoder = ZlibEncoder::new(Vec::new(), compression); data.with_ref(|input_bytes| encoder.write_all(input_bytes).unwrap()); @@ -89,39 +113,88 @@ mod decl { Ok(vm.ctx.new_bytes(encoded_bytes)) } - fn header_from_wbits( - wbits: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult<(Option, u8)> { + enum InitOptions { + Standard { + header: bool, + // [De]Compress::new_with_window_bits is only enabled for zlib; miniz_oxide doesn't + // support wbits (yet?) + #[cfg(feature = "zlib")] + wbits: u8, + }, + #[cfg(feature = "zlib")] + Gzip { wbits: u8 }, + } + + impl InitOptions { + fn decompress(self) -> Decompress { + match self { + #[cfg(not(feature = "zlib"))] + Self::Standard { header } => Decompress::new(header), + #[cfg(feature = "zlib")] + Self::Standard { header, wbits } => Decompress::new_with_window_bits(header, wbits), + #[cfg(feature = "zlib")] + Self::Gzip { wbits } => Decompress::new_gzip(wbits), + } + } + fn compress(self, level: Compression) -> Compress { + match self { + #[cfg(not(feature = "zlib"))] + Self::Standard { header } => Compress::new(level, header), + #[cfg(feature = "zlib")] + Self::Standard { header, wbits } => { + Compress::new_with_window_bits(level, header, wbits) + } + #[cfg(feature = "zlib")] + Self::Gzip { wbits } => Compress::new_gzip(level, wbits), + } + } + } + + fn header_from_wbits(wbits: OptionalArg, vm: &VirtualMachine) -> PyResult { let wbits = wbits.unwrap_or(MAX_WBITS as i8); let header = wbits > 0; let wbits = wbits.abs() as u8; match wbits { - 9..=15 => Ok((Some(header), wbits)), - 25..=31 => Ok((None, wbits - 16)), + 9..=15 => Ok(InitOptions::Standard { + header, + #[cfg(feature = "zlib")] + wbits, + }), + #[cfg(feature = "zlib")] + 25..=31 => Ok(InitOptions::Gzip { wbits: wbits - 16 }), _ => Err(vm.new_value_error("Invalid initialization option".to_owned())), } } fn _decompress( - data: &[u8], + mut data: &[u8], d: &mut Decompress, bufsize: usize, max_length: Option, + is_flush: bool, vm: &VirtualMachine, ) -> PyResult<(Vec, bool)> { if data.is_empty() { return Ok((Vec::new(), true)); } - let orig_in = d.total_in(); let mut buf = Vec::new(); - for mut chunk in data.chunks(CHUNKSIZE) { + loop { + let final_chunk = data.len() <= CHUNKSIZE; + let chunk = if final_chunk { + data + } else { + &data[..CHUNKSIZE] + }; // if this is the final chunk, finish it - let flush = if d.total_in() - orig_in == (data.len() - chunk.len()) as u64 { - FlushDecompress::Finish + let flush = if is_flush { + if final_chunk { + FlushDecompress::Finish + } else { + FlushDecompress::None + } } else { - FlushDecompress::None + FlushDecompress::Sync }; loop { let additional = if let Some(max_length) = max_length { @@ -129,46 +202,31 @@ mod decl { } else { bufsize }; + if additional == 0 { + return Ok((buf, false)); + } buf.reserve_exact(additional); let prev_in = d.total_in(); let status = d .decompress_vec(chunk, &mut buf, flush) .map_err(|_| new_zlib_error("invalid input data", vm))?; - match status { + let consumed = d.total_in() - prev_in; + data = &data[consumed as usize..]; + let stream_end = status == Status::StreamEnd; + if stream_end || data.is_empty() { // we've reached the end of the stream, we're done - Status::StreamEnd => { - buf.shrink_to_fit(); - return Ok((buf, true)); - } - // we have hit the maximum length that we can decompress, so stop - _ if max_length.map_or(false, |max_length| buf.len() == max_length) => { - return Ok((buf, false)); - } - _ => { - chunk = &chunk[(d.total_in() - prev_in) as usize..]; - - if !chunk.is_empty() { - // there is more input to process - continue; - } else if flush == FlushDecompress::Finish { - if buf.len() == buf.capacity() { - // we've run out of space, loop again and allocate more room - continue; - } else { - // we need more input to continue - buf.shrink_to_fit(); - return Ok((buf, false)); - } - } else { - // progress onto next chunk - break; - } - } + buf.shrink_to_fit(); + return Ok((buf, stream_end)); + } else if !chunk.is_empty() && consumed == 0 { + // we're gonna need a bigger buffer + continue; + } else { + // next chunk + break; } } } - unreachable!("Didn't reach end of stream or capacity limit") } /// Returns a bytes object containing the uncompressed data. @@ -180,14 +238,11 @@ mod decl { vm: &VirtualMachine, ) -> PyResult> { data.with_ref(|data| { - let (header, wbits) = header_from_wbits(wbits, vm)?; let bufsize = bufsize.unwrap_or(DEF_BUF_SIZE); - let mut d = match header { - Some(header) => Decompress::new_with_window_bits(header, wbits), - None => Decompress::new_gzip(wbits), - }; - _decompress(data, &mut d, bufsize, None, vm).and_then(|(buf, stream_end)| { + let mut d = header_from_wbits(wbits, vm)?.decompress(); + + _decompress(data, &mut d, bufsize, None, false, vm).and_then(|(buf, stream_end)| { if stream_end { Ok(buf) } else { @@ -198,12 +253,10 @@ mod decl { } #[pyfunction] - fn decompressobj(args: DecopmressobjArgs, vm: &VirtualMachine) -> PyResult { - let (header, wbits) = header_from_wbits(args.wbits, vm)?; - let mut decompress = match header { - Some(header) => Decompress::new_with_window_bits(header, wbits), - None => Decompress::new_gzip(wbits), - }; + fn decompressobj(args: DecompressobjArgs, vm: &VirtualMachine) -> PyResult { + #[allow(unused_mut)] + let mut decompress = header_from_wbits(args.wbits, vm)?.decompress(); + #[cfg(feature = "zlib")] if let OptionalArg::Present(dict) = args.zdict { dict.with_ref(|d| decompress.set_dictionary(d).unwrap()); } @@ -278,23 +331,25 @@ mod decl { let mut d = self.decompress.lock(); let orig_in = d.total_in(); - let (ret, stream_end) = match _decompress(data, &mut d, DEF_BUF_SIZE, max_length, vm) { - Ok((buf, true)) => { - self.eof.store(true); - (Ok(buf), true) - } - Ok((buf, false)) => (Ok(buf), false), - Err(err) => (Err(err), false), - }; + let (ret, stream_end) = + match _decompress(data, &mut d, DEF_BUF_SIZE, max_length, false, vm) { + Ok((buf, true)) => { + self.eof.store(true); + (Ok(buf), true) + } + Ok((buf, false)) => (Ok(buf), false), + Err(err) => (Err(err), false), + }; self.save_unused_input(&mut d, data, stream_end, orig_in, vm); - let leftover = if !stream_end { - &data[(d.total_in() - orig_in) as usize..] - } else { + let leftover = if stream_end { b"" + } else { + &data[(d.total_in() - orig_in) as usize..] }; + let mut unconsumed_tail = self.unconsumed_tail.lock(); - if !leftover.is_empty() || unconsumed_tail.len() > 0 { + if !leftover.is_empty() || !unconsumed_tail.is_empty() { *unconsumed_tail = PyBytes::from(leftover.to_owned()).into_ref(vm); } @@ -321,7 +376,7 @@ mod decl { let orig_in = d.total_in(); - let (ret, stream_end) = match _decompress(&data, &mut d, length, None, vm) { + let (ret, stream_end) = match _decompress(&data, &mut d, length, None, true, vm) { Ok((buf, stream_end)) => (Ok(buf), stream_end), Err(err) => (Err(err), false), }; @@ -346,9 +401,10 @@ mod decl { } #[derive(FromArgs)] - struct DecopmressobjArgs { + struct DecompressobjArgs { #[pyarg(any, optional)] wbits: OptionalArg, + #[cfg(feature = "zlib")] #[pyarg(any, optional)] zdict: OptionalArg, } @@ -365,19 +421,9 @@ mod decl { _zdict: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - let (header, wbits) = header_from_wbits(wbits, vm)?; - let level = level.unwrap_or(-1); - - let level = match level { - -1 => libz::Z_DEFAULT_COMPRESSION as u32, - n @ 0..=9 => n as u32, - _ => return Err(vm.new_value_error("invalid initialization option".to_owned())), - }; - let level = Compression::new(level); - let compress = match header { - Some(header) => Compress::new_with_window_bits(level, header, wbits), - None => Compress::new_gzip(level, wbits), - }; + let level = compression_from_int(level.into_option()) + .ok_or_else(|| vm.new_value_error("invalid initialization option".to_owned()))?; + let compress = header_from_wbits(wbits, vm)?.compress(level); Ok(PyCompress { inner: PyMutex::new(CompressInner { compress, @@ -428,7 +474,7 @@ mod decl { // } } - const CHUNKSIZE: usize = libc::c_uint::MAX as usize; + const CHUNKSIZE: usize = u32::MAX as usize; impl CompressInner { fn save_unconsumed_input(&mut self, data: &[u8], orig_in: u64) { diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 7ab6524219..ffafe0c8c4 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -86,7 +86,7 @@ pub(crate) mod thread { }) } - pub fn with_vm(obj: &PyObjectRef, f: F) -> R + pub fn with_vm(obj: &PyObjectRef, f: F) -> Option where F: Fn(&VirtualMachine) -> R, { @@ -101,14 +101,12 @@ pub(crate) mod thread { debug_assert!(vm_owns_obj(x)); x } - Err(mut others) => others - .find(|x| vm_owns_obj(*x)) - .unwrap_or_else(|| panic!("can't get a vm for {:?}; none on stack", obj)), + Err(mut others) => others.find(|x| vm_owns_obj(*x))?, }; // SAFETY: all references in VM_STACK should be valid, and should not be changed or moved // at least until this function returns and the stack unwinds to an enter_vm() call let vm = unsafe { intp.as_ref() }; - f(vm) + Some(f(vm)) }) } } diff --git a/wasm/demo/snippets/asyncbrowser.py b/wasm/demo/snippets/asyncbrowser.py index 5c4d1fcb6e..5cd2f7b0a0 100644 --- a/wasm/demo/snippets/asyncbrowser.py +++ b/wasm/demo/snippets/asyncbrowser.py @@ -1,62 +1,10 @@ import browser -import functools - - -# just setting up the framework, skip to the bottom to see the real code - -ready = object() -go = object() - - -def run(coro, *, payload=None, error=False): - send = coro.throw if error else coro.send - try: - cmd = send(payload) - except StopIteration: - return - if cmd is ready: - coro.send( - ( - lambda *args: run(coro, payload=args), - lambda *args: run(coro, payload=args, error=True), - ) - ) - elif cmd is go: - pass - else: - raise RuntimeError(f"expected cmd to be ready or go, got {cmd}") - - -class JSFuture: - def __init__(self, prom): - self._prom = prom - - def __await__(self): - done, error = yield ready - self._prom.then(done, error) - res, = yield go - return res - - -def wrap_prom_func(func): - @functools.wraps(func) - async def wrapper(*args, **kwargs): - return await JSFuture(func(*args, **kwargs)) - - return wrapper - - -fetch = wrap_prom_func(browser.fetch) - -################### -# Real code start # -################### - +import asyncweb async def main(delay): url = f"https://httpbin.org/delay/{delay}" print(f"fetching {url}...") - res = await fetch( + res = await browser.fetch( url, response_format="json", headers={"X-Header-Thing": "rustpython is neat!"} ) print(f"got res from {res['url']}:") @@ -64,5 +12,5 @@ async def main(delay): for delay in range(3): - run(main(delay)) + asyncweb.run(main(delay)) print() diff --git a/wasm/demo/snippets/import_pypi.py b/wasm/demo/snippets/import_pypi.py new file mode 100644 index 0000000000..cbe837cd05 --- /dev/null +++ b/wasm/demo/snippets/import_pypi.py @@ -0,0 +1,20 @@ +import asyncweb +import whlimport + +whlimport.setup() + +# make sys.modules['os'] a dumb version of the os module, which has posixpath +# available as os.path as well as a few other utilities, but will raise an +# OSError for anything that actually requires an OS +import _dummy_os +_dummy_os._shim() + +@asyncweb.main +async def main(): + await whlimport.load_package("pygments") + import pygments + import pygments.lexers + import pygments.formatters.html + lexer = pygments.lexers.get_lexer_by_name("python") + fmter = pygments.formatters.html.HtmlFormatter(noclasses=True, style="default") + print(pygments.highlight("print('hi, mom!')", lexer, fmter)) diff --git a/wasm/demo/src/browser_module.rs b/wasm/demo/src/browser_module.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/wasm/lib/Lib/_microdistlib.py b/wasm/lib/Lib/_microdistlib.py new file mode 100644 index 0000000000..b70106730d --- /dev/null +++ b/wasm/lib/Lib/_microdistlib.py @@ -0,0 +1,131 @@ +# taken from https://bitbucket.org/pypa/distlib/src/master/distlib/util.py +# flake8: noqa +# fmt: off + +from types import SimpleNamespace as Container +import re + +IDENTIFIER = re.compile(r'^([\w\.-]+)\s*') +VERSION_IDENTIFIER = re.compile(r'^([\w\.*+-]+)\s*') +COMPARE_OP = re.compile(r'^(<=?|>=?|={2,3}|[~!]=)\s*') +NON_SPACE = re.compile(r'(\S+)\s*') + +def parse_requirement(req): + """ + Parse a requirement passed in as a string. Return a Container + whose attributes contain the various parts of the requirement. + """ + remaining = req.strip() + if not remaining or remaining.startswith('#'): + return None + m = IDENTIFIER.match(remaining) + if not m: + raise SyntaxError('name expected: %s' % remaining) + distname = m.groups()[0] + remaining = remaining[m.end():] + extras = mark_expr = versions = uri = None + if remaining and remaining[0] == '[': + i = remaining.find(']', 1) + if i < 0: + raise SyntaxError('unterminated extra: %s' % remaining) + s = remaining[1:i] + remaining = remaining[i + 1:].lstrip() + extras = [] + while s: + m = IDENTIFIER.match(s) + if not m: + raise SyntaxError('malformed extra: %s' % s) + extras.append(m.groups()[0]) + s = s[m.end():] + if not s: + break + if s[0] != ',': + raise SyntaxError('comma expected in extras: %s' % s) + s = s[1:].lstrip() + if not extras: + extras = None + if remaining: + if remaining[0] == '@': + # it's a URI + remaining = remaining[1:].lstrip() + m = NON_SPACE.match(remaining) + if not m: + raise SyntaxError('invalid URI: %s' % remaining) + uri = m.groups()[0] + t = urlparse(uri) + # there are issues with Python and URL parsing, so this test + # is a bit crude. See bpo-20271, bpo-23505. Python doesn't + # always parse invalid URLs correctly - it should raise + # exceptions for malformed URLs + if not (t.scheme and t.netloc): + raise SyntaxError('Invalid URL: %s' % uri) + remaining = remaining[m.end():].lstrip() + else: + + def get_versions(ver_remaining): + """ + Return a list of operator, version tuples if any are + specified, else None. + """ + m = COMPARE_OP.match(ver_remaining) + versions = None + if m: + versions = [] + while True: + op = m.groups()[0] + ver_remaining = ver_remaining[m.end():] + m = VERSION_IDENTIFIER.match(ver_remaining) + if not m: + raise SyntaxError('invalid version: %s' % ver_remaining) + v = m.groups()[0] + versions.append((op, v)) + ver_remaining = ver_remaining[m.end():] + if not ver_remaining or ver_remaining[0] != ',': + break + ver_remaining = ver_remaining[1:].lstrip() + m = COMPARE_OP.match(ver_remaining) + if not m: + raise SyntaxError('invalid constraint: %s' % ver_remaining) + if not versions: + versions = None + return versions, ver_remaining + + if remaining[0] != '(': + versions, remaining = get_versions(remaining) + else: + i = remaining.find(')', 1) + if i < 0: + raise SyntaxError('unterminated parenthesis: %s' % remaining) + s = remaining[1:i] + remaining = remaining[i + 1:].lstrip() + # As a special diversion from PEP 508, allow a version number + # a.b.c in parentheses as a synonym for ~= a.b.c (because this + # is allowed in earlier PEPs) + if COMPARE_OP.match(s): + versions, _ = get_versions(s) + else: + m = VERSION_IDENTIFIER.match(s) + if not m: + raise SyntaxError('invalid constraint: %s' % s) + v = m.groups()[0] + s = s[m.end():].lstrip() + if s: + raise SyntaxError('invalid constraint: %s' % s) + versions = [('~=', v)] + + if remaining: + if remaining[0] != ';': + raise SyntaxError('invalid requirement: %s' % remaining) + remaining = remaining[1:].lstrip() + + mark_expr, remaining = parse_marker(remaining) + + if remaining and remaining[0] != '#': + raise SyntaxError('unexpected trailing data: %s' % remaining) + + if not versions: + rs = distname + else: + rs = '%s %s' % (distname, ', '.join(['%s %s' % con for con in versions])) + return Container(name=distname, extras=extras, constraints=versions, + marker=mark_expr, url=uri, requirement=rs) diff --git a/wasm/lib/Lib/asyncweb.py b/wasm/lib/Lib/asyncweb.py new file mode 100644 index 0000000000..40bd843499 --- /dev/null +++ b/wasm/lib/Lib/asyncweb.py @@ -0,0 +1,210 @@ +from _js import Promise +from collections.abc import Coroutine + +try: + import browser +except ImportError: + browser = None + + +def is_promise(prom): + return callable(getattr(prom, "then", None)) + + +def run(coro): + """ + Run a coroutine. The coroutine should yield promise objects with a + ``.then(on_success, on_error)`` method. + """ + _Runner(coro) + + +def spawn(coro): + """ + Run a coroutine. Like run(), but returns a promise that resolves with + the result of the coroutine. + """ + return _coro_promise(coro) + + +class _Runner: + def __init__(self, coro): + self._send = coro.send + self._throw = coro.throw + # start the coro + self.success(None) + + def _run(self, send, arg): + try: + ret = send(arg) + except StopIteration: + return + ret.then(self.success, self.error) + + def success(self, res): + self._run(self._send, res) + + def error(self, err): + self._run(self._throw, err) + + +def main(async_func): + """ + A decorator to mark a function as main. This calls run() on the + result of the function, and logs an error that occurs. + """ + run(_main_wrapper(async_func())) + return async_func + + +async def _main_wrapper(coro): + try: + await coro + except: # noqa: E722 + import traceback + import sys + + # TODO: sys.stderr on wasm + traceback.print_exc(file=sys.stdout) + + +def _resolve(prom): + if is_promise(prom): + return prom + elif isinstance(prom, Coroutine): + return _coro_promise(prom) + else: + return Promise.resolve(prom) + + +class CallbackPromise: + def __init__(self): + self.done = 0 + self.__successes = [] + self.__errors = [] + + def then(self, success=None, error=None): + if success and not callable(success): + raise TypeError("success callback must be callable") + if error and not callable(error): + raise TypeError("error callback must be callable") + + if not self.done: + if success: + self.__successes.append(success) + if error: + self.__errors.append(error) + return + + cb = success if self.done == 1 else error + if cb: + return _call_resolve(cb, self.__result) + else: + return self + + def __await__(self): + yield self + + def resolve(self, value): + if self.done: + return + self.__result = value + self.done = 1 + for f in self.__successes: + f(value) + del self.__successes, self.__errors + + def reject(self, err): + if self.done: + return + self.__result = err + self.done = -1 + for f in self.__errors: + f(err) + del self.__successes, self.__errors + + +def _coro_promise(coro): + prom = CallbackPromise() + + async def run_coro(): + try: + res = await coro + except BaseException as e: + prom.reject(e) + else: + prom.resolve(res) + + run(run_coro()) + + return prom + + +def _call_resolve(f, arg): + try: + ret = f(arg) + except BaseException as e: + return Promise.reject(e) + else: + return _resolve(ret) + + +# basically an implementation of Promise.all +def wait_all(proms): + cbs = CallbackPromise() + + if not isinstance(proms, (list, tuple)): + proms = tuple(proms) + num_completed = 0 + num_proms = len(proms) + + if num_proms == 0: + cbs.resolve(()) + return cbs + + results = [None] * num_proms + + # needs to be a separate function for creating a closure in a loop + def register_promise(i, prom): + prom_completed = False + + def promise_done(success, res): + nonlocal prom_completed, results, num_completed + if prom_completed or cbs.done: + return + prom_completed = True + if success: + results[i] = res + num_completed += 1 + if num_completed == num_proms: + result = tuple(results) + del results + cbs.resolve(result) + else: + del results + cbs.reject(res) + + _resolve(prom).then( + lambda res: promise_done(True, res), + lambda err: promise_done(False, err), + ) + + for i, prom in enumerate(proms): + register_promise(i, prom) + + return cbs + + +if browser: + _settimeout = browser.window.get_prop("setTimeout") + + def timeout(ms): + prom = CallbackPromise() + + @browser.jsclosure_once + def cb(this): + print("AAA") + prom.resolve(None) + + _settimeout.call(cb.detach(), browser.jsfloat(ms)) + return prom diff --git a/wasm/lib/Lib/browser.py b/wasm/lib/Lib/browser.py new file mode 100644 index 0000000000..515fe2e673 --- /dev/null +++ b/wasm/lib/Lib/browser.py @@ -0,0 +1,76 @@ +from _browser import ( + fetch, + request_animation_frame, + cancel_animation_frame, + Document, + Element, + load_module, +) + +from _js import JSValue, Promise +from _window import window + +__all__ = [ + "jsstr", + "jsclosure", + "jsclosure_once", + "jsfloat", + "NULL", + "UNDEFINED", + "alert", + "confirm", + "prompt", + "fetch", + "request_animation_frame", + "cancel_animation_frame", + "Document", + "Element", + "load_module", + "JSValue", + "Promise", +] + + +jsstr = window.new_from_str +jsclosure = window.new_closure +jsclosure_once = window.new_closure_once +_jsfloat = window.new_from_float + +UNDEFINED = window.undefined() +NULL = window.null() + + +def jsfloat(n): + return _jsfloat(float(n)) + + +_alert = window.get_prop("alert") + + +def alert(msg): + if type(msg) != str: + raise TypeError("msg must be a string") + _alert.call(jsstr(msg)) + + +_confirm = window.get_prop("confirm") + + +def confirm(msg): + if type(msg) != str: + raise TypeError("msg must be a string") + return _confirm.call(jsstr(msg)).as_bool() + + +_prompt = window.get_prop("prompt") + + +def prompt(msg, default_val=None): + if type(msg) != str: + raise TypeError("msg must be a string") + if default_val is not None and type(default_val) != str: + raise TypeError("default_val must be a string") + + return _prompt.call( + jsstr(msg), jsstr(default_val) if default_val else UNDEFINED + ).as_str() diff --git a/wasm/lib/Lib/whlimport.py b/wasm/lib/Lib/whlimport.py new file mode 100644 index 0000000000..16ed11edf0 --- /dev/null +++ b/wasm/lib/Lib/whlimport.py @@ -0,0 +1,168 @@ +import browser +import zipfile +import asyncweb +import io +import re +import posixpath +from urllib.parse import urlparse +import _frozen_importlib as _bootstrap +import _microdistlib + +_IS_SETUP = False + + +def setup(*, log=print): + global _IS_SETUP, LOG_FUNC + + if not _IS_SETUP: + import sys + + sys.meta_path.insert(0, ZipFinder) + _IS_SETUP = True + + if log: + + LOG_FUNC = log + else: + + def LOG_FUNC(log): + pass + + +async def load_package(*args): + await asyncweb.wait_all(_load_package(pkg) for pkg in args) + + +_loaded_packages = {} + +LOG_FUNC = print + +_http_url = re.compile("^http[s]?://") + + +async def _load_package(pkg): + if isinstance(pkg, str) and _http_url.match(pkg): + urlobj = urlparse(pkg) + fname = posixpath.basename(urlobj.path) + name, url, size, deps = fname, pkg, None, [] + else: + # TODO: load dependencies as well + name, fname, url, size, deps = await _load_info_pypi(pkg) + if name in _loaded_packages: + return + deps = asyncweb.spawn(asyncweb.wait_all(_load_package for dep in deps)) + size_str = format_size(size) if size is not None else "unknown size" + LOG_FUNC(f"Downloading {fname} ({size_str})...") + zip_data = io.BytesIO(await browser.fetch(url, response_format="array_buffer")) + size = len(zip_data.getbuffer()) + LOG_FUNC(f"{fname} done!") + _loaded_packages[name] = zipfile.ZipFile(zip_data) + await deps + + +async def _load_info_pypi(pkg): + pkg = _microdistlib.parse_requirement(pkg) + # TODO: use VersionMatcher from distlib + api_url = ( + f"https://pypi.org/pypi/{pkg.name}/json" + if not pkg.constraints + else f"https://pypi.org/pypi/{pkg.name}/{pkg.constraints[0][1]}/json" + ) + info = await browser.fetch(api_url, response_format="json") + name = info["info"]["name"] + ver = info["info"]["version"] + ver_downloads = info["releases"][ver] + try: + dl = next(dl for dl in ver_downloads if dl["packagetype"] == "bdist_wheel") + except StopIteration: + raise ValueError(f"no wheel available for package {name!r} {ver}") + return ( + name, + dl["filename"], + dl["url"], + dl["size"], + info["info"]["requires_dist"] or [], + ) + + +def format_size(bytes): + # type: (float) -> str + if bytes > 1000 * 1000: + return "{:.1f} MB".format(bytes / 1000.0 / 1000) + elif bytes > 10 * 1000: + return "{} kB".format(int(bytes / 1000)) + elif bytes > 1000: + return "{:.1f} kB".format(bytes / 1000.0) + else: + return "{} bytes".format(int(bytes)) + + +class ZipFinder: + _packages = _loaded_packages + + @classmethod + def find_spec(cls, fullname, path=None, target=None): + path = fullname.replace(".", "/") + for zname, z in cls._packages.items(): + mi, fullpath = _get_module_info(z, path) + if mi is not None: + return _bootstrap.spec_from_loader( + fullname, cls, origin=f"zip:{zname}/{fullpath}", is_package=mi + ) + return None + + @classmethod + def create_module(cls, spec): + return None + + @classmethod + def get_source(cls, fullname): + spec = cls.find_spec(fullname) + if spec: + return cls._get_source(spec) + else: + raise ImportError("cannot find source for module", name=fullname) + + @classmethod + def _get_source(cls, spec): + origin = spec.origin and remove_prefix(spec.origin, "zip:") + if not origin: + raise ImportError(f"{spec.name!r} is not a zip module") + + zipname, slash, path = origin.partition("/") + return cls._packages[zipname].read(path).decode() + + @classmethod + def exec_module(cls, module): + spec = module.__spec__ + source = cls._get_source(spec) + code = _bootstrap._call_with_frames_removed( + compile, source, spec.origin, "exec", dont_inherit=True + ) + _bootstrap._call_with_frames_removed(exec, code, module.__dict__) + + +def remove_prefix(s, prefix): + if s.startswith(prefix): + return s[len(prefix) :] # noqa: E203 + else: + return None + + +_zip_searchorder = ( + ("/__init__.pyc", True, True), + ("/__init__.py", False, True), + (".pyc", True, False), + (".py", False, False), +) + + +def _get_module_info(zf, path): + for suffix, isbytecode, ispackage in _zip_searchorder: + fullpath = path + suffix + try: + zf.getinfo(fullpath) + except KeyError: + continue + return ispackage, fullpath + return None, None diff --git a/wasm/lib/src/browser.py b/wasm/lib/src/browser.py deleted file mode 100644 index 9f6769ff70..0000000000 --- a/wasm/lib/src/browser.py +++ /dev/null @@ -1,38 +0,0 @@ -from _browser import * - -from _js import JSValue, Promise -from _window import window - - -jsstr = window.new_from_str -jsclosure = window.new_closure - - -_alert = window.get_prop("alert") - - -def alert(msg): - if type(msg) != str: - raise TypeError("msg must be a string") - _alert.call(jsstr(msg)) - - -_confirm = window.get_prop("confirm") - - -def confirm(msg): - if type(msg) != str: - raise TypeError("msg must be a string") - return _confirm.call(jsstr(msg)).as_bool() - - -_prompt = window.get_prop("prompt") - - -def prompt(msg, default_val=None): - if type(msg) != str: - raise TypeError("msg must be a string") - if default_val is not None and type(default_val) != str: - raise TypeError("default_val must be a string") - - return _prompt.call(*(jsstr(arg) for arg in [msg, default_val] if arg)).as_str() diff --git a/wasm/lib/src/browser_module.rs b/wasm/lib/src/browser_module.rs index 595ea8a3b9..64346ab933 100644 --- a/wasm/lib/src/browser_module.rs +++ b/wasm/lib/src/browser_module.rs @@ -276,5 +276,5 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { pub fn setup_browser_module(vm: &mut VirtualMachine) { vm.add_native_module("_browser".to_owned(), Box::new(make_module)); - vm.add_frozen(py_freeze!(file = "src/browser.py", module_name = "browser")); + vm.add_frozen(py_freeze!(dir = "Lib")); } diff --git a/wasm/lib/src/convert.rs b/wasm/lib/src/convert.rs index 7ce8b3ead4..b07838fb98 100644 --- a/wasm/lib/src/convert.rs +++ b/wasm/lib/src/convert.rs @@ -137,7 +137,7 @@ pub fn py_to_js(vm: &VirtualMachine, py_obj: PyObjectRef) -> JsValue { // the browser module might not be injected if vm.try_class("_js", "Promise").is_ok() { if let Some(py_prom) = py_obj.payload::() { - return py_prom.value().into(); + return py_prom.as_js(vm).into(); } } diff --git a/wasm/lib/src/js_module.rs b/wasm/lib/src/js_module.rs index ecc16524db..629f796619 100644 --- a/wasm/lib/src/js_module.rs +++ b/wasm/lib/src/js_module.rs @@ -8,11 +8,12 @@ use wasm_bindgen_futures::{future_to_promise, JsFuture}; use rustpython_vm::builtins::{PyFloatRef, PyStrRef, PyTypeRef}; use rustpython_vm::exceptions::PyBaseExceptionRef; -use rustpython_vm::function::{Args, OptionalArg}; +use rustpython_vm::function::{Args, OptionalArg, OptionalOption}; use rustpython_vm::pyobject::{ BorrowValue, IntoPyObject, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, StaticType, TryFromObject, }; +use rustpython_vm::slots::PyIter; use rustpython_vm::types::create_simple_type; use rustpython_vm::VirtualMachine; @@ -121,7 +122,12 @@ impl PyJsValue { #[pymethod] fn new_closure(&self, obj: PyObjectRef, vm: &VirtualMachine) -> JsClosure { - JsClosure::new(obj, vm) + JsClosure::new(obj, false, vm) + } + + #[pymethod] + fn new_closure_once(&self, obj: PyObjectRef, vm: &VirtualMachine) -> JsClosure { + JsClosure::new(obj, true, vm) } #[pymethod] @@ -272,7 +278,7 @@ struct NewObjectOptions { prototype: Option, } -type ClosureType = Closure) -> Result>; +type ClosureType = Closure) -> Result>; #[pyclass(module = "_js", name = "JSClosure")] struct JsClosure { @@ -295,7 +301,7 @@ impl PyValue for JsClosure { #[pyimpl] impl JsClosure { - fn new(obj: PyObjectRef, vm: &VirtualMachine) -> Self { + fn new(obj: PyObjectRef, once: bool, vm: &VirtualMachine) -> Self { let wasm_vm = WASMVirtualMachine { id: vm.wasm_id.clone().unwrap(), }; @@ -320,7 +326,11 @@ impl JsClosure { convert::pyresult_to_jsresult(vm, res) }) }; - let closure = Closure::wrap(Box::new(f) as _); + let closure: ClosureType = if once { + Closure::wrap(Box::new(f)) + } else { + Closure::once(Box::new(f)) + }; let wrapped = PyJsValue::new(wrap_closure(closure.as_ref())).into_ref(vm); JsClosure { closure: Some((closure, wrapped)).into(), @@ -369,13 +379,21 @@ impl JsClosure { } } -#[pyclass(module = "browser", name = "Promise")] -#[derive(Debug)] +#[pyclass(module = "_js", name = "Promise")] +#[derive(Debug, Clone)] pub struct PyPromise { - value: Promise, + value: PromiseKind, } pub type PyPromiseRef = PyRef; +#[derive(Debug, Clone)] +enum PromiseKind { + Js(Promise), + PyProm { then: PyObjectRef }, + PyResolved(PyObjectRef), + PyRejected(PyBaseExceptionRef), +} + impl PyValue for PyPromise { fn class(_vm: &VirtualMachine) -> &PyTypeRef { Self::static_type() @@ -385,7 +403,9 @@ impl PyValue for PyPromise { #[pyimpl] impl PyPromise { pub fn new(value: Promise) -> PyPromise { - PyPromise { value } + PyPromise { + value: PromiseKind::Js(value), + } } pub fn from_future(future: F) -> PyPromise where @@ -393,73 +413,198 @@ impl PyPromise { { PyPromise::new(future_to_promise(future)) } - pub fn value(&self) -> Promise { - self.value.clone() + pub fn as_js(&self, vm: &VirtualMachine) -> Promise { + match &self.value { + PromiseKind::Js(prom) => prom.clone(), + PromiseKind::PyProm { then } => Promise::new(&mut |js_resolve, js_reject| { + let resolve = move |res: PyObjectRef, vm: &VirtualMachine| { + let _ = js_resolve.call1(&JsValue::UNDEFINED, &convert::py_to_js(vm, res)); + }; + let reject = move |err: PyBaseExceptionRef, vm: &VirtualMachine| { + let _ = + js_reject.call1(&JsValue::UNDEFINED, &convert::py_err_to_js_err(vm, &err)); + }; + let _ = vm.invoke( + then, + ( + vm.ctx.new_function("resolve", resolve), + vm.ctx.new_function("reject", reject), + ), + ); + }), + PromiseKind::PyResolved(obj) => Promise::resolve(&convert::py_to_js(vm, obj.clone())), + PromiseKind::PyRejected(err) => Promise::reject(&convert::py_err_to_js_err(vm, err)), + } + } + + fn cast(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let then = vm.get_attribute_opt(obj.clone(), "then")?; + let value = if let Some(then) = then.filter(|obj| vm.is_callable(obj)) { + PromiseKind::PyProm { then } + } else { + PromiseKind::PyResolved(obj) + }; + Ok(Self { value }) + } + + fn cast_result(res: PyResult, vm: &VirtualMachine) -> PyResult { + match res { + Ok(res) => Self::cast(res, vm), + Err(e) => Ok(Self { + value: PromiseKind::PyRejected(e), + }), + } + } + + #[pyclassmethod] + fn resolve(cls: PyTypeRef, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult> { + Self::cast(obj, vm)?.into_ref_with_type(vm, cls) + } + + #[pyclassmethod] + fn reject( + cls: PyTypeRef, + err: PyBaseExceptionRef, + vm: &VirtualMachine, + ) -> PyResult> { + Self { + value: PromiseKind::PyRejected(err), + } + .into_ref_with_type(vm, cls) } #[pymethod] fn then( &self, - on_fulfill: PyCallable, - on_reject: OptionalArg, + on_fulfill: OptionalOption, + on_reject: OptionalOption, vm: &VirtualMachine, - ) -> PyPromiseRef { - let weak_vm = weak_vm(vm); - let prom = JsFuture::from(self.value.clone()); - - let ret_future = async move { - let stored_vm = &weak_vm - .upgrade() - .expect("that the vm is valid when the promise resolves"); - let res = prom.await; - match res { - Ok(val) => stored_vm.interp.enter(move |vm| { - let args = if val.is_null() { - vec![] - } else { - vec![convert::js_to_py(vm, val)] - }; - let res = vm.invoke(&on_fulfill.into_object(), args); - convert::pyresult_to_jsresult(vm, res) - }), - Err(err) => { - if let OptionalArg::Present(on_reject) = on_reject { - stored_vm.interp.enter(move |vm| { - let err = convert::js_to_py(vm, err); - let res = vm.invoke(&on_reject.into_object(), (err,)); - convert::pyresult_to_jsresult(vm, res) - }) - } else { - Err(err) + ) -> PyResult { + let (on_fulfill, on_reject) = (on_fulfill.flatten(), on_reject.flatten()); + if on_fulfill.is_none() && on_reject.is_none() { + return Ok(self.clone()); + } + match &self.value { + PromiseKind::Js(prom) => { + let weak_vm = weak_vm(vm); + let prom = JsFuture::from(prom.clone()); + + let ret_future = async move { + let stored_vm = &weak_vm + .upgrade() + .expect("that the vm is valid when the promise resolves"); + let res = prom.await; + match res { + Ok(val) => match on_fulfill { + Some(on_fulfill) => stored_vm.interp.enter(move |vm| { + let val = convert::js_to_py(vm, val); + let res = on_fulfill.invoke((val,), vm); + convert::pyresult_to_jsresult(vm, res) + }), + None => Ok(val), + }, + Err(err) => match on_reject { + Some(on_reject) => stored_vm.interp.enter(move |vm| { + let err = new_js_error(vm, err); + let res = on_reject.invoke((err,), vm); + convert::pyresult_to_jsresult(vm, res) + }), + None => Err(err), + }, } - } + }; + + Ok(PyPromise::from_future(ret_future)) } - }; + PromiseKind::PyProm { then } => { + Self::cast_result(vm.invoke(then, (on_fulfill, on_reject)), vm) + } + PromiseKind::PyResolved(res) => match on_fulfill { + Some(resolve) => Self::cast_result(resolve.invoke((res.clone(),), vm), vm), + None => Ok(self.clone()), + }, + PromiseKind::PyRejected(err) => match on_reject { + Some(reject) => Self::cast_result(reject.invoke((err.clone(),), vm), vm), + None => Ok(self.clone()), + }, + } + } - PyPromise::from_future(ret_future).into_ref(vm) + #[pymethod] + fn catch( + &self, + on_reject: OptionalOption, + vm: &VirtualMachine, + ) -> PyResult { + self.then(OptionalArg::Present(None), on_reject, vm) + } + + #[pymethod(name = "__await__")] + fn r#await(zelf: PyRef) -> AwaitPromise { + AwaitPromise { + obj: Some(zelf.into_object()).into(), + } + } +} + +#[pyclass(module = "_js", name = "AwaitPromise")] +struct AwaitPromise { + obj: cell::Cell>, +} + +impl fmt::Debug for AwaitPromise { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AwaitPromise").finish() + } +} + +impl PyValue for AwaitPromise { + fn class(_vm: &VirtualMachine) -> &PyTypeRef { + Self::static_type() } +} +#[pyimpl(with(PyIter))] +impl AwaitPromise { #[pymethod] - fn catch(&self, on_reject: PyCallable, vm: &VirtualMachine) -> PyPromiseRef { - let weak_vm = weak_vm(vm); - let prom = JsFuture::from(self.value.clone()); - - let ret_future = async move { - let err = match prom.await { - Ok(x) => return Ok(x), - Err(e) => e, - }; - let stored_vm = weak_vm - .upgrade() - .expect("that the vm is valid when the promise resolves"); - stored_vm.interp.enter(move |vm| { - let err = convert::js_to_py(vm, err); - let res = vm.invoke(&on_reject.into_object(), (err,)); - convert::pyresult_to_jsresult(vm, res) - }) - }; + fn send(&self, val: Option, vm: &VirtualMachine) -> PyResult { + match self.obj.take() { + Some(prom) => { + if val.is_some() { + Err(vm + .new_type_error("can't send non-None value to an awaitpromise".to_owned())) + } else { + Ok(prom) + } + } + None => Err(rustpython_vm::iterator::stop_iter_with_value( + vm.unwrap_or_none(val), + vm, + )), + } + } + + #[pymethod] + fn throw( + &self, + exc_type: PyObjectRef, + exc_val: OptionalArg, + exc_tb: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let err = rustpython_vm::exceptions::normalize( + exc_type, + exc_val.unwrap_or_none(vm), + exc_tb.unwrap_or_none(vm), + vm, + )?; + Err(err) + } +} - PyPromise::from_future(ret_future).into_ref(vm) +impl PyIter for AwaitPromise { + fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { + zelf.send(None, vm) } } @@ -478,6 +623,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "value" => ctx.new_readonly_getset("value", |exc: PyBaseExceptionRef| exc.get_arg(0)), }); + AwaitPromise::make_class(ctx); + py_module!(vm, "_js", { "JSError" => js_error, "JSValue" => PyJsValue::make_class(ctx),