From 90711f10a09f62062a358041fbf3ecf7c6f49506 Mon Sep 17 00:00:00 2001 From: ben Date: Sun, 12 May 2019 13:14:27 +1200 Subject: [PATCH] Accept tuple for first arg in str.startswith and str.endswith --- tests/snippets/strings.py | 7 ++++++ vm/src/builtins.rs | 48 +++++++++++++++++++-------------------- vm/src/function.rs | 26 +++++++++++++++++++++ vm/src/obj/objstr.rs | 44 ++++++++++++++++++++++++++--------- 4 files changed, 89 insertions(+), 36 deletions(-) diff --git a/tests/snippets/strings.py b/tests/snippets/strings.py index 0cb3782b17..4a480cbf54 100644 --- a/tests/snippets/strings.py +++ b/tests/snippets/strings.py @@ -42,9 +42,15 @@ assert a.upper() == 'HALLO' assert a.split('al') == ['H', 'lo'] assert a.startswith('H') +assert a.startswith(('H', 1)) +assert a.startswith(('A', 'H')) assert not a.startswith('f') +assert not a.startswith(('A', 'f')) assert a.endswith('llo') +assert a.endswith(('lo', 1)) +assert a.endswith(('A', 'lo')) assert not a.endswith('on') +assert not a.endswith(('A', 'll')) assert a.zfill(8) == '000Hallo' assert a.isalnum() assert not a.isdigit() @@ -144,6 +150,7 @@ assert '___a__'.find('a', 4, 3) == -1 assert 'abcd'.startswith('b', 1) +assert 'abcd'.startswith(('b', 'z'), 1) assert not 'abcd'.startswith('b', -4) assert 'abcd'.startswith('b', -3) diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index a48a3b4412..e9056e2eaf 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -17,11 +17,10 @@ use crate::obj::objdict::PyDictRef; use crate::obj::objint::{self, PyIntRef}; use crate::obj::objiter; use crate::obj::objstr::{self, PyString, PyStringRef}; -use crate::obj::objtuple::PyTuple; -use crate::obj::objtype::{self, PyClass, PyClassRef}; +use crate::obj::objtype::{self, PyClassRef}; use crate::frame::Scope; -use crate::function::{Args, KwArgs, OptionalArg, PyFuncArgs}; +use crate::function::{single_or_tuple_any, Args, KwArgs, OptionalArg, PyFuncArgs}; use crate::pyobject::{ IdProtocol, IntoPyObject, ItemProtocol, PyIterable, PyObjectRef, PyResult, PyValue, TryFromObject, TypeProtocol, @@ -317,37 +316,36 @@ fn builtin_id(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { // builtin_input -fn type_test( - vm: &VirtualMachine, - typ: PyObjectRef, - test: impl Fn(&PyClassRef) -> PyResult, - test_name: &str, -) -> PyResult { - match_class!(typ, - cls @ PyClass => test(&cls), - tuple @ PyTuple => { - for cls_obj in tuple.elements.borrow().iter() { - let cls = PyClassRef::try_from_object(vm, cls_obj.clone())?; - if test(&cls)? { - return Ok(true); - } - } - Ok(false) +fn builtin_isinstance(obj: PyObjectRef, typ: PyObjectRef, vm: &VirtualMachine) -> PyResult { + single_or_tuple_any( + typ, + |cls: PyClassRef| vm.isinstance(&obj, &cls), + |o| { + format!( + "isinstance() arg 2 must be a type or tuple of types, not {}", + o.class() + ) }, - _ => Err(vm.new_type_error(format!("{}() arg 2 must be a type or tuple of types", test_name))) + vm, ) } -fn builtin_isinstance(obj: PyObjectRef, typ: PyObjectRef, vm: &VirtualMachine) -> PyResult { - type_test(vm, typ, |cls| vm.isinstance(&obj, cls), "isinstance") -} - fn builtin_issubclass( subclass: PyClassRef, typ: PyObjectRef, vm: &VirtualMachine, ) -> PyResult { - type_test(vm, typ, |cls| vm.issubclass(&subclass, cls), "issubclass") + single_or_tuple_any( + typ, + |cls: PyClassRef| vm.issubclass(&subclass, &cls), + |o| { + format!( + "issubclass() arg 2 must be a class or tuple of classes, not {}", + o.class() + ) + }, + vm, + ) } fn builtin_iter(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/vm/src/function.rs b/vm/src/function.rs index d3f26950f0..170c395f6e 100644 --- a/vm/src/function.rs +++ b/vm/src/function.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::mem; use std::ops::RangeInclusive; +use crate::obj::objtuple::PyTuple; use crate::obj::objtype::{isinstance, PyClassRef}; use crate::pyobject::{ IntoPyObject, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, @@ -523,3 +524,28 @@ into_py_native_func_tuple!((a, A), (b, B)); into_py_native_func_tuple!((a, A), (b, B), (c, C)); into_py_native_func_tuple!((a, A), (b, B), (c, C), (d, D)); into_py_native_func_tuple!((a, A), (b, B), (c, C), (d, D), (e, E)); + +/// Tests that the predicate is True on a single value, or if the value is a tuple a tuple, then +/// test that any of the values contained within the tuples satisfies the predicate. Type parameter +/// T specifies the type that is expected, if the input value is not of that type or a tuple of +/// values of that type, then a TypeError is raised. +pub fn single_or_tuple_any) -> PyResult>( + obj: PyObjectRef, + predicate: F, + message: fn(&PyObjectRef) -> String, + vm: &VirtualMachine, +) -> PyResult { + match_class!(obj, + obj @ T => predicate(obj), + tuple @ PyTuple => { + for obj in tuple.elements.borrow().iter() { + let inner_val = PyRef::::try_from_object(vm, obj.clone())?; + if predicate(inner_val)? { + return Ok(true); + } + } + Ok(false) + }, + obj => Err(vm.new_type_error(message(&obj))) + ) +} diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 7b2fcbe9a1..621cbf5bcd 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -12,7 +12,7 @@ use unicode_segmentation::UnicodeSegmentation; use unicode_xid::UnicodeXID; use crate::format::{FormatParseError, FormatPart, FormatString}; -use crate::function::{OptionalArg, PyFuncArgs}; +use crate::function::{single_or_tuple_any, OptionalArg, PyFuncArgs}; use crate::pyobject::{ IdProtocol, IntoPyObject, PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TryIntoRef, TypeProtocol, @@ -342,30 +342,52 @@ impl PyString { #[pymethod] fn endswith( &self, - suffix: PyStringRef, + suffix: PyObjectRef, start: OptionalArg, end: OptionalArg, - _vm: &VirtualMachine, - ) -> bool { + vm: &VirtualMachine, + ) -> PyResult { if let Some((start, end)) = adjust_indices(start, end, self.value.len()) { - self.value[start..end].ends_with(&suffix.value) + let value = &self.value[start..end]; + single_or_tuple_any( + suffix, + |s: PyStringRef| Ok(value.ends_with(&s.value)), + |o| { + format!( + "endswith first arg must be str or a tuple of str, not {}", + o.class(), + ) + }, + vm, + ) } else { - false + Ok(false) } } #[pymethod] fn startswith( &self, - prefix: PyStringRef, + prefix: PyObjectRef, start: OptionalArg, end: OptionalArg, - _vm: &VirtualMachine, - ) -> bool { + vm: &VirtualMachine, + ) -> PyResult { if let Some((start, end)) = adjust_indices(start, end, self.value.len()) { - self.value[start..end].starts_with(&prefix.value) + let value = &self.value[start..end]; + single_or_tuple_any( + prefix, + |s: PyStringRef| Ok(value.starts_with(&s.value)), + |o| { + format!( + "startswith first arg must be str or a tuple of str, not {}", + o.class(), + ) + }, + vm, + ) } else { - false + Ok(false) } }