Skip to content

Commit 8ff947e

Browse files
authored
Merge pull request #4677 from youknowone/arg-index
apply ArgIndex
2 parents d5fd7ff + 10ccbc6 commit 8ff947e

File tree

18 files changed

+225
-154
lines changed

18 files changed

+225
-154
lines changed

stdlib/src/array.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,23 @@ mod array {
509509
fn to_object(self, vm: &VirtualMachine) -> PyObjectRef;
510510
}
511511

512-
macro_rules! impl_array_element {
512+
macro_rules! impl_int_element {
513+
($($t:ty,)*) => {$(
514+
impl ArrayElement for $t {
515+
fn try_into_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
516+
obj.try_index(vm)?.try_to_primitive(vm)
517+
}
518+
fn byteswap(self) -> Self {
519+
<$t>::swap_bytes(self)
520+
}
521+
fn to_object(self, vm: &VirtualMachine) -> PyObjectRef {
522+
self.to_pyobject(vm)
523+
}
524+
}
525+
)*};
526+
}
527+
528+
macro_rules! impl_float_element {
513529
($(($t:ty, $f_from:path, $f_swap:path, $f_to:path),)*) => {$(
514530
impl ArrayElement for $t {
515531
fn try_into_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
@@ -525,15 +541,8 @@ mod array {
525541
)*};
526542
}
527543

528-
impl_array_element!(
529-
(i8, i8::try_from_object, i8::swap_bytes, PyInt::from),
530-
(u8, u8::try_from_object, u8::swap_bytes, PyInt::from),
531-
(i16, i16::try_from_object, i16::swap_bytes, PyInt::from),
532-
(u16, u16::try_from_object, u16::swap_bytes, PyInt::from),
533-
(i32, i32::try_from_object, i32::swap_bytes, PyInt::from),
534-
(u32, u32::try_from_object, u32::swap_bytes, PyInt::from),
535-
(i64, i64::try_from_object, i64::swap_bytes, PyInt::from),
536-
(u64, u64::try_from_object, u64::swap_bytes, PyInt::from),
544+
impl_int_element!(i8, u8, i16, u16, i32, u32, i64, u64,);
545+
impl_float_element!(
537546
(
538547
f32,
539548
f32_try_into_from_object,

stdlib/src/bisect.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,28 @@ pub(crate) use _bisect::make_module;
33
#[pymodule]
44
mod _bisect {
55
use crate::vm::{
6-
function::OptionalArg, types::PyComparisonOp, PyObjectRef, PyResult, VirtualMachine,
6+
function::{ArgIndex, OptionalArg},
7+
types::PyComparisonOp,
8+
PyObjectRef, PyResult, VirtualMachine,
79
};
810

911
#[derive(FromArgs)]
1012
struct BisectArgs {
1113
a: PyObjectRef,
1214
x: PyObjectRef,
1315
#[pyarg(any, optional)]
14-
lo: OptionalArg<PyObjectRef>,
16+
lo: OptionalArg<ArgIndex>,
1517
#[pyarg(any, optional)]
16-
hi: OptionalArg<PyObjectRef>,
18+
hi: OptionalArg<ArgIndex>,
1719
#[pyarg(named, default)]
1820
key: Option<PyObjectRef>,
1921
}
2022

2123
// Handles objects that implement __index__ and makes sure index fits in needed isize.
2224
#[inline]
23-
fn handle_default(
24-
arg: OptionalArg<PyObjectRef>,
25-
vm: &VirtualMachine,
26-
) -> PyResult<Option<isize>> {
25+
fn handle_default(arg: OptionalArg<ArgIndex>, vm: &VirtualMachine) -> PyResult<Option<isize>> {
2726
arg.into_option()
28-
.map(|v| v.try_index(vm)?.try_to_primitive(vm))
27+
.map(|v| v.try_to_primitive(vm))
2928
.transpose()
3029
}
3130

@@ -38,8 +37,8 @@ mod _bisect {
3837
// input sequence.
3938
#[inline]
4039
fn as_usize(
41-
lo: OptionalArg<PyObjectRef>,
42-
hi: OptionalArg<PyObjectRef>,
40+
lo: OptionalArg<ArgIndex>,
41+
hi: OptionalArg<ArgIndex>,
4342
seq_len: usize,
4443
vm: &VirtualMachine,
4544
) -> PyResult<(usize, usize)> {

stdlib/src/math.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ pub(crate) use math::make_module;
44
mod math {
55
use crate::vm::{
66
builtins::{try_bigint_to_f64, try_f64_to_bigint, PyFloat, PyInt, PyIntRef, PyStrInterned},
7-
function::{ArgIntoFloat, ArgIterable, Either, OptionalArg, PosArgs},
7+
function::{ArgIndex, ArgIntoFloat, ArgIterable, Either, OptionalArg, PosArgs},
88
identifier, PyObject, PyObjectRef, PyRef, PyResult, VirtualMachine,
99
};
1010
use itertools::Itertools;
@@ -218,9 +218,8 @@ mod math {
218218
}
219219

220220
#[pyfunction]
221-
fn isqrt(x: PyObjectRef, vm: &VirtualMachine) -> PyResult<BigInt> {
222-
let index = x.try_index(vm)?;
223-
let value = index.as_bigint();
221+
fn isqrt(x: ArgIndex, vm: &VirtualMachine) -> PyResult<BigInt> {
222+
let value = x.as_bigint();
224223

225224
if value.is_negative() {
226225
return Err(vm.new_value_error("isqrt() argument must be nonnegative".to_owned()));
@@ -575,7 +574,7 @@ mod math {
575574
}
576575
}
577576

578-
fn math_perf_arb_len_int_op<F>(args: PosArgs<PyIntRef>, op: F, default: BigInt) -> BigInt
577+
fn math_perf_arb_len_int_op<F>(args: PosArgs<ArgIndex>, op: F, default: BigInt) -> BigInt
579578
where
580579
F: Fn(&BigInt, &PyInt) -> BigInt,
581580
{
@@ -595,13 +594,13 @@ mod math {
595594
}
596595

597596
#[pyfunction]
598-
fn gcd(args: PosArgs<PyIntRef>) -> BigInt {
597+
fn gcd(args: PosArgs<ArgIndex>) -> BigInt {
599598
use num_integer::Integer;
600599
math_perf_arb_len_int_op(args, |x, y| x.gcd(y.as_bigint()), BigInt::zero())
601600
}
602601

603602
#[pyfunction]
604-
fn lcm(args: PosArgs<PyIntRef>) -> BigInt {
603+
fn lcm(args: PosArgs<ArgIndex>) -> BigInt {
605604
use num_integer::Integer;
606605
math_perf_arb_len_int_op(args, |x, y| x.lcm(y.as_bigint()), BigInt::one())
607606
}
@@ -733,8 +732,8 @@ mod math {
733732

734733
#[pyfunction]
735734
fn perm(
736-
n: PyIntRef,
737-
k: OptionalArg<Option<PyIntRef>>,
735+
n: ArgIndex,
736+
k: OptionalArg<Option<ArgIndex>>,
738737
vm: &VirtualMachine,
739738
) -> PyResult<BigInt> {
740739
let n = n.as_bigint();
@@ -764,7 +763,7 @@ mod math {
764763
}
765764

766765
#[pyfunction]
767-
fn comb(n: PyIntRef, k: PyIntRef, vm: &VirtualMachine) -> PyResult<BigInt> {
766+
fn comb(n: ArgIndex, k: ArgIndex, vm: &VirtualMachine) -> PyResult<BigInt> {
768767
let mut k = k.as_bigint();
769768
let n = n.as_bigint();
770769
let one = BigInt::one();

stdlib/src/zlib.rs

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ pub(crate) use zlib::make_module;
22

33
#[pymodule]
44
mod zlib {
5-
use crate::common::lock::PyMutex;
65
use crate::vm::{
76
builtins::{PyBaseExceptionRef, PyBytes, PyBytesRef, PyIntRef, PyTypeRef},
8-
function::{ArgBytesLike, OptionalArg, OptionalOption},
7+
common::lock::PyMutex,
8+
function::{ArgBytesLike, ArgPrimitiveIndex, ArgSize, OptionalArg, OptionalOption},
99
PyPayload, PyResult, VirtualMachine,
1010
};
1111
use adler32::RollingAdler32 as Adler32;
@@ -233,9 +233,9 @@ mod zlib {
233233
#[pyarg(positional)]
234234
data: ArgBytesLike,
235235
#[pyarg(any, optional)]
236-
wbits: OptionalArg<i8>,
236+
wbits: OptionalArg<ArgPrimitiveIndex<i8>>,
237237
#[pyarg(any, optional)]
238-
bufsize: OptionalArg<usize>,
238+
bufsize: OptionalArg<ArgPrimitiveIndex<usize>>,
239239
}
240240

241241
/// Returns a bytes object containing the uncompressed data.
@@ -245,9 +245,9 @@ mod zlib {
245245
let wbits = arg.wbits;
246246
let bufsize = arg.bufsize;
247247
data.with_ref(|data| {
248-
let bufsize = bufsize.unwrap_or(DEF_BUF_SIZE);
248+
let bufsize = bufsize.into_primitive().unwrap_or(DEF_BUF_SIZE);
249249

250-
let mut d = header_from_wbits(wbits, vm)?.decompress();
250+
let mut d = header_from_wbits(wbits.into_primitive(), vm)?.decompress();
251251

252252
_decompress(data, &mut d, bufsize, None, false, vm).and_then(|(buf, stream_end)| {
253253
if stream_end {
@@ -265,7 +265,7 @@ mod zlib {
265265
#[pyfunction]
266266
fn decompressobj(args: DecompressobjArgs, vm: &VirtualMachine) -> PyResult<PyDecompress> {
267267
#[allow(unused_mut)]
268-
let mut decompress = header_from_wbits(args.wbits, vm)?.decompress();
268+
let mut decompress = header_from_wbits(args.wbits.into_primitive(), vm)?.decompress();
269269
#[cfg(feature = "zlib")]
270270
if let OptionalArg::Present(dict) = args.zdict {
271271
dict.with_ref(|d| decompress.set_dictionary(d).unwrap());
@@ -325,11 +325,8 @@ mod zlib {
325325

326326
#[pymethod]
327327
fn decompress(&self, args: DecompressArgs, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
328-
let max_length = if args.max_length == 0 {
329-
None
330-
} else {
331-
Some(args.max_length)
332-
};
328+
let max_length = args.max_length.value;
329+
let max_length = (max_length != 0).then_some(max_length);
333330
let data = args.data.borrow_buf();
334331
let data = &*data;
335332

@@ -362,12 +359,18 @@ mod zlib {
362359
}
363360

364361
#[pymethod]
365-
fn flush(&self, length: OptionalArg<isize>, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
362+
fn flush(&self, length: OptionalArg<ArgSize>, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
366363
let length = match length {
367-
OptionalArg::Present(l) if l <= 0 => {
368-
return Err(vm.new_value_error("length must be greater than zero".to_owned()));
364+
OptionalArg::Present(l) => {
365+
let l: isize = l.into();
366+
if l <= 0 {
367+
return Err(
368+
vm.new_value_error("length must be greater than zero".to_owned())
369+
);
370+
} else {
371+
l as usize
372+
}
369373
}
370-
OptionalArg::Present(l) => l as usize,
371374
OptionalArg::Missing => DEF_BUF_SIZE,
372375
};
373376

@@ -396,14 +399,17 @@ mod zlib {
396399
struct DecompressArgs {
397400
#[pyarg(positional)]
398401
data: ArgBytesLike,
399-
#[pyarg(any, default = "0")]
400-
max_length: usize,
402+
#[pyarg(
403+
any,
404+
default = "rustpython_vm::function::ArgPrimitiveIndex { value: 0 }"
405+
)]
406+
max_length: ArgPrimitiveIndex<usize>,
401407
}
402408

403409
#[derive(FromArgs)]
404410
struct DecompressobjArgs {
405411
#[pyarg(any, optional)]
406-
wbits: OptionalArg<i8>,
412+
wbits: OptionalArg<ArgPrimitiveIndex<i8>>,
407413
#[cfg(feature = "zlib")]
408414
#[pyarg(any, optional)]
409415
zdict: OptionalArg<ArgBytesLike>,
@@ -414,7 +420,7 @@ mod zlib {
414420
level: OptionalArg<i32>,
415421
// only DEFLATED is valid right now, it's w/e
416422
_method: OptionalArg<i32>,
417-
wbits: OptionalArg<i8>,
423+
wbits: OptionalArg<ArgPrimitiveIndex<i8>>,
418424
// these aren't used.
419425
_mem_level: OptionalArg<i32>, // this is memLevel in CPython
420426
_strategy: OptionalArg<i32>,
@@ -423,7 +429,7 @@ mod zlib {
423429
) -> PyResult<PyCompress> {
424430
let level = compression_from_int(level.into_option())
425431
.ok_or_else(|| vm.new_value_error("invalid initialization option".to_owned()))?;
426-
let compress = header_from_wbits(wbits, vm)?.compress(level);
432+
let compress = header_from_wbits(wbits.into_primitive(), vm)?.compress(level);
427433
Ok(PyCompress {
428434
inner: PyMutex::new(CompressInner {
429435
compress,

vm/src/builtins/bytearray.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ use crate::{
2020
},
2121
},
2222
convert::{ToPyObject, ToPyResult},
23-
function::Either,
2423
function::{
25-
ArgBytesLike, ArgIterable, FuncArgs, OptionalArg, OptionalOption, PyComparisonValue,
24+
ArgBytesLike, ArgIterable, ArgSize, Either, FuncArgs, OptionalArg, OptionalOption,
25+
PyComparisonValue,
2626
},
2727
protocol::{
2828
BufferDescriptor, BufferMethods, BufferResizeGuard, PyBuffer, PyIterReturn,
@@ -67,6 +67,10 @@ impl PyByteArray {
6767
pub fn borrow_buf_mut(&self) -> PyMappedRwLockWriteGuard<'_, Vec<u8>> {
6868
PyRwLockWriteGuard::map(self.inner.write(), |inner| &mut inner.elements)
6969
}
70+
71+
fn repeat(&self, value: isize, vm: &VirtualMachine) -> PyResult<Self> {
72+
self.inner().mul(value, vm).map(|x| x.into())
73+
}
7074
}
7175

7276
impl From<PyBytesInner> for PyByteArray {
@@ -658,13 +662,13 @@ impl PyByteArray {
658662

659663
#[pymethod(name = "__rmul__")]
660664
#[pymethod(magic)]
661-
fn mul(&self, value: isize, vm: &VirtualMachine) -> PyResult<Self> {
662-
self.inner().mul(value, vm).map(|x| x.into())
665+
fn mul(&self, value: ArgSize, vm: &VirtualMachine) -> PyResult<Self> {
666+
self.repeat(value.into(), vm)
663667
}
664668

665669
#[pymethod(magic)]
666-
fn imul(zelf: PyRef<Self>, value: isize, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
667-
Self::irepeat(&zelf, value, vm)?;
670+
fn imul(zelf: PyRef<Self>, value: ArgSize, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
671+
Self::irepeat(&zelf, value.into(), vm)?;
668672
Ok(zelf)
669673
}
670674

@@ -820,7 +824,7 @@ impl AsSequence for PyByteArray {
820824
}),
821825
repeat: atomic_func!(|seq, n, vm| {
822826
PyByteArray::sequence_downcast(seq)
823-
.mul(n, vm)
827+
.repeat(n, vm)
824828
.map(|x| x.into_pyobject(vm))
825829
}),
826830
item: atomic_func!(|seq, i, vm| {
@@ -849,7 +853,8 @@ impl AsSequence for PyByteArray {
849853
}),
850854
inplace_repeat: atomic_func!(|seq, n, vm| {
851855
let zelf = PyByteArray::sequence_downcast(seq).to_owned();
852-
PyByteArray::imul(zelf, n, vm).map(|x| x.into())
856+
PyByteArray::irepeat(&zelf, n, vm)?;
857+
Ok(zelf.into())
853858
}),
854859
};
855860
&AS_SEQUENCE

vm/src/builtins/bytes.rs

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ use crate::{
1111
class::PyClassImpl,
1212
common::{hash::PyHash, lock::PyMutex},
1313
convert::{ToPyObject, ToPyResult},
14-
function::Either,
15-
function::{ArgBytesLike, ArgIterable, OptionalArg, OptionalOption, PyComparisonValue},
14+
function::{
15+
ArgBytesLike, ArgIndex, ArgIterable, Either, OptionalArg, OptionalOption, PyComparisonValue,
16+
},
1617
protocol::{
1718
BufferDescriptor, BufferMethods, PyBuffer, PyIterReturn, PyMappingMethods, PyNumberMethods,
1819
PySequenceMethods,
@@ -99,6 +100,19 @@ impl PyBytes {
99100
pub fn new_ref(data: Vec<u8>, ctx: &Context) -> PyRef<Self> {
100101
PyRef::new_ref(Self::from(data), ctx.types.bytes_type.to_owned(), None)
101102
}
103+
104+
fn repeat(zelf: PyRef<Self>, count: isize, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
105+
if count == 1 && zelf.class().is(vm.ctx.types.bytes_type) {
106+
// Special case: when some `bytes` is multiplied by `1`,
107+
// nothing really happens, we need to return an object itself
108+
// with the same `id()` to be compatible with CPython.
109+
// This only works for `bytes` itself, not its subclasses.
110+
return Ok(zelf);
111+
}
112+
zelf.inner
113+
.mul(count, vm)
114+
.map(|x| Self::from(x).into_ref(vm))
115+
}
102116
}
103117

104118
#[pyclass(
@@ -497,17 +511,8 @@ impl PyBytes {
497511

498512
#[pymethod(name = "__rmul__")]
499513
#[pymethod(magic)]
500-
fn mul(zelf: PyRef<Self>, value: isize, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
501-
if value == 1 && zelf.class().is(vm.ctx.types.bytes_type) {
502-
// Special case: when some `bytes` is multiplied by `1`,
503-
// nothing really happens, we need to return an object itself
504-
// with the same `id()` to be compatible with CPython.
505-
// This only works for `bytes` itself, not its subclasses.
506-
return Ok(zelf);
507-
}
508-
zelf.inner
509-
.mul(value, vm)
510-
.map(|x| Self::from(x).into_ref(vm))
514+
fn mul(zelf: PyRef<Self>, value: ArgIndex, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
515+
Self::repeat(zelf, value.try_to_primitive(vm)?, vm)
511516
}
512517

513518
#[pymethod(name = "__mod__")]
@@ -605,7 +610,7 @@ impl AsSequence for PyBytes {
605610
}),
606611
repeat: atomic_func!(|seq, n, vm| {
607612
if let Ok(zelf) = seq.obj.to_owned().downcast::<PyBytes>() {
608-
PyBytes::mul(zelf, n, vm).to_pyresult(vm)
613+
PyBytes::repeat(zelf, n, vm).to_pyresult(vm)
609614
} else {
610615
Err(vm.new_type_error("bad argument type for built-in operation".to_owned()))
611616
}

0 commit comments

Comments
 (0)