diff --git a/common/src/encodings.rs b/common/src/encodings.rs index 7a6cd16069..bbbc5e8202 100644 --- a/common/src/encodings.rs +++ b/common/src/encodings.rs @@ -172,6 +172,81 @@ pub mod utf8 { } } +pub mod latin_1 { + use super::*; + + pub const ENCODING_NAME: &str = "latin-1"; + + const ERR_REASON: &str = "ordinal not in range(256)"; + + #[inline] + pub fn encode(s: &str, errors: &E) -> Result, E::Error> { + let full_data = s; + let mut data = s; + let mut char_data_index = 0; + let mut out = Vec::::new(); + loop { + match data + .char_indices() + .enumerate() + .find(|(_, (_, c))| !c.is_ascii()) + { + None => { + out.extend_from_slice(data.as_bytes()); + break; + } + Some((char_i, (byte_i, ch))) => { + out.extend_from_slice(&data.as_bytes()[..byte_i]); + let char_start = char_data_index + char_i; + if (ch as u32) <= 255 { + out.push(ch as u8); + let char_restart = char_start + 1; + data = crate::str::try_get_chars(full_data, char_restart..) + .ok_or_else(|| errors.error_oob_restart(char_restart))?; + char_data_index = char_restart; + } else { + // number of non-latin_1 chars between the first non-latin_1 char and the next latin_1 char + let non_latin_1_run_length = data[byte_i..] + .chars() + .take_while(|c| (*c as u32) > 255) + .count(); + let char_range = char_start..char_start + non_latin_1_run_length; + let (replace, char_restart) = errors.handle_encode_error( + full_data, + char_range.clone(), + ERR_REASON, + )?; + match replace { + EncodeReplace::Str(s) => { + if s.as_ref().chars().any(|c| (c as u32) > 255) { + return Err( + errors.error_encoding(full_data, char_range, ERR_REASON) + ); + } + out.extend_from_slice(s.as_ref().as_bytes()); + } + EncodeReplace::Bytes(b) => { + out.extend_from_slice(b.as_ref()); + } + } + data = crate::str::try_get_chars(full_data, char_restart..) + .ok_or_else(|| errors.error_oob_restart(char_restart))?; + char_data_index = char_restart; + } + continue; + } + } + } + Ok(out) + } + + pub fn decode(data: &[u8], _errors: &E) -> Result<(String, usize), E::Error> { + let out: String = data.iter().map(|c| *c as char).collect(); + let out_len = out.len(); + Ok((out, out_len)) + } +} + pub mod ascii { use super::*; use ::ascii::AsciiStr; diff --git a/vm/src/stdlib/codecs.rs b/vm/src/stdlib/codecs.rs index ac63e4e1d4..355661852a 100644 --- a/vm/src/stdlib/codecs.rs +++ b/vm/src/stdlib/codecs.rs @@ -315,6 +315,19 @@ mod _codecs { do_codec!(utf8::decode, args, vm) } + #[pyfunction] + fn latin_1_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult { + if args.s.is_ascii() { + return Ok((args.s.as_str().as_bytes().to_vec(), args.s.byte_len())); + } + do_codec!(latin_1::encode, args, vm) + } + + #[pyfunction] + fn latin_1_decode(args: DecodeArgsNoFinal, vm: &VirtualMachine) -> DecodeResult { + do_codec!(latin_1::decode, args, vm) + } + #[pyfunction] fn ascii_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult { if args.s.is_ascii() { @@ -353,14 +366,6 @@ mod _codecs { }}; } - #[pyfunction] - fn latin_1_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult { - delegate_pycodecs!(latin_1_encode, args, vm) - } - #[pyfunction] - fn latin_1_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult { - delegate_pycodecs!(latin_1_decode, args, vm) - } #[pyfunction] fn mbcs_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult { delegate_pycodecs!(mbcs_encode, args, vm)