diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index 3364a293c5..531e8f83c1 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -157,6 +157,13 @@ assert 'z' >= 'b' assert 'a' >= 'a' +# str.translate +assert "abc".translate({97: '🎅', 98: None, 99: "xd"}) == "🎅xd" + +# str.maketrans +assert str.maketrans({"a": "abc", "b": None, "c": 33}) == {97: "abc", 98: None, 99: 33} +assert str.maketrans("hello", "world", "rust") == {104: 119, 101: 111, 108: 108, 111: 100, 114: None, 117: None, 115: None, 116: None} + def try_mutate_str(): word = "word" word[0] = 'x' diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 56fc373c39..d81b69f93e 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -11,12 +11,14 @@ use unicode_segmentation::UnicodeSegmentation; use crate::format::{FormatParseError, FormatPart, FormatString}; use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyobject::{ - IdProtocol, IntoPyObject, PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, - PyValue, TryFromObject, TryIntoRef, TypeProtocol, + IdProtocol, IntoPyObject, ItemProtocol, PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, + PyResult, PyValue, TryFromObject, TryIntoRef, TypeProtocol, }; use crate::vm::VirtualMachine; -use super::objint; +use super::objdict::PyDict; +use super::objint::{self, PyInt}; +use super::objnone::PyNone; use super::objsequence::PySliceableSequence; use super::objslice::PySlice; use super::objtype::{self, PyClassRef}; @@ -829,6 +831,99 @@ impl PyString { false } } + + // https://docs.python.org/3/library/stdtypes.html#str.translate + #[pymethod] + fn translate(&self, table: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let mut translated = String::new(); + // It throws a type error if it is not subscribtable + vm.get_method(table.clone(), "__getitem__")?; + for c in self.value.chars() { + match table.get_item(c as u32, vm) { + Ok(value) => { + if let Some(text) = value.payload::() { + translated.extend(text.value.chars()); + } else if let Some(_) = value.payload::() { + // Do Nothing + } else if let Some(bigint) = value.payload::() { + match bigint.as_bigint().to_u32().and_then(std::char::from_u32) { + Some(ch) => translated.push(ch as char), + None => { + return Err(vm.new_value_error(format!( + "character mapping must be in range(0x110000)" + ))); + } + } + } else { + return Err(vm.new_type_error( + "character mapping must return integer, None or str".to_owned(), + )); + } + } + _ => translated.push(c), + } + } + Ok(translated) + } + + #[pymethod] + fn maketrans( + dict_or_str: PyObjectRef, + to_str: OptionalArg, + none_str: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + let new_dict = vm.context().new_dict(); + if let OptionalArg::Present(to_str) = to_str { + match dict_or_str.downcast::() { + Ok(from_str) => { + if to_str.len(vm) == from_str.len(vm) { + for (c1, c2) in from_str.value.chars().zip(to_str.value.chars()) { + new_dict.set_item(c1 as u32, vm.new_int(c2 as u32), vm)?; + } + if let OptionalArg::Present(none_str) = none_str { + for c in none_str.value.chars() { + new_dict.set_item(c as u32, vm.get_none(), vm)?; + } + } + new_dict.into_pyobject(vm) + } else { + Err(vm.new_value_error( + "the first two maketrans arguments must have equal length".to_owned(), + )) + } + } + _ => Err(vm.new_type_error( + "first maketrans argument must be a string if there is a second argument" + .to_owned(), + )), + } + } else { + // dict_str must be a dict + match dict_or_str.downcast::() { + Ok(dict) => { + for (key, val) in dict { + if let Some(num) = key.payload::() { + new_dict.set_item(num.as_bigint().to_i32(), val, vm)?; + } else if let Some(string) = key.payload::() { + if string.len(vm) == 1 { + let num_value = string.value.chars().next().unwrap() as u32; + new_dict.set_item(num_value, val, vm)?; + } else { + return Err(vm.new_value_error( + "string keys in translate table must be of length 1".to_owned(), + )); + } + } + } + new_dict.into_pyobject(vm) + } + _ => Err(vm.new_value_error( + "if you give only one argument to maketrans it must be a dict".to_owned(), + )), + } + } + } } impl PyValue for PyString { @@ -1104,4 +1199,30 @@ mod tests { assert!(!PyString::from(s).istitle(&vm)); } } + + #[test] + fn str_maketrans_and_translate() { + let vm = VirtualMachine::new(); + + let table = vm.context().new_dict(); + table + .set_item("a", vm.new_str("🎅".to_owned()), &vm) + .unwrap(); + table.set_item("b", vm.get_none(), &vm).unwrap(); + table + .set_item("c", vm.new_str("xda".to_owned()), &vm) + .unwrap(); + let translated = PyString::maketrans( + table.into_object(), + OptionalArg::Missing, + OptionalArg::Missing, + &vm, + ) + .unwrap(); + let text = PyString::from("abc"); + let translated = text.translate(translated, &vm).unwrap(); + assert_eq!(translated, "🎅xda".to_owned()); + let translated = text.translate(vm.new_int(3), &vm); + assert_eq!(translated.unwrap_err().class().name, "TypeError".to_owned()); + } } diff --git a/whats_left.sh b/whats_left.sh index e05051d283..66e64f1a37 100755 --- a/whats_left.sh +++ b/whats_left.sh @@ -8,4 +8,4 @@ python3 not_impl_gen.py cd .. -cargo run -- tests/snippets/whats_left_to_implement.py +cargo run -- tests/snippets/whats_left_to_implement.py \ No newline at end of file