From db8eb7b8c340370fb8cd595ac9b362b9c9aea9f6 Mon Sep 17 00:00:00 2001 From: Marcin Pajkowski Date: Wed, 24 Jul 2019 01:07:08 +0200 Subject: [PATCH 1/4] Implement str.__iter__ and str.__next__ --- vm/src/obj/objstr.rs | 80 ++++++++++++++++++++++++++++++++++++++++++++ vm/src/pyobject.rs | 14 ++++++++ 2 files changed, 94 insertions(+) diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 3ffeef0a9e..890918beb1 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -1,6 +1,7 @@ extern crate unicode_categories; extern crate unicode_xid; +use std::cell::Cell; use std::char; use std::fmt; use std::ops::Range; @@ -28,6 +29,7 @@ use crate::vm::VirtualMachine; use super::objbytes::PyBytes; use super::objdict::PyDict; use super::objint::{self, PyInt}; +use super::objiter; use super::objnone::PyNone; use super::objsequence::PySliceableSequence; use super::objslice::PySlice; @@ -90,6 +92,73 @@ impl TryIntoRef for &str { } } +#[pyclass] +#[derive(Debug)] +pub struct PyStringIterator { + pub string: PyStringRef, + position: Cell, +} + +impl PyValue for PyStringIterator { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.ctx.striterator_type() + } +} + +#[pyimpl] +impl PyStringIterator { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let pos = self.position.get(); + + if pos < self.string.value.chars().count() { + let value = self.string.value.do_slice(pos..pos + 1); + self.position.set(self.position.get() + 1); + value.into_pyobject(vm) + } else { + Err(objiter::new_stop_iteration(vm)) + } + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + +#[pyclass] +#[derive(Debug)] +pub struct PyStringReverseIterator { + pub position: Cell, + pub string: PyStringRef, +} + +impl PyValue for PyStringReverseIterator { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.ctx.strreverseiterator_type() + } +} + +#[pyimpl] +impl PyStringReverseIterator { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + if self.position.get() > 0 { + let position: usize = self.position.get() - 1; + let value = self.string.value.do_slice(position..position + 1); + self.position.set(position); + value.into_pyobject(vm) + } else { + Err(objiter::new_stop_iteration(vm)) + } + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + #[pyimpl] impl PyString { // TODO: should with following format @@ -1025,6 +1094,14 @@ impl PyString { let encoded = PyBytes::from_string(&self.value, &encoding, vm)?; Ok(encoded.into_pyobject(vm)?) } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyStringIterator { + PyStringIterator { + position: Cell::new(0), + string: zelf, + } + } } impl PyValue for PyString { @@ -1053,6 +1130,9 @@ impl IntoPyObject for &String { pub fn init(ctx: &PyContext) { PyString::extend_class(ctx, &ctx.str_type); + + PyStringIterator::extend_class(ctx, &ctx.striterator_type); + PyStringReverseIterator::extend_class(ctx, &ctx.strreverseiterator_type); } pub fn get_value(obj: &PyObjectRef) -> String { diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 74f5b2e664..60e8f42c8a 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -135,6 +135,8 @@ pub struct PyContext { pub list_type: PyClassRef, pub listiterator_type: PyClassRef, pub listreverseiterator_type: PyClassRef, + pub striterator_type: PyClassRef, + pub strreverseiterator_type: PyClassRef, pub dictkeyiterator_type: PyClassRef, pub dictvalueiterator_type: PyClassRef, pub dictitemiterator_type: PyClassRef, @@ -274,6 +276,8 @@ impl PyContext { let listiterator_type = create_type("list_iterator", &type_type, &object_type); let listreverseiterator_type = create_type("list_reverseiterator", &type_type, &object_type); + let striterator_type = create_type("str_iterator", &type_type, &object_type); + let strreverseiterator_type = create_type("str_reverseiterator", &type_type, &object_type); let dictkeys_type = create_type("dict_keys", &type_type, &object_type); let dictvalues_type = create_type("dict_values", &type_type, &object_type); let dictitems_type = create_type("dict_items", &type_type, &object_type); @@ -341,6 +345,8 @@ impl PyContext { list_type, listiterator_type, listreverseiterator_type, + striterator_type, + strreverseiterator_type, dictkeys_type, dictvalues_type, dictitems_type, @@ -476,6 +482,14 @@ impl PyContext { self.listreverseiterator_type.clone() } + pub fn striterator_type(&self) -> PyClassRef { + self.striterator_type.clone() + } + + pub fn strreverseiterator_type(&self) -> PyClassRef { + self.strreverseiterator_type.clone() + } + pub fn module_type(&self) -> PyClassRef { self.module_type.clone() } From 6aeec34df02adde598a6b86d4342c6b587b6b38b Mon Sep 17 00:00:00 2001 From: Marcin Pajkowski Date: Wed, 24 Jul 2019 01:07:08 +0200 Subject: [PATCH 2/4] Add tests for str.__iter__ and str.__next__ --- tests/snippets/strings.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index e073544099..ed5bfb9729 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -263,3 +263,33 @@ def try_mutate_str(): assert "\u00BE" == "¾" assert "\u9487" == "钇" assert "\U0001F609" == "😉" + +# test str iter +iterable_str = "123456789" +str_iter = iter(iterable_str) + +assert next(str_iter) == "1" +assert next(str_iter) == "2" +assert next(str_iter) == "3" +assert next(str_iter) == "4" +assert next(str_iter) == "5" +assert next(str_iter) == "6" +assert next(str_iter) == "7" +assert next(str_iter) == "8" +assert next(str_iter) == "9" +assert next(str_iter, None) == None +assert_raises(StopIteration, lambda: next(str_iter)) + +str_iter_reversed = reversed(iterable_str) + +assert next(str_iter_reversed) == "9" +assert next(str_iter_reversed) == "8" +assert next(str_iter_reversed) == "7" +assert next(str_iter_reversed) == "6" +assert next(str_iter_reversed) == "5" +assert next(str_iter_reversed) == "4" +assert next(str_iter_reversed) == "3" +assert next(str_iter_reversed) == "2" +assert next(str_iter_reversed) == "1" +assert next(str_iter_reversed, None) == None +assert_raises(StopIteration, lambda: next(str_iter_reversed)) From 0e033ad93f0ee15926e3910ee8269dff66c11cf5 Mon Sep 17 00:00:00 2001 From: Marcin Pajkowski Date: Wed, 24 Jul 2019 19:03:02 +0200 Subject: [PATCH 3/4] Add str.__reversed__ method --- vm/src/obj/objstr.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 890918beb1..5dd3adcd29 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -1102,6 +1102,16 @@ impl PyString { string: zelf, } } + + #[pymethod(name = "__reversed__")] + fn reversed(zelf: PyRef, _vm: &VirtualMachine) -> PyStringReverseIterator { + let begin = zelf.value.chars().count(); + + PyStringReverseIterator { + position: Cell::new(begin), + string: zelf, + } + } } impl PyValue for PyString { From 1f59532f8221e13f8df164814fc45b1917af7b28 Mon Sep 17 00:00:00 2001 From: Marcin Pajkowski Date: Wed, 24 Jul 2019 19:32:05 +0200 Subject: [PATCH 4/4] Silent clippy with allowing range_plus_one PySliceableSequence trait methods require Range as arguments --- vm/src/obj/objstr.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 5dd3adcd29..dab426077a 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -112,8 +112,11 @@ impl PyStringIterator { let pos = self.position.get(); if pos < self.string.value.chars().count() { - let value = self.string.value.do_slice(pos..pos + 1); self.position.set(self.position.get() + 1); + + #[allow(clippy::range_plus_one)] + let value = self.string.value.do_slice(pos..pos + 1); + value.into_pyobject(vm) } else { Err(objiter::new_stop_iteration(vm)) @@ -145,7 +148,10 @@ impl PyStringReverseIterator { fn next(&self, vm: &VirtualMachine) -> PyResult { if self.position.get() > 0 { let position: usize = self.position.get() - 1; + + #[allow(clippy::range_plus_one)] let value = self.string.value.do_slice(position..position + 1); + self.position.set(position); value.into_pyobject(vm) } else {