diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index 786622e9f5..aaeaed5f3e 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -206,3 +206,13 @@ def try_mutate_str(): word[0] = 'x' assert_raises(TypeError, try_mutate_str) + +ss = ['Hello', '안녕', '👋'] +bs = [b'Hello', b'\xec\x95\x88\xeb\x85\x95', b'\xf0\x9f\x91\x8b'] + +for s, b in zip(ss, bs): + assert s.encode() == b + +for s, b, e in zip(ss, bs, ['u8', 'U8', 'utf-8', 'UTF-8', 'utf_8']): + assert s.encode(e) == b + # assert s.encode(encoding=e) == b diff --git a/vm/src/function.rs b/vm/src/function.rs index 47fc051d59..22802ac1fa 100644 --- a/vm/src/function.rs +++ b/vm/src/function.rs @@ -374,6 +374,17 @@ impl OptionalArg { Missing => f(), } } + + pub fn map_or_else(self, default: D, f: F) -> U + where + D: FnOnce() -> U, + F: FnOnce(T) -> U, + { + match self { + Present(value) => f(value), + Missing => default(), + } + } } impl FromArgs for OptionalArg diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index e420fa038b..32a58255f8 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -74,24 +74,33 @@ pub struct ByteInnerNewOptions { encoding: OptionalArg, } +//same algorithm as cpython +pub fn normalize_encoding(encoding: &str) -> String { + let mut res = String::new(); + let mut punct = false; + + for c in encoding.chars() { + if c.is_alphanumeric() || c == '.' { + if punct && !res.is_empty() { + res.push('_') + } + res.push(c.to_ascii_lowercase()); + punct = false; + } else { + punct = true; + } + } + res +} + impl ByteInnerNewOptions { pub fn get_value(self, vm: &VirtualMachine) -> PyResult { // First handle bytes(string, encoding[, errors]) if let OptionalArg::Present(enc) = self.encoding { if let OptionalArg::Present(eval) = self.val_option { if let Ok(input) = eval.downcast::() { - let encoding = enc.as_str(); - if encoding.to_lowercase() == "utf8" || encoding.to_lowercase() == "utf-8" - // TODO: different encoding - { - return Ok(PyByteInner { - elements: input.value.as_bytes().to_vec(), - }); - } else { - return Err( - vm.new_value_error(format!("unknown encoding: {}", encoding)), //should be lookup error - ); - } + let inner = PyByteInner::from_string(&input.value, enc.as_str(), vm)?; + return Ok(inner); } else { return Err(vm.new_type_error("encoding without a string argument".to_string())); } @@ -311,6 +320,20 @@ impl ByteInnerSplitlinesOptions { } impl PyByteInner { + pub fn from_string(value: &str, encoding: &str, vm: &VirtualMachine) -> PyResult { + let normalized = normalize_encoding(encoding); + if normalized == "utf_8" || normalized == "utf8" || normalized == "u8" { + Ok(PyByteInner { + elements: value.as_bytes().to_vec(), + }) + } else { + // TODO: different encoding + Err( + vm.new_value_error(format!("unknown encoding: {}", encoding)), // should be lookup error + ) + } + } + pub fn repr(&self) -> PyResult { let mut res = String::with_capacity(self.elements.len()); for i in self.elements.iter() { diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index dd2e1518c0..33016e511a 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -45,6 +45,13 @@ impl PyBytes { inner: PyByteInner { elements }, } } + + pub fn from_string(value: &str, encoding: &str, vm: &VirtualMachine) -> PyResult { + Ok(PyBytes { + inner: PyByteInner::from_string(value, encoding, vm)?, + }) + } + pub fn get_value(&self) -> &[u8] { &self.inner.elements } diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 6545c99858..b0d19e9c3f 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -19,6 +19,7 @@ use crate::pyobject::{ }; use crate::vm::VirtualMachine; +use super::objbytes::PyBytes; use super::objdict::PyDict; use super::objint::{self, PyInt}; use super::objnone::PyNone; @@ -957,6 +958,31 @@ impl PyString { } } } + + #[pymethod] + fn encode( + &self, + encoding: OptionalArg, + _errors: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let encoding = encoding.map_or_else( + || Ok("utf-8".to_string()), + |v| { + if objtype::isinstance(&v, &vm.ctx.str_type()) { + Ok(get_value(&v)) + } else { + Err(vm.new_type_error(format!( + "encode() argument 1 must be str, not {}", + v.class().name + ))) + } + }, + )?; + + let encoded = PyBytes::from_string(&self.value, &encoding, vm)?; + Ok(encoded.into_pyobject(vm)?) + } } impl PyValue for PyString {