diff --git a/derive/src/lib.rs b/derive/src/lib.rs index cfda0f7736..46547071e9 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -12,6 +12,7 @@ mod error; mod compile_bytecode; mod from_args; mod pyclass; +mod pymodule; mod util; use error::{extract_spans, Diagnostic}; @@ -44,6 +45,13 @@ pub fn pyimpl(attr: TokenStream, item: TokenStream) -> TokenStream { result_to_tokens(pyclass::impl_pyimpl(attr, item)) } +#[proc_macro_attribute] +pub fn pymodule(attr: TokenStream, item: TokenStream) -> TokenStream { + let attr = parse_macro_input!(attr as AttributeArgs); + let item = parse_macro_input!(item as Item); + result_to_tokens(pymodule::impl_pymodule(attr, item)) +} + #[proc_macro_attribute] pub fn pystruct_sequence(attr: TokenStream, item: TokenStream) -> TokenStream { let attr = parse_macro_input!(attr as AttributeArgs); diff --git a/derive/src/pyclass.rs b/derive/src/pyclass.rs index 42194395a0..1d64740480 100644 --- a/derive/src/pyclass.rs +++ b/derive/src/pyclass.rs @@ -1,11 +1,11 @@ use super::Diagnostic; -use crate::util::path_eq; +use crate::util::{def_to_name, path_eq, strip_prefix, ItemIdent, ItemMeta}; use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, quote_spanned, ToTokens}; use std::collections::{HashMap, HashSet}; use syn::{ parse_quote, spanned::Spanned, Attribute, AttributeArgs, Ident, Index, Item, Lit, Meta, - NestedMeta, Signature, + NestedMeta, }; fn meta_to_vec(meta: Meta) -> Result<Vec<NestedMeta>, Meta> { @@ -42,181 +42,6 @@ enum ClassItem { }, } -struct ClassItemMeta<'a> { - sig: &'a Signature, - parent_type: &'static str, - meta: HashMap<String, Option<Lit>>, -} - -impl<'a> ClassItemMeta<'a> { - const METHOD_NAMES: &'static [&'static str] = &["name", "magic"]; - const PROPERTY_NAMES: &'static [&'static str] = &["name", "magic", "setter"]; - - fn from_nested_meta( - parent_type: &'static str, - sig: &'a Signature, - nested_meta: &[NestedMeta], - names: &[&'static str], - ) -> Result<Self, Diagnostic> { - let mut extracted = Self { - sig, - parent_type, - meta: HashMap::new(), - }; - - let validate_name = |name: &str, extracted: &Self| -> Result<(), Diagnostic> { - if names.contains(&name) { - if extracted.meta.contains_key(name) { - bail_span!( - &sig.ident, - "#[{}] must have only one '{}'", - parent_type, - name - ); - } else { - Ok(()) - } - } else { - bail_span!( - &sig.ident, - "#[{}({})] is not one of allowed attributes {}", - parent_type, - name, - names.join(", ") - ); - } - }; - - for meta in nested_meta { - let meta = match meta { - NestedMeta::Meta(meta) => meta, - NestedMeta::Lit(_) => continue, - }; - - match meta { - Meta::NameValue(name_value) => { - if let Some(ident) = name_value.path.get_ident() { - let name = ident.to_string(); - validate_name(&name, &extracted)?; - extracted.meta.insert(name, Some(name_value.lit.clone())); - } - } - Meta::Path(path) => { - if let Some(ident) = path.get_ident() { - let name = ident.to_string(); - validate_name(&name, &extracted)?; - extracted.meta.insert(name, None); - } else { - continue; - } - } - _ => (), - } - } - - Ok(extracted) - } - - fn _str(&self, key: &str) -> Result<Option<String>, Diagnostic> { - Ok(match self.meta.get(key) { - Some(Some(lit)) => { - if let Lit::Str(s) = lit { - Some(s.value()) - } else { - bail_span!( - &self.sig.ident, - "#[{}({} = ...)] must be a string", - self.parent_type, - key - ); - } - } - Some(None) => { - bail_span!( - &self.sig.ident, - "#[{}({} = ...)] is expected", - self.parent_type, - key, - ); - } - None => None, - }) - } - - fn _bool(&self, key: &str) -> Result<bool, Diagnostic> { - Ok(match self.meta.get(key) { - Some(Some(_)) => { - bail_span!( - &self.sig.ident, - "#[{}({})] is expected", - self.parent_type, - key, - ); - } - Some(None) => true, - None => false, - }) - } - - fn method_name(&self) -> Result<String, Diagnostic> { - let name = self._str("name")?; - let magic = self._bool("magic")?; - Ok(if let Some(name) = name { - name - } else { - let name = self.sig.ident.to_string(); - if magic { - format!("__{}__", name) - } else { - name - } - }) - } - - fn setter(&self) -> Result<bool, Diagnostic> { - self._bool("setter") - } - - fn property_name(&self) -> Result<String, Diagnostic> { - let magic = self._bool("magic")?; - let setter = self._bool("setter")?; - let name = self._str("name")?; - - Ok(if let Some(name) = name { - name - } else { - let sig_name = self.sig.ident.to_string(); - let name = if setter { - if let Some(name) = strip_prefix(&sig_name, "set_") { - if name.is_empty() { - bail_span!( - &self.sig.ident, - "A #[{}(setter)] fn with a set_* name must \ - have something after \"set_\"", - self.parent_type - ) - } - name.to_string() - } else { - bail_span!( - &self.sig.ident, - "A #[{}(setter)] fn must either have a `name` \ - parameter or a fn name along the lines of \"set_*\"", - self.parent_type - ) - } - } else { - sig_name - }; - if magic { - format!("__{}__", name) - } else { - name - } - }) - } -} - impl Class { fn add_item(&mut self, item: ClassItem, span: Span) -> Result<(), Diagnostic> { if self.items.insert(item) { @@ -229,7 +54,7 @@ impl Class { } } - fn extract_method(sig: &Signature, meta: Meta) -> Result<ClassItem, Diagnostic> { + fn extract_method(ident: &Ident, meta: Meta) -> Result<ClassItem, Diagnostic> { let nesteds = meta_to_vec(meta).map_err(|meta| { err_span!( meta, @@ -238,19 +63,15 @@ impl Class { ) })?; - let item_meta = ClassItemMeta::from_nested_meta( - "pymethod", - sig, - &nesteds, - ClassItemMeta::METHOD_NAMES, - )?; + let item_meta = + ItemMeta::from_nested_meta("pymethod", &ident, &nesteds, ItemMeta::ATTRIBUTE_NAMES)?; Ok(ClassItem::Method { - item_ident: sig.ident.clone(), + item_ident: ident.clone(), py_name: item_meta.method_name()?, }) } - fn extract_classmethod(sig: &Signature, meta: Meta) -> Result<ClassItem, Diagnostic> { + fn extract_classmethod(ident: &Ident, meta: Meta) -> Result<ClassItem, Diagnostic> { let nesteds = meta_to_vec(meta).map_err(|meta| { err_span!( meta, @@ -258,19 +79,19 @@ impl Class { #[pyclassmethod(name = \"...\")]", ) })?; - let item_meta = ClassItemMeta::from_nested_meta( + let item_meta = ItemMeta::from_nested_meta( "pyclassmethod", - sig, + &ident, &nesteds, - ClassItemMeta::METHOD_NAMES, + ItemMeta::ATTRIBUTE_NAMES, )?; Ok(ClassItem::ClassMethod { - item_ident: sig.ident.clone(), + item_ident: ident.clone(), py_name: item_meta.method_name()?, }) } - fn extract_property(sig: &Signature, meta: Meta) -> Result<ClassItem, Diagnostic> { + fn extract_property(ident: &Ident, meta: Meta) -> Result<ClassItem, Diagnostic> { let nesteds = meta_to_vec(meta).map_err(|meta| { err_span!( meta, @@ -278,31 +99,27 @@ impl Class { #[pyproperty(name = \"...\")]" ) })?; - let item_meta = ClassItemMeta::from_nested_meta( - "pyproperty", - sig, - &nesteds, - ClassItemMeta::PROPERTY_NAMES, - )?; + let item_meta = + ItemMeta::from_nested_meta("pyproperty", &ident, &nesteds, ItemMeta::PROPERTY_NAMES)?; Ok(ClassItem::Property { py_name: item_meta.property_name()?, - item_ident: sig.ident.clone(), + item_ident: ident.clone(), setter: item_meta.setter()?, }) } - fn extract_slot(sig: &Signature, meta: Meta) -> Result<ClassItem, Diagnostic> { + fn extract_slot(ident: &Ident, meta: Meta) -> Result<ClassItem, Diagnostic> { let pyslot_err = "#[pyslot] must be of the form #[pyslot] or #[pyslot(slotname)]"; let nesteds = meta_to_vec(meta).map_err(|meta| err_span!(meta, "{}", pyslot_err))?; if nesteds.len() > 1 { return Err(Diagnostic::spanned_error("e!(#(#nesteds)*), pyslot_err)); } let slot_ident = if nesteds.is_empty() { - let ident_str = sig.ident.to_string(); + let ident_str = ident.to_string(); if let Some(stripped) = strip_prefix(&ident_str, "tp_") { - proc_macro2::Ident::new(stripped, sig.ident.span()) + proc_macro2::Ident::new(stripped, ident.span()) } else { - sig.ident.clone() + ident.clone() } } else { match nesteds.into_iter().next().unwrap() { @@ -315,14 +132,14 @@ impl Class { }; Ok(ClassItem::Slot { slot_ident, - item_ident: sig.ident.clone(), + item_ident: ident.clone(), }) } fn extract_item_from_syn( &mut self, attrs: &mut Vec<Attribute>, - sig: &Signature, + ident: &Ident, ) -> Result<(), Diagnostic> { let mut attr_idxs = Vec::new(); for (i, meta) in attrs @@ -335,17 +152,16 @@ impl Class { Some(name) => name, None => continue, }; - if name == "pymethod" { - self.add_item(Self::extract_method(sig, meta)?, meta_span)?; - } else if name == "pyclassmethod" { - self.add_item(Self::extract_classmethod(sig, meta)?, meta_span)?; - } else if name == "pyproperty" { - self.add_item(Self::extract_property(sig, meta)?, meta_span)?; - } else if name == "pyslot" { - self.add_item(Self::extract_slot(sig, meta)?, meta_span)?; - } else { - continue; - } + let item = match name.to_string().as_str() { + "pymethod" => Self::extract_method(ident, meta)?, + "pyclassmethod" => Self::extract_classmethod(ident, meta)?, + "pyproperty" => Self::extract_property(ident, meta)?, + "pyslot" => Self::extract_slot(ident, meta)?, + _ => { + continue; + } + }; + self.add_item(item, meta_span)?; attr_idxs.push(i); } let mut i = 0; @@ -365,12 +181,7 @@ impl Class { } } -struct ItemSig<'a> { - attrs: &'a mut Vec<Attribute>, - sig: &'a Signature, -} - -fn extract_impl_items(mut items: Vec<ItemSig>) -> Result<TokenStream2, Diagnostic> { +fn extract_impl_items(mut items: Vec<ItemIdent>) -> Result<TokenStream2, Diagnostic> { let mut diagnostics: Vec<Diagnostic> = Vec::new(); let mut class = Class::default(); @@ -378,7 +189,7 @@ fn extract_impl_items(mut items: Vec<ItemSig>) -> Result<TokenStream2, Diagnosti for item in items.iter_mut() { push_diag_result!( diagnostics, - class.extract_item_from_syn(&mut item.attrs, item.sig), + class.extract_item_from_syn(&mut item.attrs, &item.ident), ); } @@ -544,7 +355,10 @@ pub fn impl_pyimpl(attr: AttributeArgs, item: Item) -> Result<TokenStream2, Diag .iter_mut() .filter_map(|item| match item { syn::ImplItem::Method(syn::ImplItemMethod { attrs, sig, .. }) => { - Some(ItemSig { attrs, sig }) + Some(ItemIdent { + attrs, + ident: &sig.ident, + }) } _ => None, }) @@ -574,7 +388,10 @@ pub fn impl_pyimpl(attr: AttributeArgs, item: Item) -> Result<TokenStream2, Diag .iter_mut() .filter_map(|item| match item { syn::TraitItem::Method(syn::TraitItemMethod { attrs, sig, .. }) => { - Some(ItemSig { attrs, sig }) + Some(ItemIdent { + attrs, + ident: &sig.ident, + }) } _ => None, }) @@ -597,30 +414,9 @@ pub fn impl_pyimpl(attr: AttributeArgs, item: Item) -> Result<TokenStream2, Diag fn generate_class_def( ident: &Ident, - attr_name: &'static str, - attr: AttributeArgs, + name: &str, attrs: &[Attribute], ) -> Result<TokenStream2, Diagnostic> { - let mut class_name = None; - for attr in attr { - if let NestedMeta::Meta(meta) = attr { - if let Meta::NameValue(name_value) = meta { - if path_eq(&name_value.path, "name") { - if let Lit::Str(s) = name_value.lit { - class_name = Some(s.value()); - } else { - bail_span!( - name_value.lit, - "#[{}(name = ...)] must be a string", - attr_name - ); - } - } - } - } - } - let class_name = class_name.unwrap_or_else(|| ident.to_string()); - let mut doc: Option<Vec<String>> = None; for attr in attrs.iter() { if attr.path.is_ident("doc") { @@ -646,7 +442,7 @@ fn generate_class_def( let ret = quote! { impl ::rustpython_vm::pyobject::PyClassDef for #ident { - const NAME: &'static str = #class_name; + const NAME: &'static str = #name; const DOC: Option<&'static str> = #doc; } }; @@ -663,7 +459,8 @@ pub fn impl_pyclass(attr: AttributeArgs, item: Item) -> Result<TokenStream2, Dia ), }; - let class_def = generate_class_def(&ident, "pyclass", attr, &attrs)?; + let class_name = def_to_name(&ident, "pyclass", attr)?; + let class_def = generate_class_def(&ident, &class_name, &attrs)?; let ret = quote! { #item @@ -681,7 +478,8 @@ pub fn impl_pystruct_sequence(attr: AttributeArgs, item: Item) -> Result<TokenSt "#[pystruct_sequence] can only be on a struct declaration" ) }; - let class_def = generate_class_def(&struc.ident, "pystruct_sequence", attr, &struc.attrs)?; + let class_name = def_to_name(&struc.ident, "pystruct_sequence", attr)?; + let class_def = generate_class_def(&struc.ident, &class_name, &struc.attrs)?; let mut properties = Vec::new(); let mut field_names = Vec::new(); for (i, field) in struc.fields.iter().enumerate() { @@ -745,11 +543,3 @@ pub fn impl_pystruct_sequence(attr: AttributeArgs, item: Item) -> Result<TokenSt }; Ok(ret) } - -fn strip_prefix<'a>(s: &'a str, prefix: &str) -> Option<&'a str> { - if s.starts_with(prefix) { - Some(&s[prefix.len()..]) - } else { - None - } -} diff --git a/derive/src/pymodule.rs b/derive/src/pymodule.rs new file mode 100644 index 0000000000..1127152139 --- /dev/null +++ b/derive/src/pymodule.rs @@ -0,0 +1,217 @@ +use super::Diagnostic; +use crate::util::{def_to_name, ItemIdent, ItemMeta}; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::{quote, quote_spanned}; +use std::collections::HashSet; +use syn::{parse_quote, spanned::Spanned, Attribute, AttributeArgs, Ident, Item, Meta, NestedMeta}; + +fn meta_to_vec(meta: Meta) -> Result<Vec<NestedMeta>, Meta> { + match meta { + Meta::Path(_) => Ok(Vec::new()), + Meta::List(list) => Ok(list.nested.into_iter().collect()), + Meta::NameValue(_) => Err(meta), + } +} + +#[derive(Default)] +struct Module { + items: HashSet<ModuleItem>, +} + +#[derive(PartialEq, Eq, Hash)] +enum ModuleItem { + Function { item_ident: Ident, py_name: String }, + Class { item_ident: Ident, py_name: String }, +} + +impl Module { + fn add_item(&mut self, item: ModuleItem, span: Span) -> Result<(), Diagnostic> { + if self.items.insert(item) { + Ok(()) + } else { + Err(Diagnostic::span_error( + span, + "Duplicate #[py*] attribute on pyimpl".to_owned(), + )) + } + } + + fn extract_function(ident: &Ident, meta: Meta) -> Result<ModuleItem, Diagnostic> { + let nesteds = meta_to_vec(meta).map_err(|meta| { + err_span!( + meta, + "#[pyfunction = \"...\"] cannot be a name/value, you probably meant \ + #[pyfunction(name = \"...\")]", + ) + })?; + + let item_meta = + ItemMeta::from_nested_meta("pyfunction", &ident, &nesteds, ItemMeta::SIMPLE_NAMES)?; + Ok(ModuleItem::Function { + item_ident: ident.clone(), + py_name: item_meta.simple_name()?, + }) + } + + fn extract_class(ident: &Ident, meta: Meta) -> Result<ModuleItem, Diagnostic> { + let nesteds = meta_to_vec(meta).map_err(|meta| { + err_span!( + meta, + "#[pyclass = \"...\"] cannot be a name/value, you probably meant \ + #[pyclass(name = \"...\")]", + ) + })?; + + let item_meta = + ItemMeta::from_nested_meta("pyclass", &ident, &nesteds, ItemMeta::SIMPLE_NAMES)?; + Ok(ModuleItem::Class { + item_ident: ident.clone(), + py_name: item_meta.simple_name()?, + }) + } + + fn extract_item_from_syn( + &mut self, + attrs: &mut Vec<Attribute>, + ident: &Ident, + ) -> Result<(), Diagnostic> { + let mut attr_idxs = Vec::new(); + for (i, meta) in attrs + .iter() + .filter_map(|attr| attr.parse_meta().ok()) + .enumerate() + { + let meta_span = meta.span(); + let name = match meta.path().get_ident() { + Some(name) => name, + None => continue, + }; + let item = match name.to_string().as_str() { + "pyfunction" => { + attr_idxs.push(i); + Self::extract_function(ident, meta)? + } + "pyclass" => Self::extract_class(ident, meta)?, + _ => { + continue; + } + }; + self.add_item(item, meta_span)?; + } + let mut i = 0; + let mut attr_idxs = &*attr_idxs; + attrs.retain(|_| { + let drop = attr_idxs.first().copied() == Some(i); + if drop { + attr_idxs = &attr_idxs[1..]; + } + i += 1; + !drop + }); + for (i, idx) in attr_idxs.iter().enumerate() { + attrs.remove(idx - i); + } + Ok(()) + } +} + +fn extract_module_items(mut items: Vec<ItemIdent>) -> Result<TokenStream2, Diagnostic> { + let mut diagnostics: Vec<Diagnostic> = Vec::new(); + + let mut module = Module::default(); + + for item in items.iter_mut() { + push_diag_result!( + diagnostics, + module.extract_item_from_syn(&mut item.attrs, item.ident), + ); + } + + let functions = module.items.into_iter().map(|item| match item { + ModuleItem::Function { + item_ident, + py_name, + } => { + let new_func = quote_spanned!(item_ident.span() => .new_function(#item_ident)); + quote! { + vm.__module_set_attr(&module, #py_name, vm.ctx#new_func).unwrap(); + } + } + ModuleItem::Class { + item_ident, + py_name, + } => { + let new_class = quote_spanned!(item_ident.span() => #item_ident::make_class(&vm.ctx)); + quote! { + vm.__module_set_attr(&module, #py_name, #new_class).unwrap(); + } + } + }); + + Diagnostic::from_vec(diagnostics)?; + + Ok(quote! { + #(#functions)* + }) +} + +pub fn impl_pymodule(attr: AttributeArgs, item: Item) -> Result<TokenStream2, Diagnostic> { + match item { + Item::Mod(mut module) => { + let module_name = def_to_name(&module.ident, "pymodule", attr)?; + + if let Some(content) = module.content.as_mut() { + let items = content + .1 + .iter_mut() + .filter_map(|item| match item { + Item::Fn(syn::ItemFn { attrs, sig, .. }) => Some(ItemIdent { + attrs, + ident: &sig.ident, + }), + Item::Struct(syn::ItemStruct { attrs, ident, .. }) => { + Some(ItemIdent { attrs, ident }) + } + Item::Enum(syn::ItemEnum { attrs, ident, .. }) => { + Some(ItemIdent { attrs, ident }) + } + _ => None, + }) + .collect(); + + let extend_mod = extract_module_items(items)?; + content.1.push(parse_quote! { + const MODULE_NAME: &str = #module_name; + }); + content.1.push(parse_quote! { + pub(crate) fn extend_module( + vm: &::rustpython_vm::vm::VirtualMachine, + module: &::rustpython_vm::pyobject::PyObjectRef, + ) { + #extend_mod + } + }); + content.1.push(parse_quote! { + #[allow(dead_code)] + pub(crate) fn make_module( + vm: &::rustpython_vm::vm::VirtualMachine + ) -> ::rustpython_vm::pyobject::PyObjectRef { + let module = vm.new_module(MODULE_NAME, vm.ctx.new_dict()); + extend_module(vm, &module); + module + } + }); + + Ok(quote! { + #module + }) + } else { + bail_span!( + module, + "#[pymodule] can only be on a module declaration with body" + ) + } + } + other => bail_span!(other, "#[pymodule] can only be on a module declaration"), + } +} diff --git a/derive/src/util.rs b/derive/src/util.rs index 7284d8014a..deb2ef0461 100644 --- a/derive/src/util.rs +++ b/derive/src/util.rs @@ -1,3 +1,216 @@ -pub fn path_eq(path: &syn::Path, s: &str) -> bool { +use super::Diagnostic; +use std::collections::HashMap; +use syn::{Attribute, AttributeArgs, Ident, Lit, Meta, NestedMeta, Path}; + +pub fn path_eq(path: &Path, s: &str) -> bool { path.get_ident().map_or(false, |id| id == s) } + +pub fn def_to_name( + ident: &Ident, + attr_name: &'static str, + attr: AttributeArgs, +) -> Result<String, Diagnostic> { + let mut name = None; + for attr in attr { + if let NestedMeta::Meta(meta) = attr { + if let Meta::NameValue(name_value) = meta { + if path_eq(&name_value.path, "name") { + if let Lit::Str(s) = name_value.lit { + name = Some(s.value()); + } else { + bail_span!( + name_value.lit, + "#[{}(name = ...)] must be a string", + attr_name + ); + } + } + } + } + } + Ok(name.unwrap_or_else(|| ident.to_string())) +} + +pub fn strip_prefix<'a>(s: &'a str, prefix: &str) -> Option<&'a str> { + if s.starts_with(prefix) { + Some(&s[prefix.len()..]) + } else { + None + } +} + +pub struct ItemIdent<'a> { + pub attrs: &'a mut Vec<Attribute>, + pub ident: &'a Ident, +} + +pub struct ItemMeta<'a> { + ident: &'a Ident, + parent_type: &'static str, + meta: HashMap<String, Option<Lit>>, +} + +impl<'a> ItemMeta<'a> { + pub const SIMPLE_NAMES: &'static [&'static str] = &["name"]; + pub const ATTRIBUTE_NAMES: &'static [&'static str] = &["name", "magic"]; + pub const PROPERTY_NAMES: &'static [&'static str] = &["name", "magic", "setter"]; + + pub fn from_nested_meta( + parent_type: &'static str, + ident: &'a Ident, + nested_meta: &[NestedMeta], + names: &[&'static str], + ) -> Result<Self, Diagnostic> { + let mut extracted = Self { + ident, + parent_type, + meta: HashMap::new(), + }; + + let validate_name = |name: &str, extracted: &Self| -> Result<(), Diagnostic> { + if names.contains(&name) { + if extracted.meta.contains_key(name) { + bail_span!(ident, "#[{}] must have only one '{}'", parent_type, name); + } else { + Ok(()) + } + } else { + bail_span!( + ident, + "#[{}({})] is not one of allowed attributes {}", + parent_type, + name, + names.join(", ") + ); + } + }; + + for meta in nested_meta { + let meta = match meta { + NestedMeta::Meta(meta) => meta, + NestedMeta::Lit(_) => continue, + }; + + match meta { + Meta::NameValue(name_value) => { + if let Some(ident) = name_value.path.get_ident() { + let name = ident.to_string(); + validate_name(&name, &extracted)?; + extracted.meta.insert(name, Some(name_value.lit.clone())); + } + } + Meta::Path(path) => { + if let Some(ident) = path.get_ident() { + let name = ident.to_string(); + validate_name(&name, &extracted)?; + extracted.meta.insert(name, None); + } else { + continue; + } + } + _ => (), + } + } + + Ok(extracted) + } + + fn _str(&self, key: &str) -> Result<Option<String>, Diagnostic> { + Ok(match self.meta.get(key) { + Some(Some(lit)) => { + if let Lit::Str(s) = lit { + Some(s.value()) + } else { + bail_span!( + &self.ident, + "#[{}({} = ...)] must be a string", + self.parent_type, + key + ); + } + } + Some(None) => { + bail_span!( + &self.ident, + "#[{}({} = ...)] is expected", + self.parent_type, + key, + ); + } + None => None, + }) + } + + fn _bool(&self, key: &str) -> Result<bool, Diagnostic> { + Ok(match self.meta.get(key) { + Some(Some(_)) => { + bail_span!(&self.ident, "#[{}({})] is expected", self.parent_type, key,); + } + Some(None) => true, + None => false, + }) + } + + pub fn simple_name(&self) -> Result<String, Diagnostic> { + Ok(self._str("name")?.unwrap_or_else(|| self.ident.to_string())) + } + + pub fn method_name(&self) -> Result<String, Diagnostic> { + let name = self._str("name")?; + let magic = self._bool("magic")?; + Ok(if let Some(name) = name { + name + } else { + let name = self.ident.to_string(); + if magic { + format!("__{}__", name) + } else { + name + } + }) + } + + pub fn property_name(&self) -> Result<String, Diagnostic> { + let magic = self._bool("magic")?; + let setter = self._bool("setter")?; + let name = self._str("name")?; + + Ok(if let Some(name) = name { + name + } else { + let sig_name = self.ident.to_string(); + let name = if setter { + if let Some(name) = strip_prefix(&sig_name, "set_") { + if name.is_empty() { + bail_span!( + &self.ident, + "A #[{}(setter)] fn with a set_* name must \ + have something after \"set_\"", + self.parent_type + ) + } + name.to_string() + } else { + bail_span!( + &self.ident, + "A #[{}(setter)] fn must either have a `name` \ + parameter or a fn name along the lines of \"set_*\"", + self.parent_type + ) + } + } else { + sig_name + }; + if magic { + format!("__{}__", name) + } else { + name + } + }) + } + + pub fn setter(&self) -> Result<bool, Diagnostic> { + self._bool("setter") + } +} diff --git a/vm/src/obj/objtuple.rs b/vm/src/obj/objtuple.rs index 622da42766..85427a1898 100644 --- a/vm/src/obj/objtuple.rs +++ b/vm/src/obj/objtuple.rs @@ -11,7 +11,7 @@ use crate::pyobject::{ PyArithmaticValue::{self, *}, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, }; -use crate::sequence::{self, SimpleSeq}; +use crate::sequence; use crate::vm::{ReprGuard, VirtualMachine}; /// tuple() -> empty tuple diff --git a/vm/src/stdlib/binascii.rs b/vm/src/stdlib/binascii.rs index 18db3e7d96..9803463e83 100644 --- a/vm/src/stdlib/binascii.rs +++ b/vm/src/stdlib/binascii.rs @@ -1,147 +1,144 @@ -use crate::function::OptionalArg; -use crate::obj::objbytearray::{PyByteArray, PyByteArrayRef}; -use crate::obj::objbyteinner::PyBytesLike; -use crate::obj::objbytes::{PyBytes, PyBytesRef}; -use crate::obj::objstr::{PyString, PyStringRef}; -use crate::pyobject::{PyObjectRef, PyResult, TryFromObject, TypeProtocol}; -use crate::vm::VirtualMachine; - -use crc::{crc32, Hasher32}; -use itertools::Itertools; - -enum SerializedData { - Bytes(PyBytesRef), - Buffer(PyByteArrayRef), - Ascii(PyStringRef), -} - -impl TryFromObject for SerializedData { - fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> { - match_class!(match obj { - b @ PyBytes => Ok(SerializedData::Bytes(b)), - b @ PyByteArray => Ok(SerializedData::Buffer(b)), - a @ PyString => { - if a.as_str().is_ascii() { - Ok(SerializedData::Ascii(a)) - } else { - Err(vm.new_value_error( - "string argument should contain only ASCII characters".to_owned(), - )) - } - } - obj => Err(vm.new_type_error(format!( - "argument should be bytes, buffer or ASCII string, not '{}'", - obj.class().name, - ))), - }) +pub(crate) use decl::make_module; + +#[pymodule(name = "binascii")] +mod decl { + use crate::function::OptionalArg; + use crate::obj::objbytearray::{PyByteArray, PyByteArrayRef}; + use crate::obj::objbyteinner::PyBytesLike; + use crate::obj::objbytes::{PyBytes, PyBytesRef}; + use crate::obj::objstr::{PyString, PyStringRef}; + use crate::pyobject::{PyObjectRef, PyResult, TryFromObject, TypeProtocol}; + use crate::vm::VirtualMachine; + use crc::{crc32, Hasher32}; + use itertools::Itertools; + + enum SerializedData { + Bytes(PyBytesRef), + Buffer(PyByteArrayRef), + Ascii(PyStringRef), } -} -impl SerializedData { - #[inline] - pub fn with_ref<R>(&self, f: impl FnOnce(&[u8]) -> R) -> R { - match self { - SerializedData::Bytes(b) => f(b.get_value()), - SerializedData::Buffer(b) => f(&b.borrow_value().elements), - SerializedData::Ascii(a) => f(a.as_str().as_bytes()), + impl TryFromObject for SerializedData { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> { + match_class!(match obj { + b @ PyBytes => Ok(SerializedData::Bytes(b)), + b @ PyByteArray => Ok(SerializedData::Buffer(b)), + a @ PyString => { + if a.as_str().is_ascii() { + Ok(SerializedData::Ascii(a)) + } else { + Err(vm.new_value_error( + "string argument should contain only ASCII characters".to_owned(), + )) + } + } + obj => Err(vm.new_type_error(format!( + "argument should be bytes, buffer or ASCII string, not '{}'", + obj.class().name, + ))), + }) } } -} -fn hex_nibble(n: u8) -> u8 { - match n { - 0..=9 => b'0' + n, - 10..=15 => b'a' + n, - _ => unreachable!(), + impl SerializedData { + #[inline] + pub fn with_ref<R>(&self, f: impl FnOnce(&[u8]) -> R) -> R { + match self { + SerializedData::Bytes(b) => f(b.get_value()), + SerializedData::Buffer(b) => f(&b.borrow_value().elements), + SerializedData::Ascii(a) => f(a.as_str().as_bytes()), + } + } } -} -fn binascii_hexlify(data: PyBytesLike) -> Vec<u8> { - data.with_ref(|bytes| { - let mut hex = Vec::<u8>::with_capacity(bytes.len() * 2); - for b in bytes.iter() { - hex.push(hex_nibble(b >> 4)); - hex.push(hex_nibble(b & 0xf)); + fn hex_nibble(n: u8) -> u8 { + match n { + 0..=9 => b'0' + n, + 10..=15 => b'a' + n, + _ => unreachable!(), } - hex - }) -} + } -fn unhex_nibble(c: u8) -> Option<u8> { - match c { - b'0'..=b'9' => Some(c - b'0'), - b'a'..=b'f' => Some(c - b'a' + 10), - b'A'..=b'F' => Some(c - b'A' + 10), - _ => None, + #[pyfunction(name = "b2a_hex")] + #[pyfunction] + fn hexlify(data: PyBytesLike) -> Vec<u8> { + data.with_ref(|bytes| { + let mut hex = Vec::<u8>::with_capacity(bytes.len() * 2); + for b in bytes.iter() { + hex.push(hex_nibble(b >> 4)); + hex.push(hex_nibble(b & 0xf)); + } + hex + }) } -} -fn binascii_unhexlify(data: SerializedData, vm: &VirtualMachine) -> PyResult<Vec<u8>> { - data.with_ref(|hex_bytes| { - if hex_bytes.len() % 2 != 0 { - return Err(vm.new_value_error("Odd-length string".to_owned())); + fn unhex_nibble(c: u8) -> Option<u8> { + match c { + b'0'..=b'9' => Some(c - b'0'), + b'a'..=b'f' => Some(c - b'a' + 10), + b'A'..=b'F' => Some(c - b'A' + 10), + _ => None, } + } - let mut unhex = Vec::<u8>::with_capacity(hex_bytes.len() / 2); - for (n1, n2) in hex_bytes.iter().tuples() { - if let (Some(n1), Some(n2)) = (unhex_nibble(*n1), unhex_nibble(*n2)) { - unhex.push(n1 << 4 | n2); - } else { - return Err(vm.new_value_error("Non-hexadecimal digit found".to_owned())); + #[pyfunction(name = "a2b_hex")] + #[pyfunction] + fn unhexlify(data: SerializedData, vm: &VirtualMachine) -> PyResult<Vec<u8>> { + data.with_ref(|hex_bytes| { + if hex_bytes.len() % 2 != 0 { + return Err(vm.new_value_error("Odd-length string".to_owned())); } - } - Ok(unhex) - }) -} + let mut unhex = Vec::<u8>::with_capacity(hex_bytes.len() / 2); + for (n1, n2) in hex_bytes.iter().tuples() { + if let (Some(n1), Some(n2)) = (unhex_nibble(*n1), unhex_nibble(*n2)) { + unhex.push(n1 << 4 | n2); + } else { + return Err(vm.new_value_error("Non-hexadecimal digit found".to_owned())); + } + } -fn binascii_crc32(data: SerializedData, value: OptionalArg<u32>, vm: &VirtualMachine) -> PyResult { - let crc = value.unwrap_or(0); + Ok(unhex) + }) + } - let mut digest = crc32::Digest::new_with_initial(crc32::IEEE, crc); - data.with_ref(|bytes| digest.write(&bytes)); + #[pyfunction] + fn crc32(data: SerializedData, value: OptionalArg<u32>, vm: &VirtualMachine) -> PyResult { + let crc = value.unwrap_or(0); - Ok(vm.ctx.new_int(digest.sum32())) -} + let mut digest = crc32::Digest::new_with_initial(crc32::IEEE, crc); + data.with_ref(|bytes| digest.write(&bytes)); -#[derive(FromArgs)] -struct NewlineArg { - #[pyarg(keyword_only, default = "true")] - newline: bool, -} + Ok(vm.ctx.new_int(digest.sum32())) + } -/// trim a newline from the end of the bytestring, if it exists -fn trim_newline(b: &[u8]) -> &[u8] { - if b.ends_with(b"\n") { - &b[..b.len() - 1] - } else { - b + #[derive(FromArgs)] + struct NewlineArg { + #[pyarg(keyword_only, default = "true")] + newline: bool, } -} -fn binascii_a2b_base64(s: SerializedData, vm: &VirtualMachine) -> PyResult<Vec<u8>> { - s.with_ref(|b| base64::decode(trim_newline(b))) - .map_err(|err| vm.new_value_error(format!("error decoding base64: {}", err))) -} + /// trim a newline from the end of the bytestring, if it exists + fn trim_newline(b: &[u8]) -> &[u8] { + if b.ends_with(b"\n") { + &b[..b.len() - 1] + } else { + b + } + } -fn binascii_b2a_base64(data: PyBytesLike, NewlineArg { newline }: NewlineArg) -> Vec<u8> { - let mut encoded = data.with_ref(base64::encode).into_bytes(); - if newline { - encoded.push(b'\n'); + #[pyfunction] + fn a2b_base64(s: SerializedData, vm: &VirtualMachine) -> PyResult<Vec<u8>> { + s.with_ref(|b| base64::decode(trim_newline(b))) + .map_err(|err| vm.new_value_error(format!("error decoding base64: {}", err))) } - encoded -} -pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { - let ctx = &vm.ctx; - - py_module!(vm, "binascii", { - "hexlify" => ctx.new_function(binascii_hexlify), - "b2a_hex" => ctx.new_function(binascii_hexlify), - "unhexlify" => ctx.new_function(binascii_unhexlify), - "a2b_hex" => ctx.new_function(binascii_unhexlify), - "crc32" => ctx.new_function(binascii_crc32), - "a2b_base64" => ctx.new_function(binascii_a2b_base64), - "b2a_base64" => ctx.new_function(binascii_b2a_base64), - }) + #[pyfunction] + fn b2a_base64(data: PyBytesLike, NewlineArg { newline }: NewlineArg) -> Vec<u8> { + let mut encoded = data.with_ref(base64::encode).into_bytes(); + if newline { + encoded.push(b'\n'); + } + encoded + } } diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 5a5c6b9ead..4ec3633b5c 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -1,395 +1,398 @@ -use crate::function::OptionalArg; -use crate::obj::{objiter, objtype::PyClassRef}; -use crate::pyobject::{ - IdProtocol, PyArithmaticValue::*, PyClassImpl, PyComparisonValue, PyIterable, PyObjectRef, - PyRef, PyResult, PyValue, -}; -use crate::sequence::{self, SimpleSeq}; -use crate::vm::ReprGuard; -use crate::VirtualMachine; -use itertools::Itertools; -use std::cell::{Cell, RefCell}; -use std::collections::VecDeque; - -#[pyclass(name = "deque")] -#[derive(Debug, Clone)] -struct PyDeque { - deque: RefCell<VecDeque<PyObjectRef>>, - maxlen: Cell<Option<usize>>, -} -type PyDequeRef = PyRef<PyDeque>; - -impl PyValue for PyDeque { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("_collections", "deque") +pub(crate) use _collections::make_module; + +#[pymodule] +mod _collections { + use crate::function::OptionalArg; + use crate::obj::{objiter, objtype::PyClassRef}; + use crate::pyobject::{ + IdProtocol, PyArithmaticValue::*, PyClassImpl, PyComparisonValue, PyIterable, PyObjectRef, + PyRef, PyResult, PyValue, + }; + use crate::sequence; + use crate::vm::ReprGuard; + use crate::VirtualMachine; + use itertools::Itertools; + use std::cell::{Cell, RefCell}; + use std::collections::VecDeque; + + #[pyclass(name = "deque")] + #[derive(Debug, Clone)] + struct PyDeque { + deque: RefCell<VecDeque<PyObjectRef>>, + maxlen: Cell<Option<usize>>, + } + type PyDequeRef = PyRef<PyDeque>; + + impl PyValue for PyDeque { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("_collections", "deque") + } } -} - -#[derive(FromArgs)] -struct PyDequeOptions { - #[pyarg(positional_or_keyword, default = "None")] - maxlen: Option<usize>, -} -impl PyDeque { - fn borrow_deque<'a>(&'a self) -> impl std::ops::Deref<Target = VecDeque<PyObjectRef>> + 'a { - self.deque.borrow() + #[derive(FromArgs)] + struct PyDequeOptions { + #[pyarg(positional_or_keyword, default = "None")] + maxlen: Option<usize>, } -} -#[pyimpl(flags(BASETYPE))] -impl PyDeque { - #[pyslot] - fn tp_new( - cls: PyClassRef, - iter: OptionalArg<PyIterable>, - PyDequeOptions { maxlen }: PyDequeOptions, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let py_deque = PyDeque { - deque: RefCell::default(), - maxlen: maxlen.into(), - }; - if let OptionalArg::Present(iter) = iter { - py_deque.extend(iter, vm)?; - } - py_deque.into_ref_with_type(vm, cls) + impl PyDeque { + fn borrow_deque<'a>(&'a self) -> impl std::ops::Deref<Target = VecDeque<PyObjectRef>> + 'a { + self.deque.borrow() + } } - #[pymethod] - fn append(&self, obj: PyObjectRef) { - let mut deque = self.deque.borrow_mut(); - if self.maxlen.get() == Some(deque.len()) { - deque.pop_front(); + #[pyimpl(flags(BASETYPE))] + impl PyDeque { + #[pyslot] + fn tp_new( + cls: PyClassRef, + iter: OptionalArg<PyIterable>, + PyDequeOptions { maxlen }: PyDequeOptions, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let py_deque = PyDeque { + deque: RefCell::default(), + maxlen: maxlen.into(), + }; + if let OptionalArg::Present(iter) = iter { + py_deque.extend(iter, vm)?; + } + py_deque.into_ref_with_type(vm, cls) + } + + #[pymethod] + fn append(&self, obj: PyObjectRef) { + let mut deque = self.deque.borrow_mut(); + if self.maxlen.get() == Some(deque.len()) { + deque.pop_front(); + } + deque.push_back(obj); } - deque.push_back(obj); - } - #[pymethod] - fn appendleft(&self, obj: PyObjectRef) { - let mut deque = self.deque.borrow_mut(); - if self.maxlen.get() == Some(deque.len()) { - deque.pop_back(); + #[pymethod] + fn appendleft(&self, obj: PyObjectRef) { + let mut deque = self.deque.borrow_mut(); + if self.maxlen.get() == Some(deque.len()) { + deque.pop_back(); + } + deque.push_front(obj); } - deque.push_front(obj); - } - #[pymethod] - fn clear(&self) { - self.deque.borrow_mut().clear() - } + #[pymethod] + fn clear(&self) { + self.deque.borrow_mut().clear() + } - #[pymethod] - fn copy(&self) -> Self { - self.clone() - } + #[pymethod] + fn copy(&self) -> Self { + self.clone() + } - #[pymethod] - fn count(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> { - let mut count = 0; - for elem in self.deque.borrow().iter() { - if vm.identical_or_equal(elem, &obj)? { - count += 1; + #[pymethod] + fn count(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> { + let mut count = 0; + for elem in self.deque.borrow().iter() { + if vm.identical_or_equal(elem, &obj)? { + count += 1; + } } + Ok(count) } - Ok(count) - } - #[pymethod] - fn extend(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult<()> { - // TODO: use length_hint here and for extendleft - for elem in iter.iter(vm)? { - self.append(elem?); + #[pymethod] + fn extend(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult<()> { + // TODO: use length_hint here and for extendleft + for elem in iter.iter(vm)? { + self.append(elem?); + } + Ok(()) } - Ok(()) - } - #[pymethod] - fn extendleft(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult<()> { - for elem in iter.iter(vm)? { - self.appendleft(elem?); + #[pymethod] + fn extendleft(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult<()> { + for elem in iter.iter(vm)? { + self.appendleft(elem?); + } + Ok(()) } - Ok(()) - } - #[pymethod] - fn index( - &self, - obj: PyObjectRef, - start: OptionalArg<usize>, - stop: OptionalArg<usize>, - vm: &VirtualMachine, - ) -> PyResult<usize> { - let deque = self.deque.borrow(); - let start = start.unwrap_or(0); - let stop = stop.unwrap_or_else(|| deque.len()); - for (i, elem) in deque.iter().skip(start).take(stop - start).enumerate() { - if vm.identical_or_equal(elem, &obj)? { - return Ok(i); + #[pymethod] + fn index( + &self, + obj: PyObjectRef, + start: OptionalArg<usize>, + stop: OptionalArg<usize>, + vm: &VirtualMachine, + ) -> PyResult<usize> { + let deque = self.deque.borrow(); + let start = start.unwrap_or(0); + let stop = stop.unwrap_or_else(|| deque.len()); + for (i, elem) in deque.iter().skip(start).take(stop - start).enumerate() { + if vm.identical_or_equal(elem, &obj)? { + return Ok(i); + } } + Err(vm.new_value_error( + vm.to_repr(&obj) + .map(|repr| format!("{} is not in deque", repr)) + .unwrap_or_else(|_| String::new()), + )) } - Err(vm.new_value_error( - vm.to_repr(&obj) - .map(|repr| format!("{} is not in deque", repr)) - .unwrap_or_else(|_| String::new()), - )) - } - #[pymethod] - fn insert(&self, idx: i32, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - let mut deque = self.deque.borrow_mut(); + #[pymethod] + fn insert(&self, idx: i32, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let mut deque = self.deque.borrow_mut(); - if self.maxlen.get() == Some(deque.len()) { - return Err(vm.new_index_error("deque already at its maximum size".to_owned())); - } + if self.maxlen.get() == Some(deque.len()) { + return Err(vm.new_index_error("deque already at its maximum size".to_owned())); + } - let idx = if idx < 0 { - if -idx as usize > deque.len() { - 0 + let idx = if idx < 0 { + if -idx as usize > deque.len() { + 0 + } else { + deque.len() - ((-idx) as usize) + } + } else if idx as usize >= deque.len() { + deque.len() - 1 } else { - deque.len() - ((-idx) as usize) - } - } else if idx as usize >= deque.len() { - deque.len() - 1 - } else { - idx as usize - }; + idx as usize + }; - deque.insert(idx, obj); + deque.insert(idx, obj); - Ok(()) - } + Ok(()) + } - #[pymethod] - fn pop(&self, vm: &VirtualMachine) -> PyResult { - self.deque - .borrow_mut() - .pop_back() - .ok_or_else(|| vm.new_index_error("pop from an empty deque".to_owned())) - } + #[pymethod] + fn pop(&self, vm: &VirtualMachine) -> PyResult { + self.deque + .borrow_mut() + .pop_back() + .ok_or_else(|| vm.new_index_error("pop from an empty deque".to_owned())) + } - #[pymethod] - fn popleft(&self, vm: &VirtualMachine) -> PyResult { - self.deque - .borrow_mut() - .pop_front() - .ok_or_else(|| vm.new_index_error("pop from an empty deque".to_owned())) - } + #[pymethod] + fn popleft(&self, vm: &VirtualMachine) -> PyResult { + self.deque + .borrow_mut() + .pop_front() + .ok_or_else(|| vm.new_index_error("pop from an empty deque".to_owned())) + } - #[pymethod] - fn remove(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let mut deque = self.deque.borrow_mut(); - let mut idx = None; - for (i, elem) in deque.iter().enumerate() { - if vm.identical_or_equal(elem, &obj)? { - idx = Some(i); - break; + #[pymethod] + fn remove(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let mut deque = self.deque.borrow_mut(); + let mut idx = None; + for (i, elem) in deque.iter().enumerate() { + if vm.identical_or_equal(elem, &obj)? { + idx = Some(i); + break; + } } + idx.map(|idx| deque.remove(idx).unwrap()) + .ok_or_else(|| vm.new_value_error("deque.remove(x): x not in deque".to_owned())) } - idx.map(|idx| deque.remove(idx).unwrap()) - .ok_or_else(|| vm.new_value_error("deque.remove(x): x not in deque".to_owned())) - } - #[pymethod] - fn reverse(&self) { - self.deque - .replace_with(|deque| deque.iter().cloned().rev().collect()); - } + #[pymethod] + fn reverse(&self) { + self.deque + .replace_with(|deque| deque.iter().cloned().rev().collect()); + } - #[pymethod] - fn rotate(&self, mid: OptionalArg<isize>) { - let mut deque = self.deque.borrow_mut(); - let mid = mid.unwrap_or(1); - if mid < 0 { - deque.rotate_left(-mid as usize); - } else { - deque.rotate_right(mid as usize); + #[pymethod] + fn rotate(&self, mid: OptionalArg<isize>) { + let mut deque = self.deque.borrow_mut(); + let mid = mid.unwrap_or(1); + if mid < 0 { + deque.rotate_left(-mid as usize); + } else { + deque.rotate_right(mid as usize); + } } - } - #[pyproperty] - fn maxlen(&self) -> Option<usize> { - self.maxlen.get() - } + #[pyproperty] + fn maxlen(&self) -> Option<usize> { + self.maxlen.get() + } - #[pyproperty(setter)] - fn set_maxlen(&self, maxlen: Option<usize>) { - self.maxlen.set(maxlen); - } + #[pyproperty(setter)] + fn set_maxlen(&self, maxlen: Option<usize>) { + self.maxlen.set(maxlen); + } - #[pymethod(name = "__repr__")] - fn repr(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<String> { - let repr = if let Some(_guard) = ReprGuard::enter(zelf.as_object()) { - let elements = zelf - .deque - .borrow() - .iter() - .map(|obj| vm.to_repr(obj)) - .collect::<Result<Vec<_>, _>>()?; - let maxlen = zelf - .maxlen - .get() - .map(|maxlen| format!(", maxlen={}", maxlen)) - .unwrap_or_default(); - format!("deque([{}]{})", elements.into_iter().format(", "), maxlen) - } else { - "[...]".to_owned() - }; - Ok(repr) - } + #[pymethod(name = "__repr__")] + fn repr(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<String> { + let repr = if let Some(_guard) = ReprGuard::enter(zelf.as_object()) { + let elements = zelf + .deque + .borrow() + .iter() + .map(|obj| vm.to_repr(obj)) + .collect::<Result<Vec<_>, _>>()?; + let maxlen = zelf + .maxlen + .get() + .map(|maxlen| format!(", maxlen={}", maxlen)) + .unwrap_or_default(); + format!("deque([{}]{})", elements.into_iter().format(", "), maxlen) + } else { + "[...]".to_owned() + }; + Ok(repr) + } - #[inline] - fn cmp<F>(&self, other: PyObjectRef, op: F, vm: &VirtualMachine) -> PyResult<PyComparisonValue> - where - F: Fn(&VecDeque<PyObjectRef>, &VecDeque<PyObjectRef>) -> PyResult<bool>, - { - let r = if let Some(other) = other.payload_if_subclass::<PyDeque>(vm) { - Implemented(op(&*self.borrow_deque(), &*other.borrow_deque())?) - } else { - NotImplemented - }; - Ok(r) - } + #[inline] + fn cmp<F>( + &self, + other: PyObjectRef, + op: F, + vm: &VirtualMachine, + ) -> PyResult<PyComparisonValue> + where + F: Fn(&VecDeque<PyObjectRef>, &VecDeque<PyObjectRef>) -> PyResult<bool>, + { + let r = if let Some(other) = other.payload_if_subclass::<PyDeque>(vm) { + Implemented(op(&*self.borrow_deque(), &*other.borrow_deque())?) + } else { + NotImplemented + }; + Ok(r) + } - #[pymethod(name = "__eq__")] - fn eq( - zelf: PyRef<Self>, - other: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyComparisonValue> { - if zelf.as_object().is(&other) { - Ok(Implemented(true)) - } else { - zelf.cmp(other, |a, b| sequence::eq(vm, a, b), vm) + #[pymethod(name = "__eq__")] + fn eq( + zelf: PyRef<Self>, + other: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyComparisonValue> { + if zelf.as_object().is(&other) { + Ok(Implemented(true)) + } else { + zelf.cmp(other, |a, b| sequence::eq(vm, a, b), vm) + } } - } - #[pymethod(name = "__ne__")] - fn ne( - zelf: PyRef<Self>, - other: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyComparisonValue> { - Ok(PyDeque::eq(zelf, other, vm)?.map(|v| !v)) - } + #[pymethod(name = "__ne__")] + fn ne( + zelf: PyRef<Self>, + other: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyComparisonValue> { + Ok(PyDeque::eq(zelf, other, vm)?.map(|v| !v)) + } - #[pymethod(name = "__lt__")] - fn lt( - zelf: PyRef<Self>, - other: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyComparisonValue> { - if zelf.as_object().is(&other) { - Ok(Implemented(false)) - } else { - zelf.cmp(other, |a, b| sequence::lt(vm, a, b), vm) + #[pymethod(name = "__lt__")] + fn lt( + zelf: PyRef<Self>, + other: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyComparisonValue> { + if zelf.as_object().is(&other) { + Ok(Implemented(false)) + } else { + zelf.cmp(other, |a, b| sequence::lt(vm, a, b), vm) + } } - } - #[pymethod(name = "__gt__")] - fn gt( - zelf: PyRef<Self>, - other: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyComparisonValue> { - if zelf.as_object().is(&other) { - Ok(Implemented(false)) - } else { - zelf.cmp(other, |a, b| sequence::gt(vm, a, b), vm) + #[pymethod(name = "__gt__")] + fn gt( + zelf: PyRef<Self>, + other: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyComparisonValue> { + if zelf.as_object().is(&other) { + Ok(Implemented(false)) + } else { + zelf.cmp(other, |a, b| sequence::gt(vm, a, b), vm) + } } - } - #[pymethod(name = "__le__")] - fn le( - zelf: PyRef<Self>, - other: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyComparisonValue> { - if zelf.as_object().is(&other) { - Ok(Implemented(true)) - } else { - zelf.cmp(other, |a, b| sequence::le(vm, a, b), vm) + #[pymethod(name = "__le__")] + fn le( + zelf: PyRef<Self>, + other: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyComparisonValue> { + if zelf.as_object().is(&other) { + Ok(Implemented(true)) + } else { + zelf.cmp(other, |a, b| sequence::le(vm, a, b), vm) + } } - } - #[pymethod(name = "__ge__")] - fn ge( - zelf: PyRef<Self>, - other: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyComparisonValue> { - if zelf.as_object().is(&other) { - Ok(Implemented(true)) - } else { - zelf.cmp(other, |a, b| sequence::ge(vm, a, b), vm) + #[pymethod(name = "__ge__")] + fn ge( + zelf: PyRef<Self>, + other: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyComparisonValue> { + if zelf.as_object().is(&other) { + Ok(Implemented(true)) + } else { + zelf.cmp(other, |a, b| sequence::ge(vm, a, b), vm) + } } - } - #[pymethod(name = "__mul__")] - fn mul(&self, n: isize) -> Self { - let deque: &VecDeque<_> = &self.deque.borrow(); - let mul = sequence::seq_mul(deque, n); - let skipped = if let Some(maxlen) = self.maxlen.get() { - mul.len() - maxlen - } else { - 0 - }; - let deque = mul.skip(skipped).cloned().collect(); - PyDeque { - deque: RefCell::new(deque), - maxlen: self.maxlen.clone(), + #[pymethod(name = "__mul__")] + fn mul(&self, n: isize) -> Self { + let deque: &VecDeque<_> = &self.deque.borrow(); + let mul = sequence::seq_mul(deque, n); + let skipped = if let Some(maxlen) = self.maxlen.get() { + mul.len() - maxlen + } else { + 0 + }; + let deque = mul.skip(skipped).cloned().collect(); + PyDeque { + deque: RefCell::new(deque), + maxlen: self.maxlen.clone(), + } } - } - #[pymethod(name = "__len__")] - fn len(&self) -> usize { - self.deque.borrow().len() - } + #[pymethod(name = "__len__")] + fn len(&self) -> usize { + self.deque.borrow().len() + } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyDequeIterator { - PyDequeIterator { - position: Cell::new(0), - deque: zelf, + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyDequeIterator { + PyDequeIterator { + position: Cell::new(0), + deque: zelf, + } } } -} -#[pyclass(name = "_deque_iterator")] -#[derive(Debug)] -struct PyDequeIterator { - position: Cell<usize>, - deque: PyDequeRef, -} - -impl PyValue for PyDequeIterator { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("_collections", "_deque_iterator") + #[pyclass(name = "_deque_iterator")] + #[derive(Debug)] + struct PyDequeIterator { + position: Cell<usize>, + deque: PyDequeRef, } -} -#[pyimpl] -impl PyDequeIterator { - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.position.get() < self.deque.deque.borrow().len() { - let ret = self.deque.deque.borrow()[self.position.get()].clone(); - self.position.set(self.position.get() + 1); - Ok(ret) - } else { - Err(objiter::new_stop_iteration(vm)) + impl PyValue for PyDequeIterator { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("_collections", "_deque_iterator") } } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf - } -} + #[pyimpl] + impl PyDequeIterator { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + if self.position.get() < self.deque.deque.borrow().len() { + let ret = self.deque.deque.borrow()[self.position.get()].clone(); + self.position.set(self.position.get() + 1); + Ok(ret) + } else { + Err(objiter::new_stop_iteration(vm)) + } + } -pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { - py_module!(vm, "_collections", { - "deque" => PyDeque::make_class(&vm.ctx), - "_deque_iterator" => PyDequeIterator::make_class(&vm.ctx), - }) + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } + } } diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 3ffc01e806..ea7767fb20 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1,1413 +1,1345 @@ -use std::cell::{Cell, RefCell}; -use std::iter; -use std::rc::Rc; - -use num_bigint::BigInt; -use num_traits::{One, Signed, ToPrimitive, Zero}; - -use crate::function::{Args, OptionalArg, OptionalOption, PyFuncArgs}; -use crate::obj::objbool; -use crate::obj::objint::{self, PyInt, PyIntRef}; -use crate::obj::objiter::{call_next, get_all, get_iter, get_next_object, new_stop_iteration}; -use crate::obj::objtuple::PyTuple; -use crate::obj::objtype::{self, PyClassRef}; -use crate::pyobject::{ - IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, -}; -use crate::vm::VirtualMachine; - -#[pyclass(name = "chain")] -#[derive(Debug)] -struct PyItertoolsChain { - iterables: Vec<PyObjectRef>, - cur: RefCell<(usize, Option<PyObjectRef>)>, -} - -impl PyValue for PyItertoolsChain { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "chain") - } -} - -#[pyimpl] -impl PyItertoolsChain { - #[pyslot] - fn tp_new(cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult<PyRef<Self>> { - PyItertoolsChain { - iterables: args.args, - cur: RefCell::new((0, None)), +pub(crate) use decl::make_module; + +#[pymodule(name = "itertools")] +mod decl { + use num_bigint::BigInt; + use num_traits::{One, Signed, ToPrimitive, Zero}; + use std::cell::{Cell, RefCell}; + use std::iter; + use std::rc::Rc; + + use crate::function::{Args, OptionalArg, OptionalOption, PyFuncArgs}; + use crate::obj::objbool; + use crate::obj::objint::{self, PyInt, PyIntRef}; + use crate::obj::objiter::{call_next, get_all, get_iter, get_next_object, new_stop_iteration}; + use crate::obj::objtuple::PyTuple; + use crate::obj::objtype::{self, PyClassRef}; + use crate::pyobject::{ + IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, + }; + use crate::vm::VirtualMachine; + + #[pyclass(name = "chain")] + #[derive(Debug)] + struct PyItertoolsChain { + iterables: Vec<PyObjectRef>, + cur: RefCell<(usize, Option<PyObjectRef>)>, + } + + impl PyValue for PyItertoolsChain { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "chain") } - .into_ref_with_type(vm, cls) } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - let (ref mut cur_idx, ref mut cur_iter) = *self.cur.borrow_mut(); - while *cur_idx < self.iterables.len() { - if cur_iter.is_none() { - *cur_iter = Some(get_iter(vm, &self.iterables[*cur_idx])?); + #[pyimpl] + impl PyItertoolsChain { + #[pyslot] + fn tp_new(cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult<PyRef<Self>> { + PyItertoolsChain { + iterables: args.args, + cur: RefCell::new((0, None)), } + .into_ref_with_type(vm, cls) + } - // can't be directly inside the 'match' clause, otherwise the borrows collide. - let obj = call_next(vm, cur_iter.as_ref().unwrap()); - match obj { - Ok(ok) => return Ok(ok), - Err(err) => { - if objtype::isinstance(&err, &vm.ctx.exceptions.stop_iteration) { - *cur_idx += 1; - *cur_iter = None; - } else { - return Err(err); + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let (ref mut cur_idx, ref mut cur_iter) = *self.cur.borrow_mut(); + while *cur_idx < self.iterables.len() { + if cur_iter.is_none() { + *cur_iter = Some(get_iter(vm, &self.iterables[*cur_idx])?); + } + + // can't be directly inside the 'match' clause, otherwise the borrows collide. + let obj = call_next(vm, cur_iter.as_ref().unwrap()); + match obj { + Ok(ok) => return Ok(ok), + Err(err) => { + if objtype::isinstance(&err, &vm.ctx.exceptions.stop_iteration) { + *cur_idx += 1; + *cur_iter = None; + } else { + return Err(err); + } } } } - } - Err(new_stop_iteration(vm)) - } - - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf - } - - #[pyclassmethod(name = "from_iterable")] - fn from_iterable( - cls: PyClassRef, - iterable: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let it = get_iter(vm, &iterable)?; - let iterables = get_all(vm, &it)?; - - PyItertoolsChain { - iterables, - cur: RefCell::new((0, None)), + Err(new_stop_iteration(vm)) } - .into_ref_with_type(vm, cls) - } -} -#[pyclass(name = "compress")] -#[derive(Debug)] -struct PyItertoolsCompress { - data: PyObjectRef, - selector: PyObjectRef, -} + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } -impl PyValue for PyItertoolsCompress { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "compress") + #[pyclassmethod(name = "from_iterable")] + fn from_iterable( + cls: PyClassRef, + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let it = get_iter(vm, &iterable)?; + let iterables = get_all(vm, &it)?; + + PyItertoolsChain { + iterables, + cur: RefCell::new((0, None)), + } + .into_ref_with_type(vm, cls) + } } -} -#[pyimpl] -impl PyItertoolsCompress { - #[pyslot] - fn tp_new( - cls: PyClassRef, + #[pyclass(name = "compress")] + #[derive(Debug)] + struct PyItertoolsCompress { data: PyObjectRef, selector: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let data_iter = get_iter(vm, &data)?; - let selector_iter = get_iter(vm, &selector)?; + } - PyItertoolsCompress { - data: data_iter, - selector: selector_iter, + impl PyValue for PyItertoolsCompress { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "compress") } - .into_ref_with_type(vm, cls) } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - loop { - let sel_obj = call_next(vm, &self.selector)?; - let verdict = objbool::boolval(vm, sel_obj.clone())?; - let data_obj = call_next(vm, &self.data)?; - - if verdict { - return Ok(data_obj); + #[pyimpl] + impl PyItertoolsCompress { + #[pyslot] + fn tp_new( + cls: PyClassRef, + data: PyObjectRef, + selector: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let data_iter = get_iter(vm, &data)?; + let selector_iter = get_iter(vm, &selector)?; + + PyItertoolsCompress { + data: data_iter, + selector: selector_iter, } + .into_ref_with_type(vm, cls) } - } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf - } -} + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + loop { + let sel_obj = call_next(vm, &self.selector)?; + let verdict = objbool::boolval(vm, sel_obj.clone())?; + let data_obj = call_next(vm, &self.data)?; -#[pyclass] -#[derive(Debug)] -struct PyItertoolsCount { - cur: RefCell<BigInt>, - step: BigInt, -} + if verdict { + return Ok(data_obj); + } + } + } -impl PyValue for PyItertoolsCount { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "count") + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } } -} -#[pyimpl] -impl PyItertoolsCount { - #[pyslot] - fn tp_new( - cls: PyClassRef, - start: OptionalArg<PyIntRef>, - step: OptionalArg<PyIntRef>, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let start = match start.into_option() { - Some(int) => int.as_bigint().clone(), - None => BigInt::zero(), - }; - let step = match step.into_option() { - Some(int) => int.as_bigint().clone(), - None => BigInt::one(), - }; - - PyItertoolsCount { - cur: RefCell::new(start), - step, - } - .into_ref_with_type(vm, cls) + #[pyclass(name = "count")] + #[derive(Debug)] + struct PyItertoolsCount { + cur: RefCell<BigInt>, + step: BigInt, } - #[pymethod(name = "__next__")] - fn next(&self) -> PyResult<PyInt> { - let result = self.cur.borrow().clone(); - *self.cur.borrow_mut() += &self.step; - Ok(PyInt::new(result)) + impl PyValue for PyItertoolsCount { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "count") + } } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf - } -} + #[pyimpl] + impl PyItertoolsCount { + #[pyslot] + fn tp_new( + cls: PyClassRef, + start: OptionalArg<PyIntRef>, + step: OptionalArg<PyIntRef>, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let start = match start.into_option() { + Some(int) => int.as_bigint().clone(), + None => BigInt::zero(), + }; + let step = match step.into_option() { + Some(int) => int.as_bigint().clone(), + None => BigInt::one(), + }; -#[pyclass] -#[derive(Debug)] -struct PyItertoolsCycle { - iter: RefCell<PyObjectRef>, - saved: RefCell<Vec<PyObjectRef>>, - index: Cell<usize>, - first_pass: Cell<bool>, -} + PyItertoolsCount { + cur: RefCell::new(start), + step, + } + .into_ref_with_type(vm, cls) + } + + #[pymethod(name = "__next__")] + fn next(&self) -> PyResult<PyInt> { + let result = self.cur.borrow().clone(); + *self.cur.borrow_mut() += &self.step; + Ok(PyInt::new(result)) + } -impl PyValue for PyItertoolsCycle { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "cycle") + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } } -} -#[pyimpl] -impl PyItertoolsCycle { - #[pyslot] - fn tp_new( - cls: PyClassRef, - iterable: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let iter = get_iter(vm, &iterable)?; + #[pyclass(name = "cycle")] + #[derive(Debug)] + struct PyItertoolsCycle { + iter: RefCell<PyObjectRef>, + saved: RefCell<Vec<PyObjectRef>>, + index: Cell<usize>, + first_pass: Cell<bool>, + } - PyItertoolsCycle { - iter: RefCell::new(iter.clone()), - saved: RefCell::new(Vec::new()), - index: Cell::new(0), - first_pass: Cell::new(false), + impl PyValue for PyItertoolsCycle { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "cycle") } - .into_ref_with_type(vm, cls) } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - let item = if let Some(item) = get_next_object(vm, &self.iter.borrow())? { - if self.first_pass.get() { - return Ok(item); - } - - self.saved.borrow_mut().push(item.clone()); - item - } else { - if self.saved.borrow().len() == 0 { - return Err(new_stop_iteration(vm)); + #[pyimpl] + impl PyItertoolsCycle { + #[pyslot] + fn tp_new( + cls: PyClassRef, + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let iter = get_iter(vm, &iterable)?; + + PyItertoolsCycle { + iter: RefCell::new(iter.clone()), + saved: RefCell::new(Vec::new()), + index: Cell::new(0), + first_pass: Cell::new(false), } + .into_ref_with_type(vm, cls) + } - let last_index = self.index.get(); - self.index.set(self.index.get() + 1); + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let item = if let Some(item) = get_next_object(vm, &self.iter.borrow())? { + if self.first_pass.get() { + return Ok(item); + } - if self.index.get() >= self.saved.borrow().len() { - self.index.set(0); - } + self.saved.borrow_mut().push(item.clone()); + item + } else { + if self.saved.borrow().len() == 0 { + return Err(new_stop_iteration(vm)); + } - self.saved.borrow()[last_index].clone() - }; + let last_index = self.index.get(); + self.index.set(self.index.get() + 1); - Ok(item) - } + if self.index.get() >= self.saved.borrow().len() { + self.index.set(0); + } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf - } -} + self.saved.borrow()[last_index].clone() + }; -#[pyclass] -#[derive(Debug)] -struct PyItertoolsRepeat { - object: PyObjectRef, - times: Option<RefCell<BigInt>>, -} + Ok(item) + } -impl PyValue for PyItertoolsRepeat { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "repeat") + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } } -} -#[pyimpl] -impl PyItertoolsRepeat { - #[pyslot] - fn tp_new( - cls: PyClassRef, + #[pyclass(name = "repeat")] + #[derive(Debug)] + struct PyItertoolsRepeat { object: PyObjectRef, - times: OptionalArg<PyIntRef>, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let times = match times.into_option() { - Some(int) => Some(RefCell::new(int.as_bigint().clone())), - None => None, - }; - - PyItertoolsRepeat { - object: object.clone(), - times, - } - .into_ref_with_type(vm, cls) + times: Option<RefCell<BigInt>>, } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - if let Some(ref times) = self.times { - if *times.borrow() <= BigInt::zero() { - return Err(new_stop_iteration(vm)); - } - *times.borrow_mut() -= 1; + impl PyValue for PyItertoolsRepeat { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "repeat") } - - Ok(self.object.clone()) } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf - } + #[pyimpl] + impl PyItertoolsRepeat { + #[pyslot] + fn tp_new( + cls: PyClassRef, + object: PyObjectRef, + times: OptionalArg<PyIntRef>, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let times = match times.into_option() { + Some(int) => Some(RefCell::new(int.as_bigint().clone())), + None => None, + }; - #[pymethod(name = "__length_hint__")] - fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { - match self.times { - Some(ref times) => vm.new_int(times.borrow().clone()), - None => vm.new_int(0), + PyItertoolsRepeat { + object: object.clone(), + times, + } + .into_ref_with_type(vm, cls) } - } -} -#[pyclass(name = "starmap")] -#[derive(Debug)] -struct PyItertoolsStarmap { - function: PyObjectRef, - iter: PyObjectRef, -} + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + if let Some(ref times) = self.times { + if *times.borrow() <= BigInt::zero() { + return Err(new_stop_iteration(vm)); + } + *times.borrow_mut() -= 1; + } + + Ok(self.object.clone()) + } -impl PyValue for PyItertoolsStarmap { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "starmap") + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } + + #[pymethod(name = "__length_hint__")] + fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { + match self.times { + Some(ref times) => vm.new_int(times.borrow().clone()), + None => vm.new_int(0), + } + } } -} -#[pyimpl] -impl PyItertoolsStarmap { - #[pyslot] - fn tp_new( - cls: PyClassRef, + #[pyclass(name = "starmap")] + #[derive(Debug)] + struct PyItertoolsStarmap { function: PyObjectRef, - iterable: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let iter = get_iter(vm, &iterable)?; + iter: PyObjectRef, + } - PyItertoolsStarmap { function, iter }.into_ref_with_type(vm, cls) + impl PyValue for PyItertoolsStarmap { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "starmap") + } } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - let obj = call_next(vm, &self.iter)?; - let function = &self.function; + #[pyimpl] + impl PyItertoolsStarmap { + #[pyslot] + fn tp_new( + cls: PyClassRef, + function: PyObjectRef, + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let iter = get_iter(vm, &iterable)?; - vm.invoke(function, vm.extract_elements(&obj)?) - } + PyItertoolsStarmap { function, iter }.into_ref_with_type(vm, cls) + } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf - } -} + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let obj = call_next(vm, &self.iter)?; + let function = &self.function; -#[pyclass] -#[derive(Debug)] -struct PyItertoolsTakewhile { - predicate: PyObjectRef, - iterable: PyObjectRef, - stop_flag: RefCell<bool>, -} + vm.invoke(function, vm.extract_elements(&obj)?) + } -impl PyValue for PyItertoolsTakewhile { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "takewhile") + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } } -} -#[pyimpl] -impl PyItertoolsTakewhile { - #[pyslot] - fn tp_new( - cls: PyClassRef, + #[pyclass(name = "takewhile")] + #[derive(Debug)] + struct PyItertoolsTakewhile { predicate: PyObjectRef, iterable: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let iter = get_iter(vm, &iterable)?; + stop_flag: RefCell<bool>, + } - PyItertoolsTakewhile { - predicate, - iterable: iter, - stop_flag: RefCell::new(false), + impl PyValue for PyItertoolsTakewhile { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "takewhile") } - .into_ref_with_type(vm, cls) } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - if *self.stop_flag.borrow() { - return Err(new_stop_iteration(vm)); + #[pyimpl] + impl PyItertoolsTakewhile { + #[pyslot] + fn tp_new( + cls: PyClassRef, + predicate: PyObjectRef, + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let iter = get_iter(vm, &iterable)?; + + PyItertoolsTakewhile { + predicate, + iterable: iter, + stop_flag: RefCell::new(false), + } + .into_ref_with_type(vm, cls) } - // might be StopIteration or anything else, which is propagated upwards - let obj = call_next(vm, &self.iterable)?; - let predicate = &self.predicate; - - let verdict = vm.invoke(predicate, vec![obj.clone()])?; - let verdict = objbool::boolval(vm, verdict)?; - if verdict { - Ok(obj) - } else { - *self.stop_flag.borrow_mut() = true; - Err(new_stop_iteration(vm)) - } - } + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + if *self.stop_flag.borrow() { + return Err(new_stop_iteration(vm)); + } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf - } -} + // might be StopIteration or anything else, which is propagated upwards + let obj = call_next(vm, &self.iterable)?; + let predicate = &self.predicate; -#[pyclass] -#[derive(Debug)] -struct PyItertoolsDropwhile { - predicate: PyCallable, - iterable: PyObjectRef, - start_flag: Cell<bool>, -} + let verdict = vm.invoke(predicate, vec![obj.clone()])?; + let verdict = objbool::boolval(vm, verdict)?; + if verdict { + Ok(obj) + } else { + *self.stop_flag.borrow_mut() = true; + Err(new_stop_iteration(vm)) + } + } -impl PyValue for PyItertoolsDropwhile { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "dropwhile") + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } } -} -#[pyimpl] -impl PyItertoolsDropwhile { - #[pyslot] - fn tp_new( - cls: PyClassRef, + #[pyclass(name = "dropwhile")] + #[derive(Debug)] + struct PyItertoolsDropwhile { predicate: PyCallable, iterable: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let iter = get_iter(vm, &iterable)?; + start_flag: Cell<bool>, + } - PyItertoolsDropwhile { - predicate, - iterable: iter, - start_flag: Cell::new(false), + impl PyValue for PyItertoolsDropwhile { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "dropwhile") } - .into_ref_with_type(vm, cls) } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - let predicate = &self.predicate; - let iterable = &self.iterable; + #[pyimpl] + impl PyItertoolsDropwhile { + #[pyslot] + fn tp_new( + cls: PyClassRef, + predicate: PyCallable, + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let iter = get_iter(vm, &iterable)?; + + PyItertoolsDropwhile { + predicate, + iterable: iter, + start_flag: Cell::new(false), + } + .into_ref_with_type(vm, cls) + } - if !self.start_flag.get() { - loop { - let obj = call_next(vm, iterable)?; - let pred = predicate.clone(); - let pred_value = vm.invoke(&pred.into_object(), vec![obj.clone()])?; - if !objbool::boolval(vm, pred_value)? { - self.start_flag.set(true); - return Ok(obj); + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let predicate = &self.predicate; + let iterable = &self.iterable; + + if !self.start_flag.get() { + loop { + let obj = call_next(vm, iterable)?; + let pred = predicate.clone(); + let pred_value = vm.invoke(&pred.into_object(), vec![obj.clone()])?; + if !objbool::boolval(vm, pred_value)? { + self.start_flag.set(true); + return Ok(obj); + } } } + call_next(vm, iterable) } - call_next(vm, iterable) - } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } } -} -#[pyclass(name = "islice")] -#[derive(Debug)] -struct PyItertoolsIslice { - iterable: PyObjectRef, - cur: RefCell<usize>, - next: RefCell<usize>, - stop: Option<usize>, - step: usize, -} + #[pyclass(name = "islice")] + #[derive(Debug)] + struct PyItertoolsIslice { + iterable: PyObjectRef, + cur: RefCell<usize>, + next: RefCell<usize>, + stop: Option<usize>, + step: usize, + } -impl PyValue for PyItertoolsIslice { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "islice") + impl PyValue for PyItertoolsIslice { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "islice") + } } -} -fn pyobject_to_opt_usize(obj: PyObjectRef, vm: &VirtualMachine) -> Option<usize> { - let is_int = objtype::isinstance(&obj, &vm.ctx.int_type()); - if is_int { - objint::get_value(&obj).to_usize() - } else { - None + fn pyobject_to_opt_usize(obj: PyObjectRef, vm: &VirtualMachine) -> Option<usize> { + let is_int = objtype::isinstance(&obj, &vm.ctx.int_type()); + if is_int { + objint::get_value(&obj).to_usize() + } else { + None + } } -} -#[pyimpl] -impl PyItertoolsIslice { - #[pyslot] - fn tp_new(cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult<PyRef<Self>> { - let (iter, start, stop, step) = match args.args.len() { - 0 | 1 => { - return Err(vm.new_type_error(format!( - "islice expected at least 2 arguments, got {}", - args.args.len() - ))); - } + #[pyimpl] + impl PyItertoolsIslice { + #[pyslot] + fn tp_new(cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult<PyRef<Self>> { + let (iter, start, stop, step) = match args.args.len() { + 0 | 1 => { + return Err(vm.new_type_error(format!( + "islice expected at least 2 arguments, got {}", + args.args.len() + ))); + } - 2 => { - let (iter, stop): (PyObjectRef, PyObjectRef) = args.bind(vm)?; + 2 => { + let (iter, stop): (PyObjectRef, PyObjectRef) = args.bind(vm)?; - (iter, 0usize, stop, 1usize) - } - _ => { - let (iter, start, stop, step): ( - PyObjectRef, - PyObjectRef, - PyObjectRef, - PyObjectRef, - ) = args.bind(vm)?; - - let start = if !start.is(&vm.get_none()) { - pyobject_to_opt_usize(start, &vm).ok_or_else(|| { + (iter, 0usize, stop, 1usize) + } + _ => { + let (iter, start, stop, step): ( + PyObjectRef, + PyObjectRef, + PyObjectRef, + PyObjectRef, + ) = args.bind(vm)?; + + let start = if !start.is(&vm.get_none()) { + pyobject_to_opt_usize(start, &vm).ok_or_else(|| { vm.new_value_error( "Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.".to_owned(), ) })? - } else { - 0usize - }; - - let step = if !step.is(&vm.get_none()) { - pyobject_to_opt_usize(step, &vm).ok_or_else(|| { - vm.new_value_error( - "Step for islice() must be a positive integer or None.".to_owned(), - ) - })? - } else { - 1usize - }; + } else { + 0usize + }; + + let step = if !step.is(&vm.get_none()) { + pyobject_to_opt_usize(step, &vm).ok_or_else(|| { + vm.new_value_error( + "Step for islice() must be a positive integer or None.".to_owned(), + ) + })? + } else { + 1usize + }; - (iter, start, stop, step) - } - }; + (iter, start, stop, step) + } + }; - let stop = if !stop.is(&vm.get_none()) { - Some(pyobject_to_opt_usize(stop, &vm).ok_or_else(|| { - vm.new_value_error( + let stop = if !stop.is(&vm.get_none()) { + Some(pyobject_to_opt_usize(stop, &vm).ok_or_else(|| { + vm.new_value_error( "Stop argument for islice() must be None or an integer: 0 <= x <= sys.maxsize." .to_owned(), ) - })?) - } else { - None - }; - - let iter = get_iter(vm, &iter)?; - - PyItertoolsIslice { - iterable: iter, - cur: RefCell::new(0), - next: RefCell::new(start), - stop, - step, - } - .into_ref_with_type(vm, cls) - } + })?) + } else { + None + }; - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - while *self.cur.borrow() < *self.next.borrow() { - call_next(vm, &self.iterable)?; - *self.cur.borrow_mut() += 1; - } + let iter = get_iter(vm, &iter)?; - if let Some(stop) = self.stop { - if *self.cur.borrow() >= stop { - return Err(new_stop_iteration(vm)); + PyItertoolsIslice { + iterable: iter, + cur: RefCell::new(0), + next: RefCell::new(start), + stop, + step, } + .into_ref_with_type(vm, cls) } - let obj = call_next(vm, &self.iterable)?; - *self.cur.borrow_mut() += 1; + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + while *self.cur.borrow() < *self.next.borrow() { + call_next(vm, &self.iterable)?; + *self.cur.borrow_mut() += 1; + } - // TODO is this overflow check required? attempts to copy CPython. - let (next, ovf) = (*self.next.borrow()).overflowing_add(self.step); - *self.next.borrow_mut() = if ovf { self.stop.unwrap() } else { next }; + if let Some(stop) = self.stop { + if *self.cur.borrow() >= stop { + return Err(new_stop_iteration(vm)); + } + } - Ok(obj) - } + let obj = call_next(vm, &self.iterable)?; + *self.cur.borrow_mut() += 1; - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf - } -} + // TODO is this overflow check required? attempts to copy CPython. + let (next, ovf) = (*self.next.borrow()).overflowing_add(self.step); + *self.next.borrow_mut() = if ovf { self.stop.unwrap() } else { next }; -#[pyclass] -#[derive(Debug)] -struct PyItertoolsFilterFalse { - predicate: PyObjectRef, - iterable: PyObjectRef, -} + Ok(obj) + } -impl PyValue for PyItertoolsFilterFalse { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "filterfalse") + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } } -} -#[pyimpl] -impl PyItertoolsFilterFalse { - #[pyslot] - fn tp_new( - cls: PyClassRef, + #[pyclass(name = "filterfalse")] + #[derive(Debug)] + struct PyItertoolsFilterFalse { predicate: PyObjectRef, iterable: PyObjectRef, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let iter = get_iter(vm, &iterable)?; + } - PyItertoolsFilterFalse { - predicate, - iterable: iter, + impl PyValue for PyItertoolsFilterFalse { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "filterfalse") } - .into_ref_with_type(vm, cls) } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - let predicate = &self.predicate; - let iterable = &self.iterable; - - loop { - let obj = call_next(vm, iterable)?; - let pred_value = if predicate.is(&vm.get_none()) { - obj.clone() - } else { - vm.invoke(predicate, vec![obj.clone()])? - }; - - if !objbool::boolval(vm, pred_value)? { - return Ok(obj); + #[pyimpl] + impl PyItertoolsFilterFalse { + #[pyslot] + fn tp_new( + cls: PyClassRef, + predicate: PyObjectRef, + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let iter = get_iter(vm, &iterable)?; + + PyItertoolsFilterFalse { + predicate, + iterable: iter, } + .into_ref_with_type(vm, cls) } - } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf - } -} + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let predicate = &self.predicate; + let iterable = &self.iterable; -#[pyclass] -#[derive(Debug)] -struct PyItertoolsAccumulate { - iterable: PyObjectRef, - binop: PyObjectRef, - acc_value: RefCell<Option<PyObjectRef>>, -} + loop { + let obj = call_next(vm, iterable)?; + let pred_value = if predicate.is(&vm.get_none()) { + obj.clone() + } else { + vm.invoke(predicate, vec![obj.clone()])? + }; -impl PyValue for PyItertoolsAccumulate { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "accumulate") + if !objbool::boolval(vm, pred_value)? { + return Ok(obj); + } + } + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } } -} -#[pyimpl] -impl PyItertoolsAccumulate { - #[pyslot] - fn tp_new( - cls: PyClassRef, + #[pyclass(name = "accumulate")] + #[derive(Debug)] + struct PyItertoolsAccumulate { iterable: PyObjectRef, - binop: OptionalArg<PyObjectRef>, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let iter = get_iter(vm, &iterable)?; + binop: PyObjectRef, + acc_value: RefCell<Option<PyObjectRef>>, + } - PyItertoolsAccumulate { - iterable: iter, - binop: binop.unwrap_or_else(|| vm.get_none()), - acc_value: RefCell::from(Option::None), + impl PyValue for PyItertoolsAccumulate { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "accumulate") } - .into_ref_with_type(vm, cls) } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - let iterable = &self.iterable; - let obj = call_next(vm, iterable)?; + #[pyimpl] + impl PyItertoolsAccumulate { + #[pyslot] + fn tp_new( + cls: PyClassRef, + iterable: PyObjectRef, + binop: OptionalArg<PyObjectRef>, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let iter = get_iter(vm, &iterable)?; + + PyItertoolsAccumulate { + iterable: iter, + binop: binop.unwrap_or_else(|| vm.get_none()), + acc_value: RefCell::from(Option::None), + } + .into_ref_with_type(vm, cls) + } - let next_acc_value = match &*self.acc_value.borrow() { - None => obj.clone(), - Some(value) => { - if self.binop.is(&vm.get_none()) { - vm._add(value.clone(), obj.clone())? - } else { - vm.invoke(&self.binop, vec![value.clone(), obj.clone()])? + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let iterable = &self.iterable; + let obj = call_next(vm, iterable)?; + + let next_acc_value = match &*self.acc_value.borrow() { + None => obj.clone(), + Some(value) => { + if self.binop.is(&vm.get_none()) { + vm._add(value.clone(), obj.clone())? + } else { + vm.invoke(&self.binop, vec![value.clone(), obj.clone()])? + } } - } - }; - self.acc_value.replace(Option::from(next_acc_value.clone())); + }; + self.acc_value.replace(Option::from(next_acc_value.clone())); - Ok(next_acc_value) - } + Ok(next_acc_value) + } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } } -} -#[derive(Debug)] -struct PyItertoolsTeeData { - iterable: PyObjectRef, - values: RefCell<Vec<PyObjectRef>>, -} - -impl PyItertoolsTeeData { - fn new(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult<Rc<PyItertoolsTeeData>> { - Ok(Rc::new(PyItertoolsTeeData { - iterable: get_iter(vm, &iterable)?, - values: RefCell::new(vec![]), - })) + #[derive(Debug)] + struct PyItertoolsTeeData { + iterable: PyObjectRef, + values: RefCell<Vec<PyObjectRef>>, } - fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult { - if self.values.borrow().len() == index { - let result = call_next(vm, &self.iterable)?; - self.values.borrow_mut().push(result); + impl PyItertoolsTeeData { + fn new(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult<Rc<PyItertoolsTeeData>> { + Ok(Rc::new(PyItertoolsTeeData { + iterable: get_iter(vm, &iterable)?, + values: RefCell::new(vec![]), + })) } - Ok(self.values.borrow()[index].clone()) - } -} -#[pyclass] -#[derive(Debug)] -struct PyItertoolsTee { - tee_data: Rc<PyItertoolsTeeData>, - index: Cell<usize>, -} + fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult { + if self.values.borrow().len() == index { + let result = call_next(vm, &self.iterable)?; + self.values.borrow_mut().push(result); + } + Ok(self.values.borrow()[index].clone()) + } + } -impl PyValue for PyItertoolsTee { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "tee") + #[pyclass(name = "tee")] + #[derive(Debug)] + struct PyItertoolsTee { + tee_data: Rc<PyItertoolsTeeData>, + index: Cell<usize>, } -} -#[pyimpl] -impl PyItertoolsTee { - fn from_iter(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let it = get_iter(vm, &iterable)?; - if it.class().is(&PyItertoolsTee::class(vm)) { - return vm.call_method(&it, "__copy__", PyFuncArgs::from(vec![])); - } - Ok(PyItertoolsTee { - tee_data: PyItertoolsTeeData::new(it, vm)?, - index: Cell::from(0), + impl PyValue for PyItertoolsTee { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "tee") } - .into_ref_with_type(vm, PyItertoolsTee::class(vm))? - .into_object()) } - // TODO: make tee() a function, rename this class to itertools._tee and make - // teedata a python class - #[pyslot] - #[allow(clippy::new_ret_no_self)] - fn tp_new( - _cls: PyClassRef, - iterable: PyObjectRef, - n: OptionalArg<usize>, - vm: &VirtualMachine, - ) -> PyResult<PyRef<PyTuple>> { - let n = n.unwrap_or(2); + #[pyimpl] + impl PyItertoolsTee { + fn from_iter(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let it = get_iter(vm, &iterable)?; + if it.class().is(&PyItertoolsTee::class(vm)) { + return vm.call_method(&it, "__copy__", PyFuncArgs::from(vec![])); + } + Ok(PyItertoolsTee { + tee_data: PyItertoolsTeeData::new(it, vm)?, + index: Cell::from(0), + } + .into_ref_with_type(vm, PyItertoolsTee::class(vm))? + .into_object()) + } - let copyable = if iterable.class().has_attr("__copy__") { - vm.call_method(&iterable, "__copy__", PyFuncArgs::from(vec![]))? - } else { - PyItertoolsTee::from_iter(iterable, vm)? - }; + // TODO: make tee() a function, rename this class to itertools._tee and make + // teedata a python class + #[pyslot] + #[allow(clippy::new_ret_no_self)] + fn tp_new( + _cls: PyClassRef, + iterable: PyObjectRef, + n: OptionalArg<usize>, + vm: &VirtualMachine, + ) -> PyResult<PyRef<PyTuple>> { + let n = n.unwrap_or(2); + + let copyable = if iterable.class().has_attr("__copy__") { + vm.call_method(&iterable, "__copy__", PyFuncArgs::from(vec![]))? + } else { + PyItertoolsTee::from_iter(iterable, vm)? + }; - let mut tee_vec: Vec<PyObjectRef> = Vec::with_capacity(n); - for _ in 0..n { - let no_args = PyFuncArgs::from(vec![]); - tee_vec.push(vm.call_method(©able, "__copy__", no_args)?); + let mut tee_vec: Vec<PyObjectRef> = Vec::with_capacity(n); + for _ in 0..n { + let no_args = PyFuncArgs::from(vec![]); + tee_vec.push(vm.call_method(©able, "__copy__", no_args)?); + } + + Ok(PyTuple::from(tee_vec).into_ref(vm)) } - Ok(PyTuple::from(tee_vec).into_ref(vm)) - } + #[pymethod(name = "__copy__")] + fn copy(&self, vm: &VirtualMachine) -> PyResult { + Ok(PyItertoolsTee { + tee_data: Rc::clone(&self.tee_data), + index: self.index.clone(), + } + .into_ref_with_type(vm, Self::class(vm))? + .into_object()) + } - #[pymethod(name = "__copy__")] - fn copy(&self, vm: &VirtualMachine) -> PyResult { - Ok(PyItertoolsTee { - tee_data: Rc::clone(&self.tee_data), - index: self.index.clone(), + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let value = self.tee_data.get_item(vm, self.index.get())?; + self.index.set(self.index.get() + 1); + Ok(value) } - .into_ref_with_type(vm, Self::class(vm))? - .into_object()) - } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - let value = self.tee_data.get_item(vm, self.index.get())?; - self.index.set(self.index.get() + 1); - Ok(value) + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf + #[pyclass(name = "product")] + #[derive(Debug)] + struct PyItertoolsProduct { + pools: Vec<Vec<PyObjectRef>>, + idxs: RefCell<Vec<usize>>, + cur: Cell<usize>, + stop: Cell<bool>, } -} -#[pyclass] -#[derive(Debug)] -struct PyItertoolsProduct { - pools: Vec<Vec<PyObjectRef>>, - idxs: RefCell<Vec<usize>>, - cur: Cell<usize>, - stop: Cell<bool>, -} - -impl PyValue for PyItertoolsProduct { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "product") + impl PyValue for PyItertoolsProduct { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "product") + } } -} -#[derive(FromArgs)] -struct ProductArgs { - #[pyarg(keyword_only, optional = true)] - repeat: OptionalArg<usize>, -} + #[derive(FromArgs)] + struct ProductArgs { + #[pyarg(keyword_only, optional = true)] + repeat: OptionalArg<usize>, + } + + #[pyimpl] + impl PyItertoolsProduct { + #[pyslot] + fn tp_new( + cls: PyClassRef, + iterables: Args<PyObjectRef>, + args: ProductArgs, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let repeat = match args.repeat.into_option() { + Some(i) => i, + None => 1, + }; -#[pyimpl] -impl PyItertoolsProduct { - #[pyslot] - fn tp_new( - cls: PyClassRef, - iterables: Args<PyObjectRef>, - args: ProductArgs, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let repeat = match args.repeat.into_option() { - Some(i) => i, - None => 1, - }; - - let mut pools = Vec::new(); - for arg in iterables.into_iter() { - let it = get_iter(vm, &arg)?; - let pool = get_all(vm, &it)?; - - pools.push(pool); - } - let pools = iter::repeat(pools) - .take(repeat) - .flatten() - .collect::<Vec<Vec<PyObjectRef>>>(); - - let l = pools.len(); - - PyItertoolsProduct { - pools, - idxs: RefCell::new(vec![0; l]), - cur: Cell::new(l - 1), - stop: Cell::new(false), - } - .into_ref_with_type(vm, cls) - } + let mut pools = Vec::new(); + for arg in iterables.into_iter() { + let it = get_iter(vm, &arg)?; + let pool = get_all(vm, &it)?; - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - // stop signal - if self.stop.get() { - return Err(new_stop_iteration(vm)); + pools.push(pool); + } + let pools = iter::repeat(pools) + .take(repeat) + .flatten() + .collect::<Vec<Vec<PyObjectRef>>>(); + + let l = pools.len(); + + PyItertoolsProduct { + pools, + idxs: RefCell::new(vec![0; l]), + cur: Cell::new(l - 1), + stop: Cell::new(false), + } + .into_ref_with_type(vm, cls) } - let pools = &self.pools; - - for p in pools { - if p.is_empty() { + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + // stop signal + if self.stop.get() { return Err(new_stop_iteration(vm)); } - } - - let res = PyTuple::from( - pools - .iter() - .zip(self.idxs.borrow().iter()) - .map(|(pool, idx)| pool[*idx].clone()) - .collect::<Vec<PyObjectRef>>(), - ); - self.update_idxs(); + let pools = &self.pools; - if self.is_end() { - self.stop.set(true); - } - - Ok(res.into_ref(vm).into_object()) - } + for p in pools { + if p.is_empty() { + return Err(new_stop_iteration(vm)); + } + } - fn is_end(&self) -> bool { - let cur = self.cur.get(); - self.idxs.borrow()[cur] == self.pools[cur].len() - 1 && cur == 0 - } + let res = PyTuple::from( + pools + .iter() + .zip(self.idxs.borrow().iter()) + .map(|(pool, idx)| pool[*idx].clone()) + .collect::<Vec<PyObjectRef>>(), + ); - fn update_idxs(&self) { - let lst_idx = &self.pools[self.cur.get()].len() - 1; + self.update_idxs(); - if self.idxs.borrow()[self.cur.get()] == lst_idx { if self.is_end() { - return; + self.stop.set(true); } - self.idxs.borrow_mut()[self.cur.get()] = 0; - self.cur.set(self.cur.get() - 1); - self.update_idxs(); - } else { - self.idxs.borrow_mut()[self.cur.get()] += 1; - self.cur.set(self.idxs.borrow().len() - 1); - } - } - - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf - } -} -#[pyclass] -#[derive(Debug)] -struct PyItertoolsCombinations { - pool: Vec<PyObjectRef>, - indices: RefCell<Vec<usize>>, - r: Cell<usize>, - exhausted: Cell<bool>, -} + Ok(res.into_ref(vm).into_object()) + } -impl PyValue for PyItertoolsCombinations { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "combinations") - } -} + fn is_end(&self) -> bool { + let cur = self.cur.get(); + self.idxs.borrow()[cur] == self.pools[cur].len() - 1 && cur == 0 + } -#[pyimpl] -impl PyItertoolsCombinations { - #[pyslot] - fn tp_new( - cls: PyClassRef, - iterable: PyObjectRef, - r: PyIntRef, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let iter = get_iter(vm, &iterable)?; - let pool = get_all(vm, &iter)?; + fn update_idxs(&self) { + let lst_idx = &self.pools[self.cur.get()].len() - 1; - let r = r.as_bigint(); - if r.is_negative() { - return Err(vm.new_value_error("r must be non-negative".to_owned())); + if self.idxs.borrow()[self.cur.get()] == lst_idx { + if self.is_end() { + return; + } + self.idxs.borrow_mut()[self.cur.get()] = 0; + self.cur.set(self.cur.get() - 1); + self.update_idxs(); + } else { + self.idxs.borrow_mut()[self.cur.get()] += 1; + self.cur.set(self.idxs.borrow().len() - 1); + } } - let r = r.to_usize().unwrap(); - - let n = pool.len(); - PyItertoolsCombinations { - pool, - indices: RefCell::new((0..r).collect()), - r: Cell::new(r), - exhausted: Cell::new(r > n), + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf } - .into_ref_with_type(vm, cls) } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf + #[pyclass(name = "combinations")] + #[derive(Debug)] + struct PyItertoolsCombinations { + pool: Vec<PyObjectRef>, + indices: RefCell<Vec<usize>>, + r: Cell<usize>, + exhausted: Cell<bool>, } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - // stop signal - if self.exhausted.get() { - return Err(new_stop_iteration(vm)); + impl PyValue for PyItertoolsCombinations { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "combinations") } + } - let n = self.pool.len(); - let r = self.r.get(); - - if r == 0 { - self.exhausted.set(true); - return Ok(vm.ctx.new_tuple(vec![])); - } + #[pyimpl] + impl PyItertoolsCombinations { + #[pyslot] + fn tp_new( + cls: PyClassRef, + iterable: PyObjectRef, + r: PyIntRef, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let iter = get_iter(vm, &iterable)?; + let pool = get_all(vm, &iter)?; + + let r = r.as_bigint(); + if r.is_negative() { + return Err(vm.new_value_error("r must be non-negative".to_owned())); + } + let r = r.to_usize().unwrap(); - let res = PyTuple::from( - self.indices - .borrow() - .iter() - .map(|&i| self.pool[i].clone()) - .collect::<Vec<PyObjectRef>>(), - ); + let n = pool.len(); - let mut indices = self.indices.borrow_mut(); + PyItertoolsCombinations { + pool, + indices: RefCell::new((0..r).collect()), + r: Cell::new(r), + exhausted: Cell::new(r > n), + } + .into_ref_with_type(vm, cls) + } - // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). - let mut idx = r as isize - 1; - while idx >= 0 && indices[idx as usize] == idx as usize + n - r { - idx -= 1; + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf } - // If no suitable index is found, then the indices are all at - // their maximum value and we're done. - if idx < 0 { - self.exhausted.set(true); - } else { - // Increment the current index which we know is not at its - // maximum. Then move back to the right setting each index - // to its lowest possible value (one higher than the index - // to its left -- this maintains the sort order invariant). - indices[idx as usize] += 1; - for j in idx as usize + 1..r { - indices[j] = indices[j - 1] + 1; + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + // stop signal + if self.exhausted.get() { + return Err(new_stop_iteration(vm)); } - } - Ok(res.into_ref(vm).into_object()) - } -} + let n = self.pool.len(); + let r = self.r.get(); -#[pyclass] -#[derive(Debug)] -struct PyItertoolsCombinationsWithReplacement { - pool: Vec<PyObjectRef>, - indices: RefCell<Vec<usize>>, - r: Cell<usize>, - exhausted: Cell<bool>, -} + if r == 0 { + self.exhausted.set(true); + return Ok(vm.ctx.new_tuple(vec![])); + } -impl PyValue for PyItertoolsCombinationsWithReplacement { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "combinations_with_replacement") - } -} + let res = PyTuple::from( + self.indices + .borrow() + .iter() + .map(|&i| self.pool[i].clone()) + .collect::<Vec<PyObjectRef>>(), + ); -#[pyimpl] -impl PyItertoolsCombinationsWithReplacement { - #[pyslot] - fn tp_new( - cls: PyClassRef, - iterable: PyObjectRef, - r: PyIntRef, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let iter = get_iter(vm, &iterable)?; - let pool = get_all(vm, &iter)?; + let mut indices = self.indices.borrow_mut(); - let r = r.as_bigint(); - if r.is_negative() { - return Err(vm.new_value_error("r must be non-negative".to_owned())); - } - let r = r.to_usize().unwrap(); + // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). + let mut idx = r as isize - 1; + while idx >= 0 && indices[idx as usize] == idx as usize + n - r { + idx -= 1; + } - let n = pool.len(); + // If no suitable index is found, then the indices are all at + // their maximum value and we're done. + if idx < 0 { + self.exhausted.set(true); + } else { + // Increment the current index which we know is not at its + // maximum. Then move back to the right setting each index + // to its lowest possible value (one higher than the index + // to its left -- this maintains the sort order invariant). + indices[idx as usize] += 1; + for j in idx as usize + 1..r { + indices[j] = indices[j - 1] + 1; + } + } - PyItertoolsCombinationsWithReplacement { - pool, - indices: RefCell::new(vec![0; r]), - r: Cell::new(r), - exhausted: Cell::new(n == 0 && r > 0), + Ok(res.into_ref(vm).into_object()) } - .into_ref_with_type(vm, cls) } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf + #[pyclass(name = "combinations_with_replacement")] + #[derive(Debug)] + struct PyItertoolsCombinationsWithReplacement { + pool: Vec<PyObjectRef>, + indices: RefCell<Vec<usize>>, + r: Cell<usize>, + exhausted: Cell<bool>, } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - // stop signal - if self.exhausted.get() { - return Err(new_stop_iteration(vm)); + impl PyValue for PyItertoolsCombinationsWithReplacement { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "combinations_with_replacement") } + } - let n = self.pool.len(); - let r = self.r.get(); - - if r == 0 { - self.exhausted.set(true); - return Ok(vm.ctx.new_tuple(vec![])); - } + #[pyimpl] + impl PyItertoolsCombinationsWithReplacement { + #[pyslot] + fn tp_new( + cls: PyClassRef, + iterable: PyObjectRef, + r: PyIntRef, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let iter = get_iter(vm, &iterable)?; + let pool = get_all(vm, &iter)?; + + let r = r.as_bigint(); + if r.is_negative() { + return Err(vm.new_value_error("r must be non-negative".to_owned())); + } + let r = r.to_usize().unwrap(); - let mut indices = self.indices.borrow_mut(); + let n = pool.len(); - let res = vm - .ctx - .new_tuple(indices.iter().map(|&i| self.pool[i].clone()).collect()); + PyItertoolsCombinationsWithReplacement { + pool, + indices: RefCell::new(vec![0; r]), + r: Cell::new(r), + exhausted: Cell::new(n == 0 && r > 0), + } + .into_ref_with_type(vm, cls) + } - // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). - let mut idx = r as isize - 1; - while idx >= 0 && indices[idx as usize] == n - 1 { - idx -= 1; + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf } - // If no suitable index is found, then the indices are all at - // their maximum value and we're done. - if idx < 0 { - self.exhausted.set(true); - } else { - let index = indices[idx as usize] + 1; + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + // stop signal + if self.exhausted.get() { + return Err(new_stop_iteration(vm)); + } + + let n = self.pool.len(); + let r = self.r.get(); - // Increment the current index which we know is not at its - // maximum. Then set all to the right to the same value. - for j in idx as usize..r { - indices[j as usize] = index as usize; + if r == 0 { + self.exhausted.set(true); + return Ok(vm.ctx.new_tuple(vec![])); } - } - Ok(res) - } -} + let mut indices = self.indices.borrow_mut(); -#[pyclass] -#[derive(Debug)] -struct PyItertoolsPermutations { - pool: Vec<PyObjectRef>, // Collected input iterable - indices: RefCell<Vec<usize>>, // One index per element in pool - cycles: RefCell<Vec<usize>>, // One rollover counter per element in the result - result: RefCell<Option<Vec<usize>>>, // Indexes of the most recently returned result - r: Cell<usize>, // Size of result tuple - exhausted: Cell<bool>, // Set when the iterator is exhausted -} + let res = vm + .ctx + .new_tuple(indices.iter().map(|&i| self.pool[i].clone()).collect()); -impl PyValue for PyItertoolsPermutations { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "permutations") - } -} + // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). + let mut idx = r as isize - 1; + while idx >= 0 && indices[idx as usize] == n - 1 { + idx -= 1; + } -#[pyimpl] -impl PyItertoolsPermutations { - #[pyslot] - fn tp_new( - cls: PyClassRef, - iterable: PyObjectRef, - r: OptionalOption<PyObjectRef>, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let iter = get_iter(vm, &iterable)?; - let pool = get_all(vm, &iter)?; - - let n = pool.len(); - // If r is not provided, r == n. If provided, r must be a positive integer, or None. - // If None, it behaves the same as if it was not provided. - let r = match r.flat_option() { - Some(r) => { - let val = r - .payload::<PyInt>() - .ok_or_else(|| vm.new_type_error("Expected int as r".to_owned()))? - .as_bigint(); - - if val.is_negative() { - return Err(vm.new_value_error("r must be non-negative".to_owned())); + // If no suitable index is found, then the indices are all at + // their maximum value and we're done. + if idx < 0 { + self.exhausted.set(true); + } else { + let index = indices[idx as usize] + 1; + + // Increment the current index which we know is not at its + // maximum. Then set all to the right to the same value. + for j in idx as usize..r { + indices[j as usize] = index as usize; } - val.to_usize().unwrap() } - None => n, - }; - - PyItertoolsPermutations { - pool, - indices: RefCell::new((0..n).collect()), - cycles: RefCell::new((0..r).map(|i| n - i).collect()), - result: RefCell::new(None), - r: Cell::new(r), - exhausted: Cell::new(r > n), - } - .into_ref_with_type(vm, cls) + + Ok(res) + } } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf + #[pyclass(name = "permutations")] + #[derive(Debug)] + struct PyItertoolsPermutations { + pool: Vec<PyObjectRef>, // Collected input iterable + indices: RefCell<Vec<usize>>, // One index per element in pool + cycles: RefCell<Vec<usize>>, // One rollover counter per element in the result + result: RefCell<Option<Vec<usize>>>, // Indexes of the most recently returned result + r: Cell<usize>, // Size of result tuple + exhausted: Cell<bool>, // Set when the iterator is exhausted } - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - // stop signal - if self.exhausted.get() { - return Err(new_stop_iteration(vm)); + impl PyValue for PyItertoolsPermutations { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "permutations") } + } - let n = self.pool.len(); - let r = self.r.get(); + #[pyimpl] + impl PyItertoolsPermutations { + #[pyslot] + fn tp_new( + cls: PyClassRef, + iterable: PyObjectRef, + r: OptionalOption<PyObjectRef>, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let iter = get_iter(vm, &iterable)?; + let pool = get_all(vm, &iter)?; + + let n = pool.len(); + // If r is not provided, r == n. If provided, r must be a positive integer, or None. + // If None, it behaves the same as if it was not provided. + let r = match r.flat_option() { + Some(r) => { + let val = r + .payload::<PyInt>() + .ok_or_else(|| vm.new_type_error("Expected int as r".to_owned()))? + .as_bigint(); + + if val.is_negative() { + return Err(vm.new_value_error("r must be non-negative".to_owned())); + } + val.to_usize().unwrap() + } + None => n, + }; - if n == 0 { - self.exhausted.set(true); - return Ok(vm.ctx.new_tuple(vec![])); + PyItertoolsPermutations { + pool, + indices: RefCell::new((0..n).collect()), + cycles: RefCell::new((0..r).map(|i| n - i).collect()), + result: RefCell::new(None), + r: Cell::new(r), + exhausted: Cell::new(r > n), + } + .into_ref_with_type(vm, cls) } - let result = &mut *self.result.borrow_mut(); - - if let Some(ref mut result) = result { - let mut indices = self.indices.borrow_mut(); - let mut cycles = self.cycles.borrow_mut(); - let mut sentinel = false; - - // Decrement rightmost cycle, moving leftward upon zero rollover - for i in (0..r).rev() { - cycles[i] -= 1; - - if cycles[i] == 0 { - // rotation: indices[i:] = indices[i+1:] + indices[i:i+1] - let index = indices[i]; - for j in i..n - 1 { - indices[j] = indices[j + i]; - } - indices[n - 1] = index; - cycles[i] = n - i; - } else { - let j = cycles[i]; - indices.swap(i, n - j); + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } - for k in i..r { - // start with i, the leftmost element that changed - // yield tuple(pool[k] for k in indices[:r]) - result[k] = indices[k]; - } - sentinel = true; - break; - } - } - if !sentinel { - self.exhausted.set(true); + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + // stop signal + if self.exhausted.get() { return Err(new_stop_iteration(vm)); } - } else { - // On the first pass, initialize result tuple using the indices - *result = Some((0..r).collect()); - } - - Ok(vm.ctx.new_tuple( - result - .as_ref() - .unwrap() - .iter() - .map(|&i| self.pool[i].clone()) - .collect(), - )) - } -} -#[pyclass] -#[derive(Debug)] -struct PyItertoolsZiplongest { - iterators: Vec<PyObjectRef>, - fillvalue: PyObjectRef, - numactive: Cell<usize>, -} + let n = self.pool.len(); + let r = self.r.get(); -impl PyValue for PyItertoolsZiplongest { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("itertools", "zip_longest") - } -} + if n == 0 { + self.exhausted.set(true); + return Ok(vm.ctx.new_tuple(vec![])); + } -#[derive(FromArgs)] -struct ZiplongestArgs { - #[pyarg(keyword_only, optional = true)] - fillvalue: OptionalArg<PyObjectRef>, -} + let result = &mut *self.result.borrow_mut(); -#[pyimpl] -impl PyItertoolsZiplongest { - #[pyslot] - fn tp_new( - cls: PyClassRef, - iterables: Args, - args: ZiplongestArgs, - vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let fillvalue = match args.fillvalue.into_option() { - Some(i) => i, - None => vm.get_none(), - }; - - let iterators = iterables - .into_iter() - .map(|iterable| get_iter(vm, &iterable)) - .collect::<Result<Vec<_>, _>>()?; - - let numactive = Cell::new(iterators.len()); - - PyItertoolsZiplongest { - iterators, - fillvalue, - numactive, - } - .into_ref_with_type(vm, cls) - } + if let Some(ref mut result) = result { + let mut indices = self.indices.borrow_mut(); + let mut cycles = self.cycles.borrow_mut(); + let mut sentinel = false; - #[pymethod(name = "__next__")] - fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.iterators.is_empty() { - Err(new_stop_iteration(vm)) - } else { - let mut result: Vec<PyObjectRef> = Vec::new(); - let mut numactive = self.numactive.get(); + // Decrement rightmost cycle, moving leftward upon zero rollover + for i in (0..r).rev() { + cycles[i] -= 1; - for idx in 0..self.iterators.len() { - let next_obj = match call_next(vm, &self.iterators[idx]) { - Ok(obj) => obj, - Err(err) => { - if !objtype::isinstance(&err, &vm.ctx.exceptions.stop_iteration) { - return Err(err); + if cycles[i] == 0 { + // rotation: indices[i:] = indices[i+1:] + indices[i:i+1] + let index = indices[i]; + for j in i..n - 1 { + indices[j] = indices[j + i]; } - numactive -= 1; - if numactive == 0 { - return Err(new_stop_iteration(vm)); + indices[n - 1] = index; + cycles[i] = n - i; + } else { + let j = cycles[i]; + indices.swap(i, n - j); + + for k in i..r { + // start with i, the leftmost element that changed + // yield tuple(pool[k] for k in indices[:r]) + result[k] = indices[k]; } - self.fillvalue.clone() + sentinel = true; + break; } - }; - result.push(next_obj); + } + if !sentinel { + self.exhausted.set(true); + return Err(new_stop_iteration(vm)); + } + } else { + // On the first pass, initialize result tuple using the indices + *result = Some((0..r).collect()); } - Ok(vm.ctx.new_tuple(result)) + + Ok(vm.ctx.new_tuple( + result + .as_ref() + .unwrap() + .iter() + .map(|&i| self.pool[i].clone()) + .collect(), + )) } } - #[pymethod(name = "__iter__")] - fn iter(zelf: PyRef<Self>) -> PyRef<Self> { - zelf + #[pyclass(name = "zip_longest")] + #[derive(Debug)] + struct PyItertoolsZipLongest { + iterators: Vec<PyObjectRef>, + fillvalue: PyObjectRef, + numactive: Cell<usize>, } -} -pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { - let ctx = &vm.ctx; - - let accumulate = ctx.new_class("accumulate", ctx.object()); - PyItertoolsAccumulate::extend_class(ctx, &accumulate); - - let chain = PyItertoolsChain::make_class(ctx); - - let compress = PyItertoolsCompress::make_class(ctx); - - let combinations = ctx.new_class("combinations", ctx.object()); - PyItertoolsCombinations::extend_class(ctx, &combinations); - - let combinations_with_replacement = - ctx.new_class("combinations_with_replacement", ctx.object()); - PyItertoolsCombinationsWithReplacement::extend_class(ctx, &combinations_with_replacement); - - let count = ctx.new_class("count", ctx.object()); - PyItertoolsCount::extend_class(ctx, &count); - - let cycle = ctx.new_class("cycle", ctx.object()); - PyItertoolsCycle::extend_class(ctx, &cycle); - - let dropwhile = ctx.new_class("dropwhile", ctx.object()); - PyItertoolsDropwhile::extend_class(ctx, &dropwhile); - - let islice = PyItertoolsIslice::make_class(ctx); - - let filterfalse = ctx.new_class("filterfalse", ctx.object()); - PyItertoolsFilterFalse::extend_class(ctx, &filterfalse); - - let permutations = ctx.new_class("permutations", ctx.object()); - PyItertoolsPermutations::extend_class(ctx, &permutations); - - let product = ctx.new_class("product", ctx.object()); - PyItertoolsProduct::extend_class(ctx, &product); + impl PyValue for PyItertoolsZipLongest { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "zip_longest") + } + } - let repeat = ctx.new_class("repeat", ctx.object()); - PyItertoolsRepeat::extend_class(ctx, &repeat); + #[derive(FromArgs)] + struct ZiplongestArgs { + #[pyarg(keyword_only, optional = true)] + fillvalue: OptionalArg<PyObjectRef>, + } + + #[pyimpl] + impl PyItertoolsZipLongest { + #[pyslot] + fn tp_new( + cls: PyClassRef, + iterables: Args, + args: ZiplongestArgs, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let fillvalue = match args.fillvalue.into_option() { + Some(i) => i, + None => vm.get_none(), + }; - let starmap = PyItertoolsStarmap::make_class(ctx); + let iterators = iterables + .into_iter() + .map(|iterable| get_iter(vm, &iterable)) + .collect::<Result<Vec<_>, _>>()?; - let takewhile = ctx.new_class("takewhile", ctx.object()); - PyItertoolsTakewhile::extend_class(ctx, &takewhile); + let numactive = Cell::new(iterators.len()); - let tee = ctx.new_class("tee", ctx.object()); - PyItertoolsTee::extend_class(ctx, &tee); + PyItertoolsZipLongest { + iterators, + fillvalue, + numactive, + } + .into_ref_with_type(vm, cls) + } - let zip_longest = ctx.new_class("zip_longest", ctx.object()); - PyItertoolsZiplongest::extend_class(ctx, &zip_longest); + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + if self.iterators.is_empty() { + Err(new_stop_iteration(vm)) + } else { + let mut result: Vec<PyObjectRef> = Vec::new(); + let mut numactive = self.numactive.get(); + + for idx in 0..self.iterators.len() { + let next_obj = match call_next(vm, &self.iterators[idx]) { + Ok(obj) => obj, + Err(err) => { + if !objtype::isinstance(&err, &vm.ctx.exceptions.stop_iteration) { + return Err(err); + } + numactive -= 1; + if numactive == 0 { + return Err(new_stop_iteration(vm)); + } + self.fillvalue.clone() + } + }; + result.push(next_obj); + } + Ok(vm.ctx.new_tuple(result)) + } + } - py_module!(vm, "itertools", { - "accumulate" => accumulate, - "chain" => chain, - "compress" => compress, - "combinations" => combinations, - "combinations_with_replacement" => combinations_with_replacement, - "count" => count, - "cycle" => cycle, - "dropwhile" => dropwhile, - "islice" => islice, - "filterfalse" => filterfalse, - "repeat" => repeat, - "starmap" => starmap, - "takewhile" => takewhile, - "tee" => tee, - "permutations" => permutations, - "product" => product, - "zip_longest" => zip_longest, - }) + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>) -> PyRef<Self> { + zelf + } + } } diff --git a/vm/src/stdlib/platform.rs b/vm/src/stdlib/platform.rs index 26f4949273..af22e54973 100644 --- a/vm/src/stdlib/platform.rs +++ b/vm/src/stdlib/platform.rs @@ -1,39 +1,37 @@ -use crate::pyobject::PyObjectRef; -use crate::version; -use crate::vm::VirtualMachine; +pub(crate) use decl::make_module; -pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { - let ctx = &vm.ctx; - py_module!(vm, "platform", { - "python_branch" => ctx.new_function(platform_python_branch), - "python_build" => ctx.new_function(platform_python_build), - "python_compiler" => ctx.new_function(platform_python_compiler), - "python_implementation" => ctx.new_function(platform_python_implementation), - "python_revision" => ctx.new_function(platform_python_revision), - "python_version" => ctx.new_function(platform_python_version), - }) -} +#[pymodule(name = "platform")] +mod decl { + use crate::version; + use crate::vm::VirtualMachine; -fn platform_python_implementation(_vm: &VirtualMachine) -> String { - "RustPython".to_owned() -} + #[pyfunction] + fn python_implementation(_vm: &VirtualMachine) -> String { + "RustPython".to_owned() + } -fn platform_python_version(_vm: &VirtualMachine) -> String { - version::get_version_number() -} + #[pyfunction] + fn python_version(_vm: &VirtualMachine) -> String { + version::get_version_number() + } -fn platform_python_compiler(_vm: &VirtualMachine) -> String { - version::get_compiler() -} + #[pyfunction] + fn python_compiler(_vm: &VirtualMachine) -> String { + version::get_compiler() + } -fn platform_python_build(_vm: &VirtualMachine) -> (String, String) { - (version::get_git_identifier(), version::get_git_datetime()) -} + #[pyfunction] + fn python_build(_vm: &VirtualMachine) -> (String, String) { + (version::get_git_identifier(), version::get_git_datetime()) + } -fn platform_python_branch(_vm: &VirtualMachine) -> String { - version::get_git_branch() -} + #[pyfunction] + fn python_branch(_vm: &VirtualMachine) -> String { + version::get_git_branch() + } -fn platform_python_revision(_vm: &VirtualMachine) -> String { - version::get_git_revision() + #[pyfunction] + fn python_revision(_vm: &VirtualMachine) -> String { + version::get_git_revision() + } } diff --git a/vm/src/stdlib/pystruct.rs b/vm/src/stdlib/pystruct.rs index f31aa36c7b..320d9875ae 100644 --- a/vm/src/stdlib/pystruct.rs +++ b/vm/src/stdlib/pystruct.rs @@ -9,236 +9,242 @@ * https://docs.rs/byteorder/1.2.6/byteorder/ */ -use byteorder::{ReadBytesExt, WriteBytesExt}; -use num_bigint::BigInt; -use num_traits::ToPrimitive; -use std::cmp; -use std::io::{Cursor, Read, Write}; -use std::iter::Peekable; - -use crate::exceptions::PyBaseExceptionRef; -use crate::function::Args; -use crate::obj::{ - objbool::IntoPyBool, objbytes::PyBytesRef, objstr::PyString, objstr::PyStringRef, - objtuple::PyTuple, objtype::PyClassRef, -}; -use crate::pyobject::{Either, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject}; +use crate::pyobject::PyObjectRef; use crate::VirtualMachine; -#[derive(Debug)] -enum Endianness { - Native, - Little, - Big, - Network, -} +#[pymodule] +mod _struct { + use byteorder::{ReadBytesExt, WriteBytesExt}; + use num_bigint::BigInt; + use num_traits::ToPrimitive; + use std::io::{Cursor, Read, Write}; + use std::iter::Peekable; + + use crate::exceptions::PyBaseExceptionRef; + use crate::function::Args; + use crate::obj::{ + objbool::IntoPyBool, objbytes::PyBytesRef, objstr::PyString, objstr::PyStringRef, + objtuple::PyTuple, objtype::PyClassRef, + }; + use crate::pyobject::{ + Either, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + }; + use crate::VirtualMachine; + + #[derive(Debug)] + enum Endianness { + Native, + Little, + Big, + Network, + } -#[derive(Debug)] -struct FormatCode { - repeat: u32, - code: char, -} + #[derive(Debug)] + struct FormatCode { + repeat: u32, + code: char, + } -impl FormatCode { - fn unit_size(&self) -> usize { - match self.code { - 'x' | 'c' | 'b' | 'B' | '?' | 's' | 'p' => 1, - 'h' | 'H' => 2, - 'i' | 'l' | 'I' | 'L' | 'f' => 4, - 'q' | 'Q' | 'd' => 8, - 'n' | 'N' | 'P' => std::mem::size_of::<usize>(), - c => { - panic!("Unsupported format code {:?}", c); + impl FormatCode { + fn unit_size(&self) -> usize { + match self.code { + 'x' | 'c' | 'b' | 'B' | '?' | 's' | 'p' => 1, + 'h' | 'H' => 2, + 'i' | 'l' | 'I' | 'L' | 'f' => 4, + 'q' | 'Q' | 'd' => 8, + 'n' | 'N' | 'P' => std::mem::size_of::<usize>(), + c => { + panic!("Unsupported format code {:?}", c); + } } } - } - fn size(&self) -> usize { - self.unit_size() * self.repeat as usize - } + fn size(&self) -> usize { + self.unit_size() * self.repeat as usize + } - fn arg_count(&self) -> usize { - match self.code { - 'x' => 0, - 's' | 'p' => 1, - _ => self.repeat as usize, + fn arg_count(&self) -> usize { + match self.code { + 'x' => 0, + 's' | 'p' => 1, + _ => self.repeat as usize, + } } } -} - -#[derive(Debug)] -struct FormatSpec { - endianness: Endianness, - codes: Vec<FormatCode>, -} -impl FormatSpec { - fn parse(fmt: &str) -> Result<FormatSpec, String> { - let mut chars = fmt.chars().peekable(); - - // First determine "<", ">","!" or "=" - let endianness = parse_endiannes(&mut chars); + #[derive(Debug)] + struct FormatSpec { + endianness: Endianness, + codes: Vec<FormatCode>, + } - // Now, analyze struct string furter: - let codes = parse_format_codes(&mut chars)?; + impl FormatSpec { + fn parse(fmt: &str) -> Result<FormatSpec, String> { + let mut chars = fmt.chars().peekable(); - Ok(FormatSpec { endianness, codes }) - } + // First determine "<", ">","!" or "=" + let endianness = parse_endiannes(&mut chars); - fn pack(&self, args: &[PyObjectRef], vm: &VirtualMachine) -> PyResult<Vec<u8>> { - let arg_count: usize = self.codes.iter().map(|c| c.arg_count()).sum(); - if arg_count != args.len() { - return Err(new_struct_error( - vm, - format!( - "pack expected {} items for packing (got {})", - self.codes.len(), - args.len() - ), - )); - } + // Now, analyze struct string furter: + let codes = parse_format_codes(&mut chars)?; - // Create data vector: - let mut data = Vec::<u8>::new(); - let mut arg_idx = 0; - // Loop over all opcodes: - for code in self.codes.iter() { - debug!("code: {:?}", code); - let pack_item = match self.endianness { - Endianness::Little => pack_item::<byteorder::LittleEndian>, - Endianness::Big => pack_item::<byteorder::BigEndian>, - Endianness::Network => pack_item::<byteorder::NetworkEndian>, - Endianness::Native => pack_item::<byteorder::NativeEndian>, - }; - arg_idx += pack_item(vm, code, &args[arg_idx..], &mut data)?; + Ok(FormatSpec { endianness, codes }) } - Ok(data) - } + fn pack(&self, args: &[PyObjectRef], vm: &VirtualMachine) -> PyResult<Vec<u8>> { + let arg_count: usize = self.codes.iter().map(|c| c.arg_count()).sum(); + if arg_count != args.len() { + return Err(new_struct_error( + vm, + format!( + "pack expected {} items for packing (got {})", + self.codes.len(), + args.len() + ), + )); + } - fn unpack(&self, data: &[u8], vm: &VirtualMachine) -> PyResult<PyTuple> { - if self.size() != data.len() { - return Err(new_struct_error( - vm, - format!("unpack requires a buffer of {} bytes", self.size()), - )); - } + // Create data vector: + let mut data = Vec::<u8>::new(); + let mut arg_idx = 0; + // Loop over all opcodes: + for code in self.codes.iter() { + debug!("code: {:?}", code); + let pack_item = match self.endianness { + Endianness::Little => pack_item::<byteorder::LittleEndian>, + Endianness::Big => pack_item::<byteorder::BigEndian>, + Endianness::Network => pack_item::<byteorder::NetworkEndian>, + Endianness::Native => pack_item::<byteorder::NativeEndian>, + }; + arg_idx += pack_item(vm, code, &args[arg_idx..], &mut data)?; + } - let mut rdr = Cursor::new(data); - let mut items = vec![]; - for code in &self.codes { - debug!("unpack code: {:?}", code); - match self.endianness { - Endianness::Little => { - unpack_code::<byteorder::LittleEndian>(vm, &code, &mut rdr, &mut items)? - } - Endianness::Big => { - unpack_code::<byteorder::BigEndian>(vm, &code, &mut rdr, &mut items)? - } - Endianness::Network => { - unpack_code::<byteorder::NetworkEndian>(vm, &code, &mut rdr, &mut items)? - } - Endianness::Native => { - unpack_code::<byteorder::NativeEndian>(vm, &code, &mut rdr, &mut items)? - } - }; + Ok(data) } - Ok(PyTuple::from(items)) - } + fn unpack(&self, data: &[u8], vm: &VirtualMachine) -> PyResult<PyTuple> { + if self.size() != data.len() { + return Err(new_struct_error( + vm, + format!("unpack requires a buffer of {} bytes", self.size()), + )); + } - fn size(&self) -> usize { - self.codes.iter().map(FormatCode::size).sum() - } -} + let mut rdr = Cursor::new(data); + let mut items = vec![]; + for code in &self.codes { + debug!("unpack code: {:?}", code); + match self.endianness { + Endianness::Little => { + unpack_code::<byteorder::LittleEndian>(vm, &code, &mut rdr, &mut items)? + } + Endianness::Big => { + unpack_code::<byteorder::BigEndian>(vm, &code, &mut rdr, &mut items)? + } + Endianness::Network => { + unpack_code::<byteorder::NetworkEndian>(vm, &code, &mut rdr, &mut items)? + } + Endianness::Native => { + unpack_code::<byteorder::NativeEndian>(vm, &code, &mut rdr, &mut items)? + } + }; + } -/// Parse endianness -/// See also: https://docs.python.org/3/library/struct.html?highlight=struct#byte-order-size-and-alignment -fn parse_endiannes<I>(chars: &mut Peekable<I>) -> Endianness -where - I: Sized + Iterator<Item = char>, -{ - match chars.peek() { - Some('@') => { - chars.next().unwrap(); - Endianness::Native + Ok(PyTuple::from(items)) } - Some('=') => { - chars.next().unwrap(); - Endianness::Native - } - Some('<') => { - chars.next().unwrap(); - Endianness::Little - } - Some('>') => { - chars.next().unwrap(); - Endianness::Big + + fn size(&self) -> usize { + self.codes.iter().map(FormatCode::size).sum() } - Some('!') => { - chars.next().unwrap(); - Endianness::Network + } + + /// Parse endianness + /// See also: https://docs.python.org/3/library/struct.html?highlight=struct#byte-order-size-and-alignment + fn parse_endiannes<I>(chars: &mut Peekable<I>) -> Endianness + where + I: Sized + Iterator<Item = char>, + { + match chars.peek() { + Some('@') => { + chars.next().unwrap(); + Endianness::Native + } + Some('=') => { + chars.next().unwrap(); + Endianness::Native + } + Some('<') => { + chars.next().unwrap(); + Endianness::Little + } + Some('>') => { + chars.next().unwrap(); + Endianness::Big + } + Some('!') => { + chars.next().unwrap(); + Endianness::Network + } + _ => Endianness::Native, } - _ => Endianness::Native, } -} -fn parse_format_codes<I>(chars: &mut Peekable<I>) -> Result<Vec<FormatCode>, String> -where - I: Sized + Iterator<Item = char>, -{ - let mut codes = vec![]; - while chars.peek().is_some() { - // determine repeat operator: - let repeat = match chars.peek() { - Some('0'..='9') => { - let mut repeat = 0; - while let Some('0'..='9') = chars.peek() { - if let Some(c) = chars.next() { - let current_digit = c.to_digit(10).unwrap(); - repeat = repeat * 10 + current_digit; + fn parse_format_codes<I>(chars: &mut Peekable<I>) -> Result<Vec<FormatCode>, String> + where + I: Sized + Iterator<Item = char>, + { + let mut codes = vec![]; + while chars.peek().is_some() { + // determine repeat operator: + let repeat = match chars.peek() { + Some('0'..='9') => { + let mut repeat = 0; + while let Some('0'..='9') = chars.peek() { + if let Some(c) = chars.next() { + let current_digit = c.to_digit(10).unwrap(); + repeat = repeat * 10 + current_digit; + } } + repeat } - repeat - } - _ => 1, - }; + _ => 1, + }; - // determine format char: - let c = chars.next(); - match c { - Some(c) if is_supported_format_character(c) => { - codes.push(FormatCode { repeat, code: c }) + // determine format char: + let c = chars.next(); + match c { + Some(c) if is_supported_format_character(c) => { + codes.push(FormatCode { repeat, code: c }) + } + _ => return Err(format!("Illegal format code {:?}", c)), } - _ => return Err(format!("Illegal format code {:?}", c)), } - } - Ok(codes) -} + Ok(codes) + } -fn is_supported_format_character(c: char) -> bool { - match c { - 'x' | 'c' | 'b' | 'B' | '?' | 'h' | 'H' | 'i' | 'I' | 'l' | 'L' | 'q' | 'Q' | 'n' | 'N' - | 'f' | 'd' | 's' | 'p' | 'P' => true, - _ => false, + fn is_supported_format_character(c: char) -> bool { + match c { + 'x' | 'c' | 'b' | 'B' | '?' | 'h' | 'H' | 'i' | 'I' | 'l' | 'L' | 'q' | 'Q' | 'n' + | 'N' | 'f' | 'd' | 's' | 'p' | 'P' => true, + _ => false, + } } -} -fn get_int_or_index<T>(vm: &VirtualMachine, arg: &PyObjectRef) -> PyResult<T> -where - T: TryFromObject, -{ - match vm.to_index(arg) { - Some(index) => Ok(T::try_from_object(vm, index?.into_object())?), - None => Err(new_struct_error( - vm, - "required argument is not an integer".to_owned(), - )), + fn get_int_or_index<T>(vm: &VirtualMachine, arg: &PyObjectRef) -> PyResult<T> + where + T: TryFromObject, + { + match vm.to_index(arg) { + Some(index) => Ok(T::try_from_object(vm, index?.into_object())?), + None => Err(new_struct_error( + vm, + "required argument is not an integer".to_owned(), + )), + } } -} -macro_rules! make_pack_no_endianess { + macro_rules! make_pack_no_endianess { ($T:ty) => { paste::item! { fn [<pack_ $T>](vm: &VirtualMachine, arg: &PyObjectRef, data: &mut dyn Write) -> PyResult<()> { @@ -249,7 +255,7 @@ macro_rules! make_pack_no_endianess { }; } -macro_rules! make_pack_with_endianess_int { + macro_rules! make_pack_with_endianess_int { ($T:ty) => { paste::item! { fn [<pack_ $T>]<Endianness>(vm: &VirtualMachine, arg: &PyObjectRef, data: &mut dyn Write) -> PyResult<()> @@ -263,7 +269,7 @@ macro_rules! make_pack_with_endianess_int { }; } -macro_rules! make_pack_with_endianess { + macro_rules! make_pack_with_endianess { ($T:ty) => { paste::item! { fn [<pack_ $T>]<Endianness>(vm: &VirtualMachine, arg: &PyObjectRef, data: &mut dyn Write) -> PyResult<()> @@ -278,452 +284,455 @@ macro_rules! make_pack_with_endianess { }; } -make_pack_no_endianess!(i8); -make_pack_no_endianess!(u8); -make_pack_with_endianess_int!(i16); -make_pack_with_endianess_int!(u16); -make_pack_with_endianess_int!(i32); -make_pack_with_endianess_int!(u32); -make_pack_with_endianess_int!(i64); -make_pack_with_endianess_int!(u64); -make_pack_with_endianess!(f32); -make_pack_with_endianess!(f64); - -fn pack_bool(vm: &VirtualMachine, arg: &PyObjectRef, data: &mut dyn Write) -> PyResult<()> { - let v = if IntoPyBool::try_from_object(vm, arg.clone())?.to_bool() { - 1 - } else { - 0 - }; - data.write_u8(v).unwrap(); - Ok(()) -} - -fn pack_isize<Endianness>( - vm: &VirtualMachine, - arg: &PyObjectRef, - data: &mut dyn Write, -) -> PyResult<()> -where - Endianness: byteorder::ByteOrder, -{ - let v: isize = get_int_or_index(vm, arg)?; - match std::mem::size_of::<isize>() { - 8 => data.write_i64::<Endianness>(v as i64).unwrap(), - 4 => data.write_i32::<Endianness>(v as i32).unwrap(), - _ => unreachable!("unexpected architecture"), - } - Ok(()) -} - -fn pack_usize<Endianness>( - vm: &VirtualMachine, - arg: &PyObjectRef, - data: &mut dyn Write, -) -> PyResult<()> -where - Endianness: byteorder::ByteOrder, -{ - let v: usize = get_int_or_index(vm, arg)?; - match std::mem::size_of::<usize>() { - 8 => data.write_u64::<Endianness>(v as u64).unwrap(), - 4 => data.write_u32::<Endianness>(v as u32).unwrap(), - _ => unreachable!("unexpected architecture"), - } - Ok(()) -} - -fn pack_string( - vm: &VirtualMachine, - arg: &PyObjectRef, - data: &mut dyn Write, - length: usize, -) -> PyResult<()> { - let mut v = PyBytesRef::try_from_object(vm, arg.clone())? - .get_value() - .to_vec(); - v.resize(length, 0); - match data.write_all(&v) { - Ok(_) => Ok(()), - Err(e) => Err(new_struct_error(vm, format!("{:?}", e))), + make_pack_no_endianess!(i8); + make_pack_no_endianess!(u8); + make_pack_with_endianess_int!(i16); + make_pack_with_endianess_int!(u16); + make_pack_with_endianess_int!(i32); + make_pack_with_endianess_int!(u32); + make_pack_with_endianess_int!(i64); + make_pack_with_endianess_int!(u64); + make_pack_with_endianess!(f32); + make_pack_with_endianess!(f64); + + fn pack_bool(vm: &VirtualMachine, arg: &PyObjectRef, data: &mut dyn Write) -> PyResult<()> { + let v = if IntoPyBool::try_from_object(vm, arg.clone())?.to_bool() { + 1 + } else { + 0 + }; + data.write_u8(v).unwrap(); + Ok(()) } -} -fn pack_pascal( - vm: &VirtualMachine, - arg: &PyObjectRef, - data: &mut dyn Write, - length: usize, -) -> PyResult<()> { - let mut v = PyBytesRef::try_from_object(vm, arg.clone())? - .get_value() - .to_vec(); - let string_length = cmp::min(cmp::min(v.len(), 255), length - 1); - data.write_u8(string_length as u8).unwrap(); - v.resize(length - 1, 0); - match data.write_all(&v) { - Ok(_) => Ok(()), - Err(e) => Err(new_struct_error(vm, format!("{:?}", e))), + fn pack_isize<Endianness>( + vm: &VirtualMachine, + arg: &PyObjectRef, + data: &mut dyn Write, + ) -> PyResult<()> + where + Endianness: byteorder::ByteOrder, + { + let v: isize = get_int_or_index(vm, arg)?; + match std::mem::size_of::<isize>() { + 8 => data.write_i64::<Endianness>(v as i64).unwrap(), + 4 => data.write_i32::<Endianness>(v as i32).unwrap(), + _ => unreachable!("unexpected architecture"), + } + Ok(()) } -} -fn pack_char(vm: &VirtualMachine, arg: &PyObjectRef, data: &mut dyn Write) -> PyResult<()> { - let v = PyBytesRef::try_from_object(vm, arg.clone())?; - if v.len() == 1 { - data.write_u8(v[0]).unwrap(); + fn pack_usize<Endianness>( + vm: &VirtualMachine, + arg: &PyObjectRef, + data: &mut dyn Write, + ) -> PyResult<()> + where + Endianness: byteorder::ByteOrder, + { + let v: usize = get_int_or_index(vm, arg)?; + match std::mem::size_of::<usize>() { + 8 => data.write_u64::<Endianness>(v as u64).unwrap(), + 4 => data.write_u32::<Endianness>(v as u32).unwrap(), + _ => unreachable!("unexpected architecture"), + } Ok(()) - } else { - Err(new_struct_error( - vm, - "char format requires a bytes object of length 1".to_owned(), - )) } -} -fn pack_item<Endianness>( - vm: &VirtualMachine, - code: &FormatCode, - args: &[PyObjectRef], - data: &mut dyn Write, -) -> PyResult<usize> -where - Endianness: byteorder::ByteOrder, -{ - let pack = match code.code { - 'c' => pack_char, - 'b' => pack_i8, - 'B' => pack_u8, - '?' => pack_bool, - 'h' => pack_i16::<Endianness>, - 'H' => pack_u16::<Endianness>, - 'i' | 'l' => pack_i32::<Endianness>, - 'I' | 'L' => pack_u32::<Endianness>, - 'q' => pack_i64::<Endianness>, - 'Q' => pack_u64::<Endianness>, - 'n' => pack_isize::<Endianness>, - 'N' | 'P' => pack_usize::<Endianness>, - 'f' => pack_f32::<Endianness>, - 'd' => pack_f64::<Endianness>, - 's' => { - pack_string(vm, &args[0], data, code.repeat as usize)?; - return Ok(1); - } - 'p' => { - pack_pascal(vm, &args[0], data, code.repeat as usize)?; - return Ok(1); - } - 'x' => { - for _ in 0..code.repeat as usize { - data.write_u8(0).unwrap(); - } - return Ok(0); - } - c => { - panic!("Unsupported format code {:?}", c); + fn pack_string( + vm: &VirtualMachine, + arg: &PyObjectRef, + data: &mut dyn Write, + length: usize, + ) -> PyResult<()> { + let mut v = PyBytesRef::try_from_object(vm, arg.clone())? + .get_value() + .to_vec(); + v.resize(length, 0); + match data.write_all(&v) { + Ok(_) => Ok(()), + Err(e) => Err(new_struct_error(vm, format!("{:?}", e))), } - }; - - for arg in args.iter().take(code.repeat as usize) { - pack(vm, arg, data)?; } - Ok(code.repeat as usize) -} -fn struct_pack(fmt: PyStringRef, args: Args, vm: &VirtualMachine) -> PyResult<Vec<u8>> { - let format_spec = FormatSpec::parse(fmt.as_str()).map_err(|e| new_struct_error(vm, e))?; - format_spec.pack(args.as_ref(), vm) -} + fn pack_pascal( + vm: &VirtualMachine, + arg: &PyObjectRef, + data: &mut dyn Write, + length: usize, + ) -> PyResult<()> { + let mut v = PyBytesRef::try_from_object(vm, arg.clone())? + .get_value() + .to_vec(); + let string_length = std::cmp::min(std::cmp::min(v.len(), 255), length - 1); + data.write_u8(string_length as u8).unwrap(); + v.resize(length - 1, 0); + match data.write_all(&v) { + Ok(_) => Ok(()), + Err(e) => Err(new_struct_error(vm, format!("{:?}", e))), + } + } -#[inline] -fn unpack<F, T, G>(vm: &VirtualMachine, rdr: &mut dyn Read, read: F, transform: G) -> PyResult -where - F: Fn(&mut dyn Read) -> std::io::Result<T>, - G: Fn(T) -> PyResult, -{ - match read(rdr) { - Ok(v) => transform(v), - Err(_) => Err(new_struct_error( - vm, - format!( - "unpack requires a buffer of {} bytes", - std::mem::size_of::<T>() - ), - )), + fn pack_char(vm: &VirtualMachine, arg: &PyObjectRef, data: &mut dyn Write) -> PyResult<()> { + let v = PyBytesRef::try_from_object(vm, arg.clone())?; + if v.len() == 1 { + data.write_u8(v[0]).unwrap(); + Ok(()) + } else { + Err(new_struct_error( + vm, + "char format requires a bytes object of length 1".to_owned(), + )) + } } -} -#[inline] -fn unpack_int<F, T>(vm: &VirtualMachine, rdr: &mut dyn Read, read: F) -> PyResult -where - F: Fn(&mut dyn Read) -> std::io::Result<T>, - T: Into<BigInt> + ToPrimitive, -{ - unpack(vm, rdr, read, |v| Ok(vm.ctx.new_int(v))) -} + fn pack_item<Endianness>( + vm: &VirtualMachine, + code: &FormatCode, + args: &[PyObjectRef], + data: &mut dyn Write, + ) -> PyResult<usize> + where + Endianness: byteorder::ByteOrder, + { + let pack = match code.code { + 'c' => pack_char, + 'b' => pack_i8, + 'B' => pack_u8, + '?' => pack_bool, + 'h' => pack_i16::<Endianness>, + 'H' => pack_u16::<Endianness>, + 'i' | 'l' => pack_i32::<Endianness>, + 'I' | 'L' => pack_u32::<Endianness>, + 'q' => pack_i64::<Endianness>, + 'Q' => pack_u64::<Endianness>, + 'n' => pack_isize::<Endianness>, + 'N' | 'P' => pack_usize::<Endianness>, + 'f' => pack_f32::<Endianness>, + 'd' => pack_f64::<Endianness>, + 's' => { + pack_string(vm, &args[0], data, code.repeat as usize)?; + return Ok(1); + } + 'p' => { + pack_pascal(vm, &args[0], data, code.repeat as usize)?; + return Ok(1); + } + 'x' => { + for _ in 0..code.repeat as usize { + data.write_u8(0).unwrap(); + } + return Ok(0); + } + c => { + panic!("Unsupported format code {:?}", c); + } + }; -fn unpack_bool(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult { - unpack(vm, rdr, |rdr| rdr.read_u8(), |v| Ok(vm.ctx.new_bool(v > 0))) -} + for arg in args.iter().take(code.repeat as usize) { + pack(vm, arg, data)?; + } + Ok(code.repeat as usize) + } -fn unpack_i8(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult { - unpack_int(vm, rdr, |rdr| rdr.read_i8()) -} + #[pyfunction] + fn pack(fmt: PyStringRef, args: Args, vm: &VirtualMachine) -> PyResult<Vec<u8>> { + let format_spec = FormatSpec::parse(fmt.as_str()).map_err(|e| new_struct_error(vm, e))?; + format_spec.pack(args.as_ref(), vm) + } -fn unpack_u8(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult { - unpack_int(vm, rdr, |rdr| rdr.read_u8()) -} + #[inline] + fn _unpack<F, T, G>(vm: &VirtualMachine, rdr: &mut dyn Read, read: F, transform: G) -> PyResult + where + F: Fn(&mut dyn Read) -> std::io::Result<T>, + G: Fn(T) -> PyResult, + { + match read(rdr) { + Ok(v) => transform(v), + Err(_) => Err(new_struct_error( + vm, + format!( + "unpack requires a buffer of {} bytes", + std::mem::size_of::<T>() + ), + )), + } + } -fn unpack_i16<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult -where - Endianness: byteorder::ByteOrder, -{ - unpack_int(vm, rdr, |rdr| rdr.read_i16::<Endianness>()) -} + #[inline] + fn unpack_int<F, T>(vm: &VirtualMachine, rdr: &mut dyn Read, read: F) -> PyResult + where + F: Fn(&mut dyn Read) -> std::io::Result<T>, + T: Into<BigInt> + ToPrimitive, + { + _unpack(vm, rdr, read, |v| Ok(vm.ctx.new_int(v))) + } -fn unpack_u16<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult -where - Endianness: byteorder::ByteOrder, -{ - unpack_int(vm, rdr, |rdr| rdr.read_u16::<Endianness>()) -} + fn unpack_bool(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult { + _unpack(vm, rdr, |rdr| rdr.read_u8(), |v| Ok(vm.ctx.new_bool(v > 0))) + } -fn unpack_i32<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult -where - Endianness: byteorder::ByteOrder, -{ - unpack_int(vm, rdr, |rdr| rdr.read_i32::<Endianness>()) -} + fn unpack_i8(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult { + unpack_int(vm, rdr, |rdr| rdr.read_i8()) + } -fn unpack_u32<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult -where - Endianness: byteorder::ByteOrder, -{ - unpack_int(vm, rdr, |rdr| rdr.read_u32::<Endianness>()) -} + fn unpack_u8(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult { + unpack_int(vm, rdr, |rdr| rdr.read_u8()) + } -fn unpack_i64<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult -where - Endianness: byteorder::ByteOrder, -{ - unpack_int(vm, rdr, |rdr| rdr.read_i64::<Endianness>()) -} + fn unpack_i16<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult + where + Endianness: byteorder::ByteOrder, + { + unpack_int(vm, rdr, |rdr| rdr.read_i16::<Endianness>()) + } -fn unpack_u64<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult -where - Endianness: byteorder::ByteOrder, -{ - unpack_int(vm, rdr, |rdr| rdr.read_u64::<Endianness>()) -} + fn unpack_u16<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult + where + Endianness: byteorder::ByteOrder, + { + unpack_int(vm, rdr, |rdr| rdr.read_u16::<Endianness>()) + } -fn unpack_isize<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult -where - Endianness: byteorder::ByteOrder, -{ - match std::mem::size_of::<isize>() { - 8 => unpack_i64::<Endianness>(vm, rdr), - 4 => unpack_i32::<Endianness>(vm, rdr), - _ => unreachable!("unexpected architecture"), + fn unpack_i32<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult + where + Endianness: byteorder::ByteOrder, + { + unpack_int(vm, rdr, |rdr| rdr.read_i32::<Endianness>()) } -} -fn unpack_usize<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult -where - Endianness: byteorder::ByteOrder, -{ - match std::mem::size_of::<usize>() { - 8 => unpack_u64::<Endianness>(vm, rdr), - 4 => unpack_u32::<Endianness>(vm, rdr), - _ => unreachable!("unexpected architecture"), + fn unpack_u32<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult + where + Endianness: byteorder::ByteOrder, + { + unpack_int(vm, rdr, |rdr| rdr.read_u32::<Endianness>()) } -} -fn unpack_f32<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult -where - Endianness: byteorder::ByteOrder, -{ - unpack( - vm, - rdr, - |rdr| rdr.read_f32::<Endianness>(), - |v| Ok(vm.ctx.new_float(f64::from(v))), - ) -} + fn unpack_i64<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult + where + Endianness: byteorder::ByteOrder, + { + unpack_int(vm, rdr, |rdr| rdr.read_i64::<Endianness>()) + } -fn unpack_f64<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult -where - Endianness: byteorder::ByteOrder, -{ - unpack( - vm, - rdr, - |rdr| rdr.read_f64::<Endianness>(), - |v| Ok(vm.ctx.new_float(v)), - ) -} + fn unpack_u64<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult + where + Endianness: byteorder::ByteOrder, + { + unpack_int(vm, rdr, |rdr| rdr.read_u64::<Endianness>()) + } -fn unpack_empty(_vm: &VirtualMachine, rdr: &mut dyn Read, length: u32) { - let mut handle = rdr.take(length as u64); - let mut buf: Vec<u8> = Vec::new(); - let _ = handle.read_to_end(&mut buf); -} + fn unpack_isize<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult + where + Endianness: byteorder::ByteOrder, + { + match std::mem::size_of::<isize>() { + 8 => unpack_i64::<Endianness>(vm, rdr), + 4 => unpack_i32::<Endianness>(vm, rdr), + _ => unreachable!("unexpected architecture"), + } + } -fn unpack_char(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult { - unpack_string(vm, rdr, 1) -} + fn unpack_usize<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult + where + Endianness: byteorder::ByteOrder, + { + match std::mem::size_of::<usize>() { + 8 => unpack_u64::<Endianness>(vm, rdr), + 4 => unpack_u32::<Endianness>(vm, rdr), + _ => unreachable!("unexpected architecture"), + } + } -fn unpack_string(vm: &VirtualMachine, rdr: &mut dyn Read, length: u32) -> PyResult { - let mut handle = rdr.take(length as u64); - let mut buf: Vec<u8> = Vec::new(); - handle.read_to_end(&mut buf).map_err(|_| { - new_struct_error(vm, format!("unpack requires a buffer of {} bytes", length,)) - })?; - Ok(vm.ctx.new_bytes(buf)) -} + fn unpack_f32<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult + where + Endianness: byteorder::ByteOrder, + { + _unpack( + vm, + rdr, + |rdr| rdr.read_f32::<Endianness>(), + |v| Ok(vm.ctx.new_float(f64::from(v))), + ) + } -fn unpack_pascal(vm: &VirtualMachine, rdr: &mut dyn Read, length: u32) -> PyResult { - let mut handle = rdr.take(length as u64); - let mut buf: Vec<u8> = Vec::new(); - handle.read_to_end(&mut buf).map_err(|_| { - new_struct_error(vm, format!("unpack requires a buffer of {} bytes", length,)) - })?; - let string_length = buf[0] as usize; - Ok(vm.ctx.new_bytes(buf[1..=string_length].to_vec())) -} + fn unpack_f64<Endianness>(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult + where + Endianness: byteorder::ByteOrder, + { + _unpack( + vm, + rdr, + |rdr| rdr.read_f64::<Endianness>(), + |v| Ok(vm.ctx.new_float(v)), + ) + } -fn struct_unpack(fmt: PyStringRef, buffer: PyBytesRef, vm: &VirtualMachine) -> PyResult<PyTuple> { - let fmt_str = fmt.as_str(); - let format_spec = FormatSpec::parse(fmt_str).map_err(|e| new_struct_error(vm, e))?; - format_spec.unpack(buffer.get_value(), vm) -} + fn unpack_empty(_vm: &VirtualMachine, rdr: &mut dyn Read, length: u32) { + let mut handle = rdr.take(length as u64); + let mut buf: Vec<u8> = Vec::new(); + let _ = handle.read_to_end(&mut buf); + } -fn unpack_code<Endianness>( - vm: &VirtualMachine, - code: &FormatCode, - rdr: &mut dyn Read, - items: &mut Vec<PyObjectRef>, -) -> PyResult<()> -where - Endianness: byteorder::ByteOrder, -{ - let unpack = match code.code { - 'b' => unpack_i8, - 'B' => unpack_u8, - 'c' => unpack_char, - '?' => unpack_bool, - 'h' => unpack_i16::<Endianness>, - 'H' => unpack_u16::<Endianness>, - 'i' | 'l' => unpack_i32::<Endianness>, - 'I' | 'L' => unpack_u32::<Endianness>, - 'q' => unpack_i64::<Endianness>, - 'Q' => unpack_u64::<Endianness>, - 'n' => unpack_isize::<Endianness>, - 'N' => unpack_usize::<Endianness>, - 'P' => unpack_usize::<Endianness>, // FIXME: native-only - 'f' => unpack_f32::<Endianness>, - 'd' => unpack_f64::<Endianness>, - 'x' => { - unpack_empty(vm, rdr, code.repeat); - return Ok(()); - } - 's' => { - items.push(unpack_string(vm, rdr, code.repeat)?); - return Ok(()); - } - 'p' => { - items.push(unpack_pascal(vm, rdr, code.repeat)?); - return Ok(()); - } - c => { - panic!("Unsupported format code {:?}", c); - } - }; - for _ in 0..code.repeat { - items.push(unpack(vm, rdr)?); + fn unpack_char(vm: &VirtualMachine, rdr: &mut dyn Read) -> PyResult { + unpack_string(vm, rdr, 1) } - Ok(()) -} -fn struct_calcsize(fmt: Either<PyStringRef, PyBytesRef>, vm: &VirtualMachine) -> PyResult<usize> { - // FIXME: the given fmt must be parsed as ascii string - // https://github.com/RustPython/RustPython/pull/1792#discussion_r387340905 - let parsed = match fmt { - Either::A(string) => FormatSpec::parse(string.as_str()), - Either::B(bytes) => FormatSpec::parse(std::str::from_utf8(&bytes).unwrap()), - }; - let format_spec = parsed.map_err(|e| new_struct_error(vm, e))?; - Ok(format_spec.size()) -} + fn unpack_string(vm: &VirtualMachine, rdr: &mut dyn Read, length: u32) -> PyResult { + let mut handle = rdr.take(length as u64); + let mut buf: Vec<u8> = Vec::new(); + handle.read_to_end(&mut buf).map_err(|_| { + new_struct_error(vm, format!("unpack requires a buffer of {} bytes", length,)) + })?; + Ok(vm.ctx.new_bytes(buf)) + } -#[pyclass(name = "Struct")] -#[derive(Debug)] -struct PyStruct { - spec: FormatSpec, - fmt_str: PyStringRef, -} + fn unpack_pascal(vm: &VirtualMachine, rdr: &mut dyn Read, length: u32) -> PyResult { + let mut handle = rdr.take(length as u64); + let mut buf: Vec<u8> = Vec::new(); + handle.read_to_end(&mut buf).map_err(|_| { + new_struct_error(vm, format!("unpack requires a buffer of {} bytes", length,)) + })?; + let string_length = buf[0] as usize; + Ok(vm.ctx.new_bytes(buf[1..=string_length].to_vec())) + } -impl PyValue for PyStruct { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("_struct", "Struct") + #[pyfunction] + fn unpack(fmt: PyStringRef, buffer: PyBytesRef, vm: &VirtualMachine) -> PyResult<PyTuple> { + let fmt_str = fmt.as_str(); + let format_spec = FormatSpec::parse(fmt_str).map_err(|e| new_struct_error(vm, e))?; + format_spec.unpack(buffer.get_value(), vm) } -} -#[pyimpl] -impl PyStruct { - #[pyslot] - fn tp_new( - cls: PyClassRef, - fmt: Either<PyStringRef, PyBytesRef>, + fn unpack_code<Endianness>( vm: &VirtualMachine, - ) -> PyResult<PyRef<Self>> { - let fmt_str = match fmt { - Either::A(s) => s, - Either::B(b) => PyString::from(std::str::from_utf8(b.get_value()).unwrap()) - .into_ref_with_type(vm, vm.ctx.str_type())?, + code: &FormatCode, + rdr: &mut dyn Read, + items: &mut Vec<PyObjectRef>, + ) -> PyResult<()> + where + Endianness: byteorder::ByteOrder, + { + let unpack = match code.code { + 'b' => unpack_i8, + 'B' => unpack_u8, + 'c' => unpack_char, + '?' => unpack_bool, + 'h' => unpack_i16::<Endianness>, + 'H' => unpack_u16::<Endianness>, + 'i' | 'l' => unpack_i32::<Endianness>, + 'I' | 'L' => unpack_u32::<Endianness>, + 'q' => unpack_i64::<Endianness>, + 'Q' => unpack_u64::<Endianness>, + 'n' => unpack_isize::<Endianness>, + 'N' => unpack_usize::<Endianness>, + 'P' => unpack_usize::<Endianness>, // FIXME: native-only + 'f' => unpack_f32::<Endianness>, + 'd' => unpack_f64::<Endianness>, + 'x' => { + unpack_empty(vm, rdr, code.repeat); + return Ok(()); + } + 's' => { + items.push(unpack_string(vm, rdr, code.repeat)?); + return Ok(()); + } + 'p' => { + items.push(unpack_pascal(vm, rdr, code.repeat)?); + return Ok(()); + } + c => { + panic!("Unsupported format code {:?}", c); + } + }; + for _ in 0..code.repeat { + items.push(unpack(vm, rdr)?); + } + Ok(()) + } + + #[pyfunction] + fn calcsize(fmt: Either<PyStringRef, PyBytesRef>, vm: &VirtualMachine) -> PyResult<usize> { + // FIXME: the given fmt must be parsed as ascii string + // https://github.com/RustPython/RustPython/pull/1792#discussion_r387340905 + let parsed = match fmt { + Either::A(string) => FormatSpec::parse(string.as_str()), + Either::B(bytes) => FormatSpec::parse(std::str::from_utf8(&bytes).unwrap()), }; - let spec = FormatSpec::parse(fmt_str.as_str()).map_err(|e| new_struct_error(vm, e))?; - PyStruct { spec, fmt_str }.into_ref_with_type(vm, cls) + let format_spec = parsed.map_err(|e| new_struct_error(vm, e))?; + Ok(format_spec.size()) } - #[pyproperty] - fn format(&self) -> PyStringRef { - self.fmt_str.clone() + #[pyclass(name = "Struct")] + #[derive(Debug)] + struct PyStruct { + spec: FormatSpec, + fmt_str: PyStringRef, } - #[pyproperty] - fn size(&self) -> usize { - self.spec.size() + + impl PyValue for PyStruct { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("_struct", "Struct") + } } - #[pymethod] - fn pack(&self, args: Args, vm: &VirtualMachine) -> PyResult<Vec<u8>> { - self.spec.pack(args.as_ref(), vm) + #[pyimpl] + impl PyStruct { + #[pyslot] + fn tp_new( + cls: PyClassRef, + fmt: Either<PyStringRef, PyBytesRef>, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let fmt_str = match fmt { + Either::A(s) => s, + Either::B(b) => PyString::from(std::str::from_utf8(b.get_value()).unwrap()) + .into_ref_with_type(vm, vm.ctx.str_type())?, + }; + let spec = FormatSpec::parse(fmt_str.as_str()).map_err(|e| new_struct_error(vm, e))?; + PyStruct { spec, fmt_str }.into_ref_with_type(vm, cls) + } + + #[pyproperty] + fn format(&self) -> PyStringRef { + self.fmt_str.clone() + } + + #[pyproperty] + fn size(&self) -> usize { + self.spec.size() + } + + #[pymethod] + fn pack(&self, args: Args, vm: &VirtualMachine) -> PyResult<Vec<u8>> { + self.spec.pack(args.as_ref(), vm) + } + #[pymethod] + fn unpack(&self, data: PyBytesRef, vm: &VirtualMachine) -> PyResult<PyTuple> { + self.spec.unpack(data.get_value(), vm) + } } - #[pymethod] - fn unpack(&self, data: PyBytesRef, vm: &VirtualMachine) -> PyResult<PyTuple> { - self.spec.unpack(data.get_value(), vm) + + // seems weird that this is part of the "public" API, but whatever + // TODO: implement a format code->spec cache like CPython does? + #[pyfunction] + fn _clearcache() {} + + fn new_struct_error(vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef { + // _struct.error must exist + let class = vm.try_class("_struct", "error").unwrap(); + vm.new_exception_msg(class, msg) } } -// seems weird that this is part of the "public" API, but whatever -// TODO: implement a format code->spec cache like CPython does? -fn clearcache() {} - -pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { +pub(crate) fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; let struct_error = ctx.new_class("struct.error", ctx.exceptions.exception_type.clone()); - py_module!(vm, "_struct", { - "_clearcache" => ctx.new_function(clearcache), - "pack" => ctx.new_function(struct_pack), - "unpack" => ctx.new_function(struct_unpack), - "calcsize" => ctx.new_function(struct_calcsize), + let module = _struct::make_module(vm); + extend_module!(vm, module, { "error" => struct_error, - "Struct" => PyStruct::make_class(ctx), - }) -} - -fn new_struct_error(vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef { - // _struct.error must exist - let class = vm.try_class("_struct", "error").unwrap(); - vm.new_exception_msg(class, msg) + }); + module } diff --git a/vm/src/stdlib/random.rs b/vm/src/stdlib/random.rs index 7079ad5642..63144451cc 100644 --- a/vm/src/stdlib/random.rs +++ b/vm/src/stdlib/random.rs @@ -1,127 +1,123 @@ //! Random module. -use std::cell::RefCell; - -use num_bigint::{BigInt, Sign}; -use num_traits::Signed; -use rand::RngCore; - -use crate::function::OptionalOption; -use crate::obj::objint::PyIntRef; -use crate::obj::objtype::PyClassRef; -use crate::pyobject::{PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue}; -use crate::VirtualMachine; - -#[derive(Debug)] -enum PyRng { - Std(rand::rngs::ThreadRng), - MT(Box<mt19937::MT19937>), -} - -impl Default for PyRng { - fn default() -> Self { - PyRng::Std(rand::thread_rng()) +pub(crate) use _random::make_module; + +#[pymodule] +mod _random { + use crate::function::OptionalOption; + use crate::obj::objint::PyIntRef; + use crate::obj::objtype::PyClassRef; + use crate::pyobject::{PyClassImpl, PyRef, PyResult, PyValue}; + use crate::VirtualMachine; + use num_bigint::{BigInt, Sign}; + use num_traits::Signed; + use rand::RngCore; + use std::cell::RefCell; + + #[derive(Debug)] + enum PyRng { + Std(rand::rngs::ThreadRng), + MT(Box<mt19937::MT19937>), } -} -impl RngCore for PyRng { - fn next_u32(&mut self) -> u32 { - match self { - Self::Std(s) => s.next_u32(), - Self::MT(m) => m.next_u32(), + impl Default for PyRng { + fn default() -> Self { + PyRng::Std(rand::thread_rng()) } } - fn next_u64(&mut self) -> u64 { - match self { - Self::Std(s) => s.next_u64(), - Self::MT(m) => m.next_u64(), + + impl RngCore for PyRng { + fn next_u32(&mut self) -> u32 { + match self { + Self::Std(s) => s.next_u32(), + Self::MT(m) => m.next_u32(), + } } - } - fn fill_bytes(&mut self, dest: &mut [u8]) { - match self { - Self::Std(s) => s.fill_bytes(dest), - Self::MT(m) => m.fill_bytes(dest), + fn next_u64(&mut self) -> u64 { + match self { + Self::Std(s) => s.next_u64(), + Self::MT(m) => m.next_u64(), + } } - } - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> { - match self { - Self::Std(s) => s.try_fill_bytes(dest), - Self::MT(m) => m.try_fill_bytes(dest), + fn fill_bytes(&mut self, dest: &mut [u8]) { + match self { + Self::Std(s) => s.fill_bytes(dest), + Self::MT(m) => m.fill_bytes(dest), + } + } + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> { + match self { + Self::Std(s) => s.try_fill_bytes(dest), + Self::MT(m) => m.try_fill_bytes(dest), + } } } -} -#[pyclass(name = "Random")] -#[derive(Debug)] -struct PyRandom { - rng: RefCell<PyRng>, -} - -impl PyValue for PyRandom { - fn class(vm: &VirtualMachine) -> PyClassRef { - vm.class("_random", "Random") + #[pyclass(name = "Random")] + #[derive(Debug)] + struct PyRandom { + rng: RefCell<PyRng>, } -} -#[pyimpl(flags(BASETYPE))] -impl PyRandom { - #[pyslot(new)] - fn new(cls: PyClassRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> { - PyRandom { - rng: RefCell::new(PyRng::default()), + impl PyValue for PyRandom { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("_random", "Random") } - .into_ref_with_type(vm, cls) } - #[pymethod] - fn random(&self) -> f64 { - mt19937::gen_res53(&mut *self.rng.borrow_mut()) - } + #[pyimpl(flags(BASETYPE))] + impl PyRandom { + #[pyslot(new)] + fn new(cls: PyClassRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> { + PyRandom { + rng: RefCell::new(PyRng::default()), + } + .into_ref_with_type(vm, cls) + } + + #[pymethod] + fn random(&self) -> f64 { + mt19937::gen_res53(&mut *self.rng.borrow_mut()) + } - #[pymethod] - fn seed(&self, n: OptionalOption<PyIntRef>) { - let new_rng = match n.flat_option() { - None => PyRng::default(), - Some(n) => { - let (_, mut key) = n.as_bigint().abs().to_u32_digits(); - if cfg!(target_endian = "big") { - key.reverse(); + #[pymethod] + fn seed(&self, n: OptionalOption<PyIntRef>) { + let new_rng = match n.flat_option() { + None => PyRng::default(), + Some(n) => { + let (_, mut key) = n.as_bigint().abs().to_u32_digits(); + if cfg!(target_endian = "big") { + key.reverse(); + } + PyRng::MT(Box::new(mt19937::MT19937::new_with_slice_seed(&key))) } - PyRng::MT(Box::new(mt19937::MT19937::new_with_slice_seed(&key))) - } - }; + }; - *self.rng.borrow_mut() = new_rng; - } + *self.rng.borrow_mut() = new_rng; + } - #[pymethod] - fn getrandbits(&self, mut k: usize) -> BigInt { - let mut rng = self.rng.borrow_mut(); + #[pymethod] + fn getrandbits(&self, mut k: usize) -> BigInt { + let mut rng = self.rng.borrow_mut(); - let mut gen_u32 = |k| rng.next_u32() >> (32 - k) as u32; + let mut gen_u32 = |k| rng.next_u32() >> (32 - k) as u32; - if k <= 32 { - return gen_u32(k).into(); - } + if k <= 32 { + return gen_u32(k).into(); + } - let words = (k - 1) / 8 + 1; - let mut wordarray = vec![0u32; words]; + let words = (k - 1) / 8 + 1; + let mut wordarray = vec![0u32; words]; - let it = wordarray.iter_mut(); - #[cfg(target_endian = "big")] - let it = it.rev(); - for word in it { - *word = gen_u32(k); - k -= 32; - } + let it = wordarray.iter_mut(); + #[cfg(target_endian = "big")] + let it = it.rev(); + for word in it { + *word = gen_u32(k); + k -= 32; + } - BigInt::from_slice(Sign::NoSign, &wordarray) + BigInt::from_slice(Sign::NoSign, &wordarray) + } } } - -pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { - let ctx = &vm.ctx; - py_module!(vm, "_random", { - "Random" => PyRandom::make_class(ctx), - }) -}