diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index e073544099..ed5bfb9729 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -263,3 +263,33 @@ def try_mutate_str(): assert "\u00BE" == "¾" assert "\u9487" == "钇" assert "\U0001F609" == "😉" + +# test str iter +iterable_str = "123456789" +str_iter = iter(iterable_str) + +assert next(str_iter) == "1" +assert next(str_iter) == "2" +assert next(str_iter) == "3" +assert next(str_iter) == "4" +assert next(str_iter) == "5" +assert next(str_iter) == "6" +assert next(str_iter) == "7" +assert next(str_iter) == "8" +assert next(str_iter) == "9" +assert next(str_iter, None) == None +assert_raises(StopIteration, lambda: next(str_iter)) + +str_iter_reversed = reversed(iterable_str) + +assert next(str_iter_reversed) == "9" +assert next(str_iter_reversed) == "8" +assert next(str_iter_reversed) == "7" +assert next(str_iter_reversed) == "6" +assert next(str_iter_reversed) == "5" +assert next(str_iter_reversed) == "4" +assert next(str_iter_reversed) == "3" +assert next(str_iter_reversed) == "2" +assert next(str_iter_reversed) == "1" +assert next(str_iter_reversed, None) == None +assert_raises(StopIteration, lambda: next(str_iter_reversed)) diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 3ffeef0a9e..dab426077a 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -1,6 +1,7 @@ extern crate unicode_categories; extern crate unicode_xid; +use std::cell::Cell; use std::char; use std::fmt; use std::ops::Range; @@ -28,6 +29,7 @@ use crate::vm::VirtualMachine; use super::objbytes::PyBytes; use super::objdict::PyDict; use super::objint::{self, PyInt}; +use super::objiter; use super::objnone::PyNone; use super::objsequence::PySliceableSequence; use super::objslice::PySlice; @@ -90,6 +92,79 @@ impl TryIntoRef for &str { } } +#[pyclass] +#[derive(Debug)] +pub struct PyStringIterator { + pub string: PyStringRef, + position: Cell, +} + +impl PyValue for PyStringIterator { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.ctx.striterator_type() + } +} + +#[pyimpl] +impl PyStringIterator { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let pos = self.position.get(); + + if pos < self.string.value.chars().count() { + self.position.set(self.position.get() + 1); + + #[allow(clippy::range_plus_one)] + let value = self.string.value.do_slice(pos..pos + 1); + + value.into_pyobject(vm) + } else { + Err(objiter::new_stop_iteration(vm)) + } + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + +#[pyclass] +#[derive(Debug)] +pub struct PyStringReverseIterator { + pub position: Cell, + pub string: PyStringRef, +} + +impl PyValue for PyStringReverseIterator { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.ctx.strreverseiterator_type() + } +} + +#[pyimpl] +impl PyStringReverseIterator { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + if self.position.get() > 0 { + let position: usize = self.position.get() - 1; + + #[allow(clippy::range_plus_one)] + let value = self.string.value.do_slice(position..position + 1); + + self.position.set(position); + value.into_pyobject(vm) + } else { + Err(objiter::new_stop_iteration(vm)) + } + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + #[pyimpl] impl PyString { // TODO: should with following format @@ -1025,6 +1100,24 @@ impl PyString { let encoded = PyBytes::from_string(&self.value, &encoding, vm)?; Ok(encoded.into_pyobject(vm)?) } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyStringIterator { + PyStringIterator { + position: Cell::new(0), + string: zelf, + } + } + + #[pymethod(name = "__reversed__")] + fn reversed(zelf: PyRef, _vm: &VirtualMachine) -> PyStringReverseIterator { + let begin = zelf.value.chars().count(); + + PyStringReverseIterator { + position: Cell::new(begin), + string: zelf, + } + } } impl PyValue for PyString { @@ -1053,6 +1146,9 @@ impl IntoPyObject for &String { pub fn init(ctx: &PyContext) { PyString::extend_class(ctx, &ctx.str_type); + + PyStringIterator::extend_class(ctx, &ctx.striterator_type); + PyStringReverseIterator::extend_class(ctx, &ctx.strreverseiterator_type); } pub fn get_value(obj: &PyObjectRef) -> String { diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 74f5b2e664..60e8f42c8a 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -135,6 +135,8 @@ pub struct PyContext { pub list_type: PyClassRef, pub listiterator_type: PyClassRef, pub listreverseiterator_type: PyClassRef, + pub striterator_type: PyClassRef, + pub strreverseiterator_type: PyClassRef, pub dictkeyiterator_type: PyClassRef, pub dictvalueiterator_type: PyClassRef, pub dictitemiterator_type: PyClassRef, @@ -274,6 +276,8 @@ impl PyContext { let listiterator_type = create_type("list_iterator", &type_type, &object_type); let listreverseiterator_type = create_type("list_reverseiterator", &type_type, &object_type); + let striterator_type = create_type("str_iterator", &type_type, &object_type); + let strreverseiterator_type = create_type("str_reverseiterator", &type_type, &object_type); let dictkeys_type = create_type("dict_keys", &type_type, &object_type); let dictvalues_type = create_type("dict_values", &type_type, &object_type); let dictitems_type = create_type("dict_items", &type_type, &object_type); @@ -341,6 +345,8 @@ impl PyContext { list_type, listiterator_type, listreverseiterator_type, + striterator_type, + strreverseiterator_type, dictkeys_type, dictvalues_type, dictitems_type, @@ -476,6 +482,14 @@ impl PyContext { self.listreverseiterator_type.clone() } + pub fn striterator_type(&self) -> PyClassRef { + self.striterator_type.clone() + } + + pub fn strreverseiterator_type(&self) -> PyClassRef { + self.strreverseiterator_type.clone() + } + pub fn module_type(&self) -> PyClassRef { self.module_type.clone() }