diff --git a/tests/snippets/subclass_str.py b/tests/snippets/subclass_str.py new file mode 100644 index 0000000000..706566ec6d --- /dev/null +++ b/tests/snippets/subclass_str.py @@ -0,0 +1,22 @@ +from testutils import assertRaises + +x = "An interesting piece of text" +assert x is str(x) + +class Stringy(str): + def __new__(cls, value=""): + return str.__new__(cls, value) + + def __init__(self, value): + self.x = "substr" + +y = Stringy(1) +assert type(y) is Stringy, "Type of Stringy should be stringy" +assert type(str(y)) is str, "Str of a str-subtype should be a str." + +assert y + " other" == "1 other" +assert y.x == "substr" + +## Base strings currently get an attribute dict, but shouldn't. +# with assertRaises(AttributeError): +# "hello".x = 5 diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index a0e2e110d7..1ce1cdc747 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -7,15 +7,15 @@ use unicode_segmentation::UnicodeSegmentation; use crate::format::{FormatParseError, FormatPart, FormatString}; use crate::pyobject::{ - IntoPyObject, OptionalArg, PyContext, PyFuncArgs, PyIterable, PyObjectRef, PyRef, PyResult, - PyValue, TypeProtocol, + IdProtocol, IntoPyObject, OptionalArg, PyContext, PyFuncArgs, PyIterable, PyObjectRef, PyRef, + PyResult, PyValue, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; use super::objint; use super::objsequence::PySliceableSequence; use super::objslice::PySlice; -use super::objtype; +use super::objtype::{self, PyClassRef}; #[derive(Clone, Debug)] pub struct PyString { @@ -788,16 +788,21 @@ fn perform_format( // TODO: should with following format // class str(object='') // class str(object=b'', encoding='utf-8', errors='strict') -fn str_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - if args.args.len() == 1 { - return Ok(vm.new_str("".to_string())); - } - - if args.args.len() > 2 { - panic!("str expects exactly one parameter"); +fn str_new( + cls: PyClassRef, + object: OptionalArg, + vm: &mut VirtualMachine, +) -> PyResult { + let string = match object { + OptionalArg::Present(ref input) => vm.to_str(input)?, + OptionalArg::Missing => vm.new_str("".to_string()), }; - - vm.to_str(&args.args[1]) + if string.typ().is(&cls) { + TryFromObject::try_from_object(vm, string) + } else { + let payload = string.payload::().unwrap(); + PyRef::new_with_type(vm, payload.clone(), cls) + } } impl PySliceableSequence for String {