Skip to content

Fix bytes.{start|end}swith #1863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions Lib/test/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,6 @@ def test_count(self):
self.assertEqual(b.count(i, 1, 3), 1)
self.assertEqual(b.count(p, 7, 9), 1)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_startswith(self):
b = self.type2test(b'hello')
self.assertFalse(self.type2test().startswith(b"anything"))
Expand All @@ -564,8 +562,6 @@ def test_startswith(self):
self.assertIn('bytes', exc)
self.assertIn('tuple', exc)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_endswith(self):
b = self.type2test(b'hello')
self.assertFalse(bytearray().endswith(b"anything"))
Expand Down
10 changes: 7 additions & 3 deletions vm/src/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,14 @@ fn builtin_input(prompt: OptionalArg<PyStringRef>, vm: &VirtualMachine) -> PyRes
}
}

fn builtin_isinstance(obj: PyObjectRef, typ: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
pub fn builtin_isinstance(
obj: PyObjectRef,
typ: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<bool> {
single_or_tuple_any(
typ,
|cls: PyClassRef| vm.isinstance(&obj, &cls),
|cls: &PyClassRef| vm.isinstance(&obj, cls),
|o| {
format!(
"isinstance() arg 2 must be a type or tuple of types, not {}",
Expand All @@ -384,7 +388,7 @@ fn builtin_issubclass(
) -> PyResult<bool> {
single_or_tuple_any(
typ,
|cls: PyClassRef| vm.issubclass(&subclass, &cls),
|cls: &PyClassRef| vm.issubclass(&subclass, cls),
|o| {
format!(
"issubclass() arg 2 must be a class or tuple of classes, not {}",
Expand Down
26 changes: 5 additions & 21 deletions vm/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ use std::sync::Mutex;
use indexmap::IndexMap;
use itertools::Itertools;

use crate::builtins::builtin_isinstance;
use crate::bytecode;
use crate::exceptions::{self, ExceptionCtor, PyBaseExceptionRef};
use crate::function::{single_or_tuple_any, PyFuncArgs};
use crate::function::PyFuncArgs;
use crate::obj::objasyncgenerator::PyAsyncGenWrappedValue;
use crate::obj::objbool;
use crate::obj::objcode::PyCodeRef;
Expand Down Expand Up @@ -1353,25 +1354,6 @@ impl ExecutingFrame<'_> {
!a.is(&b)
}

fn exc_match(
&self,
vm: &VirtualMachine,
exc: PyObjectRef,
exc_type: PyObjectRef,
) -> PyResult<bool> {
single_or_tuple_any(
exc_type,
|cls: PyClassRef| vm.isinstance(&exc, &cls),
|o| {
format!(
"isinstance() arg 2 must be a type or tuple of types, not {}",
o.class()
)
},
vm,
)
}

#[cfg_attr(feature = "flame-it", flame("Frame"))]
fn execute_compare(
&mut self,
Expand All @@ -1391,7 +1373,9 @@ impl ExecutingFrame<'_> {
bytecode::ComparisonOperator::IsNot => vm.new_bool(self._is_not(a, b)),
bytecode::ComparisonOperator::In => vm.new_bool(self._in(vm, a, b)?),
bytecode::ComparisonOperator::NotIn => vm.new_bool(self._not_in(vm, a, b)?),
bytecode::ComparisonOperator::ExceptionMatch => vm.new_bool(self.exc_match(vm, a, b)?),
bytecode::ComparisonOperator::ExceptionMatch => {
vm.new_bool(builtin_isinstance(a, b, vm)?)
}
};

self.push_value(value);
Expand Down
54 changes: 33 additions & 21 deletions vm/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use result_like::impl_option_like;
use smallbox::{smallbox, space::S1, SmallBox};

use crate::exceptions::PyBaseExceptionRef;
use crate::obj::objtuple::PyTuple;
use crate::obj::objtuple::PyTupleRef;
use crate::obj::objtype::{isinstance, PyClassRef};
use crate::pyobject::{
IntoPyObject, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
Expand Down Expand Up @@ -564,43 +564,55 @@ into_py_native_func_tuple!((a, A), (b, B), (c, C), (d, D), (e, E));
/// test that any of the values contained within the tuples satisfies the predicate. Type parameter
/// T specifies the type that is expected, if the input value is not of that type or a tuple of
/// values of that type, then a TypeError is raised.
pub fn single_or_tuple_any<T: PyValue, F: Fn(PyRef<T>) -> PyResult<bool>>(
pub fn single_or_tuple_any<T, F, M>(
obj: PyObjectRef,
predicate: F,
message: fn(&PyObjectRef) -> String,
message: M,
vm: &VirtualMachine,
) -> PyResult<bool> {
) -> PyResult<bool>
where
T: TryFromObject,
F: Fn(&T) -> PyResult<bool>,
M: Fn(&PyObjectRef) -> String,
{
// TODO: figure out some way to have recursive calls without... this
use std::marker::PhantomData;
struct Checker<'vm, T: PyValue, F: Fn(PyRef<T>) -> PyResult<bool>> {
struct Checker<T, F, M>
where
F: Fn(&T) -> PyResult<bool>,
M: Fn(&PyObjectRef) -> String,
{
predicate: F,
message: fn(&PyObjectRef) -> String,
vm: &'vm VirtualMachine,
t: PhantomData<T>,
}
impl<T: PyValue, F: Fn(PyRef<T>) -> PyResult<bool>> Checker<'_, T, F> {
fn check(&self, obj: PyObjectRef) -> PyResult<bool> {
match_class!(match obj {
obj @ T => (self.predicate)(obj),
tuple @ PyTuple => {
message: M,
t: std::marker::PhantomData<T>,
}
impl<T, F, M> Checker<T, F, M>
where
T: TryFromObject,
F: Fn(&T) -> PyResult<bool>,
M: Fn(&PyObjectRef) -> String,
{
fn check(&self, obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
match T::try_from_object(vm, obj.clone()) {
Ok(single) => (self.predicate)(&single),
Err(_) => {
let tuple = PyTupleRef::try_from_object(vm, obj.clone())
.map_err(|_| vm.new_type_error((self.message)(&obj)))?;
for obj in tuple.as_slice().iter() {
if self.check(obj.clone())? {
if self.check(&obj, vm)? {
return Ok(true);
}
}
Ok(false)
}
obj => Err(self.vm.new_type_error((self.message)(&obj))),
})
}
}
}
let checker = Checker {
predicate,
message,
vm,
t: PhantomData,
t: std::marker::PhantomData,
};
checker.check(obj)
checker.check(&obj, vm)
}

#[cfg(test)]
Expand Down
36 changes: 25 additions & 11 deletions vm/src/obj/objbytearray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use super::objint::PyIntRef;
use super::objiter;
use super::objslice::PySliceRef;
use super::objstr::{PyString, PyStringRef};
use super::objtuple::PyTupleRef;
use super::objtype::PyClassRef;
use super::pystr::PyCommonString;
use crate::cformat::CFormatString;
use crate::function::{OptionalArg, OptionalOption};
use crate::obj::objstr::do_cformat_string;
Expand Down Expand Up @@ -303,25 +303,39 @@ impl PyByteArray {
#[pymethod(name = "endswith")]
fn endswith(
&self,
suffix: Either<PyByteInner, PyTupleRef>,
start: OptionalArg<PyObjectRef>,
end: OptionalArg<PyObjectRef>,
suffix: PyObjectRef,
start: OptionalArg<Option<isize>>,
end: OptionalArg<Option<isize>>,
vm: &VirtualMachine,
) -> PyResult<bool> {
self.borrow_value()
.startsendswith(suffix, start, end, true, vm)
self.borrow_value().elements[..].py_startsendswith(
suffix,
start,
end,
"endswith",
"bytes",
|s, x: &PyByteInner| s.ends_with(&x.elements[..]),
vm,
)
}

#[pymethod(name = "startswith")]
fn startswith(
&self,
prefix: Either<PyByteInner, PyTupleRef>,
start: OptionalArg<PyObjectRef>,
end: OptionalArg<PyObjectRef>,
prefix: PyObjectRef,
start: OptionalArg<Option<isize>>,
end: OptionalArg<Option<isize>>,
vm: &VirtualMachine,
) -> PyResult<bool> {
self.borrow_value()
.startsendswith(prefix, start, end, false, vm)
self.borrow_value().elements[..].py_startsendswith(
prefix,
start,
end,
"startswith",
"bytes",
|s, x: &PyByteInner| s.starts_with(&x.elements[..]),
vm,
)
}

#[pymethod(name = "find")]
Expand Down
58 changes: 12 additions & 46 deletions vm/src/obj/objbyteinner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ use super::objint::{self, PyInt, PyIntRef};
use super::objlist::PyList;
use super::objmemory::PyMemoryView;
use super::objnone::PyNoneRef;
use super::objsequence::{is_valid_slice_arg, PySliceableSequence};
use super::objsequence::PySliceableSequence;
use super::objslice::PySliceRef;
use super::objstr::{self, adjust_indices, PyString, PyStringRef, StringRange};
use super::objtuple::PyTupleRef;
use super::pystr::PyCommonString;
use super::objstr::{self, PyString, PyStringRef};
use super::pystr::{self, PyCommonString, StringRange};
use crate::function::{OptionalArg, OptionalOption};
use crate::pyhash;
use crate::pyobject::{
Expand Down Expand Up @@ -172,7 +171,7 @@ impl ByteInnerFindOptions {
Either::A(v) => v.elements.to_vec(),
Either::B(int) => vec![int.as_bigint().byte_or(vm)?],
};
let range = adjust_indices(self.start, self.end, len);
let range = pystr::adjust_indices(self.start, self.end, len);
Ok((sub, range))
}
}
Expand Down Expand Up @@ -856,47 +855,6 @@ impl PyByteInner {
Ok(refs)
}

#[inline]
pub fn startsendswith(
&self,
arg: Either<PyByteInner, PyTupleRef>,
start: OptionalArg<PyObjectRef>,
end: OptionalArg<PyObjectRef>,
endswith: bool, // true for endswith, false for startswith
vm: &VirtualMachine,
) -> PyResult<bool> {
let suff = match arg {
Either::A(byte) => byte.elements,
Either::B(tuple) => {
let mut flatten = vec![];
for v in tuple.as_slice() {
flatten.extend(PyByteInner::try_from_object(vm, v.clone())?.elements)
}
flatten
}
};

if suff.is_empty() {
return Ok(true);
}
let range = self.elements.get_slice_range(
&is_valid_slice_arg(start, vm)?,
&is_valid_slice_arg(end, vm)?,
);

if range.end - range.start < suff.len() {
return Ok(false);
}

let offset = if endswith {
(range.end - suff.len())..range.end
} else {
0..suff.len()
};

Ok(suff.as_slice() == &self.elements.do_slice(range)[offset])
}

#[inline]
pub fn find(
&self,
Expand Down Expand Up @@ -1342,6 +1300,14 @@ pub fn bytes_zfill(bytes: &[u8], width: usize) -> Vec<u8> {
const ASCII_WHITESPACES: [u8; 6] = [0x20, 0x09, 0x0a, 0x0c, 0x0d, 0x0b];

impl PyCommonString<'_, u8> for [u8] {
fn get_slice(&self, range: std::ops::Range<usize>) -> &Self {
&self[range]
}

fn len(&self) -> usize {
Self::len(self)
}

fn py_split_whitespace<F>(&self, maxsplit: isize, convert: F) -> Vec<PyObjectRef>
where
F: Fn(&Self) -> PyObjectRef,
Expand Down
36 changes: 26 additions & 10 deletions vm/src/obj/objbytes.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crossbeam_utils::atomic::AtomicCell;
use std::mem::size_of;
use std::ops::Deref;
use std::str::FromStr;

use super::objbyteinner::{
ByteInnerExpandtabsOptions, ByteInnerFindOptions, ByteInnerNewOptions, ByteInnerPaddingOptions,
Expand All @@ -10,8 +11,8 @@ use super::objint::PyIntRef;
use super::objiter;
use super::objslice::PySliceRef;
use super::objstr::{PyString, PyStringRef};
use super::objtuple::PyTupleRef;
use super::objtype::PyClassRef;
use super::pystr::PyCommonString;
use crate::cformat::CFormatString;
use crate::function::{OptionalArg, OptionalOption};
use crate::obj::objstr::do_cformat_string;
Expand All @@ -23,7 +24,6 @@ use crate::pyobject::{
ThreadSafe, TryFromObject, TypeProtocol,
};
use crate::vm::VirtualMachine;
use std::str::FromStr;

/// "bytes(iterable_of_ints) -> bytes\n\
/// bytes(string, encoding[, errors]) -> bytes\n\
Expand Down Expand Up @@ -275,23 +275,39 @@ impl PyBytes {
#[pymethod(name = "endswith")]
fn endswith(
&self,
suffix: Either<PyByteInner, PyTupleRef>,
start: OptionalArg<PyObjectRef>,
end: OptionalArg<PyObjectRef>,
suffix: PyObjectRef,
start: OptionalArg<Option<isize>>,
end: OptionalArg<Option<isize>>,
vm: &VirtualMachine,
) -> PyResult<bool> {
self.inner.startsendswith(suffix, start, end, true, vm)
self.inner.elements[..].py_startsendswith(
suffix,
start,
end,
"endswith",
"bytes",
|s, x: &PyByteInner| s.ends_with(&x.elements[..]),
vm,
)
}

#[pymethod(name = "startswith")]
fn startswith(
&self,
prefix: Either<PyByteInner, PyTupleRef>,
start: OptionalArg<PyObjectRef>,
end: OptionalArg<PyObjectRef>,
prefix: PyObjectRef,
start: OptionalArg<Option<isize>>,
end: OptionalArg<Option<isize>>,
vm: &VirtualMachine,
) -> PyResult<bool> {
self.inner.startsendswith(prefix, start, end, false, vm)
self.inner.elements[..].py_startsendswith(
prefix,
start,
end,
"startswith",
"bytes",
|s, x: &PyByteInner| s.starts_with(&x.elements[..]),
vm,
)
}

#[pymethod(name = "find")]
Expand Down
Loading