diff --git a/Cargo.lock b/Cargo.lock index 8d6cc9ae76..67bf7fcdfe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1796,6 +1796,7 @@ dependencies = [ name = "rustpython-common" version = "0.0.0" dependencies = [ + "ascii", "cfg-if 1.0.0", "hexf-parse", "lexical-core", diff --git a/common/Cargo.toml b/common/Cargo.toml index de991b8a06..e766b867b2 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -22,3 +22,4 @@ rand = "0.8" volatile = "0.3" radium = "0.6" libc = "0.2.101" +ascii = "1.0" diff --git a/common/src/encodings.rs b/common/src/encodings.rs index 83cbcace6c..7a6cd16069 100644 --- a/common/src/encodings.rs +++ b/common/src/encodings.rs @@ -4,13 +4,20 @@ pub type EncodeErrorResult = Result<(EncodeReplace, usize), E>; pub type DecodeErrorResult = Result<(S, Option, usize), E>; +pub trait StrBuffer: AsRef { + fn is_ascii(&self) -> bool { + self.as_ref().is_ascii() + } +} + pub trait ErrorHandler { type Error; - type StrBuf: AsRef; + type StrBuf: StrBuffer; type BytesBuf: AsRef<[u8]>; fn handle_encode_error( &self, - byte_range: Range, + data: &str, + char_range: Range, reason: &str, ) -> EncodeErrorResult; fn handle_decode_error( @@ -20,12 +27,95 @@ pub trait ErrorHandler { reason: &str, ) -> DecodeErrorResult; fn error_oob_restart(&self, i: usize) -> Self::Error; + fn error_encoding(&self, data: &str, char_range: Range, reason: &str) -> Self::Error; } pub enum EncodeReplace { Str(S), Bytes(B), } +struct DecodeError<'a> { + valid_prefix: &'a str, + rest: &'a [u8], + err_len: Option, +} +/// # Safety +/// `v[..valid_up_to]` must be valid utf8 +unsafe fn make_decode_err(v: &[u8], valid_up_to: usize, err_len: Option) -> DecodeError<'_> { + let valid_prefix = core::str::from_utf8_unchecked(v.get_unchecked(..valid_up_to)); + let rest = v.get_unchecked(valid_up_to..); + DecodeError { + valid_prefix, + rest, + err_len, + } +} + +enum HandleResult<'a> { + Done, + Error { + err_len: Option, + reason: &'a str, + }, +} +fn decode_utf8_compatible( + data: &[u8], + errors: &E, + decode: DecodeF, + handle_error: ErrF, +) -> Result<(String, usize), E::Error> +where + DecodeF: Fn(&[u8]) -> Result<&str, DecodeError<'_>>, + ErrF: Fn(&[u8], Option) -> HandleResult<'_>, +{ + if data.is_empty() { + return Ok((String::new(), 0)); + } + // we need to coerce the lifetime to that of the function body rather than the + // anonymous input lifetime, so that we can assign it data borrowed from data_from_err + let mut data = &*data; + let mut data_from_err: E::BytesBuf; + let mut out = String::with_capacity(data.len()); + let mut remaining_index = 0; + let mut remaining_data = data; + loop { + match decode(remaining_data) { + Ok(decoded) => { + out.push_str(decoded); + remaining_index += decoded.len(); + break; + } + Err(e) => { + out.push_str(e.valid_prefix); + match handle_error(e.rest, e.err_len) { + HandleResult::Done => { + remaining_index += e.valid_prefix.len(); + break; + } + HandleResult::Error { err_len, reason } => { + let err_idx = remaining_index + e.valid_prefix.len(); + let err_range = + err_idx..err_len.map_or_else(|| data.len(), |len| err_idx + len); + let (replace, new_data, restart) = + errors.handle_decode_error(data, err_range, reason)?; + out.push_str(replace.as_ref()); + if let Some(new_data) = new_data { + data_from_err = new_data; + data = data_from_err.as_ref(); + } + remaining_data = data + .get(restart..) + .ok_or_else(|| errors.error_oob_restart(restart))?; + remaining_index = restart; + continue; + } + } + } + } + } + Ok((out, remaining_index)) +} + pub mod utf8 { use super::*; @@ -41,75 +131,118 @@ pub mod utf8 { errors: &E, final_decode: bool, ) -> Result<(String, usize), E::Error> { - if data.is_empty() { - return Ok((String::new(), 0)); - } - // we need to coerce the lifetime to that of the function body rather than the - // anonymous input lifetime, so that we can assign it data borrowed from data_from_err - let mut data = &*data; - let mut data_from_err: E::BytesBuf; - let mut out = String::with_capacity(data.len()); - let mut remaining_index = 0; - let mut remaining_data = data; - macro_rules! handle_error { - ($range:expr, $reason:expr) => {{ - let (replace, new_data, restart) = - errors.handle_decode_error(data, $range, $reason)?; - out.push_str(replace.as_ref()); - if let Some(new_data) = new_data { - data_from_err = new_data; - data = data_from_err.as_ref(); + decode_utf8_compatible( + data, + errors, + |v| { + core::str::from_utf8(v).map_err(|e| { + // SAFETY: as specified in valid_up_to's documentation, input[..e.valid_up_to()] + // is valid utf8 + unsafe { make_decode_err(v, e.valid_up_to(), e.error_len()) } + }) + }, + |rest, err_len| { + let first_err = rest[0]; + if matches!(first_err, 0x80..=0xc1 | 0xf5..=0xff) { + HandleResult::Error { + err_len: Some(1), + reason: "invalid start byte", + } + } else if err_len.is_none() { + // error_len() == None means unexpected eof + if final_decode { + HandleResult::Error { + err_len, + reason: "unexpected end of data", + } + } else { + HandleResult::Done + } + } else if !final_decode && matches!(rest, [0xed, 0xa0..=0xbf]) { + // truncated surrogate + HandleResult::Done + } else { + HandleResult::Error { + err_len, + reason: "invalid continuation byte", + } } - remaining_data = data - .get(restart..) - .ok_or_else(|| errors.error_oob_restart(restart))?; - remaining_index = restart; - continue; - }}; - } + }, + ) + } +} + +pub mod ascii { + use super::*; + use ::ascii::AsciiStr; + + pub const ENCODING_NAME: &str = "ascii"; + + const ERR_REASON: &str = "ordinal not in range(128)"; + + #[inline] + pub fn encode(s: &str, errors: &E) -> Result, E::Error> { + let full_data = s; + let mut data = s; + let mut char_data_index = 0; + let mut out = Vec::::new(); loop { - match core::str::from_utf8(remaining_data) { - Ok(decoded) => { - out.push_str(decoded); - remaining_index += decoded.len(); + match data + .char_indices() + .enumerate() + .find(|(_, (_, c))| !c.is_ascii()) + { + None => { + out.extend_from_slice(data.as_bytes()); break; } - Err(e) => { - let (valid_prefix, rest, first_err) = unsafe { - let index = e.valid_up_to(); - // SAFETY: as specified in valid_up_to's documentation, from_utf8(&input[..index]) will return Ok(_) - let valid = - std::str::from_utf8_unchecked(remaining_data.get_unchecked(..index)); - let rest = remaining_data.get_unchecked(index..); - // SAFETY: if index didn't have something at it, this wouldn't be an error - let first_err = *remaining_data.get_unchecked(index); - (valid, rest, first_err) - }; - out.push_str(valid_prefix); - let err_idx = remaining_index + e.valid_up_to(); - remaining_data = rest; - remaining_index += valid_prefix.len(); - if (0x80..0xc2).contains(&first_err) || (0xf5..=0xff).contains(&first_err) { - handle_error!(err_idx..err_idx + 1, "invalid start byte"); - } - let err_len = match e.error_len() { - Some(l) => l, - // error_len() == None means unexpected eof - None => { - if !final_decode { - break; + Some((char_i, (byte_i, _))) => { + out.extend_from_slice(&data.as_bytes()[..byte_i]); + let char_start = char_data_index + char_i; + // number of non-ascii chars between the first non-ascii char and the next ascii char + let non_ascii_run_length = + data[byte_i..].chars().take_while(|c| !c.is_ascii()).count(); + let char_range = char_start..char_start + non_ascii_run_length; + let (replace, char_restart) = + errors.handle_encode_error(full_data, char_range.clone(), ERR_REASON)?; + match replace { + EncodeReplace::Str(s) => { + if !s.is_ascii() { + return Err( + errors.error_encoding(full_data, char_range, ERR_REASON) + ); } - handle_error!(err_idx..data.len(), "unexpected end of data"); + out.extend_from_slice(s.as_ref().as_bytes()); + } + EncodeReplace::Bytes(b) => { + out.extend_from_slice(b.as_ref()); } - }; - if !final_decode && matches!(remaining_data, [0xed, 0xa0..=0xbf]) { - // truncated surrogate - break; } - handle_error!(err_idx..err_idx + err_len, "invalid continuation byte"); + data = crate::str::try_get_chars(full_data, char_restart..) + .ok_or_else(|| errors.error_oob_restart(char_restart))?; + char_data_index = char_restart; + continue; } } } - Ok((out, remaining_index)) + Ok(out) + } + + pub fn decode(data: &[u8], errors: &E) -> Result<(String, usize), E::Error> { + decode_utf8_compatible( + data, + errors, + |v| { + AsciiStr::from_ascii(v).map(|s| s.as_str()).map_err(|e| { + // SAFETY: as specified in valid_up_to's documentation, input[..e.valid_up_to()] + // is valid ascii & therefore valid utf8 + unsafe { make_decode_err(v, e.valid_up_to(), Some(1)) } + }) + }, + |_rest, err_len| HandleResult::Error { + err_len, + reason: ERR_REASON, + }, + ) } } diff --git a/vm/src/stdlib/codecs.rs b/vm/src/stdlib/codecs.rs index 2fd135755b..ac63e4e1d4 100644 --- a/vm/src/stdlib/codecs.rs +++ b/vm/src/stdlib/codecs.rs @@ -4,7 +4,7 @@ pub(crate) use _codecs::make_module; mod _codecs { use crate::common::encodings; use crate::{ - builtins::{PyBytesRef, PyStr, PyStrRef, PyTuple}, + builtins::{PyBytes, PyBytesRef, PyStr, PyStrRef, PyTuple}, byteslike::ArgBytesLike, codecs, exceptions::PyBaseExceptionRef, @@ -100,6 +100,11 @@ mod _codecs { }) } } + impl encodings::StrBuffer for PyStrRef { + fn is_ascii(&self) -> bool { + PyStr::is_ascii(self) + } + } impl<'vm> encodings::ErrorHandler for ErrorsHandler<'vm> { type Error = PyBaseExceptionRef; type StrBuf = PyStrRef; @@ -107,12 +112,45 @@ mod _codecs { fn handle_encode_error( &self, - _byte_range: Range, - _reason: &str, + data: &str, + char_range: Range, + reason: &str, ) -> PyResult<(encodings::EncodeReplace, usize)> { - // we don't use common::encodings to really encode anything yet (utf8 can't error - // because PyStr is always utf8), so this can't get called until we do - todo!() + let vm = self.vm; + let data_str = vm.ctx.new_utf8_str(data); + let encode_exc = vm.new_exception( + vm.ctx.exceptions.unicode_encode_error.clone(), + vec![ + vm.ctx.new_utf8_str(self.encoding), + data_str, + vm.ctx.new_int(char_range.start), + vm.ctx.new_int(char_range.end), + vm.ctx.new_utf8_str(reason), + ], + ); + let res = vm.invoke(self.handler_func()?, (encode_exc,))?; + let tuple_err = || { + vm.new_type_error( + "encoding error handler must return (str/bytes, int) tuple".to_owned(), + ) + }; + let (replace, restart) = match res.payload::().map(|tup| tup.as_slice()) { + Some([replace, restart]) => (replace.clone(), restart), + _ => return Err(tuple_err()), + }; + let replace = match_class!(match replace { + s @ PyStr => encodings::EncodeReplace::Str(s), + b @ PyBytes => encodings::EncodeReplace::Bytes(b), + _ => return Err(tuple_err()), + }); + let restart = isize::try_from_borrowed_object(vm, restart).map_err(|_| tuple_err())?; + let restart = if restart < 0 { + // will still be out of bounds if it underflows ¯\_(ツ)_/¯ + data.len().wrapping_sub(restart.unsigned_abs()) + } else { + restart as usize + }; + Ok((replace, restart)) } fn handle_decode_error( @@ -173,6 +211,25 @@ mod _codecs { self.vm .new_index_error(format!("position {} from error handler out of bounds", i)) } + + fn error_encoding( + &self, + data: &str, + char_range: Range, + reason: &str, + ) -> Self::Error { + let vm = self.vm; + vm.new_exception( + vm.ctx.exceptions.unicode_encode_error.clone(), + vec![ + vm.ctx.new_utf8_str(self.encoding), + vm.ctx.new_utf8_str(data), + vm.ctx.new_int(char_range.start), + vm.ctx.new_int(char_range.end), + vm.ctx.new_utf8_str(reason), + ], + ) + } } type EncodeResult = PyResult<(Vec, usize)>; @@ -221,6 +278,26 @@ mod _codecs { } } + #[derive(FromArgs)] + struct DecodeArgsNoFinal { + #[pyarg(positional)] + data: ArgBytesLike, + #[pyarg(positional, optional)] + errors: Option, + } + + impl DecodeArgsNoFinal { + #[inline] + fn decode<'a, F>(self, name: &'a str, decode: F, vm: &'a VirtualMachine) -> DecodeResult + where + F: FnOnce(&[u8], &ErrorsHandler<'a>) -> DecodeResult, + { + let data = self.data.borrow_buf(); + let errors = ErrorsHandler::new(name, self.errors, vm); + decode(&data, &errors) + } + } + macro_rules! do_codec { ($module:ident :: $func:ident, $args: expr, $vm:expr) => {{ use encodings::$module as codec; @@ -238,6 +315,19 @@ mod _codecs { do_codec!(utf8::decode, args, vm) } + #[pyfunction] + fn ascii_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult { + if args.s.is_ascii() { + return Ok((args.s.as_str().as_bytes().to_vec(), args.s.byte_len())); + } + do_codec!(ascii::encode, args, vm) + } + + #[pyfunction] + fn ascii_decode(args: DecodeArgsNoFinal, vm: &VirtualMachine) -> DecodeResult { + do_codec!(ascii::decode, args, vm) + } + // TODO: implement these codecs in Rust! use crate::common::static_cell::StaticCell; @@ -324,14 +414,6 @@ mod _codecs { delegate_pycodecs!(utf_16_decode, args, vm) } #[pyfunction] - fn ascii_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult { - delegate_pycodecs!(ascii_encode, args, vm) - } - #[pyfunction] - fn ascii_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult { - delegate_pycodecs!(ascii_decode, args, vm) - } - #[pyfunction] fn charmap_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult { delegate_pycodecs!(charmap_encode, args, vm) }