Skip to content

Add protocol object PyCallable #4654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions examples/call_between_rust_and_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@ pub fn main() {

let module = vm.import("call_between_rust_and_python", None, 0).unwrap();
let init_fn = module.get_attr("python_callback", vm).unwrap();
vm.invoke(&init_fn, ()).unwrap();
init_fn.call((), vm).unwrap();

let take_string_fn = module.get_attr("take_string", vm).unwrap();
vm.invoke(
&take_string_fn,
(String::from("Rust string sent to python"),),
)
.unwrap();
take_string_fn
.call((String::from("Rust string sent to python"),), vm)
.unwrap();
})
}

Expand Down
2 changes: 1 addition & 1 deletion examples/package_embed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ fn py_main(interp: &Interpreter) -> vm::PyResult<PyStrRef> {
.expect("add path");
let module = vm.import("package_embed", None, 0)?;
let name_func = module.get_attr("context", vm)?;
let result = vm.invoke(&name_func, ())?;
let result = name_func.call((), vm)?;
let result: PyStrRef = result.get_attr("name", vm)?.try_into_value(vm)?;
vm::PyResult::Ok(result)
})
Expand Down
13 changes: 6 additions & 7 deletions stdlib/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyObjectRef {
.get_attr("MutableSequence", vm)
.expect("Expect collections.abc has MutableSequence type.");

vm.invoke(
&mutable_sequence
.get_attr("register", vm)
.expect("Expect collections.abc.MutableSequence has register method."),
(array,),
)
.expect("Expect collections.abc.MutableSequence.register(array.array) not fail.");
let register = &mutable_sequence
.get_attr("register", vm)
.expect("Expect collections.abc.MutableSequence has register method.");
register
.call((array,), vm)
.expect("Expect collections.abc.MutableSequence.register(array.array) not fail.");

module
}
Expand Down
8 changes: 4 additions & 4 deletions stdlib/src/bisect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ mod _bisect {
let mid = (lo + hi) / 2;
let a_mid = a.get_item(&mid, vm)?;
let comp = if let Some(ref key) = key {
vm.invoke(key, (a_mid,))?
key.call((a_mid,), vm)?
} else {
a_mid
};
Expand All @@ -96,7 +96,7 @@ mod _bisect {
let mid = (lo + hi) / 2;
let a_mid = a.get_item(&mid, vm)?;
let comp = if let Some(ref key) = key {
vm.invoke(key, (a_mid,))?
key.call((a_mid,), vm)?
} else {
a_mid
};
Expand All @@ -112,7 +112,7 @@ mod _bisect {
#[pyfunction]
fn insort_left(BisectArgs { a, x, lo, hi, key }: BisectArgs, vm: &VirtualMachine) -> PyResult {
let x = if let Some(ref key) = key {
vm.invoke(key, (x,))?
key.call((x,), vm)?
} else {
x
};
Expand All @@ -132,7 +132,7 @@ mod _bisect {
#[pyfunction]
fn insort_right(BisectArgs { a, x, lo, hi, key }: BisectArgs, vm: &VirtualMachine) -> PyResult {
let x = if let Some(ref key) = key {
vm.invoke(key, (x,))?
key.call((x,), vm)?
} else {
x
};
Expand Down
4 changes: 2 additions & 2 deletions stdlib/src/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ mod _csv {
) -> PyResult<Writer> {
let write = match vm.get_attribute_opt(file.clone(), "write")? {
Some(write_meth) => write_meth,
None if vm.is_callable(&file) => file,
None if file.is_callable() => file,
None => {
return Err(vm.new_type_error("argument 1 must have a \"write\" method".to_owned()))
}
Expand Down Expand Up @@ -309,7 +309,7 @@ mod _csv {
let s = std::str::from_utf8(&buffer[..buffer_offset])
.map_err(|_| vm.new_unicode_decode_error("csv not utf8".to_owned()))?;

vm.invoke(&self.write, (s.to_owned(),))
self.write.call((s,), vm)
}

#[pymethod]
Expand Down
37 changes: 16 additions & 21 deletions stdlib/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ mod machinery;
mod _json {
use super::machinery;
use crate::vm::{
builtins::{PyBaseExceptionRef, PyStrRef, PyTypeRef},
builtins::{PyBaseExceptionRef, PyStrRef, PyType, PyTypeRef},
convert::{ToPyObject, ToPyResult},
function::OptionalArg,
function::{IntoFuncArgs, OptionalArg},
protocol::PyIterReturn,
types::{Callable, Constructor},
AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
Expand Down Expand Up @@ -91,25 +91,23 @@ mod _json {
'{' => {
// TODO: parse the object in rust
let parse_obj = self.ctx.get_attr("parse_object", vm)?;
return PyIterReturn::from_pyresult(
vm.invoke(
&parse_obj,
(
(pystr, next_idx),
self.strict,
scan_once,
self.object_hook.clone(),
self.object_pairs_hook.clone(),
),
let result = parse_obj.call(
(
(pystr, next_idx),
self.strict,
scan_once,
self.object_hook.clone(),
self.object_pairs_hook.clone(),
),
vm,
);
return PyIterReturn::from_pyresult(result, vm);
}
'[' => {
// TODO: parse the array in rust
let parse_array = self.ctx.get_attr("parse_array", vm)?;
return PyIterReturn::from_pyresult(
vm.invoke(&parse_array, ((pystr, next_idx), scan_once)),
parse_array.call(((pystr, next_idx), scan_once), vm),
vm,
);
}
Expand Down Expand Up @@ -138,11 +136,8 @@ mod _json {
($s:literal) => {
if s.starts_with($s) {
return Ok(PyIterReturn::Return(
vm.new_tuple((
vm.invoke(&self.parse_constant, ($s.to_owned(),))?,
idx + $s.len(),
))
.into(),
vm.new_tuple((self.parse_constant.call(($s,), vm)?, idx + $s.len()))
.into(),
));
}
};
Expand Down Expand Up @@ -181,12 +176,12 @@ mod _json {
let ret = if has_decimal || has_exponent {
// float
if let Some(ref parse_float) = self.parse_float {
vm.invoke(parse_float, (buf.to_owned(),))
parse_float.call((buf,), vm)
} else {
Ok(vm.ctx.new_float(f64::from_str(buf).unwrap()).into())
}
} else if let Some(ref parse_int) = self.parse_int {
vm.invoke(parse_int, (buf.to_owned(),))
parse_int.call((buf,), vm)
} else {
Ok(vm.new_pyobj(BigInt::from_str(buf).unwrap()))
};
Expand Down Expand Up @@ -243,7 +238,7 @@ mod _json {
) -> PyBaseExceptionRef {
let get_error = || -> PyResult<_> {
let cls = vm.try_class("json", "JSONDecodeError")?;
let exc = vm.invoke(&cls, (e.msg, s, e.pos))?;
let exc = PyType::call(&cls, (e.msg, s, e.pos).into_args(vm), vm)?;
exc.try_into_value(vm)
};
match get_error() {
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ mod math {
func_name.as_str(),
)
})?;
vm.invoke(&method, ())
method.call((), vm)
}

#[pyfunction]
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/pyexpat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ mod _pyexpat {
where
T: IntoFuncArgs,
{
vm.invoke(&handler.read().clone(), args).ok();
handler.read().call(args, vm).ok();
}

#[pyclass]
Expand Down
2 changes: 1 addition & 1 deletion stdlib/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl TryFromObject for Selectable {
vm.ctx.interned_str("fileno").unwrap(),
|| "select arg must be an int or object with a fileno() method".to_owned(),
)?;
vm.invoke(&meth, ())?.try_into_value(vm)
meth.call((), vm)?.try_into_value(vm)
})?;
Ok(Selectable { obj, fno })
}
Expand Down
28 changes: 15 additions & 13 deletions stdlib/src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ mod _sqlite {
.map(|val| value_to_object(val, db, vm))
.collect::<PyResult<Vec<PyObjectRef>>>()?;

let val = vm.invoke(func, args)?;
let val = func.call(args, vm)?;

context.result_from_object(&val, vm)
};
Expand All @@ -410,7 +410,7 @@ mod _sqlite {
let args = std::slice::from_raw_parts(argv, argc as usize);
let instance = context.aggregate_context::<*const PyObject>();
if (*instance).is_null() {
match vm.invoke(cls, ()) {
match cls.call((), vm) {
Ok(obj) => *instance = obj.into_raw(),
Err(exc) => {
return context.result_exception(
Expand Down Expand Up @@ -450,7 +450,7 @@ mod _sqlite {
let text2 = ptr_to_string(b_ptr.cast(), b_len, null_mut(), vm)?;
let text2 = vm.ctx.new_str(text2);

let val = vm.invoke(callable, (text1, text2))?;
let val = callable.call((text1, text2), vm)?;
let Some(val) = val.to_number().index(vm) else {
return Ok(0);
};
Expand Down Expand Up @@ -505,7 +505,7 @@ mod _sqlite {
let db_name = ptr_to_str(db_name, vm)?;
let access = ptr_to_str(access, vm)?;

let val = vm.invoke(callable, (action, arg1, arg2, db_name, access))?;
let val = callable.call((action, arg1, arg2, db_name, access), vm)?;
let Some(val) = val.payload::<PyInt>() else {
return Ok(SQLITE_DENY);
};
Expand All @@ -525,15 +525,16 @@ mod _sqlite {
let expanded = sqlite3_expanded_sql(stmt.cast());
let f = || -> PyResult<()> {
let stmt = ptr_to_str(expanded, vm).or_else(|_| ptr_to_str(sql.cast(), vm))?;
vm.invoke(callable, (stmt,)).map(drop)
callable.call((stmt,), vm)?;
Ok(())
};
let _ = f();
0
}

unsafe extern "C" fn progress_callback(data: *mut c_void) -> c_int {
let (callable, vm) = (*data.cast::<Self>()).retrive();
if let Ok(val) = vm.invoke(callable, ()) {
if let Ok(val) = callable.call((), vm) {
if let Ok(val) = val.is_true(vm) {
return val as c_int;
}
Expand Down Expand Up @@ -661,10 +662,10 @@ mod _sqlite {
.new_tuple(vec![obj.class().to_owned().into(), proto.clone()]);

if let Some(adapter) = adapters().get_item_opt(key.as_object(), vm)? {
return vm.invoke(&adapter, (obj,));
return adapter.call((obj,), vm);
}
if let Ok(adapter) = proto.get_attr("__adapt__", vm) {
match vm.invoke(&adapter, (obj,)) {
match adapter.call((obj,), vm) {
Ok(val) => return Ok(val),
Err(exc) => {
if !exc.fast_isinstance(vm.ctx.exceptions.type_error) {
Expand All @@ -674,7 +675,7 @@ mod _sqlite {
}
}
if let Ok(adapter) = obj.get_attr("__conform__", vm) {
match vm.invoke(&adapter, (proto,)) {
match adapter.call((proto,), vm) {
Ok(val) => return Ok(val),
Err(exc) => {
if !exc.fast_isinstance(vm.ctx.exceptions.type_error) {
Expand Down Expand Up @@ -1228,7 +1229,7 @@ mod _sqlite {
fn iterdump(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
let module = vm.import("sqlite3.dump", None, 0)?;
let func = module.get_attr("_iterdump", vm)?;
vm.invoke(&func, (zelf,))
func.call((zelf,), vm)
}

#[pymethod]
Expand Down Expand Up @@ -1699,7 +1700,7 @@ mod _sqlite {
std::slice::from_raw_parts(blob.cast::<u8>(), nbytes as usize)
};
let blob = vm.ctx.new_bytes(blob.to_vec());
vm.invoke(&converter, (blob,))?
converter.call((blob,), vm)?
}
} else {
let col_type = st.column_type(i);
Expand All @@ -1724,7 +1725,7 @@ mod _sqlite {
PyByteArray::from(text).into_ref(vm).into()
} else {
let bytes = vm.ctx.new_bytes(text);
vm.invoke(&text_factory, (bytes,))?
text_factory.call((bytes,), vm)?
}
}
SQLITE_BLOB => {
Expand Down Expand Up @@ -1765,7 +1766,8 @@ mod _sqlite {
let row = vm.ctx.new_tuple(row);

if let Some(row_factory) = zelf.row_factory.to_owned() {
vm.invoke(&row_factory, (zelf.to_owned(), row))
row_factory
.call((zelf.to_owned(), row), vm)
.map(PyIterReturn::Return)
} else {
Ok(PyIterReturn::Return(row.into()))
Expand Down
4 changes: 2 additions & 2 deletions vm/src/builtins/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl PyObjectRef {
Some(method_or_err) => {
// If descriptor returns Error, propagate it further
let method = method_or_err?;
let bool_obj = vm.invoke(&method, ())?;
let bool_obj = method.call((), vm)?;
if !bool_obj.fast_isinstance(vm.ctx.types.bool_type) {
return Err(vm.new_type_error(format!(
"__bool__ should return bool, returned type {}",
Expand All @@ -50,7 +50,7 @@ impl PyObjectRef {
None => match vm.get_method(self, identifier!(vm, __len__)) {
Some(method_or_err) => {
let method = method_or_err?;
let bool_obj = vm.invoke(&method, ())?;
let bool_obj = method.call((), vm)?;
let int_obj = bool_obj.payload::<PyInt>().ok_or_else(|| {
vm.new_type_error(format!(
"'{}' object cannot be interpreted as an integer",
Expand Down
2 changes: 1 addition & 1 deletion vm/src/builtins/classmethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl GetDescriptor for PyClassMethod {
let call_descr_get: PyResult<PyObjectRef> = zelf.callable.lock().get_attr("__get__", vm);
match call_descr_get {
Err(_) => Ok(PyBoundMethod::new_ref(cls, zelf.callable.lock().clone(), &vm.ctx).into()),
Ok(call_descr_get) => vm.invoke(&call_descr_get, (cls.clone(), cls)),
Ok(call_descr_get) => call_descr_get.call((cls.clone(), cls), vm),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion vm/src/builtins/complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl PyObjectRef {
return Ok(Some((complex.value, true)));
}
if let Some(method) = vm.get_method(self.clone(), identifier!(vm, __complex__)) {
let result = vm.invoke(&method?, ())?;
let result = method?.call((), vm)?;
// TODO: returning strict subclasses of complex in __complex__ is deprecated
return match result.payload::<PyComplex>() {
Some(complex_obj) => Ok(Some((complex_obj.value, true))),
Expand Down
4 changes: 2 additions & 2 deletions vm/src/builtins/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl PyDict {
};
let dict = &self.entries;
if let Some(keys) = vm.get_method(other.clone(), vm.ctx.intern_str("keys")) {
let keys = vm.invoke(&keys?, ())?.get_iter(vm)?;
let keys = keys?.call((), vm)?.get_iter(vm)?;
while let PyIterReturn::Return(key) = keys.next(vm)? {
let val = other.get_item(&*key, vm)?;
dict.insert(vm, &*key, val)?;
Expand Down Expand Up @@ -511,7 +511,7 @@ impl Py<PyDict> {
vm: &VirtualMachine,
) -> PyResult<Option<PyObjectRef>> {
vm.get_method(self.to_owned().into(), identifier!(vm, __missing__))
.map(|methods| vm.invoke(&methods?, (key.to_pyobject(vm),)))
.map(|methods| methods?.call((key.to_pyobject(vm),), vm))
.transpose()
}

Expand Down
2 changes: 1 addition & 1 deletion vm/src/builtins/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl IterNext for PyFilter {
} else {
// the predicate itself can raise StopIteration which does stop the filter
// iteration
match PyIterReturn::from_pyresult(vm.invoke(predicate, (next_obj.clone(),)), vm)? {
match PyIterReturn::from_pyresult(predicate.call((next_obj.clone(),), vm), vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)),
}
Expand Down
2 changes: 1 addition & 1 deletion vm/src/builtins/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ impl Callable for PyBoundMethod {
#[inline]
fn call(zelf: &crate::Py<Self>, mut args: FuncArgs, vm: &VirtualMachine) -> PyResult {
args.prepend_arg(zelf.object.clone());
vm.invoke(&zelf.function, args)
zelf.function.call(args, vm)
}
}

Expand Down
Loading