Skip to content

Commit 5450f8e

Browse files
committed
use vm.try_class, early return on unsupported op error
1 parent 03b7299 commit 5450f8e

File tree

1 file changed

+41
-50
lines changed

1 file changed

+41
-50
lines changed

vm/src/stdlib/io.rs

+41-50
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use crate::obj::objbytes::PyBytes;
1717
use crate::obj::objint;
1818
use crate::obj::objstr;
1919
use crate::obj::objtype;
20-
use crate::obj::objtype::{PyClass, PyClassRef};
20+
use crate::obj::objtype::PyClassRef;
2121
use crate::pyobject::TypeProtocol;
2222
use crate::pyobject::{BufferProtocol, PyObjectRef, PyRef, PyResult, PyValue};
2323
use crate::vm::VirtualMachine;
@@ -442,32 +442,27 @@ fn text_io_wrapper_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
442442
fn text_io_base_read(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
443443
arg_check!(vm, args, required = [(text_io_base, None)]);
444444

445-
let io_module = vm.import("_io", &vm.ctx.new_tuple(vec![]), 0)?;
446-
let buffered_reader_class = vm
447-
.get_attribute(io_module.clone(), "BufferedReader")
448-
.unwrap()
449-
.downcast::<PyClass>()
450-
.unwrap();
445+
let buffered_reader_class = vm.try_class("_io", "BufferedReader")?;
451446
let raw = vm.get_attribute(text_io_base.clone(), "buffer").unwrap();
452447

453-
if objtype::isinstance(&raw, &buffered_reader_class) {
454-
if let Ok(bytes) = vm.call_method(&raw, "read", PyFuncArgs::default()) {
455-
let value = objbytes::get_value(&bytes).to_vec();
456-
457-
//format bytes into string
458-
let rust_string = String::from_utf8(value).map_err(|e| {
459-
vm.new_unicode_decode_error(format!(
460-
"cannot decode byte at index: {}",
461-
e.utf8_error().valid_up_to()
462-
))
463-
})?;
464-
Ok(vm.ctx.new_str(rust_string))
465-
} else {
466-
Err(vm.new_value_error("Error unpacking Bytes".to_string()))
467-
}
468-
} else {
448+
if !objtype::isinstance(&raw, &buffered_reader_class) {
469449
// TODO: this should be io.UnsupportedOperation error which derives both from ValueError *and* OSError
470-
Err(vm.new_value_error("not readable".to_string()))
450+
return Err(vm.new_value_error("not readable".to_string()));
451+
}
452+
453+
if let Ok(bytes) = vm.call_method(&raw, "read", PyFuncArgs::default()) {
454+
let value = objbytes::get_value(&bytes).to_vec();
455+
456+
//format bytes into string
457+
let rust_string = String::from_utf8(value).map_err(|e| {
458+
vm.new_unicode_decode_error(format!(
459+
"cannot decode byte at index: {}",
460+
e.utf8_error().valid_up_to()
461+
))
462+
})?;
463+
Ok(vm.ctx.new_str(rust_string))
464+
} else {
465+
Err(vm.new_value_error("Error unpacking Bytes".to_string()))
471466
}
472467
}
473468

@@ -478,36 +473,32 @@ fn text_io_base_write(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
478473
required = [(text_io_base, None), (obj, Some(vm.ctx.str_type()))]
479474
);
480475

481-
let io_module = vm.import("_io", &vm.ctx.new_tuple(vec![]), 0)?;
482-
let buffered_writer_class = vm
483-
.get_attribute(io_module.clone(), "BufferedWriter")
484-
.unwrap()
485-
.downcast::<PyClass>()
486-
.unwrap();
476+
let buffered_writer_class = vm.try_class("_io", "BufferedWriter")?;
487477
let raw = vm.get_attribute(text_io_base.clone(), "buffer").unwrap();
488-
if objtype::isinstance(&raw, &buffered_writer_class) {
489-
let write = vm
490-
.get_method(raw.clone(), "write")
491-
.ok_or_else(|| vm.new_attribute_error("BufferedWriter has no write method".to_owned()))
492-
.and_then(|it| it)?;
493-
let bytes = objstr::get_value(obj).into_bytes();
494-
495-
let len = vm.invoke(
496-
write,
497-
PyFuncArgs::new(vec![vm.ctx.new_bytes(bytes.clone())], vec![]),
498-
)?;
499-
let len = objint::get_value(&len).to_usize().ok_or_else(|| {
500-
vm.new_overflow_error("int to large to convert to Rust usize".to_string())
501-
})?;
502478

503-
// returns the count of unicode code points written
504-
Ok(vm
505-
.ctx
506-
.new_int(String::from_utf8_lossy(&bytes[0..len]).chars().count()))
507-
} else {
479+
if !objtype::isinstance(&raw, &buffered_writer_class) {
508480
// TODO: this should be io.UnsupportedOperation error which derives from ValueError and OSError
509-
Err(vm.new_value_error("not writable".to_string()))
481+
return Err(vm.new_value_error("not writable".to_string()));
510482
}
483+
484+
let write = vm
485+
.get_method(raw.clone(), "write")
486+
.ok_or_else(|| vm.new_attribute_error("BufferedWriter has no write method".to_owned()))
487+
.and_then(|it| it)?;
488+
let bytes = objstr::get_value(obj).into_bytes();
489+
490+
let len = vm.invoke(
491+
write,
492+
PyFuncArgs::new(vec![vm.ctx.new_bytes(bytes.clone())], vec![]),
493+
)?;
494+
let len = objint::get_value(&len).to_usize().ok_or_else(|| {
495+
vm.new_overflow_error("int to large to convert to Rust usize".to_string())
496+
})?;
497+
498+
// returns the count of unicode code points written
499+
Ok(vm
500+
.ctx
501+
.new_int(String::from_utf8_lossy(&bytes[0..len]).chars().count()))
511502
}
512503

513504
fn split_mode_string(mode_string: String) -> Result<(String, String), String> {

0 commit comments

Comments
 (0)