From c2ee9ca3e02810c76f82bfc32a6643f2b0af0b84 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Wed, 30 Sep 2020 21:20:13 -0500 Subject: [PATCH 001/893] WIP - native _sre --- constants.rs | 114 ++++++++++++++++++++++++++++++++++++++++++++++ interp.rs | 126 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+) create mode 100644 constants.rs create mode 100644 interp.rs diff --git a/constants.rs b/constants.rs new file mode 100644 index 0000000000..f6aeb3182f --- /dev/null +++ b/constants.rs @@ -0,0 +1,114 @@ +/* + * Secret Labs' Regular Expression Engine + * + * regular expression matching engine + * + * NOTE: This file is generated by sre_constants.py. If you need + * to change anything in here, edit sre_constants.py and run it. + * + * Copyright (c) 1997-2001 by Secret Labs AB. All rights reserved. + * + * See the _sre.c file for information on usage and redistribution. + */ + +use bitflags::bitflags; + +pub const SRE_MAGIC: usize = 20140917; +#[derive(num_enum::TryFromPrimitive, Debug)] +#[repr(u32)] +#[allow(non_camel_case_types)] +pub enum SreOpcode { + FAILURE = 0, + SUCCESS = 1, + ANY = 2, + ANY_ALL = 3, + ASSERT = 4, + ASSERT_NOT = 5, + AT = 6, + BRANCH = 7, + CALL = 8, + CATEGORY = 9, + CHARSET = 10, + BIGCHARSET = 11, + GROUPREF = 12, + GROUPREF_EXISTS = 13, + GROUPREF_IGNORE = 14, + IN = 15, + IN_IGNORE = 16, + INFO = 17, + JUMP = 18, + LITERAL = 19, + LITERAL_IGNORE = 20, + MARK = 21, + MAX_UNTIL = 22, + MIN_UNTIL = 23, + NOT_LITERAL = 24, + NOT_LITERAL_IGNORE = 25, + NEGATE = 26, + RANGE = 27, + REPEAT = 28, + REPEAT_ONE = 29, + SUBPATTERN = 30, + MIN_REPEAT_ONE = 31, + RANGE_IGNORE = 32, +} +#[derive(num_enum::TryFromPrimitive, Debug)] +#[repr(u32)] +#[allow(non_camel_case_types)] +pub enum SreAtCode { + BEGINNING = 0, + BEGINNING_LINE = 1, + BEGINNING_STRING = 2, + BOUNDARY = 3, + NON_BOUNDARY = 4, + END = 5, + END_LINE = 6, + END_STRING = 7, + LOC_BOUNDARY = 8, + LOC_NON_BOUNDARY = 9, + UNI_BOUNDARY = 10, + UNI_NON_BOUNDARY = 11, +} +#[derive(num_enum::TryFromPrimitive, Debug)] +#[repr(u32)] +#[allow(non_camel_case_types)] +pub enum SreCatCode { + DIGIT = 0, + NOT_DIGIT = 1, + SPACE = 2, + NOT_SPACE = 3, + WORD = 4, + NOT_WORD = 5, + LINEBREAK = 6, + NOT_LINEBREAK = 7, + LOC_WORD = 8, + LOC_NOT_WORD = 9, + UNI_DIGIT = 10, + UNI_NOT_DIGIT = 11, + UNI_SPACE = 12, + UNI_NOT_SPACE = 13, + UNI_WORD = 14, + UNI_NOT_WORD = 15, + UNI_LINEBREAK = 16, + UNI_NOT_LINEBREAK = 17, +} +bitflags! { + pub struct SreFlag: u16 { + const TEMPLATE = 1; + const IGNORECASE = 2; + const LOCALE = 4; + const MULTILINE = 8; + const DOTALL = 16; + const UNICODE = 32; + const VERBOSE = 64; + const DEBUG = 128; + const ASCII = 256; + } +} +bitflags! { + pub struct SreInfo: u32 { + const PREFIX = 1; + const LITERAL = 2; + const CHARSET = 4; + } +} diff --git a/interp.rs b/interp.rs new file mode 100644 index 0000000000..7f93a82eb4 --- /dev/null +++ b/interp.rs @@ -0,0 +1,126 @@ +// good luck to those that follow; here be dragons + +use crate::builtins::PyStrRef; + +use super::constants::{SreFlag, SreOpcode}; + +use std::convert::TryFrom; +use std::{iter, slice}; + +pub struct State { + start: usize, + s_pos: usize, + end: usize, + pos: usize, + flags: SreFlag, + marks: Vec, + lastindex: isize, + marks_stack: Vec, + context_stack: Vec, + repeat: Option, + s: PyStrRef, +} + +// struct State1<'a> { +// state: &'a mut State, +// } + +struct MatchContext { + s_pos: usize, + code_pos: usize, +} + +// struct Context<'a> { +// context_stack: &mut Vec, +// } + +impl State { + pub fn new(s: PyStrRef, start: usize, end: usize, flags: SreFlag) -> Self { + let end = std::cmp::min(end, s.char_len()); + Self { + start, + s_pos: start, + end, + pos: start, + flags, + marks: Vec::new(), + lastindex: -1, + marks_stack: Vec::new(), + context_stack: Vec::new(), + repeat: None, + s, + } + } +} + +// struct OpcodeDispatcher { +// executing_contexts: HashMap>, +// } + +pub struct BadSreCode; + +pub fn parse_ops(code: &[u32]) -> impl Iterator> + '_ { + let mut it = code.iter().copied(); + std::iter::from_fn(move || -> Option> { + let op = it.next()?; + let op = SreOpcode::try_from(op) + .ok() + .and_then(|op| extract_code(op, &mut it)); + Some(op) + }) + .map(|x| x.ok_or(BadSreCode)) +} + +type It<'a> = iter::Copied>; +fn extract_code(op: SreOpcode, it: &mut It) -> Option { + let skip = |it: &mut It| { + let skip = it.next()? as usize; + if skip > it.len() { + None + } else { + Some(skip) + } + }; + match op { + SreOpcode::FAILURE => {} + SreOpcode::SUCCESS => {} + SreOpcode::ANY => {} + SreOpcode::ANY_ALL => {} + SreOpcode::ASSERT => {} + SreOpcode::ASSERT_NOT => {} + SreOpcode::AT => {} + SreOpcode::BRANCH => {} + SreOpcode::CALL => {} + SreOpcode::CATEGORY => {} + SreOpcode::CHARSET => {} + SreOpcode::BIGCHARSET => {} + SreOpcode::GROUPREF => {} + SreOpcode::GROUPREF_EXISTS => {} + SreOpcode::GROUPREF_IGNORE => {} + SreOpcode::IN => {} + SreOpcode::IN_IGNORE => {} + SreOpcode::INFO => { + // let skip = it.next()?; + } + SreOpcode::JUMP => {} + SreOpcode::LITERAL => {} + SreOpcode::LITERAL_IGNORE => {} + SreOpcode::MARK => {} + SreOpcode::MAX_UNTIL => {} + SreOpcode::MIN_UNTIL => {} + SreOpcode::NOT_LITERAL => {} + SreOpcode::NOT_LITERAL_IGNORE => {} + SreOpcode::NEGATE => {} + SreOpcode::RANGE => {} + SreOpcode::REPEAT => {} + SreOpcode::REPEAT_ONE => {} + SreOpcode::SUBPATTERN => {} + SreOpcode::MIN_REPEAT_ONE => {} + SreOpcode::RANGE_IGNORE => {} + } + todo!() +} + +pub enum Op { + Info {}, +} From e1362ead3c8c7462b0027a7ecb4b9bbd23e76b5a Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 1 Dec 2020 17:51:11 +0200 Subject: [PATCH 002/893] WIP structure --- constants.rs | 4 + interp.rs | 467 +++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 379 insertions(+), 92 deletions(-) diff --git a/constants.rs b/constants.rs index f6aeb3182f..a80534d70b 100644 --- a/constants.rs +++ b/constants.rs @@ -14,6 +14,10 @@ use bitflags::bitflags; pub const SRE_MAGIC: usize = 20140917; +pub const SRE_CODESIZE: usize = 4; +pub const SRE_MAXREPEAT: usize = usize::max_value(); +pub const SRE_MAXGROUPS: usize = usize::max_value() / std::mem::size_of::() / 2; + #[derive(num_enum::TryFromPrimitive, Debug)] #[repr(u32)] #[allow(non_camel_case_types)] diff --git a/interp.rs b/interp.rs index 7f93a82eb4..43c2a54515 100644 --- a/interp.rs +++ b/interp.rs @@ -1,126 +1,409 @@ // good luck to those that follow; here be dragons +use super::constants::{SreFlag, SreOpcode, SRE_MAXREPEAT}; use crate::builtins::PyStrRef; - -use super::constants::{SreFlag, SreOpcode}; - +use rustpython_common::borrow::BorrowValue; +use std::collections::HashMap; use std::convert::TryFrom; -use std::{iter, slice}; -pub struct State { +pub struct State<'a> { + // py_string: PyStrRef, + string: &'a str, start: usize, - s_pos: usize, end: usize, - pos: usize, flags: SreFlag, + pattern_codes: Vec, marks: Vec, lastindex: isize, marks_stack: Vec, context_stack: Vec, repeat: Option, - s: PyStrRef, + string_position: usize, } -// struct State1<'a> { -// state: &'a mut State, -// } - -struct MatchContext { - s_pos: usize, - code_pos: usize, -} - -// struct Context<'a> { -// context_stack: &mut Vec, -// } - -impl State { - pub fn new(s: PyStrRef, start: usize, end: usize, flags: SreFlag) -> Self { - let end = std::cmp::min(end, s.char_len()); +impl<'a> State<'a> { + pub(crate) fn new( + // py_string: PyStrRef, + string: &'a str, + start: usize, + end: usize, + flags: SreFlag, + pattern_codes: Vec, + ) -> Self { + // let string = py_string.borrow_value(); Self { + // py_string, + string, start, - s_pos: start, end, - pos: start, flags, - marks: Vec::new(), + pattern_codes, lastindex: -1, marks_stack: Vec::new(), context_stack: Vec::new(), repeat: None, - s, + marks: Vec::new(), + string_position: start, } } + + fn reset(&mut self) { + self.marks.clear(); + self.lastindex = -1; + self.marks_stack.clear(); + self.context_stack.clear(); + self.repeat = None; + } } -// struct OpcodeDispatcher { -// executing_contexts: HashMap>, -// } +pub(crate) fn pymatch(mut state: State) -> bool { + let ctx = MatchContext { + string_position: state.start, + code_position: 0, + has_matched: None, + }; + state.context_stack.push(ctx); -pub struct BadSreCode; + let mut has_matched = None; + loop { + if state.context_stack.is_empty() { + break; + } + let ctx_id = state.context_stack.len() - 1; + let mut drive = MatchContextDrive::drive(ctx_id, state); + let mut dispatcher = OpcodeDispatcher::new(); -pub fn parse_ops(code: &[u32]) -> impl Iterator> + '_ { - let mut it = code.iter().copied(); - std::iter::from_fn(move || -> Option> { - let op = it.next()?; - let op = SreOpcode::try_from(op) - .ok() - .and_then(|op| extract_code(op, &mut it)); - Some(op) - }) - .map(|x| x.ok_or(BadSreCode)) + has_matched = dispatcher.pymatch(&mut drive); + state = drive.take(); + if has_matched.is_some() { + state.context_stack.pop(); + } + } + has_matched.unwrap_or(false) } -type It<'a> = iter::Copied>; -fn extract_code(op: SreOpcode, it: &mut It) -> Option { - let skip = |it: &mut It| { - let skip = it.next()? as usize; - if skip > it.len() { - None - } else { - Some(skip) +#[derive(Debug, Copy, Clone)] +struct MatchContext { + string_position: usize, + code_position: usize, + has_matched: Option, +} + +struct MatchContextDrive<'a> { + state: State<'a>, + ctx_id: usize, +} + +impl<'a> MatchContextDrive<'a> { + fn id(&self) -> usize { + self.ctx_id + } + fn ctx_mut(&mut self) -> &mut MatchContext { + &mut self.state.context_stack[self.ctx_id] + } + fn ctx(&self) -> &MatchContext { + &self.state.context_stack[self.ctx_id] + } + fn push_new_context(&mut self, pattern_offset: usize) -> usize { + let ctx = self.ctx(); + let child_ctx = MatchContext { + string_position: ctx.string_position, + code_position: ctx.code_position + pattern_offset, + has_matched: None, + }; + self.state.context_stack.push(child_ctx); + self.state.context_stack.len() - 1 + } + fn drive(ctx_id: usize, state: State<'a>) -> Self { + Self { state, ctx_id } + } + fn take(self) -> State<'a> { + self.state + } + fn str(&self) -> &str { + unsafe { + std::str::from_utf8_unchecked( + &self.state.string.as_bytes()[self.ctx().string_position..], + ) + } + } + fn peek_char(&self) -> char { + self.str().chars().next().unwrap() + } + fn peek_code(&self, peek: usize) -> u32 { + self.state.pattern_codes[self.ctx().code_position + peek] + } + fn skip_char(&mut self, skip_count: usize) { + let skipped = self.str().char_indices().nth(skip_count).unwrap().0; + self.ctx_mut().string_position += skipped; + } + fn skip_code(&mut self, skip_count: usize) { + self.ctx_mut().code_position += skip_count; + } + fn remaining_chars(&self) -> usize { + let end = self.state.end; + end - self.ctx().string_position + self.str().len() + } + fn remaining_codes(&self) -> usize { + self.state.pattern_codes.len() - self.ctx().code_position + } + fn at_beginning(&self) -> bool { + self.ctx().string_position == 0 + } + fn at_end(&self) -> bool { + self.str().is_empty() + } + fn at_linebreak(&self) -> bool { + match self.str().chars().next() { + Some(c) => c == '\n', + None => false, } + } +} + +struct OpcodeDispatcher { + executing_contexts: HashMap>, +} + +macro_rules! once { + ($val:expr) => { + Box::new(OpEmpty {}) }; - match op { - SreOpcode::FAILURE => {} - SreOpcode::SUCCESS => {} - SreOpcode::ANY => {} - SreOpcode::ANY_ALL => {} - SreOpcode::ASSERT => {} - SreOpcode::ASSERT_NOT => {} - SreOpcode::AT => {} - SreOpcode::BRANCH => {} - SreOpcode::CALL => {} - SreOpcode::CATEGORY => {} - SreOpcode::CHARSET => {} - SreOpcode::BIGCHARSET => {} - SreOpcode::GROUPREF => {} - SreOpcode::GROUPREF_EXISTS => {} - SreOpcode::GROUPREF_IGNORE => {} - SreOpcode::IN => {} - SreOpcode::IN_IGNORE => {} - SreOpcode::INFO => { - // let skip = it.next()?; - } - SreOpcode::JUMP => {} - SreOpcode::LITERAL => {} - SreOpcode::LITERAL_IGNORE => {} - SreOpcode::MARK => {} - SreOpcode::MAX_UNTIL => {} - SreOpcode::MIN_UNTIL => {} - SreOpcode::NOT_LITERAL => {} - SreOpcode::NOT_LITERAL_IGNORE => {} - SreOpcode::NEGATE => {} - SreOpcode::RANGE => {} - SreOpcode::REPEAT => {} - SreOpcode::REPEAT_ONE => {} - SreOpcode::SUBPATTERN => {} - SreOpcode::MIN_REPEAT_ONE => {} - SreOpcode::RANGE_IGNORE => {} - } - todo!() -} - -pub enum Op { - Info {}, +} + +trait OpcodeExecutor { + fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()>; +} + +struct OpFailure {} +impl OpcodeExecutor for OpFailure { + fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + drive.ctx_mut().has_matched = Some(false); + None + } +} + +struct OpEmpty {} +impl OpcodeExecutor for OpEmpty { + fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + None + } +} + +struct OpOnce { + f: Option, +} +impl OpcodeExecutor for OpOnce { + fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + let f = self.f.take()?; + f(drive); + None + } +} +fn once(f: F) -> Box> { + Box::new(OpOnce { f: Some(f) }) +} + +struct OpMinRepeatOne { + trace_id: usize, + mincount: usize, + maxcount: usize, + count: usize, + child_ctx_id: usize, +} +impl OpcodeExecutor for OpMinRepeatOne { + fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + match self.trace_id { + 0 => self._0(drive), + _ => unreachable!(), + } + } +} +impl Default for OpMinRepeatOne { + fn default() -> Self { + OpMinRepeatOne { + trace_id: 0, + mincount: 0, + maxcount: 0, + count: 0, + child_ctx_id: 0, + } + } +} +impl OpMinRepeatOne { + fn _0(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + self.mincount = drive.peek_code(2) as usize; + self.maxcount = drive.peek_code(3) as usize; + + if drive.remaining_chars() < self.mincount { + drive.ctx_mut().has_matched = Some(false); + return None; + } + + drive.state.string_position = drive.ctx().string_position; + + self.count = if self.mincount == 0 { + 0 + } else { + let count = count_repetitions(drive, self.mincount); + if count < self.mincount { + drive.ctx_mut().has_matched = Some(false); + return None; + } + drive.skip_char(count); + count + }; + + if drive.peek_code(drive.peek_code(1) as usize + 1) == SreOpcode::SUCCESS as u32 { + drive.state.string_position = drive.ctx().string_position; + drive.ctx_mut().has_matched = Some(true); + return None; + } + + // mark push + self.trace_id = 1; + self._1(drive) + } + fn _1(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + if self.maxcount == SRE_MAXREPEAT || self.count <= self.maxcount { + drive.state.string_position = drive.ctx().string_position; + self.child_ctx_id = drive.push_new_context(drive.peek_code(1) as usize + 1); + self.trace_id = 2; + return Some(()); + } + + // mark discard + drive.ctx_mut().has_matched = Some(false); + None + } + fn _2(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + if let Some(true) = drive.state.context_stack[self.child_ctx_id].has_matched { + drive.ctx_mut().has_matched = Some(true); + return None; + } + drive.state.string_position = drive.ctx().string_position; + if count_repetitions(drive, 1) == 0 { + self.trace_id = 3; + return self._3(drive); + } + drive.skip_char(1); + self.count += 1; + // marks pop keep + self.trace_id = 1; + self._1(drive) + } + fn _3(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + // mark discard + drive.ctx_mut().has_matched = Some(false); + None + } +} + +impl OpcodeDispatcher { + fn new() -> Self { + Self { + executing_contexts: HashMap::new(), + } + } + // Returns True if the current context matches, False if it doesn't and + // None if matching is not finished, ie must be resumed after child + // contexts have been matched. + fn pymatch(&mut self, drive: &mut MatchContextDrive) -> Option { + while drive.remaining_codes() > 0 && drive.ctx().has_matched.is_none() { + let code = drive.peek_code(0); + let opcode = SreOpcode::try_from(code).unwrap(); + self.dispatch(opcode, drive); + // self.drive = self.drive; + } + match drive.ctx().has_matched { + Some(matched) => Some(matched), + None => { + drive.ctx_mut().has_matched = Some(false); + Some(false) + } + } + } + + // Dispatches a context on a given opcode. Returns True if the context + // is done matching, False if it must be resumed when next encountered. + fn dispatch(&mut self, opcode: SreOpcode, drive: &mut MatchContextDrive) -> bool { + let mut executor = match self.executing_contexts.remove_entry(&drive.id()) { + Some((_, mut executor)) => executor, + None => self.dispatch_table(opcode, drive), + }; + if let Some(()) = executor.next(drive) { + self.executing_contexts.insert(drive.id(), executor); + false + } else { + true + } + } + + fn dispatch_table( + &mut self, + opcode: SreOpcode, + drive: &mut MatchContextDrive, + ) -> Box { + // move || { + match opcode { + SreOpcode::FAILURE => { + Box::new(OpFailure {}) + } + SreOpcode::SUCCESS => once(|drive| { + drive.state.string_position = drive.ctx().string_position; + drive.ctx_mut().has_matched = Some(true); + }), + SreOpcode::ANY => once!(true), + SreOpcode::ANY_ALL => once!(true), + SreOpcode::ASSERT => once!(true), + SreOpcode::ASSERT_NOT => once!(true), + SreOpcode::AT => once!(true), + SreOpcode::BRANCH => once!(true), + SreOpcode::CALL => once!(true), + SreOpcode::CATEGORY => once!(true), + SreOpcode::CHARSET => once!(true), + SreOpcode::BIGCHARSET => once!(true), + SreOpcode::GROUPREF => once!(true), + SreOpcode::GROUPREF_EXISTS => once!(true), + SreOpcode::GROUPREF_IGNORE => once!(true), + SreOpcode::IN => once!(true), + SreOpcode::IN_IGNORE => once!(true), + SreOpcode::INFO => once!(true), + SreOpcode::JUMP => once!(true), + SreOpcode::LITERAL => { + if drive.at_end() || drive.peek_char() as u32 != drive.peek_code(1) { + drive.ctx_mut().has_matched = Some(false); + } else { + drive.skip_char(1); + } + drive.skip_code(2); + once!(true) + } + SreOpcode::LITERAL_IGNORE => once!(true), + SreOpcode::MARK => once!(true), + SreOpcode::MAX_UNTIL => once!(true), + SreOpcode::MIN_UNTIL => once!(true), + SreOpcode::NOT_LITERAL => once!(true), + SreOpcode::NOT_LITERAL_IGNORE => once!(true), + SreOpcode::NEGATE => once!(true), + SreOpcode::RANGE => once!(true), + SreOpcode::REPEAT => once!(true), + SreOpcode::REPEAT_ONE => once!(true), + SreOpcode::SUBPATTERN => once!(true), + SreOpcode::MIN_REPEAT_ONE => Box::new(OpMinRepeatOne::default()), + SreOpcode::RANGE_IGNORE => once!(true), + } + } +} + +// Returns the number of repetitions of a single item, starting from the +// current string position. The code pointer is expected to point to a +// REPEAT_ONE operation (with the repeated 4 ahead). +fn count_repetitions(drive: &mut MatchContextDrive, maxcount: usize) -> usize { + let mut count = 0; + let mut real_maxcount = drive.state.end - drive.ctx().string_position; + if maxcount < real_maxcount && maxcount != SRE_MAXREPEAT { + real_maxcount = maxcount; + } + count } From 82922bf0d796f6c79f5799a1b14b9d1ad9c26431 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 23 Dec 2020 09:01:50 +0200 Subject: [PATCH 003/893] upgrade re version; implement helper functions; --- constants.rs | 52 ++-- interp.rs | 717 ++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 588 insertions(+), 181 deletions(-) diff --git a/constants.rs b/constants.rs index a80534d70b..f5ab92c531 100644 --- a/constants.rs +++ b/constants.rs @@ -13,11 +13,7 @@ use bitflags::bitflags; -pub const SRE_MAGIC: usize = 20140917; -pub const SRE_CODESIZE: usize = 4; -pub const SRE_MAXREPEAT: usize = usize::max_value(); -pub const SRE_MAXGROUPS: usize = usize::max_value() / std::mem::size_of::() / 2; - +pub const SRE_MAGIC: usize = 20171005; #[derive(num_enum::TryFromPrimitive, Debug)] #[repr(u32)] #[allow(non_camel_case_types)] @@ -36,25 +32,33 @@ pub enum SreOpcode { BIGCHARSET = 11, GROUPREF = 12, GROUPREF_EXISTS = 13, - GROUPREF_IGNORE = 14, - IN = 15, - IN_IGNORE = 16, - INFO = 17, - JUMP = 18, - LITERAL = 19, - LITERAL_IGNORE = 20, - MARK = 21, - MAX_UNTIL = 22, - MIN_UNTIL = 23, - NOT_LITERAL = 24, - NOT_LITERAL_IGNORE = 25, - NEGATE = 26, - RANGE = 27, - REPEAT = 28, - REPEAT_ONE = 29, - SUBPATTERN = 30, - MIN_REPEAT_ONE = 31, - RANGE_IGNORE = 32, + IN = 14, + INFO = 15, + JUMP = 16, + LITERAL = 17, + MARK = 18, + MAX_UNTIL = 19, + MIN_UNTIL = 20, + NOT_LITERAL = 21, + NEGATE = 22, + RANGE = 23, + REPEAT = 24, + REPEAT_ONE = 25, + SUBPATTERN = 26, + MIN_REPEAT_ONE = 27, + GROUPREF_IGNORE = 28, + IN_IGNORE = 29, + LITERAL_IGNORE = 30, + NOT_LITERAL_IGNORE = 31, + GROUPREF_LOC_IGNORE = 32, + IN_LOC_IGNORE = 33, + LITERAL_LOC_IGNORE = 34, + NOT_LITERAL_LOC_IGNORE = 35, + GROUPREF_UNI_IGNORE = 36, + IN_UNI_IGNORE = 37, + LITERAL_UNI_IGNORE = 38, + NOT_LITERAL_UNI_IGNORE = 39, + RANGE_UNI_IGNORE = 40, } #[derive(num_enum::TryFromPrimitive, Debug)] #[repr(u32)] diff --git a/interp.rs b/interp.rs index 43c2a54515..bc6753a5fe 100644 --- a/interp.rs +++ b/interp.rs @@ -1,14 +1,14 @@ // good luck to those that follow; here be dragons -use super::constants::{SreFlag, SreOpcode, SRE_MAXREPEAT}; -use crate::builtins::PyStrRef; -use rustpython_common::borrow::BorrowValue; +use super::_sre::MAXREPEAT; +use super::constants::{SreAtCode, SreCatCode, SreFlag, SreOpcode}; use std::collections::HashMap; use std::convert::TryFrom; pub struct State<'a> { - // py_string: PyStrRef, string: &'a str, + // chars count + string_len: usize, start: usize, end: usize, flags: SreFlag, @@ -23,17 +23,18 @@ pub struct State<'a> { impl<'a> State<'a> { pub(crate) fn new( - // py_string: PyStrRef, string: &'a str, start: usize, end: usize, flags: SreFlag, pattern_codes: Vec, ) -> Self { - // let string = py_string.borrow_value(); + let string_len = string.chars().count(); + let end = std::cmp::min(end, string_len); + let start = std::cmp::min(start, end); Self { - // py_string, string, + string_len, start, end, flags, @@ -59,6 +60,7 @@ impl<'a> State<'a> { pub(crate) fn pymatch(mut state: State) -> bool { let ctx = MatchContext { string_position: state.start, + string_offset: state.string.char_indices().nth(state.start).unwrap().0, code_position: 0, has_matched: None, }; @@ -85,6 +87,7 @@ pub(crate) fn pymatch(mut state: State) -> bool { #[derive(Debug, Copy, Clone)] struct MatchContext { string_position: usize, + string_offset: usize, code_position: usize, has_matched: Option, } @@ -108,6 +111,7 @@ impl<'a> MatchContextDrive<'a> { let ctx = self.ctx(); let child_ctx = MatchContext { string_position: ctx.string_position, + string_offset: ctx.string_offset, code_position: ctx.code_position + pattern_offset, has_matched: None, }; @@ -122,9 +126,7 @@ impl<'a> MatchContextDrive<'a> { } fn str(&self) -> &str { unsafe { - std::str::from_utf8_unchecked( - &self.state.string.as_bytes()[self.ctx().string_position..], - ) + std::str::from_utf8_unchecked(&self.state.string.as_bytes()[self.ctx().string_offset..]) } } fn peek_char(&self) -> char { @@ -135,61 +137,90 @@ impl<'a> MatchContextDrive<'a> { } fn skip_char(&mut self, skip_count: usize) { let skipped = self.str().char_indices().nth(skip_count).unwrap().0; - self.ctx_mut().string_position += skipped; + self.ctx_mut().string_position += skip_count; + self.ctx_mut().string_offset += skipped; } fn skip_code(&mut self, skip_count: usize) { self.ctx_mut().code_position += skip_count; } fn remaining_chars(&self) -> usize { - let end = self.state.end; - end - self.ctx().string_position + self.str().len() + self.state.end - self.ctx().string_position } fn remaining_codes(&self) -> usize { self.state.pattern_codes.len() - self.ctx().code_position } fn at_beginning(&self) -> bool { - self.ctx().string_position == 0 + self.ctx().string_position == self.state.start } fn at_end(&self) -> bool { - self.str().is_empty() + self.ctx().string_position == self.state.end } fn at_linebreak(&self) -> bool { - match self.str().chars().next() { - Some(c) => c == '\n', - None => false, + !self.at_end() && is_linebreak(self.peek_char()) + } + fn at_boundary bool>(&self, mut word_checker: F) -> bool { + if self.at_beginning() && self.at_end() { + return false; + } + let that = !self.at_beginning() && word_checker(self.back_peek_char()); + let this = !self.at_end() && word_checker(self.peek_char()); + this != that + } + fn back_peek_offset(&self) -> usize { + let bytes = self.state.string.as_bytes(); + let mut offset = self.ctx().string_offset - 1; + if !is_utf8_first_byte(bytes[offset]) { + offset -= 1; + if !is_utf8_first_byte(bytes[offset]) { + offset -= 1; + if !is_utf8_first_byte(bytes[offset]) { + offset -= 1; + if !is_utf8_first_byte(bytes[offset]) { + panic!("not utf-8 code point"); + } + } + } + } + offset + } + fn back_peek_char(&self) -> char { + let bytes = self.state.string.as_bytes(); + let offset = self.back_peek_offset(); + let current_offset = self.ctx().string_offset; + let code = match current_offset - offset { + 1 => u32::from_ne_bytes([0, 0, 0, bytes[offset]]), + 2 => u32::from_ne_bytes([0, 0, bytes[offset], bytes[offset + 1]]), + 3 => u32::from_ne_bytes([0, bytes[offset], bytes[offset + 1], bytes[offset + 2]]), + 4 => u32::from_ne_bytes([ + bytes[offset], + bytes[offset + 1], + bytes[offset + 2], + bytes[offset + 3], + ]), + _ => unreachable!(), + }; + unsafe { std::mem::transmute(code) } + } + fn back_skip_char(&mut self, skip_count: usize) { + self.ctx_mut().string_position -= skip_count; + for _ in 0..skip_count { + self.ctx_mut().string_offset = self.back_peek_offset(); } } -} - -struct OpcodeDispatcher { - executing_contexts: HashMap>, -} - -macro_rules! once { - ($val:expr) => { - Box::new(OpEmpty {}) - }; } trait OpcodeExecutor { fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()>; } -struct OpFailure {} -impl OpcodeExecutor for OpFailure { +struct OpUnimplemented {} +impl OpcodeExecutor for OpUnimplemented { fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()> { drive.ctx_mut().has_matched = Some(false); None } } -struct OpEmpty {} -impl OpcodeExecutor for OpEmpty { - fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()> { - None - } -} - struct OpOnce { f: Option, } @@ -204,6 +235,10 @@ fn once(f: F) -> Box> { Box::new(OpOnce { f: Some(f) }) } +fn unimplemented() -> Box { + Box::new(OpUnimplemented {}) +} + struct OpMinRepeatOne { trace_id: usize, mincount: usize, @@ -213,10 +248,11 @@ struct OpMinRepeatOne { } impl OpcodeExecutor for OpMinRepeatOne { fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()> { - match self.trace_id { - 0 => self._0(drive), - _ => unreachable!(), - } + None + // match self.trace_id { + // 0 => self._0(drive), + // _ => unreachable!(), + // } } } impl Default for OpMinRepeatOne { @@ -230,75 +266,78 @@ impl Default for OpMinRepeatOne { } } } -impl OpMinRepeatOne { - fn _0(&mut self, drive: &mut MatchContextDrive) -> Option<()> { - self.mincount = drive.peek_code(2) as usize; - self.maxcount = drive.peek_code(3) as usize; +// impl OpMinRepeatOne { +// fn _0(&mut self, drive: &mut MatchContextDrive) -> Option<()> { +// self.mincount = drive.peek_code(2) as usize; +// self.maxcount = drive.peek_code(3) as usize; - if drive.remaining_chars() < self.mincount { - drive.ctx_mut().has_matched = Some(false); - return None; - } +// if drive.remaining_chars() < self.mincount { +// drive.ctx_mut().has_matched = Some(false); +// return None; +// } - drive.state.string_position = drive.ctx().string_position; +// drive.state.string_position = drive.ctx().string_position; - self.count = if self.mincount == 0 { - 0 - } else { - let count = count_repetitions(drive, self.mincount); - if count < self.mincount { - drive.ctx_mut().has_matched = Some(false); - return None; - } - drive.skip_char(count); - count - }; +// self.count = if self.mincount == 0 { +// 0 +// } else { +// let count = count_repetitions(drive, self.mincount); +// if count < self.mincount { +// drive.ctx_mut().has_matched = Some(false); +// return None; +// } +// drive.skip_char(count); +// count +// }; - if drive.peek_code(drive.peek_code(1) as usize + 1) == SreOpcode::SUCCESS as u32 { - drive.state.string_position = drive.ctx().string_position; - drive.ctx_mut().has_matched = Some(true); - return None; - } +// if drive.peek_code(drive.peek_code(1) as usize + 1) == SreOpcode::SUCCESS as u32 { +// drive.state.string_position = drive.ctx().string_position; +// drive.ctx_mut().has_matched = Some(true); +// return None; +// } - // mark push - self.trace_id = 1; - self._1(drive) - } - fn _1(&mut self, drive: &mut MatchContextDrive) -> Option<()> { - if self.maxcount == SRE_MAXREPEAT || self.count <= self.maxcount { - drive.state.string_position = drive.ctx().string_position; - self.child_ctx_id = drive.push_new_context(drive.peek_code(1) as usize + 1); - self.trace_id = 2; - return Some(()); - } +// // mark push +// self.trace_id = 1; +// self._1(drive) +// } +// fn _1(&mut self, drive: &mut MatchContextDrive) -> Option<()> { +// if self.maxcount == SRE_MAXREPEAT || self.count <= self.maxcount { +// drive.state.string_position = drive.ctx().string_position; +// self.child_ctx_id = drive.push_new_context(drive.peek_code(1) as usize + 1); +// self.trace_id = 2; +// return Some(()); +// } - // mark discard - drive.ctx_mut().has_matched = Some(false); - None - } - fn _2(&mut self, drive: &mut MatchContextDrive) -> Option<()> { - if let Some(true) = drive.state.context_stack[self.child_ctx_id].has_matched { - drive.ctx_mut().has_matched = Some(true); - return None; - } - drive.state.string_position = drive.ctx().string_position; - if count_repetitions(drive, 1) == 0 { - self.trace_id = 3; - return self._3(drive); - } - drive.skip_char(1); - self.count += 1; - // marks pop keep - self.trace_id = 1; - self._1(drive) - } - fn _3(&mut self, drive: &mut MatchContextDrive) -> Option<()> { - // mark discard - drive.ctx_mut().has_matched = Some(false); - None - } -} +// // mark discard +// drive.ctx_mut().has_matched = Some(false); +// None +// } +// fn _2(&mut self, drive: &mut MatchContextDrive) -> Option<()> { +// if let Some(true) = drive.state.context_stack[self.child_ctx_id].has_matched { +// drive.ctx_mut().has_matched = Some(true); +// return None; +// } +// drive.state.string_position = drive.ctx().string_position; +// if count_repetitions(drive, 1) == 0 { +// self.trace_id = 3; +// return self._3(drive); +// } +// drive.skip_char(1); +// self.count += 1; +// // marks pop keep +// self.trace_id = 1; +// self._1(drive) +// } +// fn _3(&mut self, drive: &mut MatchContextDrive) -> Option<()> { +// // mark discard +// drive.ctx_mut().has_matched = Some(false); +// None +// } +// } +struct OpcodeDispatcher { + executing_contexts: HashMap>, +} impl OpcodeDispatcher { fn new() -> Self { Self { @@ -313,7 +352,6 @@ impl OpcodeDispatcher { let code = drive.peek_code(0); let opcode = SreOpcode::try_from(code).unwrap(); self.dispatch(opcode, drive); - // self.drive = self.drive; } match drive.ctx().has_matched { Some(matched) => Some(matched), @@ -328,8 +366,8 @@ impl OpcodeDispatcher { // is done matching, False if it must be resumed when next encountered. fn dispatch(&mut self, opcode: SreOpcode, drive: &mut MatchContextDrive) -> bool { let mut executor = match self.executing_contexts.remove_entry(&drive.id()) { - Some((_, mut executor)) => executor, - None => self.dispatch_table(opcode, drive), + Some((_, executor)) => executor, + None => self.dispatch_table(opcode), }; if let Some(()) = executor.next(drive) { self.executing_contexts.insert(drive.id(), executor); @@ -339,71 +377,436 @@ impl OpcodeDispatcher { } } - fn dispatch_table( - &mut self, - opcode: SreOpcode, - drive: &mut MatchContextDrive, - ) -> Box { - // move || { + fn dispatch_table(&mut self, opcode: SreOpcode) -> Box { match opcode { - SreOpcode::FAILURE => { - Box::new(OpFailure {}) - } + SreOpcode::FAILURE => once(|drive| { + drive.ctx_mut().has_matched = Some(false); + }), SreOpcode::SUCCESS => once(|drive| { drive.state.string_position = drive.ctx().string_position; drive.ctx_mut().has_matched = Some(true); }), - SreOpcode::ANY => once!(true), - SreOpcode::ANY_ALL => once!(true), - SreOpcode::ASSERT => once!(true), - SreOpcode::ASSERT_NOT => once!(true), - SreOpcode::AT => once!(true), - SreOpcode::BRANCH => once!(true), - SreOpcode::CALL => once!(true), - SreOpcode::CATEGORY => once!(true), - SreOpcode::CHARSET => once!(true), - SreOpcode::BIGCHARSET => once!(true), - SreOpcode::GROUPREF => once!(true), - SreOpcode::GROUPREF_EXISTS => once!(true), - SreOpcode::GROUPREF_IGNORE => once!(true), - SreOpcode::IN => once!(true), - SreOpcode::IN_IGNORE => once!(true), - SreOpcode::INFO => once!(true), - SreOpcode::JUMP => once!(true), - SreOpcode::LITERAL => { - if drive.at_end() || drive.peek_char() as u32 != drive.peek_code(1) { + SreOpcode::ANY => once(|drive| { + if drive.at_end() || drive.at_linebreak() { drive.ctx_mut().has_matched = Some(false); } else { + drive.skip_code(1); drive.skip_char(1); } + }), + SreOpcode::ANY_ALL => once(|drive| { + if drive.at_end() { + drive.ctx_mut().has_matched = Some(false); + } else { + drive.skip_code(1); + drive.skip_char(1); + } + }), + SreOpcode::ASSERT => Box::new(OpAssert::default()), + SreOpcode::ASSERT_NOT => unimplemented(), + SreOpcode::AT => once(|drive| { + let atcode = SreAtCode::try_from(drive.peek_code(1)).unwrap(); + if !at(drive, atcode) { + drive.ctx_mut().has_matched = Some(false); + } else { + drive.skip_code(2); + } + }), + SreOpcode::BRANCH => unimplemented(), + SreOpcode::CALL => unimplemented(), + SreOpcode::CATEGORY => unimplemented(), + SreOpcode::CHARSET => unimplemented(), + SreOpcode::BIGCHARSET => unimplemented(), + SreOpcode::GROUPREF => unimplemented(), + SreOpcode::GROUPREF_EXISTS => unimplemented(), + SreOpcode::GROUPREF_IGNORE => unimplemented(), + SreOpcode::IN => unimplemented(), + SreOpcode::IN_IGNORE => unimplemented(), + SreOpcode::INFO | SreOpcode::JUMP => once(|drive| { + drive.skip_code(drive.peek_code(1) as usize + 1); + }), + SreOpcode::LITERAL => once(|drive| { + if drive.at_end() || drive.peek_char() as u32 != drive.peek_code(1) { + drive.ctx_mut().has_matched = Some(false); + } drive.skip_code(2); - once!(true) - } - SreOpcode::LITERAL_IGNORE => once!(true), - SreOpcode::MARK => once!(true), - SreOpcode::MAX_UNTIL => once!(true), - SreOpcode::MIN_UNTIL => once!(true), - SreOpcode::NOT_LITERAL => once!(true), - SreOpcode::NOT_LITERAL_IGNORE => once!(true), - SreOpcode::NEGATE => once!(true), - SreOpcode::RANGE => once!(true), - SreOpcode::REPEAT => once!(true), - SreOpcode::REPEAT_ONE => once!(true), - SreOpcode::SUBPATTERN => once!(true), + drive.skip_char(1); + }), + SreOpcode::LITERAL_IGNORE => once(|drive| { + let code = drive.peek_code(1); + let c = drive.peek_char(); + if drive.at_end() + || (c.to_ascii_lowercase() as u32 != code + && c.to_ascii_uppercase() as u32 != code) + { + drive.ctx_mut().has_matched = Some(false); + } + drive.skip_code(2); + drive.skip_char(1); + }), + SreOpcode::MARK => unimplemented(), + SreOpcode::MAX_UNTIL => unimplemented(), + SreOpcode::MIN_UNTIL => unimplemented(), + SreOpcode::NOT_LITERAL => once(|drive| { + if drive.at_end() || drive.peek_char() as u32 == drive.peek_code(1) { + drive.ctx_mut().has_matched = Some(false); + } + drive.skip_code(2); + drive.skip_char(1); + }), + SreOpcode::NOT_LITERAL_IGNORE => once(|drive| { + let code = drive.peek_code(1); + let c = drive.peek_char(); + if drive.at_end() + || (c.to_ascii_lowercase() as u32 == code + || c.to_ascii_uppercase() as u32 == code) + { + drive.ctx_mut().has_matched = Some(false); + } + drive.skip_code(2); + drive.skip_char(1); + }), + SreOpcode::NEGATE => unimplemented(), + SreOpcode::RANGE => unimplemented(), + SreOpcode::REPEAT => unimplemented(), + SreOpcode::REPEAT_ONE => unimplemented(), + SreOpcode::SUBPATTERN => unimplemented(), SreOpcode::MIN_REPEAT_ONE => Box::new(OpMinRepeatOne::default()), - SreOpcode::RANGE_IGNORE => once!(true), + SreOpcode::GROUPREF_LOC_IGNORE => unimplemented(), + SreOpcode::IN_LOC_IGNORE => unimplemented(), + SreOpcode::LITERAL_LOC_IGNORE => unimplemented(), + SreOpcode::NOT_LITERAL_LOC_IGNORE => unimplemented(), + SreOpcode::GROUPREF_UNI_IGNORE => unimplemented(), + SreOpcode::IN_UNI_IGNORE => unimplemented(), + SreOpcode::LITERAL_UNI_IGNORE => unimplemented(), + SreOpcode::NOT_LITERAL_UNI_IGNORE => unimplemented(), + SreOpcode::RANGE_UNI_IGNORE => unimplemented(), } } + + // Returns the number of repetitions of a single item, starting from the + // current string position. The code pointer is expected to point to a + // REPEAT_ONE operation (with the repeated 4 ahead). + fn count_repetitions(&mut self, drive: &mut MatchContextDrive, maxcount: usize) -> usize { + let mut count = 0; + let mut real_maxcount = drive.remaining_chars(); + if maxcount < real_maxcount && maxcount != MAXREPEAT { + real_maxcount = maxcount; + } + let code_position = drive.ctx().code_position; + let string_position = drive.ctx().string_position; + drive.skip_code(4); + let reset_position = drive.ctx().code_position; + while count < real_maxcount { + drive.ctx_mut().code_position = reset_position; + let opcode = SreOpcode::try_from(drive.peek_code(1)).unwrap(); + self.dispatch(opcode, drive); + if drive.ctx().has_matched == Some(false) { + break; + } + count += 1; + } + drive.ctx_mut().has_matched = None; + drive.ctx_mut().code_position = code_position; + drive.ctx_mut().string_position = string_position; + count + } +} + +fn at(drive: &mut MatchContextDrive, atcode: SreAtCode) -> bool { + match atcode { + SreAtCode::BEGINNING | SreAtCode::BEGINNING_STRING => drive.at_beginning(), + SreAtCode::BEGINNING_LINE => drive.at_beginning() || is_linebreak(drive.back_peek_char()), + SreAtCode::BOUNDARY => drive.at_boundary(is_word), + SreAtCode::NON_BOUNDARY => !drive.at_boundary(is_word), + SreAtCode::END => (drive.remaining_chars() == 1 && drive.at_linebreak()) || drive.at_end(), + SreAtCode::END_LINE => drive.at_linebreak() || drive.at_end(), + SreAtCode::END_STRING => drive.at_end(), + SreAtCode::LOC_BOUNDARY => drive.at_boundary(is_loc_word), + SreAtCode::LOC_NON_BOUNDARY => !drive.at_boundary(is_loc_word), + SreAtCode::UNI_BOUNDARY => drive.at_boundary(is_uni_word), + SreAtCode::UNI_NON_BOUNDARY => !drive.at_boundary(is_uni_word), + } +} + +fn category(catcode: SreCatCode, c: char) -> bool { + match catcode { + SreCatCode::DIGIT => is_digit(c), + SreCatCode::NOT_DIGIT => !is_digit(c), + SreCatCode::SPACE => is_space(c), + SreCatCode::NOT_SPACE => !is_space(c), + SreCatCode::WORD => is_word(c), + SreCatCode::NOT_WORD => !is_word(c), + SreCatCode::LINEBREAK => is_linebreak(c), + SreCatCode::NOT_LINEBREAK => !is_linebreak(c), + SreCatCode::LOC_WORD => is_loc_word(c), + SreCatCode::LOC_NOT_WORD => !is_loc_word(c), + SreCatCode::UNI_DIGIT => is_uni_digit(c), + SreCatCode::UNI_NOT_DIGIT => !is_uni_digit(c), + SreCatCode::UNI_SPACE => is_uni_space(c), + SreCatCode::UNI_NOT_SPACE => !is_uni_space(c), + SreCatCode::UNI_WORD => is_uni_word(c), + SreCatCode::UNI_NOT_WORD => !is_uni_word(c), + SreCatCode::UNI_LINEBREAK => is_uni_linebreak(c), + SreCatCode::UNI_NOT_LINEBREAK => !is_uni_linebreak(c), + } +} + +fn charset(set: &[u32], c: char) -> bool { + /* check if character is a member of the given set */ + let ch = c as u32; + let mut ok = true; + let mut i = 0; + while i < set.len() { + let opcode = match SreOpcode::try_from(set[i]) { + Ok(code) => code, + Err(_) => { + break; + } + }; + match opcode { + SreOpcode::FAILURE => { + return !ok; + } + SreOpcode::CATEGORY => { + /* */ + let catcode = match SreCatCode::try_from(set[i + 1]) { + Ok(code) => code, + Err(_) => { + break; + } + }; + if category(catcode, c) { + return ok; + } + i += 2; + } + SreOpcode::CHARSET => { + /* */ + if ch < 256 && (set[(ch / 32) as usize] & (1 << (32 - 1))) != 0 { + return ok; + } + i += 8; + } + SreOpcode::BIGCHARSET => { + /* <256 blockindices> */ + let count = set[i + 1]; + if ch < 0x10000 { + let blockindices: &[u8] = unsafe { std::mem::transmute(&set[i + 2..]) }; + let block = blockindices[(ch >> 8) as usize]; + if set[2 + 64 + ((block as u32 * 256 + (ch & 255)) / 32) as usize] + & (1 << (ch & (32 - 1))) + != 0 + { + return ok; + } + } + i += 2 + 64 + count as usize * 8; + } + SreOpcode::LITERAL => { + /* */ + if ch == set[i + 1] { + return ok; + } + i += 2; + } + SreOpcode::NEGATE => { + ok = !ok; + i += 1; + } + SreOpcode::RANGE => { + /* */ + if set[i + 1] <= ch && ch <= set[i + 2] { + return ok; + } + i += 3; + } + SreOpcode::RANGE_UNI_IGNORE => { + /* */ + if set[i + 1] <= ch && ch <= set[i + 2] { + return ok; + } + let ch = upper_unicode(c) as u32; + if set[i + 1] <= ch && ch <= set[i + 2] { + return ok; + } + i += 3; + } + _ => { + break; + } + } + } + /* internal error -- there's not much we can do about it + here, so let's just pretend it didn't match... */ + false +} + +fn count(drive: MatchContextDrive, maxcount: usize) -> usize { + let string_position = drive.state.string_position; + let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); + + let opcode = SreOpcode::try_from(drive.peek_code(1)).unwrap(); + match opcode { + SreOpcode::FAILURE => {} + SreOpcode::SUCCESS => {} + SreOpcode::ANY => {} + SreOpcode::ANY_ALL => {} + SreOpcode::ASSERT => {} + SreOpcode::ASSERT_NOT => {} + SreOpcode::AT => {} + SreOpcode::BRANCH => {} + SreOpcode::CALL => {} + SreOpcode::CATEGORY => {} + SreOpcode::CHARSET => {} + SreOpcode::BIGCHARSET => {} + SreOpcode::GROUPREF => {} + SreOpcode::GROUPREF_EXISTS => {} + SreOpcode::IN => { + } + SreOpcode::INFO => {} + SreOpcode::JUMP => {} + SreOpcode::LITERAL => {} + SreOpcode::MARK => {} + SreOpcode::MAX_UNTIL => {} + SreOpcode::MIN_UNTIL => {} + SreOpcode::NOT_LITERAL => {} + SreOpcode::NEGATE => {} + SreOpcode::RANGE => {} + SreOpcode::REPEAT => {} + SreOpcode::REPEAT_ONE => {} + SreOpcode::SUBPATTERN => {} + SreOpcode::MIN_REPEAT_ONE => {} + SreOpcode::GROUPREF_IGNORE => {} + SreOpcode::IN_IGNORE => {} + SreOpcode::LITERAL_IGNORE => {} + SreOpcode::NOT_LITERAL_IGNORE => {} + SreOpcode::GROUPREF_LOC_IGNORE => {} + SreOpcode::IN_LOC_IGNORE => {} + SreOpcode::LITERAL_LOC_IGNORE => {} + SreOpcode::NOT_LITERAL_LOC_IGNORE => {} + SreOpcode::GROUPREF_UNI_IGNORE => {} + SreOpcode::IN_UNI_IGNORE => {} + SreOpcode::LITERAL_UNI_IGNORE => {} + SreOpcode::NOT_LITERAL_UNI_IGNORE => {} + SreOpcode::RANGE_UNI_IGNORE => {} + } +} + +fn eq_loc_ignore(code: u32, c: char) -> bool { + code == c as u32 || code == lower_locate(c) as u32 || code == upper_locate(c) as u32 } -// Returns the number of repetitions of a single item, starting from the -// current string position. The code pointer is expected to point to a -// REPEAT_ONE operation (with the repeated 4 ahead). -fn count_repetitions(drive: &mut MatchContextDrive, maxcount: usize) -> usize { - let mut count = 0; - let mut real_maxcount = drive.state.end - drive.ctx().string_position; - if maxcount < real_maxcount && maxcount != SRE_MAXREPEAT { - real_maxcount = maxcount; +fn is_word(c: char) -> bool { + c.is_ascii_alphanumeric() || c == '_' +} +fn is_space(c: char) -> bool { + c.is_ascii_whitespace() +} +fn is_digit(c: char) -> bool { + c.is_ascii_digit() +} +fn is_loc_alnum(c: char) -> bool { + // TODO: check with cpython + c.is_alphanumeric() +} +fn is_loc_word(c: char) -> bool { + is_loc_alnum(c) || c == '_' +} +fn is_linebreak(c: char) -> bool { + c == '\n' +} +pub(crate) fn lower_ascii(c: char) -> char { + c.to_ascii_lowercase() +} +fn lower_locate(c: char) -> char { + // TODO: check with cpython + // https://doc.rust-lang.org/std/primitive.char.html#method.to_lowercase + c.to_lowercase().next().unwrap() +} +fn upper_locate(c: char) -> char { + // TODO: check with cpython + // https://doc.rust-lang.org/std/primitive.char.html#method.to_uppercase + c.to_uppercase().next().unwrap() +} +fn is_uni_digit(c: char) -> bool { + // TODO: check with cpython + c.is_digit(10) +} +fn is_uni_space(c: char) -> bool { + // TODO: check with cpython + c.is_whitespace() +} +fn is_uni_linebreak(c: char) -> bool { + match c { + '\u{000A}' | '\u{000B}' | '\u{000C}' | '\u{000D}' | '\u{001C}' | '\u{001D}' + | '\u{001E}' | '\u{0085}' | '\u{2028}' | '\u{2029}' => true, + _ => false, + } +} +fn is_uni_alnum(c: char) -> bool { + // TODO: check with cpython + c.is_alphanumeric() +} +fn is_uni_word(c: char) -> bool { + is_uni_alnum(c) || c == '_' +} +pub(crate) fn lower_unicode(c: char) -> char { + // TODO: check with cpython + c.to_lowercase().next().unwrap() +} +pub(crate) fn upper_unicode(c: char) -> char { + // TODO: check with cpython + c.to_uppercase().next().unwrap() +} + +fn is_utf8_first_byte(b: u8) -> bool { + // In UTF-8, there are three kinds of byte... + // 0xxxxxxx : ASCII + // 10xxxxxx : 2nd, 3rd or 4th byte of code + // 11xxxxxx : 1st byte of multibyte code + (b & 0b10000000 == 0) || (b & 0b11000000 == 0b11000000) +} + +struct OpAssert { + child_ctx_id: usize, + jump_id: usize, +} +impl Default for OpAssert { + fn default() -> Self { + OpAssert { + child_ctx_id: 0, + jump_id: 0, + } + } +} +impl OpcodeExecutor for OpAssert { + fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + match self.jump_id { + 0 => self._0(drive), + 1 => self._1(drive), + _ => unreachable!(), + } + } +} +impl OpAssert { + fn _0(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + let back = drive.peek_code(2) as usize; + if back > drive.ctx().string_position { + drive.ctx_mut().has_matched = Some(false); + return None; + } + drive.state.string_position = drive.ctx().string_position - back; + self.child_ctx_id = drive.push_new_context(3); + self.jump_id = 1; + Some(()) + } + fn _1(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + if drive.state.context_stack[self.child_ctx_id].has_matched == Some(true) { + drive.skip_code(drive.peek_code(1) as usize + 1); + } else { + drive.ctx_mut().has_matched = Some(false); + } + None } - count } From 4e03fb361f731116d392cfdda0820353454fcaac Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 25 Dec 2020 16:31:22 +0200 Subject: [PATCH 004/893] upgrade sre_parse.py; impl marks count --- interp.rs | 633 +++++++++++++++++++++++++++++++----------------------- 1 file changed, 365 insertions(+), 268 deletions(-) diff --git a/interp.rs b/interp.rs index bc6753a5fe..8e0968cbf5 100644 --- a/interp.rs +++ b/interp.rs @@ -5,6 +5,7 @@ use super::constants::{SreAtCode, SreCatCode, SreFlag, SreOpcode}; use std::collections::HashMap; use std::convert::TryFrom; +#[derive(Debug)] pub struct State<'a> { string: &'a str, // chars count @@ -13,9 +14,9 @@ pub struct State<'a> { end: usize, flags: SreFlag, pattern_codes: Vec, - marks: Vec, + marks: Vec>, lastindex: isize, - marks_stack: Vec, + marks_stack: Vec<(Vec>, isize)>, context_stack: Vec, repeat: Option, string_position: usize, @@ -55,12 +56,47 @@ impl<'a> State<'a> { self.context_stack.clear(); self.repeat = None; } + + fn set_mark(&mut self, mark_nr: usize, position: usize) { + if mark_nr & 1 != 0 { + self.lastindex = mark_nr as isize / 2 + 1; + } + if mark_nr >= self.marks.len() { + self.marks.resize(mark_nr + 1, None); + } + self.marks[mark_nr] = Some(position); + } + fn get_marks(&self, group_index: usize) -> (Option, Option) { + let marks_index = 2 * group_index; + if marks_index + 1 < self.marks.len() { + (self.marks[marks_index], self.marks[marks_index + 1]) + } else { + (None, None) + } + } + fn marks_push(&mut self) { + self.marks_stack.push(self.marks.clone(), self.lastindex); + } + fn marks_pop(&mut self) { + (self.marks, self.lastindex) = self.marks_stack.pop().unwrap(); + } + fn marks_pop_keep(&mut self) { + (self.marks, self.lastindex) = self.marks_stack.last().unwrap(); + } + fn marks_pop_discard(&mut self) { + self.marks_stack.pop(); + } } pub(crate) fn pymatch(mut state: State) -> bool { let ctx = MatchContext { string_position: state.start, - string_offset: state.string.char_indices().nth(state.start).unwrap().0, + string_offset: state + .string + .char_indices() + .nth(state.start) + .map(|x| x.0) + .unwrap_or(0), code_position: 0, has_matched: None, }; @@ -72,7 +108,7 @@ pub(crate) fn pymatch(mut state: State) -> bool { break; } let ctx_id = state.context_stack.len() - 1; - let mut drive = MatchContextDrive::drive(ctx_id, state); + let mut drive = StackDrive::drive(ctx_id, state); let mut dispatcher = OpcodeDispatcher::new(); has_matched = dispatcher.pymatch(&mut drive); @@ -92,68 +128,52 @@ struct MatchContext { has_matched: Option, } -struct MatchContextDrive<'a> { - state: State<'a>, - ctx_id: usize, -} - -impl<'a> MatchContextDrive<'a> { - fn id(&self) -> usize { - self.ctx_id - } - fn ctx_mut(&mut self) -> &mut MatchContext { - &mut self.state.context_stack[self.ctx_id] - } - fn ctx(&self) -> &MatchContext { - &self.state.context_stack[self.ctx_id] - } - fn push_new_context(&mut self, pattern_offset: usize) -> usize { - let ctx = self.ctx(); - let child_ctx = MatchContext { - string_position: ctx.string_position, - string_offset: ctx.string_offset, - code_position: ctx.code_position + pattern_offset, - has_matched: None, - }; - self.state.context_stack.push(child_ctx); - self.state.context_stack.len() - 1 - } - fn drive(ctx_id: usize, state: State<'a>) -> Self { - Self { state, ctx_id } - } - fn take(self) -> State<'a> { - self.state - } +trait MatchContextDrive { + fn ctx_mut(&mut self) -> &mut MatchContext; + fn ctx(&self) -> &MatchContext; + fn state(&self) -> &State; fn str(&self) -> &str { unsafe { - std::str::from_utf8_unchecked(&self.state.string.as_bytes()[self.ctx().string_offset..]) + std::str::from_utf8_unchecked( + &self.state().string.as_bytes()[self.ctx().string_offset..], + ) } } + fn pattern(&self) -> &[u32] { + &self.state().pattern_codes[self.ctx().code_position..] + } fn peek_char(&self) -> char { self.str().chars().next().unwrap() } fn peek_code(&self, peek: usize) -> u32 { - self.state.pattern_codes[self.ctx().code_position + peek] + self.state().pattern_codes[self.ctx().code_position + peek] } fn skip_char(&mut self, skip_count: usize) { - let skipped = self.str().char_indices().nth(skip_count).unwrap().0; - self.ctx_mut().string_position += skip_count; - self.ctx_mut().string_offset += skipped; + match self.str().char_indices().nth(skip_count).map(|x| x.0) { + Some(skipped) => { + self.ctx_mut().string_position += skip_count; + self.ctx_mut().string_offset += skipped; + } + None => { + self.ctx_mut().string_position = self.state().end; + self.ctx_mut().string_offset = self.state().string.len(); // bytes len + } + } } fn skip_code(&mut self, skip_count: usize) { self.ctx_mut().code_position += skip_count; } fn remaining_chars(&self) -> usize { - self.state.end - self.ctx().string_position + self.state().end - self.ctx().string_position } fn remaining_codes(&self) -> usize { - self.state.pattern_codes.len() - self.ctx().code_position + self.state().pattern_codes.len() - self.ctx().code_position } fn at_beginning(&self) -> bool { - self.ctx().string_position == self.state.start + self.ctx().string_position == self.state().start } fn at_end(&self) -> bool { - self.ctx().string_position == self.state.end + self.ctx().string_position == self.state().end } fn at_linebreak(&self) -> bool { !self.at_end() && is_linebreak(self.peek_char()) @@ -167,7 +187,7 @@ impl<'a> MatchContextDrive<'a> { this != that } fn back_peek_offset(&self) -> usize { - let bytes = self.state.string.as_bytes(); + let bytes = self.state().string.as_bytes(); let mut offset = self.ctx().string_offset - 1; if !is_utf8_first_byte(bytes[offset]) { offset -= 1; @@ -184,7 +204,7 @@ impl<'a> MatchContextDrive<'a> { offset } fn back_peek_char(&self) -> char { - let bytes = self.state.string.as_bytes(); + let bytes = self.state().string.as_bytes(); let offset = self.back_peek_offset(); let current_offset = self.ctx().string_offset; let code = match current_offset - offset { @@ -209,13 +229,74 @@ impl<'a> MatchContextDrive<'a> { } } +struct StackDrive<'a> { + state: State<'a>, + ctx_id: usize, +} +impl<'a> StackDrive<'a> { + fn id(&self) -> usize { + self.ctx_id + } + fn drive(ctx_id: usize, state: State<'a>) -> Self { + Self { state, ctx_id } + } + fn take(self) -> State<'a> { + self.state + } + fn push_new_context(&mut self, pattern_offset: usize) -> usize { + let ctx = self.ctx(); + let child_ctx = MatchContext { + string_position: ctx.string_position, + string_offset: ctx.string_offset, + code_position: ctx.code_position + pattern_offset, + has_matched: None, + }; + self.state.context_stack.push(child_ctx); + self.state.context_stack.len() - 1 + } +} +impl MatchContextDrive for StackDrive<'_> { + fn ctx_mut(&mut self) -> &mut MatchContext { + &mut self.state.context_stack[self.ctx_id] + } + fn ctx(&self) -> &MatchContext { + &self.state.context_stack[self.ctx_id] + } + fn state(&self) -> &State { + &self.state + } +} + +struct WrapDrive<'a> { + stack_drive: &'a StackDrive<'a>, + ctx: MatchContext, +} +impl<'a> WrapDrive<'a> { + fn drive(ctx: MatchContext, stack_drive: &'a StackDrive<'a>) -> Self { + Self { stack_drive, ctx } + } +} +impl MatchContextDrive for WrapDrive<'_> { + fn ctx_mut(&mut self) -> &mut MatchContext { + &mut self.ctx + } + + fn ctx(&self) -> &MatchContext { + &self.ctx + } + + fn state(&self) -> &State { + self.stack_drive.state() + } +} + trait OpcodeExecutor { - fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()>; + fn next(&mut self, drive: &mut StackDrive) -> Option<()>; } struct OpUnimplemented {} impl OpcodeExecutor for OpUnimplemented { - fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + fn next(&mut self, drive: &mut StackDrive) -> Option<()> { drive.ctx_mut().has_matched = Some(false); None } @@ -224,14 +305,14 @@ impl OpcodeExecutor for OpUnimplemented { struct OpOnce { f: Option, } -impl OpcodeExecutor for OpOnce { - fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()> { +impl OpcodeExecutor for OpOnce { + fn next(&mut self, drive: &mut StackDrive) -> Option<()> { let f = self.f.take()?; f(drive); None } } -fn once(f: F) -> Box> { +fn once(f: F) -> Box> { Box::new(OpOnce { f: Some(f) }) } @@ -239,102 +320,6 @@ fn unimplemented() -> Box { Box::new(OpUnimplemented {}) } -struct OpMinRepeatOne { - trace_id: usize, - mincount: usize, - maxcount: usize, - count: usize, - child_ctx_id: usize, -} -impl OpcodeExecutor for OpMinRepeatOne { - fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()> { - None - // match self.trace_id { - // 0 => self._0(drive), - // _ => unreachable!(), - // } - } -} -impl Default for OpMinRepeatOne { - fn default() -> Self { - OpMinRepeatOne { - trace_id: 0, - mincount: 0, - maxcount: 0, - count: 0, - child_ctx_id: 0, - } - } -} -// impl OpMinRepeatOne { -// fn _0(&mut self, drive: &mut MatchContextDrive) -> Option<()> { -// self.mincount = drive.peek_code(2) as usize; -// self.maxcount = drive.peek_code(3) as usize; - -// if drive.remaining_chars() < self.mincount { -// drive.ctx_mut().has_matched = Some(false); -// return None; -// } - -// drive.state.string_position = drive.ctx().string_position; - -// self.count = if self.mincount == 0 { -// 0 -// } else { -// let count = count_repetitions(drive, self.mincount); -// if count < self.mincount { -// drive.ctx_mut().has_matched = Some(false); -// return None; -// } -// drive.skip_char(count); -// count -// }; - -// if drive.peek_code(drive.peek_code(1) as usize + 1) == SreOpcode::SUCCESS as u32 { -// drive.state.string_position = drive.ctx().string_position; -// drive.ctx_mut().has_matched = Some(true); -// return None; -// } - -// // mark push -// self.trace_id = 1; -// self._1(drive) -// } -// fn _1(&mut self, drive: &mut MatchContextDrive) -> Option<()> { -// if self.maxcount == SRE_MAXREPEAT || self.count <= self.maxcount { -// drive.state.string_position = drive.ctx().string_position; -// self.child_ctx_id = drive.push_new_context(drive.peek_code(1) as usize + 1); -// self.trace_id = 2; -// return Some(()); -// } - -// // mark discard -// drive.ctx_mut().has_matched = Some(false); -// None -// } -// fn _2(&mut self, drive: &mut MatchContextDrive) -> Option<()> { -// if let Some(true) = drive.state.context_stack[self.child_ctx_id].has_matched { -// drive.ctx_mut().has_matched = Some(true); -// return None; -// } -// drive.state.string_position = drive.ctx().string_position; -// if count_repetitions(drive, 1) == 0 { -// self.trace_id = 3; -// return self._3(drive); -// } -// drive.skip_char(1); -// self.count += 1; -// // marks pop keep -// self.trace_id = 1; -// self._1(drive) -// } -// fn _3(&mut self, drive: &mut MatchContextDrive) -> Option<()> { -// // mark discard -// drive.ctx_mut().has_matched = Some(false); -// None -// } -// } - struct OpcodeDispatcher { executing_contexts: HashMap>, } @@ -347,7 +332,7 @@ impl OpcodeDispatcher { // Returns True if the current context matches, False if it doesn't and // None if matching is not finished, ie must be resumed after child // contexts have been matched. - fn pymatch(&mut self, drive: &mut MatchContextDrive) -> Option { + fn pymatch(&mut self, drive: &mut StackDrive) -> Option { while drive.remaining_codes() > 0 && drive.ctx().has_matched.is_none() { let code = drive.peek_code(0); let opcode = SreOpcode::try_from(code).unwrap(); @@ -364,7 +349,7 @@ impl OpcodeDispatcher { // Dispatches a context on a given opcode. Returns True if the context // is done matching, False if it must be resumed when next encountered. - fn dispatch(&mut self, opcode: SreOpcode, drive: &mut MatchContextDrive) -> bool { + fn dispatch(&mut self, opcode: SreOpcode, drive: &mut StackDrive) -> bool { let mut executor = match self.executing_contexts.remove_entry(&drive.id()) { Some((_, executor)) => executor, None => self.dispatch_table(opcode), @@ -414,58 +399,67 @@ impl OpcodeDispatcher { }), SreOpcode::BRANCH => unimplemented(), SreOpcode::CALL => unimplemented(), - SreOpcode::CATEGORY => unimplemented(), - SreOpcode::CHARSET => unimplemented(), - SreOpcode::BIGCHARSET => unimplemented(), + SreOpcode::CATEGORY => once(|drive| { + let catcode = SreCatCode::try_from(drive.peek_code(1)).unwrap(); + if drive.at_end() || !category(catcode, drive.peek_char()) { + drive.ctx_mut().has_matched = Some(false); + } else { + drive.skip_code(2); + drive.skip_char(1); + } + }), + SreOpcode::CHARSET | SreOpcode::BIGCHARSET => unreachable!("unexpected opcode"), SreOpcode::GROUPREF => unimplemented(), SreOpcode::GROUPREF_EXISTS => unimplemented(), SreOpcode::GROUPREF_IGNORE => unimplemented(), - SreOpcode::IN => unimplemented(), - SreOpcode::IN_IGNORE => unimplemented(), + SreOpcode::IN => once(|drive| { + general_op_in(drive, |x| x); + }), + SreOpcode::IN_IGNORE => once(|drive| { + general_op_in(drive, lower_ascii); + }), + SreOpcode::IN_UNI_IGNORE => once(|drive| { + general_op_in(drive, lower_unicode); + }), + SreOpcode::IN_LOC_IGNORE => once(|drive| { + let skip = drive.peek_code(1) as usize; + if drive.at_end() || !charset_loc_ignore(&drive.pattern()[1..], drive.peek_char()) { + drive.ctx_mut().has_matched = Some(false); + } else { + drive.skip_code(skip + 1); + drive.skip_char(1); + } + }), SreOpcode::INFO | SreOpcode::JUMP => once(|drive| { drive.skip_code(drive.peek_code(1) as usize + 1); }), SreOpcode::LITERAL => once(|drive| { - if drive.at_end() || drive.peek_char() as u32 != drive.peek_code(1) { - drive.ctx_mut().has_matched = Some(false); - } - drive.skip_code(2); - drive.skip_char(1); + general_op_literal(drive, |code, c| code == c as u32); + }), + SreOpcode::NOT_LITERAL => once(|drive| { + general_op_literal(drive, |code, c| code != c as u32); }), SreOpcode::LITERAL_IGNORE => once(|drive| { - let code = drive.peek_code(1); - let c = drive.peek_char(); - if drive.at_end() - || (c.to_ascii_lowercase() as u32 != code - && c.to_ascii_uppercase() as u32 != code) - { - drive.ctx_mut().has_matched = Some(false); - } - drive.skip_code(2); - drive.skip_char(1); + general_op_literal(drive, |code, c| code == lower_ascii(c) as u32); + }), + SreOpcode::NOT_LITERAL_IGNORE => once(|drive| { + general_op_literal(drive, |code, c| code != lower_ascii(c) as u32); + }), + SreOpcode::LITERAL_UNI_IGNORE => once(|drive| { + general_op_literal(drive, |code, c| code == lower_unicode(c) as u32); + }), + SreOpcode::NOT_LITERAL_UNI_IGNORE => once(|drive| { + general_op_literal(drive, |code, c| code != lower_unicode(c) as u32); + }), + SreOpcode::LITERAL_LOC_IGNORE => once(|drive| { + general_op_literal(drive, |code, c| char_loc_ignore(code, c)); + }), + SreOpcode::NOT_LITERAL_LOC_IGNORE => once(|drive| { + general_op_literal(drive, |code, c| !char_loc_ignore(code, c)); }), SreOpcode::MARK => unimplemented(), SreOpcode::MAX_UNTIL => unimplemented(), SreOpcode::MIN_UNTIL => unimplemented(), - SreOpcode::NOT_LITERAL => once(|drive| { - if drive.at_end() || drive.peek_char() as u32 == drive.peek_code(1) { - drive.ctx_mut().has_matched = Some(false); - } - drive.skip_code(2); - drive.skip_char(1); - }), - SreOpcode::NOT_LITERAL_IGNORE => once(|drive| { - let code = drive.peek_code(1); - let c = drive.peek_char(); - if drive.at_end() - || (c.to_ascii_lowercase() as u32 == code - || c.to_ascii_uppercase() as u32 == code) - { - drive.ctx_mut().has_matched = Some(false); - } - drive.skip_code(2); - drive.skip_char(1); - }), SreOpcode::NEGATE => unimplemented(), SreOpcode::RANGE => unimplemented(), SreOpcode::REPEAT => unimplemented(), @@ -473,47 +467,45 @@ impl OpcodeDispatcher { SreOpcode::SUBPATTERN => unimplemented(), SreOpcode::MIN_REPEAT_ONE => Box::new(OpMinRepeatOne::default()), SreOpcode::GROUPREF_LOC_IGNORE => unimplemented(), - SreOpcode::IN_LOC_IGNORE => unimplemented(), - SreOpcode::LITERAL_LOC_IGNORE => unimplemented(), - SreOpcode::NOT_LITERAL_LOC_IGNORE => unimplemented(), SreOpcode::GROUPREF_UNI_IGNORE => unimplemented(), - SreOpcode::IN_UNI_IGNORE => unimplemented(), - SreOpcode::LITERAL_UNI_IGNORE => unimplemented(), - SreOpcode::NOT_LITERAL_UNI_IGNORE => unimplemented(), SreOpcode::RANGE_UNI_IGNORE => unimplemented(), } } +} - // Returns the number of repetitions of a single item, starting from the - // current string position. The code pointer is expected to point to a - // REPEAT_ONE operation (with the repeated 4 ahead). - fn count_repetitions(&mut self, drive: &mut MatchContextDrive, maxcount: usize) -> usize { - let mut count = 0; - let mut real_maxcount = drive.remaining_chars(); - if maxcount < real_maxcount && maxcount != MAXREPEAT { - real_maxcount = maxcount; - } - let code_position = drive.ctx().code_position; - let string_position = drive.ctx().string_position; - drive.skip_code(4); - let reset_position = drive.ctx().code_position; - while count < real_maxcount { - drive.ctx_mut().code_position = reset_position; - let opcode = SreOpcode::try_from(drive.peek_code(1)).unwrap(); - self.dispatch(opcode, drive); - if drive.ctx().has_matched == Some(false) { - break; - } - count += 1; - } - drive.ctx_mut().has_matched = None; - drive.ctx_mut().code_position = code_position; - drive.ctx_mut().string_position = string_position; - count +fn char_loc_ignore(code: u32, c: char) -> bool { + code == c as u32 || code == lower_locate(c) as u32 || code == upper_locate(c) as u32 +} + +fn charset_loc_ignore(set: &[u32], c: char) -> bool { + let lo = lower_locate(c); + if charset(set, c) { + return true; } + let up = upper_locate(c); + up != lo && charset(set, up) } -fn at(drive: &mut MatchContextDrive, atcode: SreAtCode) -> bool { +fn general_op_literal bool>(drive: &mut StackDrive, f: F) { + if drive.at_end() || !f(drive.peek_code(1), drive.peek_char()) { + drive.ctx_mut().has_matched = Some(false); + } else { + drive.skip_code(2); + drive.skip_char(1); + } +} + +fn general_op_in char>(drive: &mut StackDrive, f: F) { + let skip = drive.peek_code(1) as usize; + if drive.at_end() || !charset(&drive.pattern()[1..], f(drive.peek_char())) { + drive.ctx_mut().has_matched = Some(false); + } else { + drive.skip_code(skip + 1); + drive.skip_char(1); + } +} + +fn at(drive: &StackDrive, atcode: SreAtCode) -> bool { match atcode { SreAtCode::BEGINNING | SreAtCode::BEGINNING_STRING => drive.at_beginning(), SreAtCode::BEGINNING_LINE => drive.at_beginning() || is_linebreak(drive.back_peek_char()), @@ -642,54 +634,67 @@ fn charset(set: &[u32], c: char) -> bool { false } -fn count(drive: MatchContextDrive, maxcount: usize) -> usize { - let string_position = drive.state.string_position; +fn count(stack_drive: &StackDrive, maxcount: usize) -> usize { + let drive = WrapDrive::drive(stack_drive.ctx().clone(), stack_drive); let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); + let opcode = match SreOpcode::try_from(drive.peek_code(1)) { + Ok(code) => code, + Err(_) => { + panic!("FIXME:COUNT1"); + } + }; - let opcode = SreOpcode::try_from(drive.peek_code(1)).unwrap(); match opcode { - SreOpcode::FAILURE => {} - SreOpcode::SUCCESS => {} - SreOpcode::ANY => {} - SreOpcode::ANY_ALL => {} - SreOpcode::ASSERT => {} - SreOpcode::ASSERT_NOT => {} - SreOpcode::AT => {} - SreOpcode::BRANCH => {} - SreOpcode::CALL => {} - SreOpcode::CATEGORY => {} - SreOpcode::CHARSET => {} - SreOpcode::BIGCHARSET => {} - SreOpcode::GROUPREF => {} - SreOpcode::GROUPREF_EXISTS => {} + SreOpcode::ANY => { + while !drive.at_end() && !drive.at_linebreak() { + drive.skip_char(1); + } + } + SreOpcode::ANY_ALL => { + drive.skip_char(drive.remaining_chars()); + } SreOpcode::IN => { + // TODO: pattern[2 or 1..]? + while !drive.at_end() && charset(&drive.pattern()[2..], drive.peek_char()) { + drive.skip_char(1); + } + } + SreOpcode::LITERAL => { + general_count_literal(drive, |code, c| code == c as u32); + } + SreOpcode::NOT_LITERAL => { + general_count_literal(drive, |code, c| code != c as u32); + } + SreOpcode::LITERAL_IGNORE => { + general_count_literal(drive, |code, c| code == lower_ascii(c) as u32); + } + SreOpcode::NOT_LITERAL_IGNORE => { + general_count_literal(drive, |code, c| code != lower_ascii(c) as u32); + } + SreOpcode::LITERAL_LOC_IGNORE => { + general_count_literal(drive, |code, c| char_loc_ignore(code, c)); + } + SreOpcode::NOT_LITERAL_LOC_IGNORE => { + general_count_literal(drive, |code, c| !char_loc_ignore(code, c)); + } + SreOpcode::LITERAL_UNI_IGNORE => { + general_count_literal(drive, |code, c| code == lower_unicode(c) as u32); + } + SreOpcode::NOT_LITERAL_UNI_IGNORE => { + general_count_literal(drive, |code, c| code != lower_unicode(c) as u32); } - SreOpcode::INFO => {} - SreOpcode::JUMP => {} - SreOpcode::LITERAL => {} - SreOpcode::MARK => {} - SreOpcode::MAX_UNTIL => {} - SreOpcode::MIN_UNTIL => {} - SreOpcode::NOT_LITERAL => {} - SreOpcode::NEGATE => {} - SreOpcode::RANGE => {} - SreOpcode::REPEAT => {} - SreOpcode::REPEAT_ONE => {} - SreOpcode::SUBPATTERN => {} - SreOpcode::MIN_REPEAT_ONE => {} - SreOpcode::GROUPREF_IGNORE => {} - SreOpcode::IN_IGNORE => {} - SreOpcode::LITERAL_IGNORE => {} - SreOpcode::NOT_LITERAL_IGNORE => {} - SreOpcode::GROUPREF_LOC_IGNORE => {} - SreOpcode::IN_LOC_IGNORE => {} - SreOpcode::LITERAL_LOC_IGNORE => {} - SreOpcode::NOT_LITERAL_LOC_IGNORE => {} - SreOpcode::GROUPREF_UNI_IGNORE => {} - SreOpcode::IN_UNI_IGNORE => {} - SreOpcode::LITERAL_UNI_IGNORE => {} - SreOpcode::NOT_LITERAL_UNI_IGNORE => {} - SreOpcode::RANGE_UNI_IGNORE => {} + _ => { + panic!("TODO: Not Implemented."); + } + } + + drive.ctx().string_position - stack_drive.ctx().string_position +} + +fn general_count_literal bool>(drive: &mut WrapDrive, f: F) { + let ch = drive.peek_code(1); + while !drive.at_end() && f(ch, drive.peek_char()) { + drive.skip_char(1); } } @@ -781,7 +786,7 @@ impl Default for OpAssert { } } impl OpcodeExecutor for OpAssert { - fn next(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + fn next(&mut self, drive: &mut StackDrive) -> Option<()> { match self.jump_id { 0 => self._0(drive), 1 => self._1(drive), @@ -790,7 +795,7 @@ impl OpcodeExecutor for OpAssert { } } impl OpAssert { - fn _0(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + fn _0(&mut self, drive: &mut StackDrive) -> Option<()> { let back = drive.peek_code(2) as usize; if back > drive.ctx().string_position { drive.ctx_mut().has_matched = Some(false); @@ -801,7 +806,7 @@ impl OpAssert { self.jump_id = 1; Some(()) } - fn _1(&mut self, drive: &mut MatchContextDrive) -> Option<()> { + fn _1(&mut self, drive: &mut StackDrive) -> Option<()> { if drive.state.context_stack[self.child_ctx_id].has_matched == Some(true) { drive.skip_code(drive.peek_code(1) as usize + 1); } else { @@ -810,3 +815,95 @@ impl OpAssert { None } } + +struct OpMinRepeatOne { + jump_id: usize, + mincount: usize, + maxcount: usize, + count: usize, + child_ctx_id: usize, +} +impl OpcodeExecutor for OpMinRepeatOne { + fn next(&mut self, drive: &mut StackDrive) -> Option<()> { + match self.jump_id { + 0 => self._0(drive), + 1 => self._1(drive), + 2 => self._2(drive), + _ => unreachable!(), + } + } +} +impl Default for OpMinRepeatOne { + fn default() -> Self { + OpMinRepeatOne { + jump_id: 0, + mincount: 0, + maxcount: 0, + count: 0, + child_ctx_id: 0, + } + } +} +impl OpMinRepeatOne { + fn _0(&mut self, drive: &mut StackDrive) -> Option<()> { + self.mincount = drive.peek_code(2) as usize; + self.maxcount = drive.peek_code(3) as usize; + + if drive.remaining_chars() < self.mincount { + drive.ctx_mut().has_matched = Some(false); + return None; + } + + drive.state.string_position = drive.ctx().string_position; + + self.count = if self.mincount == 0 { + 0 + } else { + let count = count(drive, self.mincount); + if count < self.mincount { + drive.ctx_mut().has_matched = Some(false); + return None; + } + drive.skip_char(count); + count + }; + + if drive.peek_code(drive.peek_code(1) as usize + 1) == SreOpcode::SUCCESS as u32 { + drive.state.string_position = drive.ctx().string_position; + drive.ctx_mut().has_matched = Some(true); + return None; + } + + drive.state.marks_push(); + self.jump_id = 1; + self._1(drive) + } + fn _1(&mut self, drive: &mut StackDrive) -> Option<()> { + if self.maxcount == MAXREPEAT || self.count <= self.maxcount { + drive.state.string_position = drive.ctx().string_position; + self.child_ctx_id = drive.push_new_context(drive.peek_code(1) as usize + 1); + self.jump_id = 2; + return Some(()); + } + + drive.state.marks_pop_discard(); + drive.ctx_mut().has_matched = Some(false); + None + } + fn _2(&mut self, drive: &mut StackDrive) -> Option<()> { + if let Some(true) = drive.state.context_stack[self.child_ctx_id].has_matched { + drive.ctx_mut().has_matched = Some(true); + return None; + } + drive.state.string_position = drive.ctx().string_position; + if count(drive, 1) == 0 { + drive.ctx_mut().has_matched = Some(false); + return None; + } + drive.skip_char(1); + self.count += 1; + drive.state.marks_pop_keep(); + self.jump_id = 1; + self._1(drive) + } +} From 78485fd8df79444283d3cb978654500da0ce0590 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sat, 26 Dec 2020 16:59:30 +0200 Subject: [PATCH 005/893] create _sre.Match --- interp.rs | 68 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 22 deletions(-) diff --git a/interp.rs b/interp.rs index 8e0968cbf5..71d89cc0e8 100644 --- a/interp.rs +++ b/interp.rs @@ -1,21 +1,24 @@ // good luck to those that follow; here be dragons -use super::_sre::MAXREPEAT; +use rustpython_common::borrow::BorrowValue; + +use super::_sre::{Match, Pattern, MAXREPEAT}; use super::constants::{SreAtCode, SreCatCode, SreFlag, SreOpcode}; +use crate::builtins::PyStrRef; use std::collections::HashMap; use std::convert::TryFrom; #[derive(Debug)] -pub struct State<'a> { +pub(crate) struct State<'a> { string: &'a str, // chars count string_len: usize, - start: usize, - end: usize, + pub start: usize, + pub end: usize, flags: SreFlag, - pattern_codes: Vec, + pattern_codes: &'a [u32], marks: Vec>, - lastindex: isize, + pub lastindex: isize, marks_stack: Vec<(Vec>, isize)>, context_stack: Vec, repeat: Option, @@ -28,7 +31,7 @@ impl<'a> State<'a> { start: usize, end: usize, flags: SreFlag, - pattern_codes: Vec, + pattern_codes: &'a [u32], ) -> Self { let string_len = string.chars().count(); let end = std::cmp::min(end, string_len); @@ -75,20 +78,36 @@ impl<'a> State<'a> { } } fn marks_push(&mut self) { - self.marks_stack.push(self.marks.clone(), self.lastindex); + self.marks_stack.push((self.marks.clone(), self.lastindex)); } fn marks_pop(&mut self) { - (self.marks, self.lastindex) = self.marks_stack.pop().unwrap(); + let (marks, lastindex) = self.marks_stack.pop().unwrap(); + self.marks = marks; + self.lastindex = lastindex; } fn marks_pop_keep(&mut self) { - (self.marks, self.lastindex) = self.marks_stack.last().unwrap(); + let (marks, lastindex) = self.marks_stack.last().unwrap().clone(); + self.marks = marks; + self.lastindex = lastindex; } fn marks_pop_discard(&mut self) { self.marks_stack.pop(); } } -pub(crate) fn pymatch(mut state: State) -> bool { +pub(crate) fn pymatch( + string: PyStrRef, + start: usize, + end: usize, + pattern: &Pattern, +) -> Option { + let mut state = State::new( + string.borrow_value(), + start, + end, + pattern.flags.clone(), + &pattern.code, + ); let ctx = MatchContext { string_position: state.start, string_offset: state @@ -117,7 +136,12 @@ pub(crate) fn pymatch(mut state: State) -> bool { state.context_stack.pop(); } } - has_matched.unwrap_or(false) + + if has_matched == None || has_matched == Some(false) { + return None; + } + + Some(Match::new(&state, pattern.pattern.clone(), string.clone())) } #[derive(Debug, Copy, Clone)] @@ -635,7 +659,7 @@ fn charset(set: &[u32], c: char) -> bool { } fn count(stack_drive: &StackDrive, maxcount: usize) -> usize { - let drive = WrapDrive::drive(stack_drive.ctx().clone(), stack_drive); + let mut drive = WrapDrive::drive(stack_drive.ctx().clone(), stack_drive); let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); let opcode = match SreOpcode::try_from(drive.peek_code(1)) { Ok(code) => code, @@ -660,28 +684,28 @@ fn count(stack_drive: &StackDrive, maxcount: usize) -> usize { } } SreOpcode::LITERAL => { - general_count_literal(drive, |code, c| code == c as u32); + general_count_literal(&mut drive, |code, c| code == c as u32); } SreOpcode::NOT_LITERAL => { - general_count_literal(drive, |code, c| code != c as u32); + general_count_literal(&mut drive, |code, c| code != c as u32); } SreOpcode::LITERAL_IGNORE => { - general_count_literal(drive, |code, c| code == lower_ascii(c) as u32); + general_count_literal(&mut drive, |code, c| code == lower_ascii(c) as u32); } SreOpcode::NOT_LITERAL_IGNORE => { - general_count_literal(drive, |code, c| code != lower_ascii(c) as u32); + general_count_literal(&mut drive, |code, c| code != lower_ascii(c) as u32); } SreOpcode::LITERAL_LOC_IGNORE => { - general_count_literal(drive, |code, c| char_loc_ignore(code, c)); + general_count_literal(&mut drive, |code, c| char_loc_ignore(code, c)); } SreOpcode::NOT_LITERAL_LOC_IGNORE => { - general_count_literal(drive, |code, c| !char_loc_ignore(code, c)); + general_count_literal(&mut drive, |code, c| !char_loc_ignore(code, c)); } SreOpcode::LITERAL_UNI_IGNORE => { - general_count_literal(drive, |code, c| code == lower_unicode(c) as u32); + general_count_literal(&mut drive, |code, c| code == lower_unicode(c) as u32); } SreOpcode::NOT_LITERAL_UNI_IGNORE => { - general_count_literal(drive, |code, c| code != lower_unicode(c) as u32); + general_count_literal(&mut drive, |code, c| code != lower_unicode(c) as u32); } _ => { panic!("TODO: Not Implemented."); @@ -691,7 +715,7 @@ fn count(stack_drive: &StackDrive, maxcount: usize) -> usize { drive.ctx().string_position - stack_drive.ctx().string_position } -fn general_count_literal bool>(drive: &mut WrapDrive, f: F) { +fn general_count_literal bool>(drive: &mut WrapDrive, mut f: F) { let ch = drive.peek_code(1); while !drive.at_end() && f(ch, drive.peek_char()) { drive.skip_char(1); From aa0f20b93e86e6b63eadefb1f17958b2d513ccd3 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sat, 26 Dec 2020 18:44:33 +0200 Subject: [PATCH 006/893] impl Pattern.fullmatch, Pattern.search --- interp.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/interp.rs b/interp.rs index 71d89cc0e8..f87cc85abc 100644 --- a/interp.rs +++ b/interp.rs @@ -1,10 +1,9 @@ // good luck to those that follow; here be dragons -use rustpython_common::borrow::BorrowValue; - use super::_sre::{Match, Pattern, MAXREPEAT}; use super::constants::{SreAtCode, SreCatCode, SreFlag, SreOpcode}; use crate::builtins::PyStrRef; +use rustpython_common::borrow::BorrowValue; use std::collections::HashMap; use std::convert::TryFrom; From 312e5b875677bbc57b085ffae98f230e56717d21 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 27 Dec 2020 21:27:06 +0200 Subject: [PATCH 007/893] impl opcode groupref and assert_not --- interp.rs | 143 ++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 124 insertions(+), 19 deletions(-) diff --git a/interp.rs b/interp.rs index f87cc85abc..88dbb034da 100644 --- a/interp.rs +++ b/interp.rs @@ -109,12 +109,7 @@ pub(crate) fn pymatch( ); let ctx = MatchContext { string_position: state.start, - string_offset: state - .string - .char_indices() - .nth(state.start) - .map(|x| x.0) - .unwrap_or(0), + string_offset: calc_string_offset(state.string, state.start), code_position: 0, has_matched: None, }; @@ -411,7 +406,7 @@ impl OpcodeDispatcher { } }), SreOpcode::ASSERT => Box::new(OpAssert::default()), - SreOpcode::ASSERT_NOT => unimplemented(), + SreOpcode::ASSERT_NOT => Box::new(OpAssertNot::default()), SreOpcode::AT => once(|drive| { let atcode = SreAtCode::try_from(drive.peek_code(1)).unwrap(); if !at(drive, atcode) { @@ -432,9 +427,6 @@ impl OpcodeDispatcher { } }), SreOpcode::CHARSET | SreOpcode::BIGCHARSET => unreachable!("unexpected opcode"), - SreOpcode::GROUPREF => unimplemented(), - SreOpcode::GROUPREF_EXISTS => unimplemented(), - SreOpcode::GROUPREF_IGNORE => unimplemented(), SreOpcode::IN => once(|drive| { general_op_in(drive, |x| x); }), @@ -480,7 +472,12 @@ impl OpcodeDispatcher { SreOpcode::NOT_LITERAL_LOC_IGNORE => once(|drive| { general_op_literal(drive, |code, c| !char_loc_ignore(code, c)); }), - SreOpcode::MARK => unimplemented(), + SreOpcode::MARK => once(|drive| { + drive + .state + .set_mark(drive.peek_code(1) as usize, drive.ctx().string_position); + drive.skip_code(2); + }), SreOpcode::MAX_UNTIL => unimplemented(), SreOpcode::MIN_UNTIL => unimplemented(), SreOpcode::NEGATE => unimplemented(), @@ -489,13 +486,36 @@ impl OpcodeDispatcher { SreOpcode::REPEAT_ONE => unimplemented(), SreOpcode::SUBPATTERN => unimplemented(), SreOpcode::MIN_REPEAT_ONE => Box::new(OpMinRepeatOne::default()), - SreOpcode::GROUPREF_LOC_IGNORE => unimplemented(), - SreOpcode::GROUPREF_UNI_IGNORE => unimplemented(), + SreOpcode::GROUPREF => once(|drive| general_op_groupref(drive, |x| x)), + SreOpcode::GROUPREF_IGNORE => once(|drive| general_op_groupref(drive, lower_ascii)), + SreOpcode::GROUPREF_LOC_IGNORE => { + once(|drive| general_op_groupref(drive, lower_locate)) + } + SreOpcode::GROUPREF_UNI_IGNORE => { + once(|drive| general_op_groupref(drive, lower_unicode)) + } + SreOpcode::GROUPREF_EXISTS => once(|drive| { + let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); + match (group_start, group_end) { + (Some(start), Some(end)) if start <= end => { + drive.skip_code(3); + } + _ => drive.skip_code(drive.peek_code(2) as usize + 1), + } + }), SreOpcode::RANGE_UNI_IGNORE => unimplemented(), } } } +fn calc_string_offset(string: &str, position: usize) -> usize { + string + .char_indices() + .nth(position) + .map(|(i, _)| i) + .unwrap_or(0) +} + fn char_loc_ignore(code: u32, c: char) -> bool { code == c as u32 || code == lower_locate(c) as u32 || code == upper_locate(c) as u32 } @@ -509,6 +529,40 @@ fn charset_loc_ignore(set: &[u32], c: char) -> bool { up != lo && charset(set, up) } +fn general_op_groupref char>(drive: &mut StackDrive, mut f: F) { + let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); + let (group_start, group_end) = match (group_start, group_end) { + (Some(start), Some(end)) if start <= end => (start, end), + _ => { + drive.ctx_mut().has_matched = Some(false); + return; + } + }; + let mut wdrive = WrapDrive::drive(*drive.ctx(), &drive); + let mut gdrive = WrapDrive::drive( + MatchContext { + string_position: group_start, + // TODO: cache the offset + string_offset: calc_string_offset(drive.state.string, group_start), + ..*drive.ctx() + }, + &drive, + ); + for _ in group_start..group_end { + if wdrive.at_end() || f(wdrive.peek_char()) != f(gdrive.peek_char()) { + drive.ctx_mut().has_matched = Some(false); + return; + } + wdrive.skip_char(1); + gdrive.skip_char(1); + } + let position = wdrive.ctx().string_position; + let offset = wdrive.ctx().string_offset; + drive.skip_code(2); + drive.ctx_mut().string_position = position; + drive.ctx_mut().string_offset = offset; +} + fn general_op_literal bool>(drive: &mut StackDrive, f: F) { if drive.at_end() || !f(drive.peek_code(1), drive.peek_char()) { drive.ctx_mut().has_matched = Some(false); @@ -766,11 +820,19 @@ fn is_uni_space(c: char) -> bool { c.is_whitespace() } fn is_uni_linebreak(c: char) -> bool { - match c { - '\u{000A}' | '\u{000B}' | '\u{000C}' | '\u{000D}' | '\u{001C}' | '\u{001D}' - | '\u{001E}' | '\u{0085}' | '\u{2028}' | '\u{2029}' => true, - _ => false, - } + matches!( + c, + '\u{000A}' + | '\u{000B}' + | '\u{000C}' + | '\u{000D}' + | '\u{001C}' + | '\u{001D}' + | '\u{001E}' + | '\u{0085}' + | '\u{2028}' + | '\u{2029}' + ) } fn is_uni_alnum(c: char) -> bool { // TODO: check with cpython @@ -802,7 +864,7 @@ struct OpAssert { } impl Default for OpAssert { fn default() -> Self { - OpAssert { + Self { child_ctx_id: 0, jump_id: 0, } @@ -839,6 +901,49 @@ impl OpAssert { } } +struct OpAssertNot { + child_ctx_id: usize, + jump_id: usize, +} +impl Default for OpAssertNot { + fn default() -> Self { + Self { + child_ctx_id: 0, + jump_id: 0, + } + } +} +impl OpcodeExecutor for OpAssertNot { + fn next(&mut self, drive: &mut StackDrive) -> Option<()> { + match self.jump_id { + 0 => self._0(drive), + 1 => self._1(drive), + _ => unreachable!(), + } + } +} +impl OpAssertNot { + fn _0(&mut self, drive: &mut StackDrive) -> Option<()> { + let back = drive.peek_code(2) as usize; + if back > drive.ctx().string_position { + drive.skip_code(drive.peek_code(1) as usize + 1); + return None; + } + drive.state.string_position = drive.ctx().string_position - back; + self.child_ctx_id = drive.push_new_context(3); + self.jump_id = 1; + Some(()) + } + fn _1(&mut self, drive: &mut StackDrive) -> Option<()> { + if drive.state.context_stack[self.child_ctx_id].has_matched == Some(true) { + drive.ctx_mut().has_matched = Some(false); + } else { + drive.skip_code(drive.peek_code(1) as usize + 1); + } + None + } +} + struct OpMinRepeatOne { jump_id: usize, mincount: usize, From 0b2c8d1fa256a0b40223f97b2a5de6b2747c3397 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 29 Dec 2020 10:33:28 +0200 Subject: [PATCH 008/893] impl OpMaxUntil --- interp.rs | 133 +++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 122 insertions(+), 11 deletions(-) diff --git a/interp.rs b/interp.rs index 88dbb034da..85787ae516 100644 --- a/interp.rs +++ b/interp.rs @@ -20,7 +20,7 @@ pub(crate) struct State<'a> { pub lastindex: isize, marks_stack: Vec<(Vec>, isize)>, context_stack: Vec, - repeat: Option, + repeat_stack: Vec, string_position: usize, } @@ -45,7 +45,7 @@ impl<'a> State<'a> { lastindex: -1, marks_stack: Vec::new(), context_stack: Vec::new(), - repeat: None, + repeat_stack: Vec::new(), marks: Vec::new(), string_position: start, } @@ -56,7 +56,7 @@ impl<'a> State<'a> { self.lastindex = -1; self.marks_stack.clear(); self.context_stack.clear(); - self.repeat = None; + self.repeat_stack.clear(); } fn set_mark(&mut self, mark_nr: usize, position: usize) { @@ -104,7 +104,7 @@ pub(crate) fn pymatch( string.borrow_value(), start, end, - pattern.flags.clone(), + pattern.flags, &pattern.code, ); let ctx = MatchContext { @@ -237,6 +237,7 @@ trait MatchContextDrive { ]), _ => unreachable!(), }; + // TODO: char::from_u32_unchecked is stable from 1.5.0 unsafe { std::mem::transmute(code) } } fn back_skip_char(&mut self, skip_count: usize) { @@ -426,7 +427,6 @@ impl OpcodeDispatcher { drive.skip_char(1); } }), - SreOpcode::CHARSET | SreOpcode::BIGCHARSET => unreachable!("unexpected opcode"), SreOpcode::IN => once(|drive| { general_op_in(drive, |x| x); }), @@ -467,7 +467,7 @@ impl OpcodeDispatcher { general_op_literal(drive, |code, c| code != lower_unicode(c) as u32); }), SreOpcode::LITERAL_LOC_IGNORE => once(|drive| { - general_op_literal(drive, |code, c| char_loc_ignore(code, c)); + general_op_literal(drive, char_loc_ignore); }), SreOpcode::NOT_LITERAL_LOC_IGNORE => once(|drive| { general_op_literal(drive, |code, c| !char_loc_ignore(code, c)); @@ -478,9 +478,8 @@ impl OpcodeDispatcher { .set_mark(drive.peek_code(1) as usize, drive.ctx().string_position); drive.skip_code(2); }), - SreOpcode::MAX_UNTIL => unimplemented(), + SreOpcode::MAX_UNTIL => Box::new(OpMaxUntil::default()), SreOpcode::MIN_UNTIL => unimplemented(), - SreOpcode::NEGATE => unimplemented(), SreOpcode::RANGE => unimplemented(), SreOpcode::REPEAT => unimplemented(), SreOpcode::REPEAT_ONE => unimplemented(), @@ -504,6 +503,10 @@ impl OpcodeDispatcher { } }), SreOpcode::RANGE_UNI_IGNORE => unimplemented(), + _ => { + // TODO + unreachable!("unexpected opcode") + } } } } @@ -661,7 +664,7 @@ fn charset(set: &[u32], c: char) -> bool { /* <256 blockindices> */ let count = set[i + 1]; if ch < 0x10000 { - let blockindices: &[u8] = unsafe { std::mem::transmute(&set[i + 2..]) }; + let (_, blockindices, _) = unsafe { set[i + 2..].align_to::() }; let block = blockindices[(ch >> 8) as usize]; if set[2 + 64 + ((block as u32 * 256 + (ch & 255)) / 32) as usize] & (1 << (ch & (32 - 1))) @@ -712,7 +715,7 @@ fn charset(set: &[u32], c: char) -> bool { } fn count(stack_drive: &StackDrive, maxcount: usize) -> usize { - let mut drive = WrapDrive::drive(stack_drive.ctx().clone(), stack_drive); + let mut drive = WrapDrive::drive(*stack_drive.ctx(), stack_drive); let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); let opcode = match SreOpcode::try_from(drive.peek_code(1)) { Ok(code) => code, @@ -749,7 +752,7 @@ fn count(stack_drive: &StackDrive, maxcount: usize) -> usize { general_count_literal(&mut drive, |code, c| code != lower_ascii(c) as u32); } SreOpcode::LITERAL_LOC_IGNORE => { - general_count_literal(&mut drive, |code, c| char_loc_ignore(code, c)); + general_count_literal(&mut drive, char_loc_ignore); } SreOpcode::NOT_LITERAL_LOC_IGNORE => { general_count_literal(&mut drive, |code, c| !char_loc_ignore(code, c)); @@ -1035,3 +1038,111 @@ impl OpMinRepeatOne { self._1(drive) } } + +#[derive(Debug, Copy, Clone)] +struct RepeatContext { + skip: usize, + mincount: usize, + maxcount: usize, + count: isize, + last_position: isize, +} + +struct OpMaxUntil { + jump_id: usize, + count: isize, + save_last_position: isize, + child_ctx_id: usize, +} +impl Default for OpMaxUntil { + fn default() -> Self { + Self { + jump_id: 0, + count: 0, + save_last_position: -1, + child_ctx_id: 0, + } + } +} +impl OpcodeExecutor for OpMaxUntil { + fn next(&mut self, drive: &mut StackDrive) -> Option<()> { + match self.jump_id { + 0 => { + drive.state.string_position = drive.ctx().string_position; + let repeat = match drive.state.repeat_stack.last_mut() { + Some(repeat) => repeat, + None => { + todo!("Internal re error: MAX_UNTIL without REPEAT."); + } + }; + self.count = repeat.count + 1; + + if self.count < repeat.mincount as isize { + // not enough matches + repeat.count = self.count; + self.child_ctx_id = drive.push_new_context(4); + self.jump_id = 1; + return Some(()); + } + + if (self.count < repeat.maxcount as isize || repeat.maxcount == MAXREPEAT) + && (drive.state.string_position as isize != repeat.last_position) + { + // we may have enough matches, if we can match another item, do so + repeat.count = self.count; + self.save_last_position = repeat.last_position; + repeat.last_position = drive.state.string_position as isize; + drive.state.marks_push(); + self.child_ctx_id = drive.push_new_context(4); + self.jump_id = 2; + return Some(()); + } + + self.child_ctx_id = drive.push_new_context(1); + + self.jump_id = 3; + Some(()) + } + 1 => { + let child_ctx = &drive.state.context_stack[self.child_ctx_id]; + drive.ctx_mut().has_matched = child_ctx.has_matched; + if drive.ctx().has_matched != Some(true) { + drive.state.string_position = drive.ctx().string_position; + let repeat = drive.state.repeat_stack.last_mut().unwrap(); + repeat.count = self.count - 1; + } + None + } + 2 => { + let repeat = drive.state.repeat_stack.last_mut().unwrap(); + repeat.last_position = drive.state.string_position as isize; + let child_ctx = &drive.state.context_stack[self.child_ctx_id]; + if child_ctx.has_matched == Some(true) { + drive.state.marks_pop_discard(); + drive.ctx_mut().has_matched = Some(true); + return None; + } + repeat.count = self.count - 1; + drive.state.marks_pop(); + drive.state.string_position = drive.ctx().string_position; + + self.child_ctx_id = drive.push_new_context(1); + + self.jump_id = 3; + Some(()) + } + 3 => { + // cannot match more repeated items here. make sure the tail matches + let child_ctx = &drive.state.context_stack[self.child_ctx_id]; + drive.ctx_mut().has_matched = child_ctx.has_matched; + if drive.ctx().has_matched != Some(true) { + drive.state.string_position = drive.ctx().string_position; + } else { + drive.state.repeat_stack.pop(); + } + None + } + _ => unreachable!(), + } + } +} From 93c2b8b55513989982156136e6b92a1f07e02c24 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 29 Dec 2020 13:02:43 +0200 Subject: [PATCH 009/893] impl OpBranch --- interp.rs | 49 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/interp.rs b/interp.rs index 85787ae516..03cfdfa96a 100644 --- a/interp.rs +++ b/interp.rs @@ -417,7 +417,6 @@ impl OpcodeDispatcher { } }), SreOpcode::BRANCH => unimplemented(), - SreOpcode::CALL => unimplemented(), SreOpcode::CATEGORY => once(|drive| { let catcode = SreCatCode::try_from(drive.peek_code(1)).unwrap(); if drive.at_end() || !category(catcode, drive.peek_char()) { @@ -1146,3 +1145,51 @@ impl OpcodeExecutor for OpMaxUntil { } } } + +struct OpBranch { + jump_id: usize, + child_ctx_id: usize, + current_branch_length: usize, +} +impl Default for OpBranch { + fn default() -> Self { + Self { jump_id: 0, child_ctx_id: 0, current_branch_length: 0 } + } +} +impl OpcodeExecutor for OpBranch { + fn next(&mut self, drive: &mut StackDrive) -> Option<()> { + match self.jump_id { + 0 => { + drive.state.marks_push(); + // jump out the head + self.current_branch_length = 1; + self.jump_id = 1; + self.next(drive) + } + 1 => { + drive.skip_code(self.current_branch_length); + self.current_branch_length = drive.peek_code(0) as usize; + if self.current_branch_length == 0 { + drive.state.marks_pop_discard(); + drive.ctx_mut().has_matched = Some(false); + return None; + } + drive.state.string_position = drive.ctx().string_position; + self.child_ctx_id = drive.push_new_context(1); + self.jump_id = 2; + Some(()) + } + 2 => { + let child_ctx = &drive.state.context_stack[self.child_ctx_id]; + if child_ctx.has_matched == Some(true) { + drive.ctx_mut().has_matched = Some(true); + return None; + } + drive.state.marks_pop_keep(); + self.jump_id = 1; + Some(()) + } + _ => unreachable!() + } + } +} \ No newline at end of file From 5a4459856ca57886279f0cc4f1abc33c7cf4a397 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 29 Dec 2020 13:56:37 +0200 Subject: [PATCH 010/893] impl OpRepeat --- interp.rs | 84 ++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 80 insertions(+), 4 deletions(-) diff --git a/interp.rs b/interp.rs index 03cfdfa96a..fbcbf260e5 100644 --- a/interp.rs +++ b/interp.rs @@ -416,7 +416,7 @@ impl OpcodeDispatcher { drive.skip_code(2); } }), - SreOpcode::BRANCH => unimplemented(), + SreOpcode::BRANCH => Box::new(OpBranch::default()), SreOpcode::CATEGORY => once(|drive| { let catcode = SreCatCode::try_from(drive.peek_code(1)).unwrap(); if drive.at_end() || !category(catcode, drive.peek_char()) { @@ -1146,6 +1146,39 @@ impl OpcodeExecutor for OpMaxUntil { } } +struct OpMinUntil { + jump_id: usize, + count: isize, + child_ctx_id: usize, +} +impl Default for OpMinUntil { + fn default() -> Self { + Self { + jump_id: 0, + count: 0, + child_ctx_id: 0, + } + } +} +impl OpcodeExecutor for OpMinUntil { + fn next(&mut self, drive: &mut StackDrive) -> Option<()> { + match self.jump_id { + 0 => { + drive.state.string_position = drive.ctx().string_position; + let repeat = match drive.state.repeat_stack.last_mut() { + Some(repeat) => repeat, + None => { + todo!("Internal re error: MAX_UNTIL without REPEAT."); + } + }; + self.count = repeat.count + 1; + None + } + _ => unreachable!(), + } + } +} + struct OpBranch { jump_id: usize, child_ctx_id: usize, @@ -1153,7 +1186,11 @@ struct OpBranch { } impl Default for OpBranch { fn default() -> Self { - Self { jump_id: 0, child_ctx_id: 0, current_branch_length: 0 } + Self { + jump_id: 0, + child_ctx_id: 0, + current_branch_length: 0, + } } } impl OpcodeExecutor for OpBranch { @@ -1189,7 +1226,46 @@ impl OpcodeExecutor for OpBranch { self.jump_id = 1; Some(()) } - _ => unreachable!() + _ => unreachable!(), + } + } +} + +struct OpRepeat { + jump_id: usize, + child_ctx_id: usize, +} +impl Default for OpRepeat { + fn default() -> Self { + Self { + jump_id: 0, + child_ctx_id: 0, + } + } +} +impl OpcodeExecutor for OpRepeat { + fn next(&mut self, drive: &mut StackDrive) -> Option<()> { + match self.jump_id { + 0 => { + let repeat = RepeatContext { + skip: drive.peek_code(1), + mincount: drive.peek_code(2), + maxcount: drive.peek_code(3), + count: -1, + last_position: -1, + }; + drive.state.repeat_stack.push(repeat); + drive.state.string_position = drive.ctx().string_position; + self.child_ctx_id = drive.push_new_context(drive.peek_code(1) as usize + 1); + self.jump_id = 1; + Some(()) + } + 1 => { + let child_ctx = &drive.state.context_stack[self.child_ctx_id]; + drive.ctx_mut().has_matched = child_ctx.has_matched; + None + } + _ => unreachable!(), } } -} \ No newline at end of file +} From af7901dcb21cba20bcfa8635b3c86d84574a988e Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 29 Dec 2020 16:48:30 +0200 Subject: [PATCH 011/893] impl OpMinUntil --- interp.rs | 59 ++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/interp.rs b/interp.rs index fbcbf260e5..ec6d41eade 100644 --- a/interp.rs +++ b/interp.rs @@ -478,11 +478,9 @@ impl OpcodeDispatcher { drive.skip_code(2); }), SreOpcode::MAX_UNTIL => Box::new(OpMaxUntil::default()), - SreOpcode::MIN_UNTIL => unimplemented(), - SreOpcode::RANGE => unimplemented(), - SreOpcode::REPEAT => unimplemented(), + SreOpcode::MIN_UNTIL => Box::new(OpMinUntil::default()), + SreOpcode::REPEAT => Box::new(OpRepeat::default()), SreOpcode::REPEAT_ONE => unimplemented(), - SreOpcode::SUBPATTERN => unimplemented(), SreOpcode::MIN_REPEAT_ONE => Box::new(OpMinRepeatOne::default()), SreOpcode::GROUPREF => once(|drive| general_op_groupref(drive, |x| x)), SreOpcode::GROUPREF_IGNORE => once(|drive| general_op_groupref(drive, lower_ascii)), @@ -501,9 +499,8 @@ impl OpcodeDispatcher { _ => drive.skip_code(drive.peek_code(2) as usize + 1), } }), - SreOpcode::RANGE_UNI_IGNORE => unimplemented(), _ => { - // TODO + // TODO error expcetion unreachable!("unexpected opcode") } } @@ -1172,8 +1169,52 @@ impl OpcodeExecutor for OpMinUntil { } }; self.count = repeat.count + 1; + + if self.count < repeat.mincount as isize { + // not enough matches + repeat.count = self.count; + self.child_ctx_id = drive.push_new_context(4); + self.jump_id = 1; + return Some(()); + } + + // see if the tail matches + drive.state.marks_push(); + self.child_ctx_id = drive.push_new_context(1); + self.jump_id = 2; + Some(()) + } + 1 => { + let child_ctx = &drive.state.context_stack[self.child_ctx_id]; + drive.ctx_mut().has_matched = child_ctx.has_matched; + if drive.ctx().has_matched != Some(true) { + drive.state.string_position = drive.ctx().string_position; + let repeat = drive.state.repeat_stack.last_mut().unwrap(); + repeat.count = self.count - 1; + } None } + 2 => { + let child_ctx = &drive.state.context_stack[self.child_ctx_id]; + if child_ctx.has_matched == Some(true) { + drive.state.repeat_stack.pop(); + drive.ctx_mut().has_matched = Some(true); + return None; + } + drive.state.string_position = drive.ctx().string_position; + drive.state.marks_pop(); + + // match more until tail matches + let repeat = drive.state.repeat_stack.last_mut().unwrap(); + if self.count >= repeat.maxcount as isize && repeat.maxcount != MAXREPEAT { + drive.ctx_mut().has_matched = Some(false); + return None; + } + repeat.count = self.count; + self.child_ctx_id = drive.push_new_context(4); + self.jump_id = 1; + Some(()) + } _ => unreachable!(), } } @@ -1248,9 +1289,9 @@ impl OpcodeExecutor for OpRepeat { match self.jump_id { 0 => { let repeat = RepeatContext { - skip: drive.peek_code(1), - mincount: drive.peek_code(2), - maxcount: drive.peek_code(3), + skip: drive.peek_code(1) as usize, + mincount: drive.peek_code(2) as usize, + maxcount: drive.peek_code(3) as usize, count: -1, last_position: -1, }; From fa2adaf2ff9bf8acf642ad210480e686848d4061 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 29 Dec 2020 17:53:14 +0200 Subject: [PATCH 012/893] Impl OpRepeatONe --- interp.rs | 132 ++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 103 insertions(+), 29 deletions(-) diff --git a/interp.rs b/interp.rs index ec6d41eade..64d70216e3 100644 --- a/interp.rs +++ b/interp.rs @@ -313,14 +313,6 @@ trait OpcodeExecutor { fn next(&mut self, drive: &mut StackDrive) -> Option<()>; } -struct OpUnimplemented {} -impl OpcodeExecutor for OpUnimplemented { - fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - drive.ctx_mut().has_matched = Some(false); - None - } -} - struct OpOnce { f: Option, } @@ -335,10 +327,6 @@ fn once(f: F) -> Box> { Box::new(OpOnce { f: Some(f) }) } -fn unimplemented() -> Box { - Box::new(OpUnimplemented {}) -} - struct OpcodeDispatcher { executing_contexts: HashMap>, } @@ -480,7 +468,7 @@ impl OpcodeDispatcher { SreOpcode::MAX_UNTIL => Box::new(OpMaxUntil::default()), SreOpcode::MIN_UNTIL => Box::new(OpMinUntil::default()), SreOpcode::REPEAT => Box::new(OpRepeat::default()), - SreOpcode::REPEAT_ONE => unimplemented(), + SreOpcode::REPEAT_ONE => Box::new(OpMinRepeatOne::default()), SreOpcode::MIN_REPEAT_ONE => Box::new(OpMinRepeatOne::default()), SreOpcode::GROUPREF => once(|drive| general_op_groupref(drive, |x| x)), SreOpcode::GROUPREF_IGNORE => once(|drive| general_op_groupref(drive, lower_ascii)), @@ -500,7 +488,7 @@ impl OpcodeDispatcher { } }), _ => { - // TODO error expcetion + // TODO python expcetion? unreachable!("unexpected opcode") } } @@ -713,6 +701,7 @@ fn charset(set: &[u32], c: char) -> bool { fn count(stack_drive: &StackDrive, maxcount: usize) -> usize { let mut drive = WrapDrive::drive(*stack_drive.ctx(), stack_drive); let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); + let end = drive.ctx().string_position + maxcount; let opcode = match SreOpcode::try_from(drive.peek_code(1)) { Ok(code) => code, Err(_) => { @@ -722,54 +711,56 @@ fn count(stack_drive: &StackDrive, maxcount: usize) -> usize { match opcode { SreOpcode::ANY => { - while !drive.at_end() && !drive.at_linebreak() { + while !drive.ctx().string_position < end && !drive.at_linebreak() { drive.skip_char(1); } } SreOpcode::ANY_ALL => { - drive.skip_char(drive.remaining_chars()); + drive.skip_char(maxcount); } SreOpcode::IN => { // TODO: pattern[2 or 1..]? - while !drive.at_end() && charset(&drive.pattern()[2..], drive.peek_char()) { + while !drive.ctx().string_position < end + && charset(&drive.pattern()[2..], drive.peek_char()) + { drive.skip_char(1); } } SreOpcode::LITERAL => { - general_count_literal(&mut drive, |code, c| code == c as u32); + general_count_literal(&mut drive, end, |code, c| code == c as u32); } SreOpcode::NOT_LITERAL => { - general_count_literal(&mut drive, |code, c| code != c as u32); + general_count_literal(&mut drive, end, |code, c| code != c as u32); } SreOpcode::LITERAL_IGNORE => { - general_count_literal(&mut drive, |code, c| code == lower_ascii(c) as u32); + general_count_literal(&mut drive, end, |code, c| code == lower_ascii(c) as u32); } SreOpcode::NOT_LITERAL_IGNORE => { - general_count_literal(&mut drive, |code, c| code != lower_ascii(c) as u32); + general_count_literal(&mut drive, end, |code, c| code != lower_ascii(c) as u32); } SreOpcode::LITERAL_LOC_IGNORE => { - general_count_literal(&mut drive, char_loc_ignore); + general_count_literal(&mut drive, end, char_loc_ignore); } SreOpcode::NOT_LITERAL_LOC_IGNORE => { - general_count_literal(&mut drive, |code, c| !char_loc_ignore(code, c)); + general_count_literal(&mut drive, end, |code, c| !char_loc_ignore(code, c)); } SreOpcode::LITERAL_UNI_IGNORE => { - general_count_literal(&mut drive, |code, c| code == lower_unicode(c) as u32); + general_count_literal(&mut drive, end, |code, c| code == lower_unicode(c) as u32); } SreOpcode::NOT_LITERAL_UNI_IGNORE => { - general_count_literal(&mut drive, |code, c| code != lower_unicode(c) as u32); + general_count_literal(&mut drive, end, |code, c| code != lower_unicode(c) as u32); } _ => { - panic!("TODO: Not Implemented."); + todo!("repeated single character pattern?"); } } - drive.ctx().string_position - stack_drive.ctx().string_position + drive.ctx().string_position - drive.state().string_position } -fn general_count_literal bool>(drive: &mut WrapDrive, mut f: F) { +fn general_count_literal bool>(drive: &mut WrapDrive, end: usize, mut f: F) { let ch = drive.peek_code(1); - while !drive.at_end() && f(ch, drive.peek_char()) { + while !drive.ctx().string_position < end && f(ch, drive.peek_char()) { drive.skip_char(1); } } @@ -1310,3 +1301,86 @@ impl OpcodeExecutor for OpRepeat { } } } + +struct OpRepeatOne { + jump_id: usize, + child_ctx_id: usize, + mincount: usize, + maxcount: usize, + count: usize, +} +impl Default for OpRepeatOne { + fn default() -> Self { + Self { + jump_id: 0, + child_ctx_id: 0, + mincount: 0, + maxcount: 0, + count: 0, + } + } +} +impl OpcodeExecutor for OpRepeatOne { + fn next(&mut self, drive: &mut StackDrive) -> Option<()> { + match self.jump_id { + 0 => { + self.mincount = drive.peek_code(2) as usize; + self.maxcount = drive.peek_code(3) as usize; + + if drive.remaining_chars() < self.mincount { + drive.ctx_mut().has_matched = Some(false); + return None; + } + drive.state.string_position = drive.ctx().string_position; + self.count = count(drive, self.maxcount); + drive.skip_char(self.count); + if self.count < self.mincount { + drive.ctx_mut().has_matched = Some(false); + return None; + } + + let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); + if next_code == SreOpcode::SUCCESS as u32 { + // tail is empty. we're finished + drive.state.string_position = drive.ctx().string_position; + drive.ctx_mut().has_matched = Some(true); + return None; + } + + drive.state.marks_push(); + // TODO: + // Special case: Tail starts with a literal. Skip positions where + // the rest of the pattern cannot possibly match. + self.jump_id = 1; + self.next(drive) + } + 1 => { + // General case: backtracking + if self.count >= self.mincount { + drive.state.string_position = drive.ctx().string_position; + self.child_ctx_id = drive.push_new_context(drive.peek_code(1) as usize + 1); + self.jump_id = 2; + return Some(()); + } + + drive.state.marks_pop_discard(); + drive.ctx_mut().has_matched = Some(false); + None + } + 2 => { + let child_ctx = &drive.state.context_stack[self.child_ctx_id]; + if child_ctx.has_matched == Some(true) { + drive.ctx_mut().has_matched = Some(true); + return None; + } + drive.back_skip_char(1); + self.count -= 1; + drive.state.marks_pop_keep(); + + self.jump_id = 1; + Some(()) + } + _ => unreachable!(), + } + } +} From ae44580371afb3e347f31647368c2a7a8b1a1578 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 30 Dec 2020 14:14:13 +0200 Subject: [PATCH 013/893] general case for count --- interp.rs | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/interp.rs b/interp.rs index 64d70216e3..99574443c0 100644 --- a/interp.rs +++ b/interp.rs @@ -3,6 +3,7 @@ use super::_sre::{Match, Pattern, MAXREPEAT}; use super::constants::{SreAtCode, SreCatCode, SreFlag, SreOpcode}; use crate::builtins::PyStrRef; +use crate::pyobject::PyRef; use rustpython_common::borrow::BorrowValue; use std::collections::HashMap; use std::convert::TryFrom; @@ -98,7 +99,7 @@ pub(crate) fn pymatch( string: PyStrRef, start: usize, end: usize, - pattern: &Pattern, + pattern: PyRef, ) -> Option { let mut state = State::new( string.borrow_value(), @@ -135,7 +136,7 @@ pub(crate) fn pymatch( return None; } - Some(Match::new(&state, pattern.pattern.clone(), string.clone())) + Some(Match::new(&state, pattern.clone().into_object(), string.clone())) } #[derive(Debug, Copy, Clone)] @@ -425,7 +426,7 @@ impl OpcodeDispatcher { }), SreOpcode::IN_LOC_IGNORE => once(|drive| { let skip = drive.peek_code(1) as usize; - if drive.at_end() || !charset_loc_ignore(&drive.pattern()[1..], drive.peek_char()) { + if drive.at_end() || !charset_loc_ignore(&drive.pattern()[2..], drive.peek_char()) { drive.ctx_mut().has_matched = Some(false); } else { drive.skip_code(skip + 1); @@ -561,7 +562,7 @@ fn general_op_literal bool>(drive: &mut StackDrive, f: F fn general_op_in char>(drive: &mut StackDrive, f: F) { let skip = drive.peek_code(1) as usize; - if drive.at_end() || !charset(&drive.pattern()[1..], f(drive.peek_char())) { + if drive.at_end() || !charset(&drive.pattern()[2..], f(drive.peek_char())) { drive.ctx_mut().has_matched = Some(false); } else { drive.skip_code(skip + 1); @@ -698,7 +699,28 @@ fn charset(set: &[u32], c: char) -> bool { false } -fn count(stack_drive: &StackDrive, maxcount: usize) -> usize { +fn count(drive: &mut StackDrive, maxcount: usize) -> usize { + let mut count = 0; + let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); + + let save_ctx = *drive.ctx(); + drive.skip_code(4); + let reset_position = drive.ctx().code_position; + + let mut dispatcher = OpcodeDispatcher::new(); + while count < maxcount { + drive.ctx_mut().code_position = reset_position; + dispatcher.dispatch(SreOpcode::try_from(drive.peek_code(0)).unwrap(), drive); + if drive.ctx().has_matched == Some(false) { + break; + } + count += 1; + } + *drive.ctx_mut() = save_ctx; + count +} + +fn _count(stack_drive: &StackDrive, maxcount: usize) -> usize { let mut drive = WrapDrive::drive(*stack_drive.ctx(), stack_drive); let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); let end = drive.ctx().string_position + maxcount; From f7287553e9ed42df638190b56ce4474eb7783c38 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 30 Dec 2020 18:04:00 +0200 Subject: [PATCH 014/893] impl re.Match object --- interp.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/interp.rs b/interp.rs index 99574443c0..ac0eadb4ae 100644 --- a/interp.rs +++ b/interp.rs @@ -17,12 +17,12 @@ pub(crate) struct State<'a> { pub end: usize, flags: SreFlag, pattern_codes: &'a [u32], - marks: Vec>, + pub marks: Vec>, pub lastindex: isize, marks_stack: Vec<(Vec>, isize)>, context_stack: Vec, repeat_stack: Vec, - string_position: usize, + pub string_position: usize, } impl<'a> State<'a> { @@ -136,7 +136,7 @@ pub(crate) fn pymatch( return None; } - Some(Match::new(&state, pattern.clone().into_object(), string.clone())) + Some(Match::new(&state, pattern.clone(), string.clone())) } #[derive(Debug, Copy, Clone)] From 04bb80f157128d59e2a5e25261311ee1be8c143e Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 31 Dec 2020 11:22:17 +0200 Subject: [PATCH 015/893] impl Match.group --- interp.rs | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/interp.rs b/interp.rs index ac0eadb4ae..f9b0898a31 100644 --- a/interp.rs +++ b/interp.rs @@ -115,6 +115,7 @@ pub(crate) fn pymatch( has_matched: None, }; state.context_stack.push(ctx); + let mut dispatcher = OpcodeDispatcher::new(); let mut has_matched = None; loop { @@ -123,7 +124,6 @@ pub(crate) fn pymatch( } let ctx_id = state.context_stack.len() - 1; let mut drive = StackDrive::drive(ctx_id, state); - let mut dispatcher = OpcodeDispatcher::new(); has_matched = dispatcher.pymatch(&mut drive); state = drive.take(); @@ -132,11 +132,11 @@ pub(crate) fn pymatch( } } - if has_matched == None || has_matched == Some(false) { - return None; + if has_matched != Some(true) { + None + } else { + Some(Match::new(&state, pattern.clone(), string.clone())) } - - Some(Match::new(&state, pattern.clone(), string.clone())) } #[derive(Debug, Copy, Clone)] @@ -344,7 +344,9 @@ impl OpcodeDispatcher { while drive.remaining_codes() > 0 && drive.ctx().has_matched.is_none() { let code = drive.peek_code(0); let opcode = SreOpcode::try_from(code).unwrap(); - self.dispatch(opcode, drive); + if !self.dispatch(opcode, drive) { + return None; + } } match drive.ctx().has_matched { Some(matched) => Some(matched), @@ -469,7 +471,7 @@ impl OpcodeDispatcher { SreOpcode::MAX_UNTIL => Box::new(OpMaxUntil::default()), SreOpcode::MIN_UNTIL => Box::new(OpMinUntil::default()), SreOpcode::REPEAT => Box::new(OpRepeat::default()), - SreOpcode::REPEAT_ONE => Box::new(OpMinRepeatOne::default()), + SreOpcode::REPEAT_ONE => Box::new(OpRepeatOne::default()), SreOpcode::MIN_REPEAT_ONE => Box::new(OpMinRepeatOne::default()), SreOpcode::GROUPREF => once(|drive| general_op_groupref(drive, |x| x)), SreOpcode::GROUPREF_IGNORE => once(|drive| general_op_groupref(drive, lower_ascii)), @@ -1329,7 +1331,7 @@ struct OpRepeatOne { child_ctx_id: usize, mincount: usize, maxcount: usize, - count: usize, + count: isize, } impl Default for OpRepeatOne { fn default() -> Self { @@ -1354,9 +1356,9 @@ impl OpcodeExecutor for OpRepeatOne { return None; } drive.state.string_position = drive.ctx().string_position; - self.count = count(drive, self.maxcount); - drive.skip_char(self.count); - if self.count < self.mincount { + self.count = count(drive, self.maxcount) as isize; + drive.skip_char(self.count as usize); + if self.count < self.mincount as isize { drive.ctx_mut().has_matched = Some(false); return None; } @@ -1378,7 +1380,7 @@ impl OpcodeExecutor for OpRepeatOne { } 1 => { // General case: backtracking - if self.count >= self.mincount { + if self.count >= self.mincount as isize { drive.state.string_position = drive.ctx().string_position; self.child_ctx_id = drive.push_new_context(drive.peek_code(1) as usize + 1); self.jump_id = 2; From 8c442f599bb67f4a9878cad9e4df180ab52b39b7 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 1 Jan 2021 10:30:05 +0200 Subject: [PATCH 016/893] rework OpMaxUntil; restruct popping context; add tests; --- interp.rs | 660 ++++++++++++++++++++++++++++++------------------------ 1 file changed, 364 insertions(+), 296 deletions(-) diff --git a/interp.rs b/interp.rs index f9b0898a31..cb06ee0b8a 100644 --- a/interp.rs +++ b/interp.rs @@ -23,6 +23,7 @@ pub(crate) struct State<'a> { context_stack: Vec, repeat_stack: Vec, pub string_position: usize, + popped_context: Option, } impl<'a> State<'a> { @@ -49,6 +50,7 @@ impl<'a> State<'a> { repeat_stack: Vec::new(), marks: Vec::new(), string_position: start, + popped_context: None, } } @@ -58,6 +60,7 @@ impl<'a> State<'a> { self.marks_stack.clear(); self.context_stack.clear(); self.repeat_stack.clear(); + self.popped_context = None; } fn set_mark(&mut self, mark_nr: usize, position: usize) { @@ -128,7 +131,7 @@ pub(crate) fn pymatch( has_matched = dispatcher.pymatch(&mut drive); state = drive.take(); if has_matched.is_some() { - state.context_stack.pop(); + state.popped_context = state.context_stack.pop(); } } @@ -151,6 +154,9 @@ trait MatchContextDrive { fn ctx_mut(&mut self) -> &mut MatchContext; fn ctx(&self) -> &MatchContext; fn state(&self) -> &State; + fn repeat_ctx(&self) -> &RepeatContext { + self.state().repeat_stack.last().unwrap() + } fn str(&self) -> &str { unsafe { std::str::from_utf8_unchecked( @@ -181,6 +187,9 @@ trait MatchContextDrive { } fn skip_code(&mut self, skip_count: usize) { self.ctx_mut().code_position += skip_count; + if self.ctx().code_position > self.state().pattern_codes.len() { + self.ctx_mut().code_position = self.state().pattern_codes.len(); + } } fn remaining_chars(&self) -> usize { self.state().end - self.ctx().string_position @@ -263,16 +272,17 @@ impl<'a> StackDrive<'a> { fn take(self) -> State<'a> { self.state } - fn push_new_context(&mut self, pattern_offset: usize) -> usize { + fn push_new_context(&mut self, pattern_offset: usize) { let ctx = self.ctx(); - let child_ctx = MatchContext { - string_position: ctx.string_position, - string_offset: ctx.string_offset, - code_position: ctx.code_position + pattern_offset, - has_matched: None, - }; + let mut child_ctx = MatchContext { ..*ctx }; + child_ctx.code_position += pattern_offset; + if child_ctx.code_position > self.state.pattern_codes.len() { + child_ctx.code_position = self.state.pattern_codes.len(); + } self.state.context_stack.push(child_ctx); - self.state.context_stack.len() - 1 + } + fn repeat_ctx_mut(&mut self) -> &mut RepeatContext { + self.state.repeat_stack.last_mut().unwrap() } } impl MatchContextDrive for StackDrive<'_> { @@ -328,6 +338,39 @@ fn once(f: F) -> Box> { Box::new(OpOnce { f: Some(f) }) } +// F1 F2 are same identical, but workaround for closure +struct OpTwice { + f1: Option, + f2: Option, +} +impl OpcodeExecutor for OpTwice +where + F1: FnOnce(&mut StackDrive), + F2: FnOnce(&mut StackDrive), +{ + fn next(&mut self, drive: &mut StackDrive) -> Option<()> { + if let Some(f1) = self.f1.take() { + f1(drive); + Some(()) + } else if let Some(f2) = self.f2.take() { + f2(drive); + None + } else { + unreachable!() + } + } +} +fn twice(f1: F1, f2: F2) -> Box> +where + F1: FnOnce(&mut StackDrive), + F2: FnOnce(&mut StackDrive), +{ + Box::new(OpTwice { + f1: Some(f1), + f2: Some(f2), + }) +} + struct OpcodeDispatcher { executing_contexts: HashMap>, } @@ -397,8 +440,44 @@ impl OpcodeDispatcher { drive.skip_char(1); } }), - SreOpcode::ASSERT => Box::new(OpAssert::default()), - SreOpcode::ASSERT_NOT => Box::new(OpAssertNot::default()), + SreOpcode::ASSERT => twice( + |drive| { + let back = drive.peek_code(2) as usize; + if back > drive.ctx().string_position { + drive.ctx_mut().has_matched = Some(false); + return; + } + drive.state.string_position = drive.ctx().string_position - back; + drive.push_new_context(3); + }, + |drive| { + let child_ctx = drive.state.popped_context.unwrap(); + if child_ctx.has_matched == Some(true) { + drive.skip_code(drive.peek_code(1) as usize + 1); + } else { + drive.ctx_mut().has_matched = Some(false); + } + }, + ), + SreOpcode::ASSERT_NOT => twice( + |drive| { + let back = drive.peek_code(2) as usize; + if back > drive.ctx().string_position { + drive.skip_code(drive.peek_code(1) as usize + 1); + return; + } + drive.state.string_position = drive.ctx().string_position - back; + drive.push_new_context(3); + }, + |drive| { + let child_ctx = drive.state.popped_context.unwrap(); + if child_ctx.has_matched == Some(true) { + drive.ctx_mut().has_matched = Some(false); + } else { + drive.skip_code(drive.peek_code(1) as usize + 1); + } + }, + ), SreOpcode::AT => once(|drive| { let atcode = SreAtCode::try_from(drive.peek_code(1)).unwrap(); if !at(drive, atcode) { @@ -468,9 +547,29 @@ impl OpcodeDispatcher { .set_mark(drive.peek_code(1) as usize, drive.ctx().string_position); drive.skip_code(2); }), + SreOpcode::REPEAT => twice( + // create repeat context. all the hard work is done by the UNTIL + // operator (MAX_UNTIL, MIN_UNTIL) + // <1=min> <2=max> item tail + |drive| { + let repeat = RepeatContext { + count: -1, + code_position: drive.ctx().code_position, + last_position: std::usize::MAX, + }; + drive.state.repeat_stack.push(repeat); + drive.state.string_position = drive.ctx().string_position; + // execute UNTIL operator + drive.push_new_context(drive.peek_code(1) as usize + 1); + }, + |drive| { + drive.state.repeat_stack.pop(); + let child_ctx = drive.state.popped_context.unwrap(); + drive.ctx_mut().has_matched = child_ctx.has_matched; + }, + ), SreOpcode::MAX_UNTIL => Box::new(OpMaxUntil::default()), - SreOpcode::MIN_UNTIL => Box::new(OpMinUntil::default()), - SreOpcode::REPEAT => Box::new(OpRepeat::default()), + SreOpcode::MIN_UNTIL => todo!("min until"), SreOpcode::REPEAT_ONE => Box::new(OpRepeatOne::default()), SreOpcode::MIN_REPEAT_ONE => Box::new(OpMinRepeatOne::default()), SreOpcode::GROUPREF => once(|drive| general_op_groupref(drive, |x| x)), @@ -872,90 +971,12 @@ fn is_utf8_first_byte(b: u8) -> bool { (b & 0b10000000 == 0) || (b & 0b11000000 == 0b11000000) } -struct OpAssert { - child_ctx_id: usize, - jump_id: usize, -} -impl Default for OpAssert { - fn default() -> Self { - Self { - child_ctx_id: 0, - jump_id: 0, - } - } -} -impl OpcodeExecutor for OpAssert { - fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - match self.jump_id { - 0 => self._0(drive), - 1 => self._1(drive), - _ => unreachable!(), - } - } -} -impl OpAssert { - fn _0(&mut self, drive: &mut StackDrive) -> Option<()> { - let back = drive.peek_code(2) as usize; - if back > drive.ctx().string_position { - drive.ctx_mut().has_matched = Some(false); - return None; - } - drive.state.string_position = drive.ctx().string_position - back; - self.child_ctx_id = drive.push_new_context(3); - self.jump_id = 1; - Some(()) - } - fn _1(&mut self, drive: &mut StackDrive) -> Option<()> { - if drive.state.context_stack[self.child_ctx_id].has_matched == Some(true) { - drive.skip_code(drive.peek_code(1) as usize + 1); - } else { - drive.ctx_mut().has_matched = Some(false); - } - None - } -} - -struct OpAssertNot { - child_ctx_id: usize, - jump_id: usize, -} -impl Default for OpAssertNot { - fn default() -> Self { - Self { - child_ctx_id: 0, - jump_id: 0, - } - } -} -impl OpcodeExecutor for OpAssertNot { - fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - match self.jump_id { - 0 => self._0(drive), - 1 => self._1(drive), - _ => unreachable!(), - } - } -} -impl OpAssertNot { - fn _0(&mut self, drive: &mut StackDrive) -> Option<()> { - let back = drive.peek_code(2) as usize; - if back > drive.ctx().string_position { - drive.skip_code(drive.peek_code(1) as usize + 1); - return None; - } - drive.state.string_position = drive.ctx().string_position - back; - self.child_ctx_id = drive.push_new_context(3); - self.jump_id = 1; - Some(()) - } - fn _1(&mut self, drive: &mut StackDrive) -> Option<()> { - if drive.state.context_stack[self.child_ctx_id].has_matched == Some(true) { - drive.ctx_mut().has_matched = Some(false); - } else { - drive.skip_code(drive.peek_code(1) as usize + 1); - } - None - } +#[derive(Debug, Copy, Clone)] +struct RepeatContext { + count: isize, + code_position: usize, + // zero-width match protection + last_position: usize, } struct OpMinRepeatOne { @@ -963,7 +984,6 @@ struct OpMinRepeatOne { mincount: usize, maxcount: usize, count: usize, - child_ctx_id: usize, } impl OpcodeExecutor for OpMinRepeatOne { fn next(&mut self, drive: &mut StackDrive) -> Option<()> { @@ -982,7 +1002,6 @@ impl Default for OpMinRepeatOne { mincount: 0, maxcount: 0, count: 0, - child_ctx_id: 0, } } } @@ -1023,7 +1042,7 @@ impl OpMinRepeatOne { fn _1(&mut self, drive: &mut StackDrive) -> Option<()> { if self.maxcount == MAXREPEAT || self.count <= self.maxcount { drive.state.string_position = drive.ctx().string_position; - self.child_ctx_id = drive.push_new_context(drive.peek_code(1) as usize + 1); + drive.push_new_context(drive.peek_code(1) as usize + 1); self.jump_id = 2; return Some(()); } @@ -1033,7 +1052,8 @@ impl OpMinRepeatOne { None } fn _2(&mut self, drive: &mut StackDrive) -> Option<()> { - if let Some(true) = drive.state.context_stack[self.child_ctx_id].has_matched { + let child_ctx = drive.state.popped_context.unwrap(); + if child_ctx.has_matched == Some(true) { drive.ctx_mut().has_matched = Some(true); return None; } @@ -1050,201 +1070,290 @@ impl OpMinRepeatOne { } } -#[derive(Debug, Copy, Clone)] -struct RepeatContext { - skip: usize, - mincount: usize, - maxcount: usize, - count: isize, - last_position: isize, -} - +// Everything is stored in RepeatContext struct OpMaxUntil { jump_id: usize, count: isize, - save_last_position: isize, - child_ctx_id: usize, + save_last_position: usize, } impl Default for OpMaxUntil { fn default() -> Self { - Self { + OpMaxUntil { jump_id: 0, count: 0, - save_last_position: -1, - child_ctx_id: 0, + save_last_position: 0, } } } impl OpcodeExecutor for OpMaxUntil { fn next(&mut self, drive: &mut StackDrive) -> Option<()> { match self.jump_id { - 0 => { - drive.state.string_position = drive.ctx().string_position; - let repeat = match drive.state.repeat_stack.last_mut() { - Some(repeat) => repeat, - None => { - todo!("Internal re error: MAX_UNTIL without REPEAT."); - } - }; - self.count = repeat.count + 1; - - if self.count < repeat.mincount as isize { - // not enough matches - repeat.count = self.count; - self.child_ctx_id = drive.push_new_context(4); - self.jump_id = 1; - return Some(()); - } - - if (self.count < repeat.maxcount as isize || repeat.maxcount == MAXREPEAT) - && (drive.state.string_position as isize != repeat.last_position) - { - // we may have enough matches, if we can match another item, do so - repeat.count = self.count; - self.save_last_position = repeat.last_position; - repeat.last_position = drive.state.string_position as isize; - drive.state.marks_push(); - self.child_ctx_id = drive.push_new_context(4); - self.jump_id = 2; - return Some(()); - } - - self.child_ctx_id = drive.push_new_context(1); - - self.jump_id = 3; - Some(()) - } - 1 => { - let child_ctx = &drive.state.context_stack[self.child_ctx_id]; - drive.ctx_mut().has_matched = child_ctx.has_matched; - if drive.ctx().has_matched != Some(true) { - drive.state.string_position = drive.ctx().string_position; - let repeat = drive.state.repeat_stack.last_mut().unwrap(); - repeat.count = self.count - 1; - } - None - } - 2 => { - let repeat = drive.state.repeat_stack.last_mut().unwrap(); - repeat.last_position = drive.state.string_position as isize; - let child_ctx = &drive.state.context_stack[self.child_ctx_id]; - if child_ctx.has_matched == Some(true) { - drive.state.marks_pop_discard(); - drive.ctx_mut().has_matched = Some(true); - return None; - } - repeat.count = self.count - 1; - drive.state.marks_pop(); - drive.state.string_position = drive.ctx().string_position; + 0 => self._0(drive), + 1 => self._1(drive), + 2 => self._2(drive), + 3 => self._3(drive), + 4 => self._4(drive), + _ => unreachable!(), + } + } +} +impl OpMaxUntil { + fn _0(&mut self, drive: &mut StackDrive) -> Option<()> { + let RepeatContext { + count, + code_position, + last_position, + } = *drive.repeat_ctx(); + drive.ctx_mut().code_position = code_position; + let mincount = drive.peek_code(2) as usize; + let maxcount = drive.peek_code(3) as usize; + self.count = count + 1; + + if (self.count as usize) < mincount { + // not enough matches + drive.repeat_ctx_mut().count = self.count; + drive.push_new_context(4); + self.jump_id = 1; + return Some(()); + } - self.child_ctx_id = drive.push_new_context(1); + if ((count as usize) < maxcount || maxcount == MAXREPEAT) + && drive.state.string_position != last_position + { + // we may have enough matches, if we can match another item, do so + drive.repeat_ctx_mut().count = self.count; + drive.state.marks_push(); + // self.save_last_position = last_position; + // drive.repeat_ctx_mut().last_position = drive.state.string_position; + drive.push_new_context(4); + self.jump_id = 2; + return Some(()); + } - self.jump_id = 3; - Some(()) - } - 3 => { - // cannot match more repeated items here. make sure the tail matches - let child_ctx = &drive.state.context_stack[self.child_ctx_id]; - drive.ctx_mut().has_matched = child_ctx.has_matched; - if drive.ctx().has_matched != Some(true) { - drive.state.string_position = drive.ctx().string_position; - } else { - drive.state.repeat_stack.pop(); - } - None - } - _ => unreachable!(), + self.jump_id = 3; + self.next(drive) + } + fn _1(&mut self, drive: &mut StackDrive) -> Option<()> { + let child_ctx = drive.state.popped_context.unwrap(); + drive.ctx_mut().has_matched = child_ctx.has_matched; + if drive.ctx().has_matched != Some(true) { + drive.repeat_ctx_mut().count = self.count - 1; + drive.state.string_position = drive.ctx().string_position; } + None + } + fn _2(&mut self, drive: &mut StackDrive) -> Option<()> { + // drive.repeat_ctx_mut().last_position = self.save_last_position; + let child_ctx = drive.state.popped_context.unwrap(); + if child_ctx.has_matched == Some(true) { + drive.state.marks_pop_discard(); + drive.ctx_mut().has_matched = Some(true); + return None; + } + drive.state.marks_pop(); + drive.repeat_ctx_mut().count = self.count - 1; + drive.state.string_position = drive.ctx().string_position; + self.jump_id = 3; + self.next(drive) + } + fn _3(&mut self, drive: &mut StackDrive) -> Option<()> { + // cannot match more repeated items here. make sure the tail matches + drive.skip_code(drive.peek_code(1) as usize + 1); + drive.push_new_context(1); + self.jump_id = 4; + Some(()) + } + fn _4(&mut self, drive: &mut StackDrive) -> Option<()> { + let child_ctx = drive.state.popped_context.unwrap(); + drive.ctx_mut().has_matched = child_ctx.has_matched; + if drive.ctx().has_matched != Some(true) { + drive.state.string_position = drive.ctx().string_position; + } + None } } +// struct OpMaxUntil { +// jump_id: usize, +// count: isize, +// save_last_position: isize, +// } +// impl Default for OpMaxUntil { +// fn default() -> Self { +// Self { +// jump_id: 0, +// count: 0, +// save_last_position: -1, +// } +// } +// } +// impl OpcodeExecutor for OpMaxUntil { +// fn next(&mut self, drive: &mut StackDrive) -> Option<()> { +// match self.jump_id { +// 0 => { +// drive.state.string_position = drive.ctx().string_position; +// let repeat = match drive.state.repeat_stack.last_mut() { +// Some(repeat) => repeat, +// None => { +// panic!("Internal re error: MAX_UNTIL without REPEAT."); +// } +// }; +// self.count = repeat.count + 1; + +// if self.count < repeat.mincount as isize { +// // not enough matches +// repeat.count = self.count; +// drive.push_new_context(4); +// self.jump_id = 1; +// return Some(()); +// } + +// if (self.count < repeat.maxcount as isize || repeat.maxcount == MAXREPEAT) +// && (drive.state.string_position as isize != repeat.last_position) +// { +// // we may have enough matches, if we can match another item, do so +// repeat.count = self.count; +// self.save_last_position = repeat.last_position; +// repeat.last_position = drive.state.string_position as isize; +// drive.state.marks_push(); +// drive.push_new_context(4); +// self.jump_id = 2; +// return Some(()); +// } + +// drive.push_new_context(1); + +// self.jump_id = 3; +// Some(()) +// } +// 1 => { +// let child_ctx = drive.state.popped_context.unwrap(); +// drive.ctx_mut().has_matched = child_ctx.has_matched; +// if drive.ctx().has_matched != Some(true) { +// drive.state.string_position = drive.ctx().string_position; +// let repeat = drive.state.repeat_stack.last_mut().unwrap(); +// repeat.count = self.count - 1; +// } +// None +// } +// 2 => { +// let repeat = drive.state.repeat_stack.last_mut().unwrap(); +// repeat.last_position = drive.state.string_position as isize; +// let child_ctx = drive.state.popped_context.unwrap(); +// if child_ctx.has_matched == Some(true) { +// drive.state.marks_pop_discard(); +// drive.ctx_mut().has_matched = Some(true); +// return None; +// } +// repeat.count = self.count - 1; +// drive.state.marks_pop(); +// drive.state.string_position = drive.ctx().string_position; + +// drive.push_new_context(1); + +// self.jump_id = 3; +// Some(()) +// } +// 3 => { +// // cannot match more repeated items here. make sure the tail matches +// let child_ctx = drive.state.popped_context.unwrap(); +// drive.ctx_mut().has_matched = child_ctx.has_matched; +// if drive.ctx().has_matched != Some(true) { +// drive.state.string_position = drive.ctx().string_position; +// } else { +// drive.state.repeat_stack.pop(); +// } +// None +// } +// _ => unreachable!(), +// } +// } +// } + struct OpMinUntil { jump_id: usize, count: isize, - child_ctx_id: usize, } impl Default for OpMinUntil { fn default() -> Self { Self { jump_id: 0, count: 0, - child_ctx_id: 0, } } } impl OpcodeExecutor for OpMinUntil { fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - match self.jump_id { - 0 => { - drive.state.string_position = drive.ctx().string_position; - let repeat = match drive.state.repeat_stack.last_mut() { - Some(repeat) => repeat, - None => { - todo!("Internal re error: MAX_UNTIL without REPEAT."); - } - }; - self.count = repeat.count + 1; - - if self.count < repeat.mincount as isize { - // not enough matches - repeat.count = self.count; - self.child_ctx_id = drive.push_new_context(4); - self.jump_id = 1; - return Some(()); - } - - // see if the tail matches - drive.state.marks_push(); - self.child_ctx_id = drive.push_new_context(1); - self.jump_id = 2; - Some(()) - } - 1 => { - let child_ctx = &drive.state.context_stack[self.child_ctx_id]; - drive.ctx_mut().has_matched = child_ctx.has_matched; - if drive.ctx().has_matched != Some(true) { - drive.state.string_position = drive.ctx().string_position; - let repeat = drive.state.repeat_stack.last_mut().unwrap(); - repeat.count = self.count - 1; - } - None - } - 2 => { - let child_ctx = &drive.state.context_stack[self.child_ctx_id]; - if child_ctx.has_matched == Some(true) { - drive.state.repeat_stack.pop(); - drive.ctx_mut().has_matched = Some(true); - return None; - } - drive.state.string_position = drive.ctx().string_position; - drive.state.marks_pop(); - - // match more until tail matches - let repeat = drive.state.repeat_stack.last_mut().unwrap(); - if self.count >= repeat.maxcount as isize && repeat.maxcount != MAXREPEAT { - drive.ctx_mut().has_matched = Some(false); - return None; - } - repeat.count = self.count; - self.child_ctx_id = drive.push_new_context(4); - self.jump_id = 1; - Some(()) - } - _ => unreachable!(), - } + None + // match self.jump_id { + // 0 => { + // drive.state.string_position = drive.ctx().string_position; + // let repeat = match drive.state.repeat_stack.last_mut() { + // Some(repeat) => repeat, + // None => { + // todo!("Internal re error: MAX_UNTIL without REPEAT."); + // } + // }; + // self.count = repeat.count + 1; + + // if self.count < repeat.mincount as isize { + // // not enough matches + // repeat.count = self.count; + // drive.push_new_context(4); + // self.jump_id = 1; + // return Some(()); + // } + + // // see if the tail matches + // drive.state.marks_push(); + // drive.push_new_context(1); + // self.jump_id = 2; + // Some(()) + // } + // 1 => { + // let child_ctx = drive.state.popped_context.unwrap(); + // drive.ctx_mut().has_matched = child_ctx.has_matched; + // if drive.ctx().has_matched != Some(true) { + // drive.state.string_position = drive.ctx().string_position; + // let repeat = drive.state.repeat_stack.last_mut().unwrap(); + // repeat.count = self.count - 1; + // } + // None + // } + // 2 => { + // let child_ctx = drive.state.popped_context.unwrap(); + // if child_ctx.has_matched == Some(true) { + // drive.state.repeat_stack.pop(); + // drive.ctx_mut().has_matched = Some(true); + // return None; + // } + // drive.state.string_position = drive.ctx().string_position; + // drive.state.marks_pop(); + + // // match more until tail matches + // let repeat = drive.state.repeat_stack.last_mut().unwrap(); + // if self.count >= repeat.maxcount as isize && repeat.maxcount != MAXREPEAT { + // drive.ctx_mut().has_matched = Some(false); + // return None; + // } + // repeat.count = self.count; + // drive.push_new_context(4); + // self.jump_id = 1; + // Some(()) + // } + // _ => unreachable!(), + // } } } struct OpBranch { jump_id: usize, - child_ctx_id: usize, current_branch_length: usize, } impl Default for OpBranch { fn default() -> Self { Self { jump_id: 0, - child_ctx_id: 0, current_branch_length: 0, } } @@ -1268,12 +1377,12 @@ impl OpcodeExecutor for OpBranch { return None; } drive.state.string_position = drive.ctx().string_position; - self.child_ctx_id = drive.push_new_context(1); + drive.push_new_context(1); self.jump_id = 2; Some(()) } 2 => { - let child_ctx = &drive.state.context_stack[self.child_ctx_id]; + let child_ctx = drive.state.popped_context.unwrap(); if child_ctx.has_matched == Some(true) { drive.ctx_mut().has_matched = Some(true); return None; @@ -1287,48 +1396,8 @@ impl OpcodeExecutor for OpBranch { } } -struct OpRepeat { - jump_id: usize, - child_ctx_id: usize, -} -impl Default for OpRepeat { - fn default() -> Self { - Self { - jump_id: 0, - child_ctx_id: 0, - } - } -} -impl OpcodeExecutor for OpRepeat { - fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - match self.jump_id { - 0 => { - let repeat = RepeatContext { - skip: drive.peek_code(1) as usize, - mincount: drive.peek_code(2) as usize, - maxcount: drive.peek_code(3) as usize, - count: -1, - last_position: -1, - }; - drive.state.repeat_stack.push(repeat); - drive.state.string_position = drive.ctx().string_position; - self.child_ctx_id = drive.push_new_context(drive.peek_code(1) as usize + 1); - self.jump_id = 1; - Some(()) - } - 1 => { - let child_ctx = &drive.state.context_stack[self.child_ctx_id]; - drive.ctx_mut().has_matched = child_ctx.has_matched; - None - } - _ => unreachable!(), - } - } -} - struct OpRepeatOne { jump_id: usize, - child_ctx_id: usize, mincount: usize, maxcount: usize, count: isize, @@ -1337,7 +1406,6 @@ impl Default for OpRepeatOne { fn default() -> Self { Self { jump_id: 0, - child_ctx_id: 0, mincount: 0, maxcount: 0, count: 0, @@ -1382,7 +1450,7 @@ impl OpcodeExecutor for OpRepeatOne { // General case: backtracking if self.count >= self.mincount as isize { drive.state.string_position = drive.ctx().string_position; - self.child_ctx_id = drive.push_new_context(drive.peek_code(1) as usize + 1); + drive.push_new_context(drive.peek_code(1) as usize + 1); self.jump_id = 2; return Some(()); } @@ -1392,7 +1460,7 @@ impl OpcodeExecutor for OpRepeatOne { None } 2 => { - let child_ctx = &drive.state.context_stack[self.child_ctx_id]; + let child_ctx = drive.state.popped_context.unwrap(); if child_ctx.has_matched == Some(true) { drive.ctx_mut().has_matched = Some(true); return None; From 8fba935bba46ade66107a3c7aabb65e4eae00dcb Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 1 Jan 2021 10:54:44 +0200 Subject: [PATCH 017/893] OpMaxUntil zero-width protection --- interp.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/interp.rs b/interp.rs index cb06ee0b8a..4a015187e9 100644 --- a/interp.rs +++ b/interp.rs @@ -1107,6 +1107,7 @@ impl OpMaxUntil { drive.ctx_mut().code_position = code_position; let mincount = drive.peek_code(2) as usize; let maxcount = drive.peek_code(3) as usize; + drive.state.string_position = drive.ctx().string_position; self.count = count + 1; if (self.count as usize) < mincount { @@ -1123,8 +1124,8 @@ impl OpMaxUntil { // we may have enough matches, if we can match another item, do so drive.repeat_ctx_mut().count = self.count; drive.state.marks_push(); - // self.save_last_position = last_position; - // drive.repeat_ctx_mut().last_position = drive.state.string_position; + self.save_last_position = last_position; + drive.repeat_ctx_mut().last_position = drive.state.string_position; drive.push_new_context(4); self.jump_id = 2; return Some(()); @@ -1143,7 +1144,7 @@ impl OpMaxUntil { None } fn _2(&mut self, drive: &mut StackDrive) -> Option<()> { - // drive.repeat_ctx_mut().last_position = self.save_last_position; + drive.repeat_ctx_mut().last_position = self.save_last_position; let child_ctx = drive.state.popped_context.unwrap(); if child_ctx.has_matched == Some(true) { drive.state.marks_pop_discard(); From af1a53cb0530f27804442f86d3e1e05a4abcacef Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 1 Jan 2021 11:34:40 +0200 Subject: [PATCH 018/893] impl Match.groups() --- interp.rs | 249 ++++++++++++++++-------------------------------------- 1 file changed, 72 insertions(+), 177 deletions(-) diff --git a/interp.rs b/interp.rs index 4a015187e9..ad0790a31b 100644 --- a/interp.rs +++ b/interp.rs @@ -1070,7 +1070,6 @@ impl OpMinRepeatOne { } } -// Everything is stored in RepeatContext struct OpMaxUntil { jump_id: usize, count: isize, @@ -1088,189 +1087,85 @@ impl Default for OpMaxUntil { impl OpcodeExecutor for OpMaxUntil { fn next(&mut self, drive: &mut StackDrive) -> Option<()> { match self.jump_id { - 0 => self._0(drive), - 1 => self._1(drive), - 2 => self._2(drive), - 3 => self._3(drive), - 4 => self._4(drive), - _ => unreachable!(), - } - } -} -impl OpMaxUntil { - fn _0(&mut self, drive: &mut StackDrive) -> Option<()> { - let RepeatContext { - count, - code_position, - last_position, - } = *drive.repeat_ctx(); - drive.ctx_mut().code_position = code_position; - let mincount = drive.peek_code(2) as usize; - let maxcount = drive.peek_code(3) as usize; - drive.state.string_position = drive.ctx().string_position; - self.count = count + 1; + 0 => { + let RepeatContext { + count, + code_position, + last_position, + } = *drive.repeat_ctx(); + drive.ctx_mut().code_position = code_position; + let mincount = drive.peek_code(2) as usize; + let maxcount = drive.peek_code(3) as usize; + drive.state.string_position = drive.ctx().string_position; + self.count = count + 1; - if (self.count as usize) < mincount { - // not enough matches - drive.repeat_ctx_mut().count = self.count; - drive.push_new_context(4); - self.jump_id = 1; - return Some(()); - } + if (self.count as usize) < mincount { + // not enough matches + drive.repeat_ctx_mut().count = self.count; + drive.push_new_context(4); + self.jump_id = 1; + return Some(()); + } - if ((count as usize) < maxcount || maxcount == MAXREPEAT) - && drive.state.string_position != last_position - { - // we may have enough matches, if we can match another item, do so - drive.repeat_ctx_mut().count = self.count; - drive.state.marks_push(); - self.save_last_position = last_position; - drive.repeat_ctx_mut().last_position = drive.state.string_position; - drive.push_new_context(4); - self.jump_id = 2; - return Some(()); - } + if ((count as usize) < maxcount || maxcount == MAXREPEAT) + && drive.state.string_position != last_position + { + // we may have enough matches, if we can match another item, do so + drive.repeat_ctx_mut().count = self.count; + drive.state.marks_push(); + self.save_last_position = last_position; + drive.repeat_ctx_mut().last_position = drive.state.string_position; + drive.push_new_context(4); + self.jump_id = 2; + return Some(()); + } - self.jump_id = 3; - self.next(drive) - } - fn _1(&mut self, drive: &mut StackDrive) -> Option<()> { - let child_ctx = drive.state.popped_context.unwrap(); - drive.ctx_mut().has_matched = child_ctx.has_matched; - if drive.ctx().has_matched != Some(true) { - drive.repeat_ctx_mut().count = self.count - 1; - drive.state.string_position = drive.ctx().string_position; - } - None - } - fn _2(&mut self, drive: &mut StackDrive) -> Option<()> { - drive.repeat_ctx_mut().last_position = self.save_last_position; - let child_ctx = drive.state.popped_context.unwrap(); - if child_ctx.has_matched == Some(true) { - drive.state.marks_pop_discard(); - drive.ctx_mut().has_matched = Some(true); - return None; - } - drive.state.marks_pop(); - drive.repeat_ctx_mut().count = self.count - 1; - drive.state.string_position = drive.ctx().string_position; - self.jump_id = 3; - self.next(drive) - } - fn _3(&mut self, drive: &mut StackDrive) -> Option<()> { - // cannot match more repeated items here. make sure the tail matches - drive.skip_code(drive.peek_code(1) as usize + 1); - drive.push_new_context(1); - self.jump_id = 4; - Some(()) - } - fn _4(&mut self, drive: &mut StackDrive) -> Option<()> { - let child_ctx = drive.state.popped_context.unwrap(); - drive.ctx_mut().has_matched = child_ctx.has_matched; - if drive.ctx().has_matched != Some(true) { - drive.state.string_position = drive.ctx().string_position; + self.jump_id = 3; + self.next(drive) + } + 1 => { + let child_ctx = drive.state.popped_context.unwrap(); + drive.ctx_mut().has_matched = child_ctx.has_matched; + if drive.ctx().has_matched != Some(true) { + drive.repeat_ctx_mut().count = self.count - 1; + drive.state.string_position = drive.ctx().string_position; + } + None + } + 2 => { + drive.repeat_ctx_mut().last_position = self.save_last_position; + let child_ctx = drive.state.popped_context.unwrap(); + if child_ctx.has_matched == Some(true) { + drive.state.marks_pop_discard(); + drive.ctx_mut().has_matched = Some(true); + return None; + } + drive.state.marks_pop(); + drive.repeat_ctx_mut().count = self.count - 1; + drive.state.string_position = drive.ctx().string_position; + self.jump_id = 3; + self.next(drive) + } + 3 => { + // cannot match more repeated items here. make sure the tail matches + drive.skip_code(drive.peek_code(1) as usize + 1); + drive.push_new_context(1); + self.jump_id = 4; + Some(()) + } + 4 => { + let child_ctx = drive.state.popped_context.unwrap(); + drive.ctx_mut().has_matched = child_ctx.has_matched; + if drive.ctx().has_matched != Some(true) { + drive.state.string_position = drive.ctx().string_position; + } + None + } + _ => unreachable!(), } - None } } -// struct OpMaxUntil { -// jump_id: usize, -// count: isize, -// save_last_position: isize, -// } -// impl Default for OpMaxUntil { -// fn default() -> Self { -// Self { -// jump_id: 0, -// count: 0, -// save_last_position: -1, -// } -// } -// } -// impl OpcodeExecutor for OpMaxUntil { -// fn next(&mut self, drive: &mut StackDrive) -> Option<()> { -// match self.jump_id { -// 0 => { -// drive.state.string_position = drive.ctx().string_position; -// let repeat = match drive.state.repeat_stack.last_mut() { -// Some(repeat) => repeat, -// None => { -// panic!("Internal re error: MAX_UNTIL without REPEAT."); -// } -// }; -// self.count = repeat.count + 1; - -// if self.count < repeat.mincount as isize { -// // not enough matches -// repeat.count = self.count; -// drive.push_new_context(4); -// self.jump_id = 1; -// return Some(()); -// } - -// if (self.count < repeat.maxcount as isize || repeat.maxcount == MAXREPEAT) -// && (drive.state.string_position as isize != repeat.last_position) -// { -// // we may have enough matches, if we can match another item, do so -// repeat.count = self.count; -// self.save_last_position = repeat.last_position; -// repeat.last_position = drive.state.string_position as isize; -// drive.state.marks_push(); -// drive.push_new_context(4); -// self.jump_id = 2; -// return Some(()); -// } - -// drive.push_new_context(1); - -// self.jump_id = 3; -// Some(()) -// } -// 1 => { -// let child_ctx = drive.state.popped_context.unwrap(); -// drive.ctx_mut().has_matched = child_ctx.has_matched; -// if drive.ctx().has_matched != Some(true) { -// drive.state.string_position = drive.ctx().string_position; -// let repeat = drive.state.repeat_stack.last_mut().unwrap(); -// repeat.count = self.count - 1; -// } -// None -// } -// 2 => { -// let repeat = drive.state.repeat_stack.last_mut().unwrap(); -// repeat.last_position = drive.state.string_position as isize; -// let child_ctx = drive.state.popped_context.unwrap(); -// if child_ctx.has_matched == Some(true) { -// drive.state.marks_pop_discard(); -// drive.ctx_mut().has_matched = Some(true); -// return None; -// } -// repeat.count = self.count - 1; -// drive.state.marks_pop(); -// drive.state.string_position = drive.ctx().string_position; - -// drive.push_new_context(1); - -// self.jump_id = 3; -// Some(()) -// } -// 3 => { -// // cannot match more repeated items here. make sure the tail matches -// let child_ctx = drive.state.popped_context.unwrap(); -// drive.ctx_mut().has_matched = child_ctx.has_matched; -// if drive.ctx().has_matched != Some(true) { -// drive.state.string_position = drive.ctx().string_position; -// } else { -// drive.state.repeat_stack.pop(); -// } -// None -// } -// _ => unreachable!(), -// } -// } -// } - struct OpMinUntil { jump_id: usize, count: isize, From db84f329816d812349b49fa843b4d0480674aba2 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 1 Jan 2021 13:10:19 +0200 Subject: [PATCH 019/893] fix Opcode::CHARSET --- interp.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/interp.rs b/interp.rs index ad0790a31b..289dbc2ec9 100644 --- a/interp.rs +++ b/interp.rs @@ -741,7 +741,8 @@ fn charset(set: &[u32], c: char) -> bool { } SreOpcode::CHARSET => { /* */ - if ch < 256 && (set[(ch / 32) as usize] & (1 << (32 - 1))) != 0 { + let set = &set[1..]; + if ch < 256 && ((set[(ch >> 5) as usize] & (1u32 << (ch & 31))) != 0) { return ok; } i += 8; From 36433a9f4d026df404e419d604e67295fa4db758 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 1 Jan 2021 20:13:06 +0200 Subject: [PATCH 020/893] fix Opcode::BIGCHARSET --- interp.rs | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/interp.rs b/interp.rs index 289dbc2ec9..ca5efd9e4f 100644 --- a/interp.rs +++ b/interp.rs @@ -497,22 +497,16 @@ impl OpcodeDispatcher { } }), SreOpcode::IN => once(|drive| { - general_op_in(drive, |x| x); + general_op_in(drive, |set, c| charset(set, c)); }), SreOpcode::IN_IGNORE => once(|drive| { - general_op_in(drive, lower_ascii); + general_op_in(drive, |set, c| charset(set, lower_ascii(c))); }), SreOpcode::IN_UNI_IGNORE => once(|drive| { - general_op_in(drive, lower_unicode); + general_op_in(drive, |set, c| charset(set, lower_unicode(c))); }), SreOpcode::IN_LOC_IGNORE => once(|drive| { - let skip = drive.peek_code(1) as usize; - if drive.at_end() || !charset_loc_ignore(&drive.pattern()[2..], drive.peek_char()) { - drive.ctx_mut().has_matched = Some(false); - } else { - drive.skip_code(skip + 1); - drive.skip_char(1); - } + general_op_in(drive, |set, c| charset_loc_ignore(set, c)); }), SreOpcode::INFO | SreOpcode::JUMP => once(|drive| { drive.skip_code(drive.peek_code(1) as usize + 1); @@ -661,9 +655,9 @@ fn general_op_literal bool>(drive: &mut StackDrive, f: F } } -fn general_op_in char>(drive: &mut StackDrive, f: F) { +fn general_op_in bool>(drive: &mut StackDrive, f: F) { let skip = drive.peek_code(1) as usize; - if drive.at_end() || !charset(&drive.pattern()[2..], f(drive.peek_char())) { + if drive.at_end() || !f(&drive.pattern()[2..], drive.peek_char()) { drive.ctx_mut().has_matched = Some(false); } else { drive.skip_code(skip + 1); @@ -749,18 +743,20 @@ fn charset(set: &[u32], c: char) -> bool { } SreOpcode::BIGCHARSET => { /* <256 blockindices> */ - let count = set[i + 1]; + let count = set[i + 1] as usize; if ch < 0x10000 { - let (_, blockindices, _) = unsafe { set[i + 2..].align_to::() }; - let block = blockindices[(ch >> 8) as usize]; - if set[2 + 64 + ((block as u32 * 256 + (ch & 255)) / 32) as usize] - & (1 << (ch & (32 - 1))) + let set = &set[2..]; + let block_index = ch >> 8; + let (_, blockindices, _) = unsafe { set.align_to::() }; + let blocks = &set[64..]; + let block = blockindices[block_index as usize]; + if blocks[((block as u32 * 256 + (ch & 255)) / 32) as usize] & (1u32 << (ch & 31)) != 0 { return ok; } } - i += 2 + 64 + count as usize * 8; + i += 2 + 64 + count * 8; } SreOpcode::LITERAL => { /* */ From f05f6cb44df000ad7cffcd79b6447c97756d15dd Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 3 Jan 2021 19:44:40 +0200 Subject: [PATCH 021/893] impl Pattern.sub --- interp.rs | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/interp.rs b/interp.rs index ca5efd9e4f..98844ee9ca 100644 --- a/interp.rs +++ b/interp.rs @@ -142,6 +142,27 @@ pub(crate) fn pymatch( } } +pub(crate) fn search( + string: PyStrRef, + start: usize, + end: usize, + pattern: PyRef, +) -> Option { + // TODO: optimize by op info and skip prefix + let end = std::cmp::min(end, string.char_len()); + for i in start..end { + if let Some(m) = pymatch( + string.clone(), + i, + end, + pattern.clone(), + ) { + return Some(m); + } + } + None +} + #[derive(Debug, Copy, Clone)] struct MatchContext { string_position: usize, @@ -750,7 +771,8 @@ fn charset(set: &[u32], c: char) -> bool { let (_, blockindices, _) = unsafe { set.align_to::() }; let blocks = &set[64..]; let block = blockindices[block_index as usize]; - if blocks[((block as u32 * 256 + (ch & 255)) / 32) as usize] & (1u32 << (ch & 31)) + if blocks[((block as u32 * 256 + (ch & 255)) / 32) as usize] + & (1u32 << (ch & 31)) != 0 { return ok; From 817eb66167810a163eb889ce2626418b51bd6afb Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 4 Jan 2021 16:01:17 +0200 Subject: [PATCH 022/893] fix OpMinUntil --- interp.rs | 100 +++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 80 insertions(+), 20 deletions(-) diff --git a/interp.rs b/interp.rs index 98844ee9ca..a3263eddce 100644 --- a/interp.rs +++ b/interp.rs @@ -151,12 +151,7 @@ pub(crate) fn search( // TODO: optimize by op info and skip prefix let end = std::cmp::min(end, string.char_len()); for i in start..end { - if let Some(m) = pymatch( - string.clone(), - i, - end, - pattern.clone(), - ) { + if let Some(m) = pymatch(string.clone(), i, end, pattern.clone()) { return Some(m); } } @@ -294,12 +289,13 @@ impl<'a> StackDrive<'a> { self.state } fn push_new_context(&mut self, pattern_offset: usize) { - let ctx = self.ctx(); - let mut child_ctx = MatchContext { ..*ctx }; + let mut child_ctx = MatchContext { ..*self.ctx() }; child_ctx.code_position += pattern_offset; - if child_ctx.code_position > self.state.pattern_codes.len() { - child_ctx.code_position = self.state.pattern_codes.len(); - } + self.state.context_stack.push(child_ctx); + } + fn push_new_context_at(&mut self, code_position: usize) { + let mut child_ctx = MatchContext { ..*self.ctx() }; + child_ctx.code_position = code_position; self.state.context_stack.push(child_ctx); } fn repeat_ctx_mut(&mut self) -> &mut RepeatContext { @@ -571,6 +567,8 @@ impl OpcodeDispatcher { count: -1, code_position: drive.ctx().code_position, last_position: std::usize::MAX, + mincount: drive.peek_code(2) as usize, + maxcount: drive.peek_code(3) as usize, }; drive.state.repeat_stack.push(repeat); drive.state.string_position = drive.ctx().string_position; @@ -584,7 +582,7 @@ impl OpcodeDispatcher { }, ), SreOpcode::MAX_UNTIL => Box::new(OpMaxUntil::default()), - SreOpcode::MIN_UNTIL => todo!("min until"), + SreOpcode::MIN_UNTIL => Box::new(OpMinUntil::default()), SreOpcode::REPEAT_ONE => Box::new(OpRepeatOne::default()), SreOpcode::MIN_REPEAT_ONE => Box::new(OpMinRepeatOne::default()), SreOpcode::GROUPREF => once(|drive| general_op_groupref(drive, |x| x)), @@ -996,6 +994,8 @@ struct RepeatContext { code_position: usize, // zero-width match protection last_position: usize, + mincount: usize, + maxcount: usize, } struct OpMinRepeatOne { @@ -1111,22 +1111,22 @@ impl OpcodeExecutor for OpMaxUntil { count, code_position, last_position, + mincount, + maxcount, } = *drive.repeat_ctx(); - drive.ctx_mut().code_position = code_position; - let mincount = drive.peek_code(2) as usize; - let maxcount = drive.peek_code(3) as usize; + drive.state.string_position = drive.ctx().string_position; self.count = count + 1; if (self.count as usize) < mincount { // not enough matches drive.repeat_ctx_mut().count = self.count; - drive.push_new_context(4); + drive.push_new_context_at(code_position + 4); self.jump_id = 1; return Some(()); } - if ((count as usize) < maxcount || maxcount == MAXREPEAT) + if ((self.count as usize) < maxcount || maxcount == MAXREPEAT) && drive.state.string_position != last_position { // we may have enough matches, if we can match another item, do so @@ -1134,7 +1134,7 @@ impl OpcodeExecutor for OpMaxUntil { drive.state.marks_push(); self.save_last_position = last_position; drive.repeat_ctx_mut().last_position = drive.state.string_position; - drive.push_new_context(4); + drive.push_new_context_at(code_position + 4); self.jump_id = 2; return Some(()); } @@ -1167,7 +1167,6 @@ impl OpcodeExecutor for OpMaxUntil { } 3 => { // cannot match more repeated items here. make sure the tail matches - drive.skip_code(drive.peek_code(1) as usize + 1); drive.push_new_context(1); self.jump_id = 4; Some(()) @@ -1188,18 +1187,79 @@ impl OpcodeExecutor for OpMaxUntil { struct OpMinUntil { jump_id: usize, count: isize, + save_repeat: Option, } impl Default for OpMinUntil { fn default() -> Self { Self { jump_id: 0, count: 0, + save_repeat: None, } } } impl OpcodeExecutor for OpMinUntil { fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - None + match self.jump_id { + 0 => { + let RepeatContext { + count, + code_position, + last_position: _, + mincount, + maxcount: _, + } = *drive.repeat_ctx(); + drive.state.string_position = drive.ctx().string_position; + self.count = count + 1; + + if (self.count as usize) < mincount { + // not enough matches + drive.repeat_ctx_mut().count = self.count; + drive.push_new_context_at(code_position + 4); + self.jump_id = 1; + return Some(()); + } + + // see if the tail matches + drive.state.marks_push(); + self.save_repeat = drive.state.repeat_stack.pop(); + drive.push_new_context(1); + self.jump_id = 2; + Some(()) + } + 1 => { + let child_ctx = drive.state.popped_context.unwrap(); + drive.ctx_mut().has_matched = child_ctx.has_matched; + if drive.ctx().has_matched != Some(true) { + drive.repeat_ctx_mut().count = self.count - 1; + drive.state.string_position = drive.ctx().string_position; + } + None + } + 2 => { + let child_ctx = drive.state.popped_context.unwrap(); + if child_ctx.has_matched == Some(true) { + drive.ctx_mut().has_matched = Some(true); + return None; + } + drive.state.repeat_stack.push(self.save_repeat.unwrap()); + drive.state.string_position = drive.ctx().string_position; + drive.state.marks_pop(); + + // match more unital tail matches + let maxcount = drive.repeat_ctx().maxcount; + let code_position = drive.repeat_ctx().code_position; + if self.count as usize >= maxcount && maxcount != MAXREPEAT { + drive.ctx_mut().has_matched = Some(false); + return None; + } + drive.repeat_ctx_mut().count = self.count; + drive.push_new_context_at(code_position + 4); + self.jump_id = 1; + Some(()) + } + _ => unreachable!(), + } // match self.jump_id { // 0 => { // drive.state.string_position = drive.ctx().string_position; From 33ef82364516422776beaa498db8d96a167c7b95 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 6 Jan 2021 10:30:57 +0200 Subject: [PATCH 023/893] impl Match.groupdict --- interp.rs | 58 ------------------------------------------------------- 1 file changed, 58 deletions(-) diff --git a/interp.rs b/interp.rs index a3263eddce..c118984b38 100644 --- a/interp.rs +++ b/interp.rs @@ -1260,64 +1260,6 @@ impl OpcodeExecutor for OpMinUntil { } _ => unreachable!(), } - // match self.jump_id { - // 0 => { - // drive.state.string_position = drive.ctx().string_position; - // let repeat = match drive.state.repeat_stack.last_mut() { - // Some(repeat) => repeat, - // None => { - // todo!("Internal re error: MAX_UNTIL without REPEAT."); - // } - // }; - // self.count = repeat.count + 1; - - // if self.count < repeat.mincount as isize { - // // not enough matches - // repeat.count = self.count; - // drive.push_new_context(4); - // self.jump_id = 1; - // return Some(()); - // } - - // // see if the tail matches - // drive.state.marks_push(); - // drive.push_new_context(1); - // self.jump_id = 2; - // Some(()) - // } - // 1 => { - // let child_ctx = drive.state.popped_context.unwrap(); - // drive.ctx_mut().has_matched = child_ctx.has_matched; - // if drive.ctx().has_matched != Some(true) { - // drive.state.string_position = drive.ctx().string_position; - // let repeat = drive.state.repeat_stack.last_mut().unwrap(); - // repeat.count = self.count - 1; - // } - // None - // } - // 2 => { - // let child_ctx = drive.state.popped_context.unwrap(); - // if child_ctx.has_matched == Some(true) { - // drive.state.repeat_stack.pop(); - // drive.ctx_mut().has_matched = Some(true); - // return None; - // } - // drive.state.string_position = drive.ctx().string_position; - // drive.state.marks_pop(); - - // // match more until tail matches - // let repeat = drive.state.repeat_stack.last_mut().unwrap(); - // if self.count >= repeat.maxcount as isize && repeat.maxcount != MAXREPEAT { - // drive.ctx_mut().has_matched = Some(false); - // return None; - // } - // repeat.count = self.count; - // drive.push_new_context(4); - // self.jump_id = 1; - // Some(()) - // } - // _ => unreachable!(), - // } } } From 76c95abbb5ecd91e05f216adefdd437b0f87b79b Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 7 Jan 2021 17:37:18 +0200 Subject: [PATCH 024/893] impl Match.lastgroup --- interp.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/interp.rs b/interp.rs index c118984b38..febbfa2437 100644 --- a/interp.rs +++ b/interp.rs @@ -150,7 +150,7 @@ pub(crate) fn search( ) -> Option { // TODO: optimize by op info and skip prefix let end = std::cmp::min(end, string.char_len()); - for i in start..end { + for i in start..end + 1 { if let Some(m) = pymatch(string.clone(), i, end, pattern.clone()) { return Some(m); } @@ -1382,6 +1382,13 @@ impl OpcodeExecutor for OpRepeatOne { drive.ctx_mut().has_matched = Some(true); return None; } + if self.count <= self.mincount as isize { + drive.state.marks_pop_discard(); + drive.ctx_mut().has_matched = Some(false); + return None; + } + + // TODO: unnesscary double check drive.back_skip_char(1); self.count -= 1; drive.state.marks_pop_keep(); From 13a8b6cc4e38927273774ab96565f2184840f900 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 17 Jan 2021 19:55:20 +0200 Subject: [PATCH 025/893] add bytes support and refactor --- interp.rs | 457 +++++++++++++++++++++++++++++++----------------------- 1 file changed, 266 insertions(+), 191 deletions(-) diff --git a/interp.rs b/interp.rs index febbfa2437..7c9246142a 100644 --- a/interp.rs +++ b/interp.rs @@ -1,18 +1,18 @@ // good luck to those that follow; here be dragons -use super::_sre::{Match, Pattern, MAXREPEAT}; +use super::_sre::MAXREPEAT; use super::constants::{SreAtCode, SreCatCode, SreFlag, SreOpcode}; -use crate::builtins::PyStrRef; -use crate::pyobject::PyRef; -use rustpython_common::borrow::BorrowValue; +use crate::builtins::PyBytes; +use crate::bytesinner::is_py_ascii_whitespace; +use crate::pyobject::{IntoPyObject, PyObjectRef}; +use crate::VirtualMachine; use std::collections::HashMap; use std::convert::TryFrom; +use std::unreachable; #[derive(Debug)] pub(crate) struct State<'a> { - string: &'a str, - // chars count - string_len: usize, + pub string: StrDrive<'a>, pub start: usize, pub end: usize, flags: SreFlag, @@ -24,22 +24,21 @@ pub(crate) struct State<'a> { repeat_stack: Vec, pub string_position: usize, popped_context: Option, + pub has_matched: Option, } impl<'a> State<'a> { pub(crate) fn new( - string: &'a str, + string: StrDrive<'a>, start: usize, end: usize, flags: SreFlag, pattern_codes: &'a [u32], ) -> Self { - let string_len = string.chars().count(); - let end = std::cmp::min(end, string_len); + let end = std::cmp::min(end, string.count()); let start = std::cmp::min(start, end); Self { string, - string_len, start, end, flags, @@ -51,16 +50,19 @@ impl<'a> State<'a> { marks: Vec::new(), string_position: start, popped_context: None, + has_matched: None, } } - fn reset(&mut self) { - self.marks.clear(); + pub fn reset(&mut self) { self.lastindex = -1; self.marks_stack.clear(); self.context_stack.clear(); self.repeat_stack.clear(); + self.marks.clear(); + self.string_position = self.start; self.popped_context = None; + self.has_matched = None; } fn set_mark(&mut self, mark_nr: usize, position: usize) { @@ -96,66 +98,133 @@ impl<'a> State<'a> { fn marks_pop_discard(&mut self) { self.marks_stack.pop(); } + + pub fn pymatch(mut self) -> Self { + let ctx = MatchContext { + string_position: self.start, + string_offset: self.string.offset(0, self.start), + code_position: 0, + has_matched: None, + }; + self.context_stack.push(ctx); + + let mut dispatcher = OpcodeDispatcher::new(); + let mut has_matched = None; + + loop { + if self.context_stack.is_empty() { + break; + } + let ctx_id = self.context_stack.len() - 1; + let mut drive = StackDrive::drive(ctx_id, self); + + has_matched = dispatcher.pymatch(&mut drive); + self = drive.take(); + if has_matched.is_some() { + self.popped_context = self.context_stack.pop(); + } + } + + self.has_matched = has_matched; + self + } + + pub fn search(mut self) -> Self { + // TODO: optimize by op info and skip prefix + loop { + self = self.pymatch(); + + if self.has_matched == Some(true) { + return self; + } + self.start += 1; + if self.start > self.end { + return self; + } + self.reset(); + } + } } -pub(crate) fn pymatch( - string: PyStrRef, - start: usize, - end: usize, - pattern: PyRef, -) -> Option { - let mut state = State::new( - string.borrow_value(), - start, - end, - pattern.flags, - &pattern.code, - ); - let ctx = MatchContext { - string_position: state.start, - string_offset: calc_string_offset(state.string, state.start), - code_position: 0, - has_matched: None, - }; - state.context_stack.push(ctx); - let mut dispatcher = OpcodeDispatcher::new(); +#[derive(Debug, Clone, Copy)] +pub(crate) enum StrDrive<'a> { + Str(&'a str), + Bytes(&'a [u8]), +} +impl<'a> StrDrive<'a> { + fn offset(&self, offset: usize, skip: usize) -> usize { + match *self { + StrDrive::Str(s) => s + .get(offset..) + .and_then(|s| s.char_indices().nth(skip).map(|x| x.0 + offset)) + .unwrap_or_else(|| s.len()), + StrDrive::Bytes(b) => std::cmp::min(offset + skip, b.len()), + } + } - let mut has_matched = None; - loop { - if state.context_stack.is_empty() { - break; + pub fn count(&self) -> usize { + match *self { + StrDrive::Str(s) => s.chars().count(), + StrDrive::Bytes(b) => b.len(), } - let ctx_id = state.context_stack.len() - 1; - let mut drive = StackDrive::drive(ctx_id, state); + } - has_matched = dispatcher.pymatch(&mut drive); - state = drive.take(); - if has_matched.is_some() { - state.popped_context = state.context_stack.pop(); + fn peek(&self, offset: usize) -> u32 { + match *self { + StrDrive::Str(s) => unsafe { s.get_unchecked(offset..) }.chars().next().unwrap() as u32, + StrDrive::Bytes(b) => b[offset] as u32, } } - if has_matched != Some(true) { - None - } else { - Some(Match::new(&state, pattern.clone(), string.clone())) + fn back_peek(&self, offset: usize) -> u32 { + match *self { + StrDrive::Str(s) => { + let bytes = s.as_bytes(); + let back_offset = utf8_back_peek_offset(bytes, offset); + match offset - back_offset { + 1 => u32::from_ne_bytes([0, 0, 0, bytes[offset]]), + 2 => u32::from_ne_bytes([0, 0, bytes[offset], bytes[offset + 1]]), + 3 => { + u32::from_ne_bytes([0, bytes[offset], bytes[offset + 1], bytes[offset + 2]]) + } + 4 => u32::from_ne_bytes([ + bytes[offset], + bytes[offset + 1], + bytes[offset + 2], + bytes[offset + 3], + ]), + _ => unreachable!(), + } + } + StrDrive::Bytes(b) => b[offset - 1] as u32, + } } -} -pub(crate) fn search( - string: PyStrRef, - start: usize, - end: usize, - pattern: PyRef, -) -> Option { - // TODO: optimize by op info and skip prefix - let end = std::cmp::min(end, string.char_len()); - for i in start..end + 1 { - if let Some(m) = pymatch(string.clone(), i, end, pattern.clone()) { - return Some(m); + fn back_offset(&self, offset: usize, skip: usize) -> usize { + match *self { + StrDrive::Str(s) => { + let bytes = s.as_bytes(); + let mut back_offset = offset; + for _ in 0..skip { + back_offset = utf8_back_peek_offset(bytes, back_offset); + } + back_offset + } + StrDrive::Bytes(_) => offset - skip, + } + } + + pub fn slice_to_pyobject(&self, start: usize, end: usize, vm: &VirtualMachine) -> PyObjectRef { + match *self { + StrDrive::Str(s) => s + .chars() + .take(end) + .skip(start) + .collect::() + .into_pyobject(vm), + StrDrive::Bytes(b) => PyBytes::from(b[start..end].to_vec()).into_pyobject(vm), } } - None } #[derive(Debug, Copy, Clone)] @@ -173,33 +242,21 @@ trait MatchContextDrive { fn repeat_ctx(&self) -> &RepeatContext { self.state().repeat_stack.last().unwrap() } - fn str(&self) -> &str { - unsafe { - std::str::from_utf8_unchecked( - &self.state().string.as_bytes()[self.ctx().string_offset..], - ) - } - } fn pattern(&self) -> &[u32] { &self.state().pattern_codes[self.ctx().code_position..] } - fn peek_char(&self) -> char { - self.str().chars().next().unwrap() + fn peek_char(&self) -> u32 { + self.state().string.peek(self.ctx().string_offset) } fn peek_code(&self, peek: usize) -> u32 { self.state().pattern_codes[self.ctx().code_position + peek] } fn skip_char(&mut self, skip_count: usize) { - match self.str().char_indices().nth(skip_count).map(|x| x.0) { - Some(skipped) => { - self.ctx_mut().string_position += skip_count; - self.ctx_mut().string_offset += skipped; - } - None => { - self.ctx_mut().string_position = self.state().end; - self.ctx_mut().string_offset = self.state().string.len(); // bytes len - } - } + self.ctx_mut().string_offset = self + .state() + .string + .offset(self.ctx().string_offset, skip_count); + self.ctx_mut().string_position += skip_count; } fn skip_code(&mut self, skip_count: usize) { self.ctx_mut().code_position += skip_count; @@ -222,7 +279,7 @@ trait MatchContextDrive { fn at_linebreak(&self) -> bool { !self.at_end() && is_linebreak(self.peek_char()) } - fn at_boundary bool>(&self, mut word_checker: F) -> bool { + fn at_boundary bool>(&self, mut word_checker: F) -> bool { if self.at_beginning() && self.at_end() { return false; } @@ -230,47 +287,15 @@ trait MatchContextDrive { let this = !self.at_end() && word_checker(self.peek_char()); this != that } - fn back_peek_offset(&self) -> usize { - let bytes = self.state().string.as_bytes(); - let mut offset = self.ctx().string_offset - 1; - if !is_utf8_first_byte(bytes[offset]) { - offset -= 1; - if !is_utf8_first_byte(bytes[offset]) { - offset -= 1; - if !is_utf8_first_byte(bytes[offset]) { - offset -= 1; - if !is_utf8_first_byte(bytes[offset]) { - panic!("not utf-8 code point"); - } - } - } - } - offset - } - fn back_peek_char(&self) -> char { - let bytes = self.state().string.as_bytes(); - let offset = self.back_peek_offset(); - let current_offset = self.ctx().string_offset; - let code = match current_offset - offset { - 1 => u32::from_ne_bytes([0, 0, 0, bytes[offset]]), - 2 => u32::from_ne_bytes([0, 0, bytes[offset], bytes[offset + 1]]), - 3 => u32::from_ne_bytes([0, bytes[offset], bytes[offset + 1], bytes[offset + 2]]), - 4 => u32::from_ne_bytes([ - bytes[offset], - bytes[offset + 1], - bytes[offset + 2], - bytes[offset + 3], - ]), - _ => unreachable!(), - }; - // TODO: char::from_u32_unchecked is stable from 1.5.0 - unsafe { std::mem::transmute(code) } + fn back_peek_char(&self) -> u32 { + self.state().string.back_peek(self.ctx().string_offset) } fn back_skip_char(&mut self, skip_count: usize) { self.ctx_mut().string_position -= skip_count; - for _ in 0..skip_count { - self.ctx_mut().string_offset = self.back_peek_offset(); - } + self.ctx_mut().string_offset = self + .state() + .string + .back_offset(self.ctx().string_offset, skip_count); } } @@ -529,22 +554,22 @@ impl OpcodeDispatcher { drive.skip_code(drive.peek_code(1) as usize + 1); }), SreOpcode::LITERAL => once(|drive| { - general_op_literal(drive, |code, c| code == c as u32); + general_op_literal(drive, |code, c| code == c); }), SreOpcode::NOT_LITERAL => once(|drive| { - general_op_literal(drive, |code, c| code != c as u32); + general_op_literal(drive, |code, c| code != c); }), SreOpcode::LITERAL_IGNORE => once(|drive| { - general_op_literal(drive, |code, c| code == lower_ascii(c) as u32); + general_op_literal(drive, |code, c| code == lower_ascii(c)); }), SreOpcode::NOT_LITERAL_IGNORE => once(|drive| { - general_op_literal(drive, |code, c| code != lower_ascii(c) as u32); + general_op_literal(drive, |code, c| code != lower_ascii(c)); }), SreOpcode::LITERAL_UNI_IGNORE => once(|drive| { - general_op_literal(drive, |code, c| code == lower_unicode(c) as u32); + general_op_literal(drive, |code, c| code == lower_unicode(c)); }), SreOpcode::NOT_LITERAL_UNI_IGNORE => once(|drive| { - general_op_literal(drive, |code, c| code != lower_unicode(c) as u32); + general_op_literal(drive, |code, c| code != lower_unicode(c)); }), SreOpcode::LITERAL_LOC_IGNORE => once(|drive| { general_op_literal(drive, char_loc_ignore); @@ -610,19 +635,11 @@ impl OpcodeDispatcher { } } -fn calc_string_offset(string: &str, position: usize) -> usize { - string - .char_indices() - .nth(position) - .map(|(i, _)| i) - .unwrap_or(0) -} - -fn char_loc_ignore(code: u32, c: char) -> bool { - code == c as u32 || code == lower_locate(c) as u32 || code == upper_locate(c) as u32 +fn char_loc_ignore(code: u32, c: u32) -> bool { + code == c || code == lower_locate(c) || code == upper_locate(c) } -fn charset_loc_ignore(set: &[u32], c: char) -> bool { +fn charset_loc_ignore(set: &[u32], c: u32) -> bool { let lo = lower_locate(c); if charset(set, c) { return true; @@ -631,7 +648,7 @@ fn charset_loc_ignore(set: &[u32], c: char) -> bool { up != lo && charset(set, up) } -fn general_op_groupref char>(drive: &mut StackDrive, mut f: F) { +fn general_op_groupref u32>(drive: &mut StackDrive, mut f: F) { let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); let (group_start, group_end) = match (group_start, group_end) { (Some(start), Some(end)) if start <= end => (start, end), @@ -645,7 +662,7 @@ fn general_op_groupref char>(drive: &mut StackDrive, mut f: F) MatchContext { string_position: group_start, // TODO: cache the offset - string_offset: calc_string_offset(drive.state.string, group_start), + string_offset: drive.state.string.offset(0, group_start), ..*drive.ctx() }, &drive, @@ -665,7 +682,7 @@ fn general_op_groupref char>(drive: &mut StackDrive, mut f: F) drive.ctx_mut().string_offset = offset; } -fn general_op_literal bool>(drive: &mut StackDrive, f: F) { +fn general_op_literal bool>(drive: &mut StackDrive, f: F) { if drive.at_end() || !f(drive.peek_code(1), drive.peek_char()) { drive.ctx_mut().has_matched = Some(false); } else { @@ -674,7 +691,7 @@ fn general_op_literal bool>(drive: &mut StackDrive, f: F } } -fn general_op_in bool>(drive: &mut StackDrive, f: F) { +fn general_op_in bool>(drive: &mut StackDrive, f: F) { let skip = drive.peek_code(1) as usize; if drive.at_end() || !f(&drive.pattern()[2..], drive.peek_char()) { drive.ctx_mut().has_matched = Some(false); @@ -700,7 +717,7 @@ fn at(drive: &StackDrive, atcode: SreAtCode) -> bool { } } -fn category(catcode: SreCatCode, c: char) -> bool { +fn category(catcode: SreCatCode, c: u32) -> bool { match catcode { SreCatCode::DIGIT => is_digit(c), SreCatCode::NOT_DIGIT => !is_digit(c), @@ -723,9 +740,8 @@ fn category(catcode: SreCatCode, c: char) -> bool { } } -fn charset(set: &[u32], c: char) -> bool { +fn charset(set: &[u32], ch: u32) -> bool { /* check if character is a member of the given set */ - let ch = c as u32; let mut ok = true; let mut i = 0; while i < set.len() { @@ -747,7 +763,7 @@ fn charset(set: &[u32], c: char) -> bool { break; } }; - if category(catcode, c) { + if category(catcode, ch) { return ok; } i += 2; @@ -801,7 +817,7 @@ fn charset(set: &[u32], c: char) -> bool { if set[i + 1] <= ch && ch <= set[i + 2] { return ok; } - let ch = upper_unicode(c) as u32; + let ch = upper_unicode(ch); if set[i + 1] <= ch && ch <= set[i + 2] { return ok; } @@ -898,86 +914,128 @@ fn _count(stack_drive: &StackDrive, maxcount: usize) -> usize { drive.ctx().string_position - drive.state().string_position } -fn general_count_literal bool>(drive: &mut WrapDrive, end: usize, mut f: F) { +fn general_count_literal bool>(drive: &mut WrapDrive, end: usize, mut f: F) { let ch = drive.peek_code(1); while !drive.ctx().string_position < end && f(ch, drive.peek_char()) { drive.skip_char(1); } } -fn eq_loc_ignore(code: u32, c: char) -> bool { - code == c as u32 || code == lower_locate(c) as u32 || code == upper_locate(c) as u32 +fn eq_loc_ignore(code: u32, ch: u32) -> bool { + code == ch || code == lower_locate(ch) || code == upper_locate(ch) } -fn is_word(c: char) -> bool { - c.is_ascii_alphanumeric() || c == '_' +fn is_word(ch: u32) -> bool { + ch == '_' as u32 + || u8::try_from(ch) + .map(|x| x.is_ascii_alphanumeric()) + .unwrap_or(false) } -fn is_space(c: char) -> bool { - c.is_ascii_whitespace() +fn is_space(ch: u32) -> bool { + u8::try_from(ch) + .map(is_py_ascii_whitespace) + .unwrap_or(false) } -fn is_digit(c: char) -> bool { - c.is_ascii_digit() +fn is_digit(ch: u32) -> bool { + u8::try_from(ch) + .map(|x| x.is_ascii_digit()) + .unwrap_or(false) } -fn is_loc_alnum(c: char) -> bool { +fn is_loc_alnum(ch: u32) -> bool { // TODO: check with cpython - c.is_alphanumeric() + u8::try_from(ch) + .map(|x| x.is_ascii_alphanumeric()) + .unwrap_or(false) } -fn is_loc_word(c: char) -> bool { - is_loc_alnum(c) || c == '_' +fn is_loc_word(ch: u32) -> bool { + ch == '_' as u32 || is_loc_alnum(ch) } -fn is_linebreak(c: char) -> bool { - c == '\n' +fn is_linebreak(ch: u32) -> bool { + ch == '\n' as u32 } -pub(crate) fn lower_ascii(c: char) -> char { - c.to_ascii_lowercase() +pub(crate) fn lower_ascii(ch: u32) -> u32 { + u8::try_from(ch) + .map(|x| x.to_ascii_lowercase() as u32) + .unwrap_or(ch) } -fn lower_locate(c: char) -> char { +fn lower_locate(ch: u32) -> u32 { // TODO: check with cpython // https://doc.rust-lang.org/std/primitive.char.html#method.to_lowercase - c.to_lowercase().next().unwrap() + lower_ascii(ch) } -fn upper_locate(c: char) -> char { +fn upper_locate(ch: u32) -> u32 { // TODO: check with cpython // https://doc.rust-lang.org/std/primitive.char.html#method.to_uppercase - c.to_uppercase().next().unwrap() + u8::try_from(ch) + .map(|x| x.to_ascii_uppercase() as u32) + .unwrap_or(ch) } -fn is_uni_digit(c: char) -> bool { +fn is_uni_digit(ch: u32) -> bool { // TODO: check with cpython - c.is_digit(10) + char::try_from(ch).map(|x| x.is_digit(10)).unwrap_or(false) } -fn is_uni_space(c: char) -> bool { +fn is_uni_space(ch: u32) -> bool { // TODO: check with cpython - c.is_whitespace() -} -fn is_uni_linebreak(c: char) -> bool { + is_space(ch) + || matches!( + ch, + 0x0009 + | 0x000A + | 0x000B + | 0x000C + | 0x000D + | 0x001C + | 0x001D + | 0x001E + | 0x001F + | 0x0020 + | 0x0085 + | 0x00A0 + | 0x1680 + | 0x2000 + | 0x2001 + | 0x2002 + | 0x2003 + | 0x2004 + | 0x2005 + | 0x2006 + | 0x2007 + | 0x2008 + | 0x2009 + | 0x200A + | 0x2028 + | 0x2029 + | 0x202F + | 0x205F + | 0x3000 + ) +} +fn is_uni_linebreak(ch: u32) -> bool { matches!( - c, - '\u{000A}' - | '\u{000B}' - | '\u{000C}' - | '\u{000D}' - | '\u{001C}' - | '\u{001D}' - | '\u{001E}' - | '\u{0085}' - | '\u{2028}' - | '\u{2029}' + ch, + 0x000A | 0x000B | 0x000C | 0x000D | 0x001C | 0x001D | 0x001E | 0x0085 | 0x2028 | 0x2029 ) } -fn is_uni_alnum(c: char) -> bool { +fn is_uni_alnum(ch: u32) -> bool { // TODO: check with cpython - c.is_alphanumeric() + char::try_from(ch) + .map(|x| x.is_alphanumeric()) + .unwrap_or(false) } -fn is_uni_word(c: char) -> bool { - is_uni_alnum(c) || c == '_' +fn is_uni_word(ch: u32) -> bool { + ch == '_' as u32 || is_uni_alnum(ch) } -pub(crate) fn lower_unicode(c: char) -> char { +pub(crate) fn lower_unicode(ch: u32) -> u32 { // TODO: check with cpython - c.to_lowercase().next().unwrap() + char::try_from(ch) + .map(|x| x.to_lowercase().next().unwrap() as u32) + .unwrap_or(ch) } -pub(crate) fn upper_unicode(c: char) -> char { +pub(crate) fn upper_unicode(ch: u32) -> u32 { // TODO: check with cpython - c.to_uppercase().next().unwrap() + char::try_from(ch) + .map(|x| x.to_uppercase().next().unwrap() as u32) + .unwrap_or(ch) } fn is_utf8_first_byte(b: u8) -> bool { @@ -988,6 +1046,23 @@ fn is_utf8_first_byte(b: u8) -> bool { (b & 0b10000000 == 0) || (b & 0b11000000 == 0b11000000) } +fn utf8_back_peek_offset(bytes: &[u8], offset: usize) -> usize { + let mut offset = offset - 1; + if !is_utf8_first_byte(bytes[offset]) { + offset -= 1; + if !is_utf8_first_byte(bytes[offset]) { + offset -= 1; + if !is_utf8_first_byte(bytes[offset]) { + offset -= 1; + if !is_utf8_first_byte(bytes[offset]) { + panic!("not utf-8 code point"); + } + } + } + } + offset +} + #[derive(Debug, Copy, Clone)] struct RepeatContext { count: isize, From 97000fc4e046c5f9d7da42357d8b3290e33972db Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 19 Jan 2021 17:14:00 +0200 Subject: [PATCH 026/893] fix multiple bugs; skip crash tests --- interp.rs | 220 ++++++++++++++++++++++-------------------------------- 1 file changed, 88 insertions(+), 132 deletions(-) diff --git a/interp.rs b/interp.rs index 7c9246142a..dfb2e04e47 100644 --- a/interp.rs +++ b/interp.rs @@ -131,18 +131,17 @@ impl<'a> State<'a> { pub fn search(mut self) -> Self { // TODO: optimize by op info and skip prefix - loop { + while self.start <= self.end { self = self.pymatch(); if self.has_matched == Some(true) { return self; } self.start += 1; - if self.start > self.end { - return self; - } self.reset(); } + + self } } @@ -182,16 +181,19 @@ impl<'a> StrDrive<'a> { let bytes = s.as_bytes(); let back_offset = utf8_back_peek_offset(bytes, offset); match offset - back_offset { - 1 => u32::from_ne_bytes([0, 0, 0, bytes[offset]]), - 2 => u32::from_ne_bytes([0, 0, bytes[offset], bytes[offset + 1]]), - 3 => { - u32::from_ne_bytes([0, bytes[offset], bytes[offset + 1], bytes[offset + 2]]) - } + 1 => u32::from_ne_bytes([0, 0, 0, bytes[offset - 1]]), + 2 => u32::from_ne_bytes([0, 0, bytes[offset - 2], bytes[offset - 1]]), + 3 => u32::from_ne_bytes([ + 0, + bytes[offset - 3], + bytes[offset - 2], + bytes[offset - 1], + ]), 4 => u32::from_ne_bytes([ - bytes[offset], - bytes[offset + 1], - bytes[offset + 2], - bytes[offset + 3], + bytes[offset - 4], + bytes[offset - 3], + bytes[offset - 2], + bytes[offset - 1], ]), _ => unreachable!(), } @@ -222,7 +224,10 @@ impl<'a> StrDrive<'a> { .skip(start) .collect::() .into_pyobject(vm), - StrDrive::Bytes(b) => PyBytes::from(b[start..end].to_vec()).into_pyobject(vm), + StrDrive::Bytes(b) => { + PyBytes::from(b.iter().take(end).skip(start).cloned().collect::>()) + .into_pyobject(vm) + } } } } @@ -256,13 +261,11 @@ trait MatchContextDrive { .state() .string .offset(self.ctx().string_offset, skip_count); - self.ctx_mut().string_position += skip_count; + self.ctx_mut().string_position = + std::cmp::min(self.ctx().string_position + skip_count, self.state().end); } fn skip_code(&mut self, skip_count: usize) { self.ctx_mut().code_position += skip_count; - if self.ctx().code_position > self.state().pattern_codes.len() { - self.ctx_mut().code_position = self.state().pattern_codes.len(); - } } fn remaining_chars(&self) -> usize { self.state().end - self.ctx().string_position @@ -314,9 +317,7 @@ impl<'a> StackDrive<'a> { self.state } fn push_new_context(&mut self, pattern_offset: usize) { - let mut child_ctx = MatchContext { ..*self.ctx() }; - child_ctx.code_position += pattern_offset; - self.state.context_stack.push(child_ctx); + self.push_new_context_at(self.ctx().code_position + pattern_offset); } fn push_new_context_at(&mut self, code_position: usize) { let mut child_ctx = MatchContext { ..*self.ctx() }; @@ -352,11 +353,9 @@ impl MatchContextDrive for WrapDrive<'_> { fn ctx_mut(&mut self) -> &mut MatchContext { &mut self.ctx } - fn ctx(&self) -> &MatchContext { &self.ctx } - fn state(&self) -> &State { self.stack_drive.state() } @@ -833,6 +832,7 @@ fn charset(set: &[u32], ch: u32) -> bool { false } +/* General case */ fn count(drive: &mut StackDrive, maxcount: usize) -> usize { let mut count = 0; let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); @@ -854,6 +854,8 @@ fn count(drive: &mut StackDrive, maxcount: usize) -> usize { count } +/* TODO: check literal cases should improve the perfermance + fn _count(stack_drive: &StackDrive, maxcount: usize) -> usize { let mut drive = WrapDrive::drive(*stack_drive.ctx(), stack_drive); let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); @@ -924,6 +926,7 @@ fn general_count_literal bool>(drive: &mut WrapDrive, end: fn eq_loc_ignore(code: u32, ch: u32) -> bool { code == ch || code == lower_locate(ch) || code == upper_locate(ch) } +*/ fn is_word(ch: u32) -> bool { ch == '_' as u32 @@ -1073,6 +1076,7 @@ struct RepeatContext { maxcount: usize, } +#[derive(Default)] struct OpMinRepeatOne { jump_id: usize, mincount: usize, @@ -1082,102 +1086,79 @@ struct OpMinRepeatOne { impl OpcodeExecutor for OpMinRepeatOne { fn next(&mut self, drive: &mut StackDrive) -> Option<()> { match self.jump_id { - 0 => self._0(drive), - 1 => self._1(drive), - 2 => self._2(drive), - _ => unreachable!(), - } - } -} -impl Default for OpMinRepeatOne { - fn default() -> Self { - OpMinRepeatOne { - jump_id: 0, - mincount: 0, - maxcount: 0, - count: 0, - } - } -} -impl OpMinRepeatOne { - fn _0(&mut self, drive: &mut StackDrive) -> Option<()> { - self.mincount = drive.peek_code(2) as usize; - self.maxcount = drive.peek_code(3) as usize; + 0 => { + self.mincount = drive.peek_code(2) as usize; + self.maxcount = drive.peek_code(3) as usize; - if drive.remaining_chars() < self.mincount { - drive.ctx_mut().has_matched = Some(false); - return None; - } + if drive.remaining_chars() < self.mincount { + drive.ctx_mut().has_matched = Some(false); + return None; + } - drive.state.string_position = drive.ctx().string_position; + drive.state.string_position = drive.ctx().string_position; - self.count = if self.mincount == 0 { - 0 - } else { - let count = count(drive, self.mincount); - if count < self.mincount { - drive.ctx_mut().has_matched = Some(false); - return None; - } - drive.skip_char(count); - count - }; + self.count = if self.mincount == 0 { + 0 + } else { + let count = count(drive, self.mincount); + if count < self.mincount { + drive.ctx_mut().has_matched = Some(false); + return None; + } + drive.skip_char(count); + count + }; - if drive.peek_code(drive.peek_code(1) as usize + 1) == SreOpcode::SUCCESS as u32 { - drive.state.string_position = drive.ctx().string_position; - drive.ctx_mut().has_matched = Some(true); - return None; - } + if drive.peek_code(drive.peek_code(1) as usize + 1) == SreOpcode::SUCCESS as u32 { + drive.state.string_position = drive.ctx().string_position; + drive.ctx_mut().has_matched = Some(true); + return None; + } - drive.state.marks_push(); - self.jump_id = 1; - self._1(drive) - } - fn _1(&mut self, drive: &mut StackDrive) -> Option<()> { - if self.maxcount == MAXREPEAT || self.count <= self.maxcount { - drive.state.string_position = drive.ctx().string_position; - drive.push_new_context(drive.peek_code(1) as usize + 1); - self.jump_id = 2; - return Some(()); - } + drive.state.marks_push(); + self.jump_id = 1; + self.next(drive) + } + 1 => { + if self.maxcount == MAXREPEAT || self.count <= self.maxcount { + drive.state.string_position = drive.ctx().string_position; + drive.push_new_context(drive.peek_code(1) as usize + 1); + self.jump_id = 2; + return Some(()); + } - drive.state.marks_pop_discard(); - drive.ctx_mut().has_matched = Some(false); - None - } - fn _2(&mut self, drive: &mut StackDrive) -> Option<()> { - let child_ctx = drive.state.popped_context.unwrap(); - if child_ctx.has_matched == Some(true) { - drive.ctx_mut().has_matched = Some(true); - return None; - } - drive.state.string_position = drive.ctx().string_position; - if count(drive, 1) == 0 { - drive.ctx_mut().has_matched = Some(false); - return None; + drive.state.marks_pop_discard(); + drive.ctx_mut().has_matched = Some(false); + None + } + 2 => { + let child_ctx = drive.state.popped_context.unwrap(); + if child_ctx.has_matched == Some(true) { + drive.ctx_mut().has_matched = Some(true); + return None; + } + drive.state.string_position = drive.ctx().string_position; + if count(drive, 1) == 0 { + drive.ctx_mut().has_matched = Some(false); + return None; + } + drive.skip_char(1); + self.count += 1; + drive.state.marks_pop_keep(); + self.jump_id = 1; + self.next(drive) + } + _ => unreachable!(), } - drive.skip_char(1); - self.count += 1; - drive.state.marks_pop_keep(); - self.jump_id = 1; - self._1(drive) } } +#[derive(Default)] struct OpMaxUntil { jump_id: usize, count: isize, save_last_position: usize, } -impl Default for OpMaxUntil { - fn default() -> Self { - OpMaxUntil { - jump_id: 0, - count: 0, - save_last_position: 0, - } - } -} impl OpcodeExecutor for OpMaxUntil { fn next(&mut self, drive: &mut StackDrive) -> Option<()> { match self.jump_id { @@ -1259,20 +1240,12 @@ impl OpcodeExecutor for OpMaxUntil { } } +#[derive(Default)] struct OpMinUntil { jump_id: usize, count: isize, save_repeat: Option, } -impl Default for OpMinUntil { - fn default() -> Self { - Self { - jump_id: 0, - count: 0, - save_repeat: None, - } - } -} impl OpcodeExecutor for OpMinUntil { fn next(&mut self, drive: &mut StackDrive) -> Option<()> { match self.jump_id { @@ -1338,18 +1311,11 @@ impl OpcodeExecutor for OpMinUntil { } } +#[derive(Default)] struct OpBranch { jump_id: usize, current_branch_length: usize, } -impl Default for OpBranch { - fn default() -> Self { - Self { - jump_id: 0, - current_branch_length: 0, - } - } -} impl OpcodeExecutor for OpBranch { fn next(&mut self, drive: &mut StackDrive) -> Option<()> { match self.jump_id { @@ -1388,22 +1354,13 @@ impl OpcodeExecutor for OpBranch { } } +#[derive(Default)] struct OpRepeatOne { jump_id: usize, mincount: usize, maxcount: usize, count: isize, } -impl Default for OpRepeatOne { - fn default() -> Self { - Self { - jump_id: 0, - mincount: 0, - maxcount: 0, - count: 0, - } - } -} impl OpcodeExecutor for OpRepeatOne { fn next(&mut self, drive: &mut StackDrive) -> Option<()> { match self.jump_id { @@ -1413,7 +1370,6 @@ impl OpcodeExecutor for OpRepeatOne { if drive.remaining_chars() < self.mincount { drive.ctx_mut().has_matched = Some(false); - return None; } drive.state.string_position = drive.ctx().string_position; self.count = count(drive, self.maxcount) as isize; From 84113cba2cd66b83106f4564b03ae8cb999197c3 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 20 Jan 2021 16:33:30 +0200 Subject: [PATCH 027/893] fix zero width repeat --- interp.rs | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/interp.rs b/interp.rs index dfb2e04e47..8b5a023ffd 100644 --- a/interp.rs +++ b/interp.rs @@ -1245,6 +1245,7 @@ struct OpMinUntil { jump_id: usize, count: isize, save_repeat: Option, + save_last_position: usize, } impl OpcodeExecutor for OpMinUntil { fn next(&mut self, drive: &mut StackDrive) -> Option<()> { @@ -1280,6 +1281,7 @@ impl OpcodeExecutor for OpMinUntil { drive.ctx_mut().has_matched = child_ctx.has_matched; if drive.ctx().has_matched != Some(true) { drive.repeat_ctx_mut().count = self.count - 1; + drive.repeat_ctx_mut().last_position = self.save_last_position; drive.state.string_position = drive.ctx().string_position; } None @@ -1295,13 +1297,26 @@ impl OpcodeExecutor for OpMinUntil { drive.state.marks_pop(); // match more unital tail matches - let maxcount = drive.repeat_ctx().maxcount; - let code_position = drive.repeat_ctx().code_position; - if self.count as usize >= maxcount && maxcount != MAXREPEAT { + let RepeatContext { + count: _, + code_position, + last_position, + mincount: _, + maxcount, + } = *drive.repeat_ctx(); + + if self.count as usize >= maxcount && maxcount != MAXREPEAT + || drive.state.string_position == last_position + { drive.ctx_mut().has_matched = Some(false); return None; } drive.repeat_ctx_mut().count = self.count; + + /* zero-width match protection */ + self.save_last_position = last_position; + drive.repeat_ctx_mut().last_position = drive.state.string_position; + drive.push_new_context_at(code_position + 4); self.jump_id = 1; Some(()) From f2311b56fcbc87431eb0e291b68f44bb3658c2c1 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 21 Jan 2021 12:48:32 +0200 Subject: [PATCH 028/893] fix op branch --- interp.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/interp.rs b/interp.rs index 8b5a023ffd..5fa7ca8c0e 100644 --- a/interp.rs +++ b/interp.rs @@ -1329,28 +1329,30 @@ impl OpcodeExecutor for OpMinUntil { #[derive(Default)] struct OpBranch { jump_id: usize, - current_branch_length: usize, + branch_offset: usize, } impl OpcodeExecutor for OpBranch { + // alternation + // <0=skip> code ... fn next(&mut self, drive: &mut StackDrive) -> Option<()> { match self.jump_id { 0 => { drive.state.marks_push(); // jump out the head - self.current_branch_length = 1; + self.branch_offset = 1; self.jump_id = 1; self.next(drive) } 1 => { - drive.skip_code(self.current_branch_length); - self.current_branch_length = drive.peek_code(0) as usize; - if self.current_branch_length == 0 { + let next_branch_length = drive.peek_code(self.branch_offset) as usize; + if next_branch_length == 0 { drive.state.marks_pop_discard(); drive.ctx_mut().has_matched = Some(false); return None; } drive.state.string_position = drive.ctx().string_position; - drive.push_new_context(1); + drive.push_new_context(self.branch_offset + 1); + self.branch_offset += next_branch_length; self.jump_id = 2; Some(()) } From 6a792324b027bea56c39cb04fe072f719fd1efed Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 21 Jan 2021 16:14:23 +0200 Subject: [PATCH 029/893] fix at_beginning --- interp.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/interp.rs b/interp.rs index 5fa7ca8c0e..11010c7152 100644 --- a/interp.rs +++ b/interp.rs @@ -274,7 +274,8 @@ trait MatchContextDrive { self.state().pattern_codes.len() - self.ctx().code_position } fn at_beginning(&self) -> bool { - self.ctx().string_position == self.state().start + // self.ctx().string_position == self.state().start + self.ctx().string_position == 0 } fn at_end(&self) -> bool { self.ctx().string_position == self.state().end From 7f0dad7901751abcc09ac8308742c91627401c81 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 21 Jan 2021 19:24:10 +0200 Subject: [PATCH 030/893] fix back_peek_char --- interp.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/interp.rs b/interp.rs index 11010c7152..beb96f1f09 100644 --- a/interp.rs +++ b/interp.rs @@ -181,15 +181,15 @@ impl<'a> StrDrive<'a> { let bytes = s.as_bytes(); let back_offset = utf8_back_peek_offset(bytes, offset); match offset - back_offset { - 1 => u32::from_ne_bytes([0, 0, 0, bytes[offset - 1]]), - 2 => u32::from_ne_bytes([0, 0, bytes[offset - 2], bytes[offset - 1]]), - 3 => u32::from_ne_bytes([ + 1 => u32::from_be_bytes([0, 0, 0, bytes[offset - 1]]), + 2 => u32::from_be_bytes([0, 0, bytes[offset - 2], bytes[offset - 1]]), + 3 => u32::from_be_bytes([ 0, bytes[offset - 3], bytes[offset - 2], bytes[offset - 1], ]), - 4 => u32::from_ne_bytes([ + 4 => u32::from_be_bytes([ bytes[offset - 4], bytes[offset - 3], bytes[offset - 2], From 4416158eceb9dd6931ca25cfc610f50582dd83dc Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 22 Jan 2021 16:40:33 +0200 Subject: [PATCH 031/893] fix multiple bugs; pass tests --- interp.rs | 97 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 70 insertions(+), 27 deletions(-) diff --git a/interp.rs b/interp.rs index beb96f1f09..71e3ead143 100644 --- a/interp.rs +++ b/interp.rs @@ -25,6 +25,7 @@ pub(crate) struct State<'a> { pub string_position: usize, popped_context: Option, pub has_matched: Option, + pub match_all: bool, } impl<'a> State<'a> { @@ -51,6 +52,7 @@ impl<'a> State<'a> { string_position: start, popped_context: None, has_matched: None, + match_all: false, } } @@ -105,6 +107,7 @@ impl<'a> State<'a> { string_offset: self.string.offset(0, self.start), code_position: 0, has_matched: None, + toplevel: true, }; self.context_stack.push(ctx); @@ -132,6 +135,7 @@ impl<'a> State<'a> { pub fn search(mut self) -> Self { // TODO: optimize by op info and skip prefix while self.start <= self.end { + self.match_all = false; self = self.pymatch(); if self.has_matched == Some(true) { @@ -232,12 +236,13 @@ impl<'a> StrDrive<'a> { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Clone, Copy)] struct MatchContext { string_position: usize, string_offset: usize, code_position: usize, has_matched: Option, + toplevel: bool, } trait MatchContextDrive { @@ -463,8 +468,12 @@ impl OpcodeDispatcher { drive.ctx_mut().has_matched = Some(false); }), SreOpcode::SUCCESS => once(|drive| { - drive.state.string_position = drive.ctx().string_position; - drive.ctx_mut().has_matched = Some(true); + if drive.ctx().toplevel && drive.state.match_all && !drive.at_end() { + drive.ctx_mut().has_matched = Some(false); + } else { + drive.state.string_position = drive.ctx().string_position; + drive.ctx_mut().has_matched = Some(true); + } }), SreOpcode::ANY => once(|drive| { if drive.at_end() || drive.at_linebreak() { @@ -482,15 +491,19 @@ impl OpcodeDispatcher { drive.skip_char(1); } }), + /* assert subpattern */ + /* */ SreOpcode::ASSERT => twice( |drive| { let back = drive.peek_code(2) as usize; - if back > drive.ctx().string_position { + let passed = drive.ctx().string_position - drive.state.start; + if passed < back { drive.ctx_mut().has_matched = Some(false); return; } drive.state.string_position = drive.ctx().string_position - back; drive.push_new_context(3); + drive.state.context_stack.last_mut().unwrap().toplevel = false; }, |drive| { let child_ctx = drive.state.popped_context.unwrap(); @@ -504,12 +517,14 @@ impl OpcodeDispatcher { SreOpcode::ASSERT_NOT => twice( |drive| { let back = drive.peek_code(2) as usize; - if back > drive.ctx().string_position { + let passed = drive.ctx().string_position - drive.state.start; + if passed < back { drive.skip_code(drive.peek_code(1) as usize + 1); return; } drive.state.string_position = drive.ctx().string_position - back; drive.push_new_context(3); + drive.state.context_stack.last_mut().unwrap().toplevel = false; }, |drive| { let child_ctx = drive.state.popped_context.unwrap(); @@ -770,17 +785,17 @@ fn charset(set: &[u32], ch: u32) -> bool { } SreOpcode::CHARSET => { /* */ - let set = &set[1..]; + let set = &set[i + 1..]; if ch < 256 && ((set[(ch >> 5) as usize] & (1u32 << (ch & 31))) != 0) { return ok; } - i += 8; + i += 1 + 8; } SreOpcode::BIGCHARSET => { /* <256 blockindices> */ let count = set[i + 1] as usize; if ch < 0x10000 { - let set = &set[2..]; + let set = &set[i + 2..]; let block_index = ch >> 8; let (_, blockindices, _) = unsafe { set.align_to::() }; let blocks = &set[64..]; @@ -1085,6 +1100,7 @@ struct OpMinRepeatOne { count: usize, } impl OpcodeExecutor for OpMinRepeatOne { + /* <1=min> <2=max> item tail */ fn next(&mut self, drive: &mut StackDrive) -> Option<()> { match self.jump_id { 0 => { @@ -1110,7 +1126,11 @@ impl OpcodeExecutor for OpMinRepeatOne { count }; - if drive.peek_code(drive.peek_code(1) as usize + 1) == SreOpcode::SUCCESS as u32 { + let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); + if next_code == SreOpcode::SUCCESS as u32 + && !(drive.ctx().toplevel && drive.state.match_all && !drive.at_end()) + { + // tail is empty. we're finished drive.state.string_position = drive.ctx().string_position; drive.ctx_mut().has_matched = Some(true); return None; @@ -1377,9 +1397,18 @@ struct OpRepeatOne { jump_id: usize, mincount: usize, maxcount: usize, - count: isize, + count: usize, + following_literal: Option, } impl OpcodeExecutor for OpRepeatOne { + /* match repeated sequence (maximizing regexp) */ + + /* this operator only works if the repeated item is + exactly one character wide, and we're not already + collecting backtracking points. for other cases, + use the MAX_REPEAT operator */ + + /* <1=min> <2=max> item tail */ fn next(&mut self, drive: &mut StackDrive) -> Option<()> { match self.jump_id { 0 => { @@ -1388,17 +1417,21 @@ impl OpcodeExecutor for OpRepeatOne { if drive.remaining_chars() < self.mincount { drive.ctx_mut().has_matched = Some(false); + return None; } + drive.state.string_position = drive.ctx().string_position; - self.count = count(drive, self.maxcount) as isize; - drive.skip_char(self.count as usize); - if self.count < self.mincount as isize { + + self.count = count(drive, self.maxcount); + drive.skip_char(self.count); + if self.count < self.mincount { drive.ctx_mut().has_matched = Some(false); return None; } let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); - if next_code == SreOpcode::SUCCESS as u32 { + if next_code == SreOpcode::SUCCESS as u32 && drive.at_end() && !drive.ctx().toplevel + { // tail is empty. we're finished drive.state.string_position = drive.ctx().string_position; drive.ctx_mut().has_matched = Some(true); @@ -1406,24 +1439,34 @@ impl OpcodeExecutor for OpRepeatOne { } drive.state.marks_push(); - // TODO: + // Special case: Tail starts with a literal. Skip positions where // the rest of the pattern cannot possibly match. + if next_code == SreOpcode::LITERAL as u32 { + self.following_literal = Some(drive.peek_code(drive.peek_code(1) as usize + 2)) + } + self.jump_id = 1; self.next(drive) } 1 => { - // General case: backtracking - if self.count >= self.mincount as isize { - drive.state.string_position = drive.ctx().string_position; - drive.push_new_context(drive.peek_code(1) as usize + 1); - self.jump_id = 2; - return Some(()); + if let Some(c) = self.following_literal { + while drive.at_end() || drive.peek_char() != c { + if self.count <= self.mincount { + drive.state.marks_pop_discard(); + drive.ctx_mut().has_matched = Some(false); + return None; + } + drive.back_skip_char(1); + self.count -= 1; + } } - drive.state.marks_pop_discard(); - drive.ctx_mut().has_matched = Some(false); - None + // General case: backtracking + drive.state.string_position = drive.ctx().string_position; + drive.push_new_context(drive.peek_code(1) as usize + 1); + self.jump_id = 2; + Some(()) } 2 => { let child_ctx = drive.state.popped_context.unwrap(); @@ -1431,19 +1474,19 @@ impl OpcodeExecutor for OpRepeatOne { drive.ctx_mut().has_matched = Some(true); return None; } - if self.count <= self.mincount as isize { + if self.count <= self.mincount { drive.state.marks_pop_discard(); drive.ctx_mut().has_matched = Some(false); return None; } - // TODO: unnesscary double check drive.back_skip_char(1); self.count -= 1; + drive.state.marks_pop_keep(); self.jump_id = 1; - Some(()) + self.next(drive) } _ => unreachable!(), } From db5bd646b928048c7183f46ede40a114f19a82e4 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Wed, 27 Jan 2021 18:54:38 -0600 Subject: [PATCH 032/893] Initial commit to switch to new repo --- .gitignore | 2 ++ Cargo.toml | 7 +++++++ constants.rs => src/constants.rs | 0 interp.rs => src/engine.rs | 0 src/lib.rs | 2 ++ 5 files changed, 11 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml rename constants.rs => src/constants.rs (100%) rename interp.rs => src/engine.rs (100%) create mode 100644 src/lib.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..96ef6c0b94 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000000..f9f504966b --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "sre-engine" +version = "0.1.0" +authors = ["Kangzhi Shi ", "RustPython Team"] +edition = "2018" + +[dependencies] diff --git a/constants.rs b/src/constants.rs similarity index 100% rename from constants.rs rename to src/constants.rs diff --git a/interp.rs b/src/engine.rs similarity index 100% rename from interp.rs rename to src/engine.rs diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000000..f305aa094a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,2 @@ +pub mod constants; +pub mod engine; From 9c95994dab20fea56004940ba50323ee7e327f81 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Wed, 27 Jan 2021 19:18:42 -0600 Subject: [PATCH 033/893] Modify to work outside of rustpython-vm --- Cargo.toml | 2 ++ src/engine.rs | 38 +++++++++++--------------------------- src/lib.rs | 12 ++++++++++++ 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f9f504966b..3a82ba73f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,3 +5,5 @@ authors = ["Kangzhi Shi ", "RustPython Team"] edition = "2018" [dependencies] +num_enum = "0.5" +bitflags = "1.2" diff --git a/src/engine.rs b/src/engine.rs index 71e3ead143..9aff2dcc91 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,17 +1,16 @@ // good luck to those that follow; here be dragons -use super::_sre::MAXREPEAT; use super::constants::{SreAtCode, SreCatCode, SreFlag, SreOpcode}; -use crate::builtins::PyBytes; -use crate::bytesinner::is_py_ascii_whitespace; -use crate::pyobject::{IntoPyObject, PyObjectRef}; -use crate::VirtualMachine; +use super::MAXREPEAT; use std::collections::HashMap; use std::convert::TryFrom; -use std::unreachable; + +const fn is_py_ascii_whitespace(b: u8) -> bool { + matches!(b, b'\t' | b'\n' | b'\x0C' | b'\r' | b' ' | b'\x0B') +} #[derive(Debug)] -pub(crate) struct State<'a> { +pub struct State<'a> { pub string: StrDrive<'a>, pub start: usize, pub end: usize, @@ -29,7 +28,7 @@ pub(crate) struct State<'a> { } impl<'a> State<'a> { - pub(crate) fn new( + pub fn new( string: StrDrive<'a>, start: usize, end: usize, @@ -150,7 +149,7 @@ impl<'a> State<'a> { } #[derive(Debug, Clone, Copy)] -pub(crate) enum StrDrive<'a> { +pub enum StrDrive<'a> { Str(&'a str), Bytes(&'a [u8]), } @@ -219,21 +218,6 @@ impl<'a> StrDrive<'a> { StrDrive::Bytes(_) => offset - skip, } } - - pub fn slice_to_pyobject(&self, start: usize, end: usize, vm: &VirtualMachine) -> PyObjectRef { - match *self { - StrDrive::Str(s) => s - .chars() - .take(end) - .skip(start) - .collect::() - .into_pyobject(vm), - StrDrive::Bytes(b) => { - PyBytes::from(b.iter().take(end).skip(start).cloned().collect::>()) - .into_pyobject(vm) - } - } - } } #[derive(Debug, Clone, Copy)] @@ -972,7 +956,7 @@ fn is_loc_word(ch: u32) -> bool { fn is_linebreak(ch: u32) -> bool { ch == '\n' as u32 } -pub(crate) fn lower_ascii(ch: u32) -> u32 { +pub fn lower_ascii(ch: u32) -> u32 { u8::try_from(ch) .map(|x| x.to_ascii_lowercase() as u32) .unwrap_or(ch) @@ -1044,13 +1028,13 @@ fn is_uni_alnum(ch: u32) -> bool { fn is_uni_word(ch: u32) -> bool { ch == '_' as u32 || is_uni_alnum(ch) } -pub(crate) fn lower_unicode(ch: u32) -> u32 { +pub fn lower_unicode(ch: u32) -> u32 { // TODO: check with cpython char::try_from(ch) .map(|x| x.to_lowercase().next().unwrap() as u32) .unwrap_or(ch) } -pub(crate) fn upper_unicode(ch: u32) -> u32 { +pub fn upper_unicode(ch: u32) -> u32 { // TODO: check with cpython char::try_from(ch) .map(|x| x.to_uppercase().next().unwrap() as u32) diff --git a/src/lib.rs b/src/lib.rs index f305aa094a..4a3ed1b754 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,2 +1,14 @@ pub mod constants; pub mod engine; + +pub const CODESIZE: usize = 4; + +#[cfg(target_pointer_width = "32")] +pub const MAXREPEAT: usize = usize::MAX; +#[cfg(target_pointer_width = "64")] +pub const MAXREPEAT: usize = u32::MAX as usize; + +#[cfg(target_pointer_width = "32")] +pub const MAXGROUPS: usize = MAXREPEAT / 4 / 2; +#[cfg(target_pointer_width = "64")] +pub const MAXGROUPS: usize = MAXREPEAT / 2; From 2592067f95f530850db8451fd45982dc3bf161e0 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Wed, 27 Jan 2021 19:19:31 -0600 Subject: [PATCH 034/893] Add LICENSE --- Cargo.toml | 1 + LICENSE | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 LICENSE diff --git a/Cargo.toml b/Cargo.toml index 3a82ba73f1..03db7aba4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ name = "sre-engine" version = "0.1.0" authors = ["Kangzhi Shi ", "RustPython Team"] +license = "MIT" edition = "2018" [dependencies] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000..7213274e0f --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 RustPython Team + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. From 2473c3e49f6ba539549a8bf4f87db9146044fc52 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 1 Feb 2021 20:25:45 +0200 Subject: [PATCH 035/893] add tests; fix OpAssert panic --- src/engine.rs | 15 ++++++++------- src/lib.rs | 2 ++ src/tests.rs | 19 +++++++++++++++++++ 3 files changed, 29 insertions(+), 7 deletions(-) create mode 100644 src/tests.rs diff --git a/src/engine.rs b/src/engine.rs index 9aff2dcc91..f6a6092f9b 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -369,20 +369,18 @@ fn once(f: F) -> Box> { Box::new(OpOnce { f: Some(f) }) } -// F1 F2 are same identical, but workaround for closure struct OpTwice { f1: Option, f2: Option, } impl OpcodeExecutor for OpTwice where - F1: FnOnce(&mut StackDrive), + F1: FnOnce(&mut StackDrive) -> Option<()>, F2: FnOnce(&mut StackDrive), { fn next(&mut self, drive: &mut StackDrive) -> Option<()> { if let Some(f1) = self.f1.take() { - f1(drive); - Some(()) + f1(drive) } else if let Some(f2) = self.f2.take() { f2(drive); None @@ -393,7 +391,7 @@ where } fn twice(f1: F1, f2: F2) -> Box> where - F1: FnOnce(&mut StackDrive), + F1: FnOnce(&mut StackDrive) -> Option<()>, F2: FnOnce(&mut StackDrive), { Box::new(OpTwice { @@ -483,11 +481,12 @@ impl OpcodeDispatcher { let passed = drive.ctx().string_position - drive.state.start; if passed < back { drive.ctx_mut().has_matched = Some(false); - return; + return None; } drive.state.string_position = drive.ctx().string_position - back; drive.push_new_context(3); drive.state.context_stack.last_mut().unwrap().toplevel = false; + Some(()) }, |drive| { let child_ctx = drive.state.popped_context.unwrap(); @@ -504,11 +503,12 @@ impl OpcodeDispatcher { let passed = drive.ctx().string_position - drive.state.start; if passed < back { drive.skip_code(drive.peek_code(1) as usize + 1); - return; + return None; } drive.state.string_position = drive.ctx().string_position - back; drive.push_new_context(3); drive.state.context_stack.last_mut().unwrap().toplevel = false; + Some(()) }, |drive| { let child_ctx = drive.state.popped_context.unwrap(); @@ -598,6 +598,7 @@ impl OpcodeDispatcher { drive.state.string_position = drive.ctx().string_position; // execute UNTIL operator drive.push_new_context(drive.peek_code(1) as usize + 1); + Some(()) }, |drive| { drive.state.repeat_stack.pop(); diff --git a/src/lib.rs b/src/lib.rs index 4a3ed1b754..eae3be617d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ pub mod constants; pub mod engine; +#[cfg(test)] +mod tests; pub const CODESIZE: usize = 4; diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000000..95922e88a7 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,19 @@ +use engine::{State, StrDrive}; + +use super::*; + +#[test] +fn test_2427() { + let str_drive = StrDrive::Str("x"); + // r'(? = vec![15, 4, 0, 1, 1, 5, 5, 1, 17, 46, 1, 17, 120, 6, 10, 1]; + let mut state = State::new( + str_drive, + 0, + std::usize::MAX, + constants::SreFlag::UNICODE, + &code, + ); + state = state.pymatch(); + assert!(state.has_matched == Some(true)); +} From cc4441b50ffa361c8f52eeecad54a9459c471e9b Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Mon, 1 Feb 2021 14:27:01 -0600 Subject: [PATCH 036/893] Compile regex patterns for tests with a script --- generate_tests.py | 37 +++++++++++++++++++++++++++++++++++++ src/engine.rs | 12 ++++++++++++ src/lib.rs | 2 -- src/tests.rs | 19 ------------------- tests/lookbehind.py | 1 + tests/lookbehind.re | 2 ++ tests/tests.rs | 26 ++++++++++++++++++++++++++ 7 files changed, 78 insertions(+), 21 deletions(-) create mode 100644 generate_tests.py delete mode 100644 src/tests.rs create mode 100644 tests/lookbehind.py create mode 100644 tests/lookbehind.re create mode 100644 tests/tests.rs diff --git a/generate_tests.py b/generate_tests.py new file mode 100644 index 0000000000..49a24792be --- /dev/null +++ b/generate_tests.py @@ -0,0 +1,37 @@ +import os +from pathlib import Path +import re +import sre_constants +import sre_compile +import sre_parse +import json + +m = re.search(r"const SRE_MAGIC: usize = (\d+);", open("src/constants.rs").read()) +sre_engine_magic = int(m.group(1)) +del m + +assert sre_constants.MAGIC == sre_engine_magic + +class CompiledPattern: + @classmethod + def compile(cls, pattern, flags=0): + p = sre_parse.parse(pattern) + code = sre_compile._code(p, flags) + self = cls() + self.pattern = pattern + self.code = code + self.flags = re.RegexFlag(flags | p.state.flags) + return self + +for k, v in re.RegexFlag.__members__.items(): + setattr(CompiledPattern, k, v) + +with os.scandir("tests") as d: + for f in d: + path = Path(f.path) + if path.suffix == ".py": + pattern = eval(path.read_text(), {"re": CompiledPattern}) + path.with_suffix(".re").write_text( + f"// {pattern.pattern!r}, flags={pattern.flags!r}\n" + f"Pattern {{ code: &{json.dumps(pattern.code)}, flags: SreFlag::from_bits_truncate({int(pattern.flags)}) }}" + ) diff --git a/src/engine.rs b/src/engine.rs index f6a6092f9b..a48b799e1b 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -153,6 +153,18 @@ pub enum StrDrive<'a> { Str(&'a str), Bytes(&'a [u8]), } + +impl<'a> From<&'a str> for StrDrive<'a> { + fn from(s: &'a str) -> Self { + Self::Str(s) + } +} +impl<'a> From<&'a [u8]> for StrDrive<'a> { + fn from(b: &'a [u8]) -> Self { + Self::Bytes(b) + } +} + impl<'a> StrDrive<'a> { fn offset(&self, offset: usize, skip: usize) -> usize { match *self { diff --git a/src/lib.rs b/src/lib.rs index eae3be617d..4a3ed1b754 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,5 @@ pub mod constants; pub mod engine; -#[cfg(test)] -mod tests; pub const CODESIZE: usize = 4; diff --git a/src/tests.rs b/src/tests.rs deleted file mode 100644 index 95922e88a7..0000000000 --- a/src/tests.rs +++ /dev/null @@ -1,19 +0,0 @@ -use engine::{State, StrDrive}; - -use super::*; - -#[test] -fn test_2427() { - let str_drive = StrDrive::Str("x"); - // r'(? = vec![15, 4, 0, 1, 1, 5, 5, 1, 17, 46, 1, 17, 120, 6, 10, 1]; - let mut state = State::new( - str_drive, - 0, - std::usize::MAX, - constants::SreFlag::UNICODE, - &code, - ); - state = state.pymatch(); - assert!(state.has_matched == Some(true)); -} diff --git a/tests/lookbehind.py b/tests/lookbehind.py new file mode 100644 index 0000000000..3da6425959 --- /dev/null +++ b/tests/lookbehind.py @@ -0,0 +1 @@ +re.compile(r'(?( + &self, + string: impl Into>, + range: std::ops::Range, + ) -> engine::State<'a> { + engine::State::new(string.into(), range.start, range.end, self.flags, self.code) + } +} + +#[test] +fn test_2427() { + // r'(? Date: Wed, 3 Feb 2021 13:32:48 +0200 Subject: [PATCH 037/893] fix OpAssert positive lookbehind --- src/engine.rs | 28 ++++++++++++++++++++++++---- tests/positive_lookbehind.py | 1 + tests/positive_lookbehind.re | 2 ++ tests/tests.rs | 9 +++++++++ 4 files changed, 36 insertions(+), 4 deletions(-) create mode 100644 tests/positive_lookbehind.py create mode 100644 tests/positive_lookbehind.re diff --git a/src/engine.rs b/src/engine.rs index a48b799e1b..5e0e0f4208 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -490,14 +490,24 @@ impl OpcodeDispatcher { SreOpcode::ASSERT => twice( |drive| { let back = drive.peek_code(2) as usize; - let passed = drive.ctx().string_position - drive.state.start; + let passed = drive.ctx().string_position; if passed < back { drive.ctx_mut().has_matched = Some(false); return None; } + let back_offset = drive + .state + .string + .back_offset(drive.ctx().string_offset, back); + drive.state.string_position = drive.ctx().string_position - back; + drive.push_new_context(3); - drive.state.context_stack.last_mut().unwrap().toplevel = false; + let child_ctx = drive.state.context_stack.last_mut().unwrap(); + child_ctx.toplevel = false; + child_ctx.string_position -= back; + child_ctx.string_offset = back_offset; + Some(()) }, |drive| { @@ -512,14 +522,24 @@ impl OpcodeDispatcher { SreOpcode::ASSERT_NOT => twice( |drive| { let back = drive.peek_code(2) as usize; - let passed = drive.ctx().string_position - drive.state.start; + let passed = drive.ctx().string_position; if passed < back { drive.skip_code(drive.peek_code(1) as usize + 1); return None; } + let back_offset = drive + .state + .string + .back_offset(drive.ctx().string_offset, back); + drive.state.string_position = drive.ctx().string_position - back; + drive.push_new_context(3); - drive.state.context_stack.last_mut().unwrap().toplevel = false; + let child_ctx = drive.state.context_stack.last_mut().unwrap(); + child_ctx.toplevel = false; + child_ctx.string_position -= back; + child_ctx.string_offset = back_offset; + Some(()) }, |drive| { diff --git a/tests/positive_lookbehind.py b/tests/positive_lookbehind.py new file mode 100644 index 0000000000..2a0ab29253 --- /dev/null +++ b/tests/positive_lookbehind.py @@ -0,0 +1 @@ +re.compile(r'(?<=abc)def') \ No newline at end of file diff --git a/tests/positive_lookbehind.re b/tests/positive_lookbehind.re new file mode 100644 index 0000000000..68923b58ee --- /dev/null +++ b/tests/positive_lookbehind.re @@ -0,0 +1,2 @@ +// '(?<=abc)def', flags=re.UNICODE +Pattern { code: &[15, 4, 0, 3, 3, 4, 9, 3, 17, 97, 17, 98, 17, 99, 1, 17, 100, 17, 101, 17, 102, 1], flags: SreFlag::from_bits_truncate(32) } \ No newline at end of file diff --git a/tests/tests.rs b/tests/tests.rs index 4db110177d..8a18b5f333 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -24,3 +24,12 @@ fn test_2427() { state = state.pymatch(); assert!(state.has_matched == Some(true)); } + +#[test] +fn test_assert() { + // '(?<=abc)def', flags=re.UNICODE + let pattern = include!("positive_lookbehind.re"); + let mut state = pattern.state("abcdef", 0..usize::MAX); + state = state.search(); + assert!(state.has_matched == Some(true)); +} From 2a43d66e11a4a2245c12b24484b14eb01d7033a1 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Thu, 1 Apr 2021 21:38:14 -0500 Subject: [PATCH 038/893] Fix clippy lint --- src/constants.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index f5ab92c531..f3962b339a 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -16,7 +16,7 @@ use bitflags::bitflags; pub const SRE_MAGIC: usize = 20171005; #[derive(num_enum::TryFromPrimitive, Debug)] #[repr(u32)] -#[allow(non_camel_case_types)] +#[allow(non_camel_case_types, clippy::upper_case_acronyms)] pub enum SreOpcode { FAILURE = 0, SUCCESS = 1, @@ -62,7 +62,7 @@ pub enum SreOpcode { } #[derive(num_enum::TryFromPrimitive, Debug)] #[repr(u32)] -#[allow(non_camel_case_types)] +#[allow(non_camel_case_types, clippy::upper_case_acronyms)] pub enum SreAtCode { BEGINNING = 0, BEGINNING_LINE = 1, @@ -79,7 +79,7 @@ pub enum SreAtCode { } #[derive(num_enum::TryFromPrimitive, Debug)] #[repr(u32)] -#[allow(non_camel_case_types)] +#[allow(non_camel_case_types, clippy::upper_case_acronyms)] pub enum SreCatCode { DIGIT = 0, NOT_DIGIT = 1, From ca1346ee031624b37e6e5d531a1c5dcaa5b284fc Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Thu, 1 Apr 2021 22:05:45 -0500 Subject: [PATCH 039/893] Have generate_tests.py generate Patterns inline in tests.rs --- generate_tests.py | 21 +++++++++++++++------ tests/lookbehind.py | 1 - tests/lookbehind.re | 2 -- tests/positive_lookbehind.py | 1 - tests/positive_lookbehind.re | 2 -- tests/tests.rs | 16 ++++++++++------ 6 files changed, 25 insertions(+), 18 deletions(-) delete mode 100644 tests/lookbehind.py delete mode 100644 tests/lookbehind.re delete mode 100644 tests/positive_lookbehind.py delete mode 100644 tests/positive_lookbehind.re diff --git a/generate_tests.py b/generate_tests.py index 49a24792be..7af1d2f0c2 100644 --- a/generate_tests.py +++ b/generate_tests.py @@ -26,12 +26,21 @@ def compile(cls, pattern, flags=0): for k, v in re.RegexFlag.__members__.items(): setattr(CompiledPattern, k, v) + +# matches `// pattern {varname} = re.compile(...)` +pattern_pattern = re.compile(r"^((\s*)\/\/\s*pattern\s+(\w+)\s+=\s+(.+?))$(?:.+?END GENERATED)?", re.M | re.S) +def replace_compiled(m): + line, indent, varname, pattern = m.groups() + pattern = eval(pattern, {"re": CompiledPattern}) + pattern = f"Pattern {{ code: &{json.dumps(pattern.code)}, flags: SreFlag::from_bits_truncate({int(pattern.flags)}) }}" + return f'''{line} +{indent}// START GENERATED by generate_tests.py +{indent}#[rustfmt::skip] let {varname} = {pattern}; +{indent}// END GENERATED''' + with os.scandir("tests") as d: for f in d: path = Path(f.path) - if path.suffix == ".py": - pattern = eval(path.read_text(), {"re": CompiledPattern}) - path.with_suffix(".re").write_text( - f"// {pattern.pattern!r}, flags={pattern.flags!r}\n" - f"Pattern {{ code: &{json.dumps(pattern.code)}, flags: SreFlag::from_bits_truncate({int(pattern.flags)}) }}" - ) + if path.suffix == ".rs": + replaced = pattern_pattern.sub(replace_compiled, path.read_text()) + path.write_text(replaced) diff --git a/tests/lookbehind.py b/tests/lookbehind.py deleted file mode 100644 index 3da6425959..0000000000 --- a/tests/lookbehind.py +++ /dev/null @@ -1 +0,0 @@ -re.compile(r'(? Date: Mon, 5 Apr 2021 11:10:32 -0500 Subject: [PATCH 040/893] Add more info to Cargo.toml --- Cargo.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 03db7aba4f..8e69ab5235 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,8 +2,11 @@ name = "sre-engine" version = "0.1.0" authors = ["Kangzhi Shi ", "RustPython Team"] +description = "A low-level implementation of Python's SRE regex engine" +repository = "https://github.com/RustPython/sre-engine" license = "MIT" edition = "2018" +keywords = ["regex"] [dependencies] num_enum = "0.5" From 9728dd8699a255686deaedbdd0f23f549e62009f Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 16 Apr 2021 09:35:11 +0200 Subject: [PATCH 041/893] fix test_string_boundaries --- src/engine.rs | 14 +++++++++++--- tests/tests.rs | 11 +++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 5e0e0f4208..0de9f43844 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -292,6 +292,14 @@ trait MatchContextDrive { let this = !self.at_end() && word_checker(self.peek_char()); this != that } + fn at_non_boundary bool>(&self, mut word_checker: F) -> bool { + if self.at_beginning() && self.at_end() { + return false; + } + let that = !self.at_beginning() && word_checker(self.back_peek_char()); + let this = !self.at_end() && word_checker(self.peek_char()); + this == that + } fn back_peek_char(&self) -> u32 { self.state().string.back_peek(self.ctx().string_offset) } @@ -738,14 +746,14 @@ fn at(drive: &StackDrive, atcode: SreAtCode) -> bool { SreAtCode::BEGINNING | SreAtCode::BEGINNING_STRING => drive.at_beginning(), SreAtCode::BEGINNING_LINE => drive.at_beginning() || is_linebreak(drive.back_peek_char()), SreAtCode::BOUNDARY => drive.at_boundary(is_word), - SreAtCode::NON_BOUNDARY => !drive.at_boundary(is_word), + SreAtCode::NON_BOUNDARY => drive.at_non_boundary(is_word), SreAtCode::END => (drive.remaining_chars() == 1 && drive.at_linebreak()) || drive.at_end(), SreAtCode::END_LINE => drive.at_linebreak() || drive.at_end(), SreAtCode::END_STRING => drive.at_end(), SreAtCode::LOC_BOUNDARY => drive.at_boundary(is_loc_word), - SreAtCode::LOC_NON_BOUNDARY => !drive.at_boundary(is_loc_word), + SreAtCode::LOC_NON_BOUNDARY => drive.at_non_boundary(is_loc_word), SreAtCode::UNI_BOUNDARY => drive.at_boundary(is_uni_word), - SreAtCode::UNI_NON_BOUNDARY => !drive.at_boundary(is_uni_word), + SreAtCode::UNI_NON_BOUNDARY => drive.at_non_boundary(is_uni_word), } } diff --git a/tests/tests.rs b/tests/tests.rs index f4cd091f0d..690c72861b 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -37,3 +37,14 @@ fn test_assert() { state = state.search(); assert!(state.has_matched == Some(true)); } + +#[test] +fn test_string_boundaries() { + // pattern big_b = re.compile(r'\B') + // START GENERATED by generate_tests.py + #[rustfmt::skip] let big_b = Pattern { code: &[15, 4, 0, 0, 0, 6, 11, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + let mut state = big_b.state("", 0..usize::MAX); + state = state.search(); + assert!(state.has_matched == None) +} From d2b48fdea2986e8513d63d19ea30859c5a48a66e Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Fri, 16 Apr 2021 10:40:42 -0500 Subject: [PATCH 042/893] Release 0.1.1 sre-engine@0.1.1 Generated by cargo-workspaces --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 8e69ab5235..cf830a6053 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sre-engine" -version = "0.1.0" +version = "0.1.1" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" repository = "https://github.com/RustPython/sre-engine" From 73abbace85aebf063eb8c7ce4815cacd2449f522 Mon Sep 17 00:00:00 2001 From: Noah <33094578+coolreader18@users.noreply.github.com> Date: Fri, 16 Apr 2021 10:53:37 -0500 Subject: [PATCH 043/893] Add explicit include for Cargo files --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index cf830a6053..f0cd628f0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ repository = "https://github.com/RustPython/sre-engine" license = "MIT" edition = "2018" keywords = ["regex"] +include = ["LICENSE", "src/**/*.rs"] [dependencies] num_enum = "0.5" From df8453d387c0a4bc6b84131f40703b46000ba7ee Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 20 Apr 2021 10:19:27 +0200 Subject: [PATCH 044/893] fix zerowidth search --- Cargo.toml | 2 +- src/engine.rs | 99 +++++++++++++++++++++++++++++++++++--------------- tests/tests.rs | 18 +++++++-- 3 files changed, 85 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f0cd628f0e..614243eeb1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sre-engine" -version = "0.1.1" +version = "0.1.2" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" repository = "https://github.com/RustPython/sre-engine" diff --git a/src/engine.rs b/src/engine.rs index 0de9f43844..bded52448d 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -23,8 +23,9 @@ pub struct State<'a> { repeat_stack: Vec, pub string_position: usize, popped_context: Option, - pub has_matched: Option, + pub has_matched: bool, pub match_all: bool, + pub must_advance: bool, } impl<'a> State<'a> { @@ -50,8 +51,9 @@ impl<'a> State<'a> { marks: Vec::new(), string_position: start, popped_context: None, - has_matched: None, + has_matched: false, match_all: false, + must_advance: false, } } @@ -63,7 +65,7 @@ impl<'a> State<'a> { self.marks.clear(); self.string_position = self.start; self.popped_context = None; - self.has_matched = None; + self.has_matched = false; } fn set_mark(&mut self, mark_nr: usize, position: usize) { @@ -100,17 +102,7 @@ impl<'a> State<'a> { self.marks_stack.pop(); } - pub fn pymatch(mut self) -> Self { - let ctx = MatchContext { - string_position: self.start, - string_offset: self.string.offset(0, self.start), - code_position: 0, - has_matched: None, - toplevel: true, - }; - self.context_stack.push(ctx); - - let mut dispatcher = OpcodeDispatcher::new(); + fn _match(mut self, dispatcher: &mut OpcodeDispatcher) -> Self { let mut has_matched = None; loop { @@ -127,21 +119,58 @@ impl<'a> State<'a> { } } - self.has_matched = has_matched; + self.has_matched = has_matched == Some(true); self } + pub fn pymatch(mut self) -> Self { + let ctx = MatchContext { + string_position: self.start, + string_offset: self.string.offset(0, self.start), + code_position: 0, + has_matched: None, + toplevel: true, + }; + self.context_stack.push(ctx); + + let mut dispatcher = OpcodeDispatcher::new(); + + self._match(&mut dispatcher) + } + pub fn search(mut self) -> Self { // TODO: optimize by op info and skip prefix - while self.start <= self.end { - self.match_all = false; - self = self.pymatch(); - if self.has_matched == Some(true) { - return self; - } + if self.start > self.end { + return self; + } + + let mut dispatcher = OpcodeDispatcher::new(); + + let ctx = MatchContext { + string_position: self.start, + string_offset: self.string.offset(0, self.start), + code_position: 0, + has_matched: None, + toplevel: true, + }; + self.context_stack.push(ctx); + self = self._match(&mut dispatcher); + + self.must_advance = false; + while !self.has_matched && self.start < self.end { self.start += 1; self.reset(); + dispatcher.clear(); + let ctx = MatchContext { + string_position: self.start, + string_offset: self.string.offset(0, self.start), + code_position: 0, + has_matched: None, + toplevel: false, + }; + self.context_stack.push(ctx); + self = self._match(&mut dispatcher); } self @@ -310,6 +339,18 @@ trait MatchContextDrive { .string .back_offset(self.ctx().string_offset, skip_count); } + fn can_success(&self) -> bool { + if !self.ctx().toplevel { + return true; + } + if self.state().match_all && !self.at_end() { + return false; + } + if self.state().must_advance && self.ctx().string_position == self.state().start { + return false; + } + true + } } struct StackDrive<'a> { @@ -429,6 +470,9 @@ impl OpcodeDispatcher { executing_contexts: HashMap::new(), } } + fn clear(&mut self) { + self.executing_contexts.clear(); + } // Returns True if the current context matches, False if it doesn't and // None if matching is not finished, ie must be resumed after child // contexts have been matched. @@ -470,11 +514,9 @@ impl OpcodeDispatcher { drive.ctx_mut().has_matched = Some(false); }), SreOpcode::SUCCESS => once(|drive| { - if drive.ctx().toplevel && drive.state.match_all && !drive.at_end() { - drive.ctx_mut().has_matched = Some(false); - } else { + drive.ctx_mut().has_matched = Some(drive.can_success()); + if drive.ctx().has_matched == Some(true) { drive.state.string_position = drive.ctx().string_position; - drive.ctx_mut().has_matched = Some(true); } }), SreOpcode::ANY => once(|drive| { @@ -1152,9 +1194,7 @@ impl OpcodeExecutor for OpMinRepeatOne { }; let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); - if next_code == SreOpcode::SUCCESS as u32 - && !(drive.ctx().toplevel && drive.state.match_all && !drive.at_end()) - { + if next_code == SreOpcode::SUCCESS as u32 && drive.can_success() { // tail is empty. we're finished drive.state.string_position = drive.ctx().string_position; drive.ctx_mut().has_matched = Some(true); @@ -1455,8 +1495,7 @@ impl OpcodeExecutor for OpRepeatOne { } let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); - if next_code == SreOpcode::SUCCESS as u32 && drive.at_end() && !drive.ctx().toplevel - { + if next_code == SreOpcode::SUCCESS as u32 && drive.can_success() { // tail is empty. we're finished drive.state.string_position = drive.ctx().string_position; drive.ctx_mut().has_matched = Some(true); diff --git a/tests/tests.rs b/tests/tests.rs index 690c72861b..d76ff3cfb5 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -24,7 +24,7 @@ fn test_2427() { // END GENERATED let mut state = lookbehind.state("x", 0..usize::MAX); state = state.pymatch(); - assert!(state.has_matched == Some(true)); + assert!(state.has_matched); } #[test] @@ -35,7 +35,7 @@ fn test_assert() { // END GENERATED let mut state = positive_lookbehind.state("abcdef", 0..usize::MAX); state = state.search(); - assert!(state.has_matched == Some(true)); + assert!(state.has_matched); } #[test] @@ -46,5 +46,17 @@ fn test_string_boundaries() { // END GENERATED let mut state = big_b.state("", 0..usize::MAX); state = state.search(); - assert!(state.has_matched == None) + assert!(!state.has_matched); +} + +#[test] +fn test_zerowidth() { + // pattern p = re.compile(r'\b|:+') + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { code: &[15, 4, 0, 0, 4294967295, 7, 5, 6, 10, 16, 12, 10, 25, 6, 1, 4294967295, 17, 58, 1, 16, 2, 0, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + let mut state = p.state("a:", 0..usize::MAX); + state.must_advance = true; + state = state.search(); + assert!(state.string_position == 1); } From a3c3573d67f94d6119a8bb7126f385c38ba438e8 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 20 Apr 2021 17:36:32 +0200 Subject: [PATCH 045/893] optimize count --- src/engine.rs | 29 +++++++---------------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index bded52448d..5409baf66d 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -916,7 +916,7 @@ fn charset(set: &[u32], ch: u32) -> bool { } /* General case */ -fn count(drive: &mut StackDrive, maxcount: usize) -> usize { +fn general_count(drive: &mut StackDrive, maxcount: usize) -> usize { let mut count = 0; let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); @@ -937,18 +937,11 @@ fn count(drive: &mut StackDrive, maxcount: usize) -> usize { count } -/* TODO: check literal cases should improve the perfermance - -fn _count(stack_drive: &StackDrive, maxcount: usize) -> usize { +fn count(stack_drive: &mut StackDrive, maxcount: usize) -> usize { let mut drive = WrapDrive::drive(*stack_drive.ctx(), stack_drive); let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); let end = drive.ctx().string_position + maxcount; - let opcode = match SreOpcode::try_from(drive.peek_code(1)) { - Ok(code) => code, - Err(_) => { - panic!("FIXME:COUNT1"); - } - }; + let opcode = SreOpcode::try_from(drive.peek_code(0)).unwrap(); match opcode { SreOpcode::ANY => { @@ -960,7 +953,6 @@ fn _count(stack_drive: &StackDrive, maxcount: usize) -> usize { drive.skip_char(maxcount); } SreOpcode::IN => { - // TODO: pattern[2 or 1..]? while !drive.ctx().string_position < end && charset(&drive.pattern()[2..], drive.peek_char()) { @@ -992,7 +984,7 @@ fn _count(stack_drive: &StackDrive, maxcount: usize) -> usize { general_count_literal(&mut drive, end, |code, c| code != lower_unicode(c) as u32); } _ => { - todo!("repeated single character pattern?"); + return general_count(stack_drive, maxcount); } } @@ -1006,11 +998,6 @@ fn general_count_literal bool>(drive: &mut WrapDrive, end: } } -fn eq_loc_ignore(code: u32, ch: u32) -> bool { - code == ch || code == lower_locate(ch) || code == upper_locate(ch) -} -*/ - fn is_word(ch: u32) -> bool { ch == '_' as u32 || u8::try_from(ch) @@ -1028,7 +1015,7 @@ fn is_digit(ch: u32) -> bool { .unwrap_or(false) } fn is_loc_alnum(ch: u32) -> bool { - // TODO: check with cpython + // FIXME: Ignore the locales u8::try_from(ch) .map(|x| x.is_ascii_alphanumeric()) .unwrap_or(false) @@ -1045,13 +1032,11 @@ pub fn lower_ascii(ch: u32) -> u32 { .unwrap_or(ch) } fn lower_locate(ch: u32) -> u32 { - // TODO: check with cpython - // https://doc.rust-lang.org/std/primitive.char.html#method.to_lowercase + // FIXME: Ignore the locales lower_ascii(ch) } fn upper_locate(ch: u32) -> u32 { - // TODO: check with cpython - // https://doc.rust-lang.org/std/primitive.char.html#method.to_uppercase + // FIXME: Ignore the locales u8::try_from(ch) .map(|x| x.to_ascii_uppercase() as u32) .unwrap_or(ch) From 5bd6b672d089fa4dc8db1d5d3d8564f6a98dbacd Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 21 Apr 2021 11:09:10 +0200 Subject: [PATCH 046/893] optimize opcode that execute only once --- src/engine.rs | 183 ++++++++++++++++++++++++++++---------------------- 1 file changed, 102 insertions(+), 81 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 5409baf66d..2888b6930c 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -416,20 +416,6 @@ trait OpcodeExecutor { fn next(&mut self, drive: &mut StackDrive) -> Option<()>; } -struct OpOnce { - f: Option, -} -impl OpcodeExecutor for OpOnce { - fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - let f = self.f.take()?; - f(drive); - None - } -} -fn once(f: F) -> Box> { - Box::new(OpOnce { f: Some(f) }) -} - struct OpTwice { f1: Option, f2: Option, @@ -496,48 +482,58 @@ impl OpcodeDispatcher { // Dispatches a context on a given opcode. Returns True if the context // is done matching, False if it must be resumed when next encountered. fn dispatch(&mut self, opcode: SreOpcode, drive: &mut StackDrive) -> bool { - let mut executor = match self.executing_contexts.remove_entry(&drive.id()) { - Some((_, executor)) => executor, - None => self.dispatch_table(opcode), - }; - if let Some(()) = executor.next(drive) { - self.executing_contexts.insert(drive.id(), executor); - false - } else { - true + let executor = self + .executing_contexts + .remove_entry(&drive.id()) + .map(|(_, x)| x) + .or_else(|| self.dispatch_table(opcode, drive)); + if let Some(mut executor) = executor { + if let Some(()) = executor.next(drive) { + self.executing_contexts.insert(drive.id(), executor); + return false; + } } + true } - fn dispatch_table(&mut self, opcode: SreOpcode) -> Box { + fn dispatch_table( + &mut self, + opcode: SreOpcode, + drive: &mut StackDrive, + ) -> Option> { match opcode { - SreOpcode::FAILURE => once(|drive| { + SreOpcode::FAILURE => { drive.ctx_mut().has_matched = Some(false); - }), - SreOpcode::SUCCESS => once(|drive| { + None + } + SreOpcode::SUCCESS => { drive.ctx_mut().has_matched = Some(drive.can_success()); if drive.ctx().has_matched == Some(true) { drive.state.string_position = drive.ctx().string_position; } - }), - SreOpcode::ANY => once(|drive| { + None + } + SreOpcode::ANY => { if drive.at_end() || drive.at_linebreak() { drive.ctx_mut().has_matched = Some(false); } else { drive.skip_code(1); drive.skip_char(1); } - }), - SreOpcode::ANY_ALL => once(|drive| { + None + } + SreOpcode::ANY_ALL => { if drive.at_end() { drive.ctx_mut().has_matched = Some(false); } else { drive.skip_code(1); drive.skip_char(1); } - }), + None + } /* assert subpattern */ /* */ - SreOpcode::ASSERT => twice( + SreOpcode::ASSERT => Some(twice( |drive| { let back = drive.peek_code(2) as usize; let passed = drive.ctx().string_position; @@ -568,8 +564,8 @@ impl OpcodeDispatcher { drive.ctx_mut().has_matched = Some(false); } }, - ), - SreOpcode::ASSERT_NOT => twice( + )), + SreOpcode::ASSERT_NOT => Some(twice( |drive| { let back = drive.peek_code(2) as usize; let passed = drive.ctx().string_position; @@ -600,17 +596,18 @@ impl OpcodeDispatcher { drive.skip_code(drive.peek_code(1) as usize + 1); } }, - ), - SreOpcode::AT => once(|drive| { + )), + SreOpcode::AT => { let atcode = SreAtCode::try_from(drive.peek_code(1)).unwrap(); if !at(drive, atcode) { drive.ctx_mut().has_matched = Some(false); } else { drive.skip_code(2); } - }), - SreOpcode::BRANCH => Box::new(OpBranch::default()), - SreOpcode::CATEGORY => once(|drive| { + None + } + SreOpcode::BRANCH => Some(Box::new(OpBranch::default())), + SreOpcode::CATEGORY => { let catcode = SreCatCode::try_from(drive.peek_code(1)).unwrap(); if drive.at_end() || !category(catcode, drive.peek_char()) { drive.ctx_mut().has_matched = Some(false); @@ -618,53 +615,68 @@ impl OpcodeDispatcher { drive.skip_code(2); drive.skip_char(1); } - }), - SreOpcode::IN => once(|drive| { + None + } + SreOpcode::IN => { general_op_in(drive, |set, c| charset(set, c)); - }), - SreOpcode::IN_IGNORE => once(|drive| { + None + } + SreOpcode::IN_IGNORE => { general_op_in(drive, |set, c| charset(set, lower_ascii(c))); - }), - SreOpcode::IN_UNI_IGNORE => once(|drive| { + None + } + SreOpcode::IN_UNI_IGNORE => { general_op_in(drive, |set, c| charset(set, lower_unicode(c))); - }), - SreOpcode::IN_LOC_IGNORE => once(|drive| { + None + } + SreOpcode::IN_LOC_IGNORE => { general_op_in(drive, |set, c| charset_loc_ignore(set, c)); - }), - SreOpcode::INFO | SreOpcode::JUMP => once(|drive| { + None + } + SreOpcode::INFO | SreOpcode::JUMP => { drive.skip_code(drive.peek_code(1) as usize + 1); - }), - SreOpcode::LITERAL => once(|drive| { + None + } + SreOpcode::LITERAL => { general_op_literal(drive, |code, c| code == c); - }), - SreOpcode::NOT_LITERAL => once(|drive| { + None + } + SreOpcode::NOT_LITERAL => { general_op_literal(drive, |code, c| code != c); - }), - SreOpcode::LITERAL_IGNORE => once(|drive| { + None + } + SreOpcode::LITERAL_IGNORE => { general_op_literal(drive, |code, c| code == lower_ascii(c)); - }), - SreOpcode::NOT_LITERAL_IGNORE => once(|drive| { + None + } + SreOpcode::NOT_LITERAL_IGNORE => { general_op_literal(drive, |code, c| code != lower_ascii(c)); - }), - SreOpcode::LITERAL_UNI_IGNORE => once(|drive| { + None + } + SreOpcode::LITERAL_UNI_IGNORE => { general_op_literal(drive, |code, c| code == lower_unicode(c)); - }), - SreOpcode::NOT_LITERAL_UNI_IGNORE => once(|drive| { + None + } + SreOpcode::NOT_LITERAL_UNI_IGNORE => { general_op_literal(drive, |code, c| code != lower_unicode(c)); - }), - SreOpcode::LITERAL_LOC_IGNORE => once(|drive| { + None + } + SreOpcode::LITERAL_LOC_IGNORE => { general_op_literal(drive, char_loc_ignore); - }), - SreOpcode::NOT_LITERAL_LOC_IGNORE => once(|drive| { + None + } + SreOpcode::NOT_LITERAL_LOC_IGNORE => { general_op_literal(drive, |code, c| !char_loc_ignore(code, c)); - }), - SreOpcode::MARK => once(|drive| { + None + } + SreOpcode::MARK => { drive .state .set_mark(drive.peek_code(1) as usize, drive.ctx().string_position); drive.skip_code(2); - }), - SreOpcode::REPEAT => twice( + None + } + SreOpcode::REPEAT => Some(twice( // create repeat context. all the hard work is done by the UNTIL // operator (MAX_UNTIL, MIN_UNTIL) // <1=min> <2=max> item tail @@ -687,20 +699,28 @@ impl OpcodeDispatcher { let child_ctx = drive.state.popped_context.unwrap(); drive.ctx_mut().has_matched = child_ctx.has_matched; }, - ), - SreOpcode::MAX_UNTIL => Box::new(OpMaxUntil::default()), - SreOpcode::MIN_UNTIL => Box::new(OpMinUntil::default()), - SreOpcode::REPEAT_ONE => Box::new(OpRepeatOne::default()), - SreOpcode::MIN_REPEAT_ONE => Box::new(OpMinRepeatOne::default()), - SreOpcode::GROUPREF => once(|drive| general_op_groupref(drive, |x| x)), - SreOpcode::GROUPREF_IGNORE => once(|drive| general_op_groupref(drive, lower_ascii)), + )), + SreOpcode::MAX_UNTIL => Some(Box::new(OpMaxUntil::default())), + SreOpcode::MIN_UNTIL => Some(Box::new(OpMinUntil::default())), + SreOpcode::REPEAT_ONE => Some(Box::new(OpRepeatOne::default())), + SreOpcode::MIN_REPEAT_ONE => Some(Box::new(OpMinRepeatOne::default())), + SreOpcode::GROUPREF => { + general_op_groupref(drive, |x| x); + None + } + SreOpcode::GROUPREF_IGNORE => { + general_op_groupref(drive, lower_ascii); + None + } SreOpcode::GROUPREF_LOC_IGNORE => { - once(|drive| general_op_groupref(drive, lower_locate)) + general_op_groupref(drive, lower_locate); + None } SreOpcode::GROUPREF_UNI_IGNORE => { - once(|drive| general_op_groupref(drive, lower_unicode)) + general_op_groupref(drive, lower_unicode); + None } - SreOpcode::GROUPREF_EXISTS => once(|drive| { + SreOpcode::GROUPREF_EXISTS => { let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); match (group_start, group_end) { (Some(start), Some(end)) if start <= end => { @@ -708,7 +728,8 @@ impl OpcodeDispatcher { } _ => drive.skip_code(drive.peek_code(2) as usize + 1), } - }), + None + } _ => { // TODO python expcetion? unreachable!("unexpected opcode") From 7324feef89dd692d909c98849e54969617728ecf Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 21 Apr 2021 11:13:58 +0200 Subject: [PATCH 047/893] optimize search cache the string offset --- src/engine.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 2888b6930c..2b85ea3514 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -147,9 +147,11 @@ impl<'a> State<'a> { let mut dispatcher = OpcodeDispatcher::new(); + let mut start_offset = self.string.offset(0, self.start); + let ctx = MatchContext { string_position: self.start, - string_offset: self.string.offset(0, self.start), + string_offset: start_offset, code_position: 0, has_matched: None, toplevel: true, @@ -160,11 +162,12 @@ impl<'a> State<'a> { self.must_advance = false; while !self.has_matched && self.start < self.end { self.start += 1; + start_offset = self.string.offset(start_offset, 1); self.reset(); dispatcher.clear(); let ctx = MatchContext { string_position: self.start, - string_offset: self.string.offset(0, self.start), + string_offset: start_offset, code_position: 0, has_matched: None, toplevel: false, From 86435b8a4b44d0b79109ace6acd66cbfcebaac66 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 22 Apr 2021 17:15:03 +0200 Subject: [PATCH 048/893] add benchmark --- benches/benches.rs | 112 +++++++++++++++++++++++++++++++++++++++++++++ generate_tests.py | 5 +- 2 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 benches/benches.rs diff --git a/benches/benches.rs b/benches/benches.rs new file mode 100644 index 0000000000..b86a592967 --- /dev/null +++ b/benches/benches.rs @@ -0,0 +1,112 @@ +#![feature(test)] + +extern crate test; +use test::Bencher; + +use sre_engine::constants::SreFlag; +use sre_engine::engine; +pub struct Pattern { + pub code: &'static [u32], + pub flags: SreFlag, +} + +impl Pattern { + pub fn state<'a>( + &self, + string: impl Into>, + range: std::ops::Range, + ) -> engine::State<'a> { + engine::State::new(string.into(), range.start, range.end, self.flags, self.code) + } +} +#[bench] +fn benchmarks(b: &mut Bencher) { + // # test common prefix + // pattern p1 = re.compile('Python|Perl') # , 'Perl'), # Alternation + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p1 = Pattern { code: &[15, 8, 1, 4, 6, 1, 1, 80, 0, 17, 80, 7, 13, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 11, 9, 17, 101, 17, 114, 17, 108, 16, 2, 0, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + // pattern p2 = re.compile('(Python|Perl)') #, 'Perl'), # Grouped alternation + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p2 = Pattern { code: &[15, 8, 1, 4, 6, 1, 0, 80, 0, 18, 0, 17, 80, 7, 13, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 11, 9, 17, 101, 17, 114, 17, 108, 16, 2, 0, 18, 1, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + // pattern pn = re.compile('Python|Perl|Tcl') #, 'Perl'), # Alternation + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p3 = Pattern { code: &[15, 9, 4, 3, 6, 17, 80, 17, 84, 0, 7, 15, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 22, 11, 17, 80, 17, 101, 17, 114, 17, 108, 16, 11, 9, 17, 84, 17, 99, 17, 108, 16, 2, 0, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + // pattern pn = re.compile('(Python|Perl|Tcl)') #, 'Perl'), # Grouped alternation + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p4 = Pattern { code: &[15, 9, 4, 3, 6, 17, 80, 17, 84, 0, 18, 0, 7, 15, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 22, 11, 17, 80, 17, 101, 17, 114, 17, 108, 16, 11, 9, 17, 84, 17, 99, 17, 108, 16, 2, 0, 18, 1, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + // pattern pn = re.compile('(Python)\\1') #, 'PythonPython'), # Backreference + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p5 = Pattern { code: &[15, 18, 1, 12, 12, 6, 0, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 18, 0, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 18, 1, 12, 0, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + // pattern pn = re.compile('([0a-z][a-z0-9]*,)+') #, 'a5,b7,c9,'), # Disable the fastmap optimization + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p6 = Pattern { code: &[15, 4, 0, 2, 4294967295, 24, 31, 1, 4294967295, 18, 0, 14, 7, 17, 48, 23, 97, 122, 0, 25, 13, 0, 4294967295, 14, 8, 23, 97, 122, 23, 48, 57, 0, 1, 17, 44, 18, 1, 19, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + // pattern pn = re.compile('([a-z][a-z0-9]*,)+') #, 'a5,b7,c9,'), # A few sets + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p7 = Pattern { code: &[15, 4, 0, 2, 4294967295, 24, 29, 1, 4294967295, 18, 0, 14, 5, 23, 97, 122, 0, 25, 13, 0, 4294967295, 14, 8, 23, 97, 122, 23, 48, 57, 0, 1, 17, 44, 18, 1, 19, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + // pattern pn = re.compile('Python') #, 'Python'), # Simple text literal + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p8 = Pattern { code: &[15, 18, 3, 6, 6, 6, 6, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + // pattern pn = re.compile('.*Python') #, 'Python'), # Bad text literal + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p9 = Pattern { code: &[15, 4, 0, 6, 4294967295, 25, 5, 0, 4294967295, 2, 1, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + // pattern pn = re.compile('.*Python.*') #, 'Python'), # Worse text literal + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p10 = Pattern { code: &[15, 4, 0, 6, 4294967295, 25, 5, 0, 4294967295, 2, 1, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 25, 5, 0, 4294967295, 2, 1, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + // pattern pn = re.compile('.*(Python)') #, 'Python'), # Bad text literal with grouping + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p11 = Pattern { code: &[15, 4, 0, 6, 4294967295, 25, 5, 0, 4294967295, 2, 1, 18, 0, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 18, 1, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + + let tests = [ + (p1, "Perl"), + (p2, "Perl"), + (p3, "Perl"), + (p4, "Perl"), + (p5, "PythonPython"), + (p6, "a5,b7,c9,"), + (p7, "a5,b7,c9,"), + (p8, "Python"), + (p9, "Python"), + (p10, "Python"), + (p11, "Python"), + ]; + + b.iter(move || { + for (p, s) in &tests { + let mut state = p.state(s.clone(), 0..usize::MAX); + state = state.search(); + assert!(state.has_matched); + state = p.state(s.clone(), 0..usize::MAX); + state = state.pymatch(); + assert!(state.has_matched); + state = p.state(s.clone(), 0..usize::MAX); + state.match_all = true; + state = state.pymatch(); + assert!(state.has_matched); + let s2 = format!("{}{}{}", " ".repeat(10000), s, " ".repeat(10000)); + state = p.state(s2.as_str(), 0..usize::MAX); + state = state.search(); + assert!(state.has_matched); + state = p.state(s2.as_str(), 10000..usize::MAX); + state = state.pymatch(); + assert!(state.has_matched); + state = p.state(s2.as_str(), 10000..10000 + s.len()); + state = state.pymatch(); + assert!(state.has_matched); + state = p.state(s2.as_str(), 10000..10000 + s.len()); + state.match_all = true; + state = state.pymatch(); + assert!(state.has_matched); + } + }) +} diff --git a/generate_tests.py b/generate_tests.py index 7af1d2f0c2..b432720cd1 100644 --- a/generate_tests.py +++ b/generate_tests.py @@ -5,6 +5,7 @@ import sre_compile import sre_parse import json +from itertools import chain m = re.search(r"const SRE_MAGIC: usize = (\d+);", open("src/constants.rs").read()) sre_engine_magic = int(m.group(1)) @@ -38,8 +39,8 @@ def replace_compiled(m): {indent}#[rustfmt::skip] let {varname} = {pattern}; {indent}// END GENERATED''' -with os.scandir("tests") as d: - for f in d: +with os.scandir("tests") as t, os.scandir("benches") as b: + for f in chain(t, b): path = Path(f.path) if path.suffix == ".rs": replaced = pattern_pattern.sub(replace_compiled, path.read_text()) From 58981a41e99cbcb53002b559ea13c7131b597938 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 22 Apr 2021 19:27:24 +0200 Subject: [PATCH 049/893] optimize; replace hashmap with btreemap --- src/engine.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 2b85ea3514..3837974838 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -2,7 +2,7 @@ use super::constants::{SreAtCode, SreCatCode, SreFlag, SreOpcode}; use super::MAXREPEAT; -use std::collections::HashMap; +use std::collections::BTreeMap; use std::convert::TryFrom; const fn is_py_ascii_whitespace(b: u8) -> bool { @@ -204,7 +204,7 @@ impl<'a> StrDrive<'a> { .get(offset..) .and_then(|s| s.char_indices().nth(skip).map(|x| x.0 + offset)) .unwrap_or_else(|| s.len()), - StrDrive::Bytes(b) => std::cmp::min(offset + skip, b.len()), + StrDrive::Bytes(_) => offset + skip, } } @@ -294,8 +294,7 @@ trait MatchContextDrive { .state() .string .offset(self.ctx().string_offset, skip_count); - self.ctx_mut().string_position = - std::cmp::min(self.ctx().string_position + skip_count, self.state().end); + self.ctx_mut().string_position += skip_count; } fn skip_code(&mut self, skip_count: usize) { self.ctx_mut().code_position += skip_count; @@ -451,12 +450,12 @@ where } struct OpcodeDispatcher { - executing_contexts: HashMap>, + executing_contexts: BTreeMap>, } impl OpcodeDispatcher { fn new() -> Self { Self { - executing_contexts: HashMap::new(), + executing_contexts: BTreeMap::new(), } } fn clear(&mut self) { @@ -487,8 +486,7 @@ impl OpcodeDispatcher { fn dispatch(&mut self, opcode: SreOpcode, drive: &mut StackDrive) -> bool { let executor = self .executing_contexts - .remove_entry(&drive.id()) - .map(|(_, x)| x) + .remove(&drive.id()) .or_else(|| self.dispatch_table(opcode, drive)); if let Some(mut executor) = executor { if let Some(()) = executor.next(drive) { From 74ebdaf4e8a50330ea6195333bcb47e086ff3b5e Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 11 Jul 2022 21:30:48 +0200 Subject: [PATCH 050/893] fix panic OpMinUntil return before restore repeat --- .vscode/launch.json | 21 +++++++++++++++++++++ src/engine.rs | 10 ++++++---- tests/tests.rs | 11 +++++++++++ 3 files changed, 38 insertions(+), 4 deletions(-) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000..5ebfe34f05 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,21 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug Unit Test", + "cargo": { + "args": [ + "test", + "--no-run" + ], + "filter": { + "kind": "test" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/src/engine.rs b/src/engine.rs index 3837974838..d4e036ff32 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -14,7 +14,7 @@ pub struct State<'a> { pub string: StrDrive<'a>, pub start: usize, pub end: usize, - flags: SreFlag, + _flags: SreFlag, pattern_codes: &'a [u32], pub marks: Vec>, pub lastindex: isize, @@ -42,7 +42,7 @@ impl<'a> State<'a> { string, start, end, - flags, + _flags: flags, pattern_codes, lastindex: -1, marks_stack: Vec::new(), @@ -1380,16 +1380,18 @@ impl OpcodeExecutor for OpMinUntil { None } 2 => { + // restore repeat before return + drive.state.repeat_stack.push(self.save_repeat.unwrap()); + let child_ctx = drive.state.popped_context.unwrap(); if child_ctx.has_matched == Some(true) { drive.ctx_mut().has_matched = Some(true); return None; } - drive.state.repeat_stack.push(self.save_repeat.unwrap()); drive.state.string_position = drive.ctx().string_position; drive.state.marks_pop(); - // match more unital tail matches + // match more until tail matches let RepeatContext { count: _, code_position, diff --git a/tests/tests.rs b/tests/tests.rs index d76ff3cfb5..b430947a9b 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -60,3 +60,14 @@ fn test_zerowidth() { state = state.search(); assert!(state.string_position == 1); } + +#[test] +fn test_repeat_context_panic() { + // pattern p = re.compile(r'(?:a*?(xx)??z)*') + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { code: &[15, 4, 0, 0, 4294967295, 24, 25, 0, 4294967295, 27, 6, 0, 4294967295, 17, 97, 1, 24, 11, 0, 1, 18, 0, 17, 120, 17, 120, 18, 1, 20, 17, 122, 19, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + let mut state = p.state("axxzaz", 0..usize::MAX); + state = state.pymatch(); + assert!(state.marks == vec![Some(1), Some(3)]); +} \ No newline at end of file From 919e1d7933b62bba0830b5b8dce63c1dfc758056 Mon Sep 17 00:00:00 2001 From: Steve Shi Date: Tue, 26 Jul 2022 20:38:03 +0200 Subject: [PATCH 051/893] Refactor and fix multiple max_until recusion (#10) * wip refactor engine * wip 2 refactor engine * wip 3 refactor engine * wip 3 refactor engine * wip 4 refactor engine * wip 5 refactor engine * refactor seperate Stacks * fix clippy * fix pymatch and search restore _stacks * fix toplevel * fix marks panic * fix double max_until repeat context * clearup * update version to 0.2.0 --- Cargo.toml | 2 +- src/engine.rs | 1700 ++++++++++++++++++++++++------------------------ src/lib.rs | 2 +- tests/tests.rs | 13 +- 4 files changed, 847 insertions(+), 870 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 614243eeb1..6ba3996947 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sre-engine" -version = "0.1.2" +version = "0.2.0" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" repository = "https://github.com/RustPython/sre-engine" diff --git a/src/engine.rs b/src/engine.rs index d4e036ff32..81903ccfdd 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -2,7 +2,6 @@ use super::constants::{SreAtCode, SreCatCode, SreFlag, SreOpcode}; use super::MAXREPEAT; -use std::collections::BTreeMap; use std::convert::TryFrom; const fn is_py_ascii_whitespace(b: u8) -> bool { @@ -20,7 +19,7 @@ pub struct State<'a> { pub lastindex: isize, marks_stack: Vec<(Vec>, isize)>, context_stack: Vec, - repeat_stack: Vec, + _stacks: Option>, pub string_position: usize, popped_context: Option, pub has_matched: bool, @@ -44,11 +43,11 @@ impl<'a> State<'a> { end, _flags: flags, pattern_codes, + marks: Vec::new(), lastindex: -1, marks_stack: Vec::new(), context_stack: Vec::new(), - repeat_stack: Vec::new(), - marks: Vec::new(), + _stacks: Default::default(), string_position: start, popped_context: None, has_matched: false, @@ -59,10 +58,12 @@ impl<'a> State<'a> { pub fn reset(&mut self) { self.lastindex = -1; + self.marks.clear(); self.marks_stack.clear(); self.context_stack.clear(); - self.repeat_stack.clear(); - self.marks.clear(); + if let Some(x) = self._stacks.as_mut() { + x.clear() + }; self.string_position = self.start; self.popped_context = None; self.has_matched = false; @@ -102,51 +103,71 @@ impl<'a> State<'a> { self.marks_stack.pop(); } - fn _match(mut self, dispatcher: &mut OpcodeDispatcher) -> Self { - let mut has_matched = None; + fn _match(mut self, stacks: &mut Stacks) -> Self { + while let Some(ctx) = self.context_stack.pop() { + let mut drive = StateContext { + state: self, + ctx, + next_ctx: None, + }; - loop { - if self.context_stack.is_empty() { - break; + if let Some(handler) = drive.ctx.handler { + handler(&mut drive, stacks); + } else if drive.remaining_codes() > 0 { + let code = drive.peek_code(0); + let code = SreOpcode::try_from(code).unwrap(); + dispatch(code, &mut drive, stacks); + } else { + drive.failure(); } - let ctx_id = self.context_stack.len() - 1; - let mut drive = StackDrive::drive(ctx_id, self); - has_matched = dispatcher.pymatch(&mut drive); - self = drive.take(); - if has_matched.is_some() { - self.popped_context = self.context_stack.pop(); + let StateContext { + mut state, + ctx, + next_ctx, + } = drive; + + if ctx.has_matched.is_some() { + state.popped_context = Some(ctx); + } else { + state.context_stack.push(ctx); + if let Some(next_ctx) = next_ctx { + state.context_stack.push(next_ctx); + } } + self = state } - - self.has_matched = has_matched == Some(true); + self.has_matched = self.popped_context.unwrap().has_matched == Some(true); self } pub fn pymatch(mut self) -> Self { + let mut stacks = self._stacks.take().unwrap_or_default(); + let ctx = MatchContext { string_position: self.start, string_offset: self.string.offset(0, self.start), code_position: 0, has_matched: None, toplevel: true, + handler: None, + repeat_ctx_id: usize::MAX, }; self.context_stack.push(ctx); - let mut dispatcher = OpcodeDispatcher::new(); - - self._match(&mut dispatcher) + self = self._match(&mut stacks); + self._stacks = Some(stacks); + self } pub fn search(mut self) -> Self { + let mut stacks = self._stacks.take().unwrap_or_default(); // TODO: optimize by op info and skip prefix if self.start > self.end { return self; } - let mut dispatcher = OpcodeDispatcher::new(); - let mut start_offset = self.string.offset(0, self.start); let ctx = MatchContext { @@ -155,31 +176,664 @@ impl<'a> State<'a> { code_position: 0, has_matched: None, toplevel: true, + handler: None, + repeat_ctx_id: usize::MAX, }; self.context_stack.push(ctx); - self = self._match(&mut dispatcher); + self = self._match(&mut stacks); self.must_advance = false; while !self.has_matched && self.start < self.end { self.start += 1; start_offset = self.string.offset(start_offset, 1); self.reset(); - dispatcher.clear(); + stacks.clear(); + let ctx = MatchContext { string_position: self.start, string_offset: start_offset, code_position: 0, has_matched: None, toplevel: false, + handler: None, + repeat_ctx_id: usize::MAX, }; self.context_stack.push(ctx); - self = self._match(&mut dispatcher); + self = self._match(&mut stacks); } + self._stacks = Some(stacks); self } } +fn dispatch(opcode: SreOpcode, drive: &mut StateContext, stacks: &mut Stacks) { + match opcode { + SreOpcode::FAILURE => { + drive.failure(); + } + SreOpcode::SUCCESS => { + drive.ctx.has_matched = Some(drive.can_success()); + if drive.ctx.has_matched == Some(true) { + drive.state.string_position = drive.ctx.string_position; + } + } + SreOpcode::ANY => { + if drive.at_end() || drive.at_linebreak() { + drive.failure(); + } else { + drive.skip_code(1); + drive.skip_char(1); + } + } + SreOpcode::ANY_ALL => { + if drive.at_end() { + drive.failure(); + } else { + drive.skip_code(1); + drive.skip_char(1); + } + } + SreOpcode::ASSERT => op_assert(drive), + SreOpcode::ASSERT_NOT => op_assert_not(drive), + SreOpcode::AT => { + let atcode = SreAtCode::try_from(drive.peek_code(1)).unwrap(); + if at(drive, atcode) { + drive.skip_code(2); + } else { + drive.failure(); + } + } + SreOpcode::BRANCH => op_branch(drive, stacks), + SreOpcode::CATEGORY => { + let catcode = SreCatCode::try_from(drive.peek_code(1)).unwrap(); + if drive.at_end() || !category(catcode, drive.peek_char()) { + drive.failure(); + } else { + drive.skip_code(2); + drive.skip_char(1); + } + } + SreOpcode::IN => general_op_in(drive, charset), + SreOpcode::IN_IGNORE => general_op_in(drive, |set, c| charset(set, lower_ascii(c))), + SreOpcode::IN_UNI_IGNORE => general_op_in(drive, |set, c| charset(set, lower_unicode(c))), + SreOpcode::IN_LOC_IGNORE => general_op_in(drive, charset_loc_ignore), + SreOpcode::INFO | SreOpcode::JUMP => drive.skip_code_from(1), + SreOpcode::LITERAL => general_op_literal(drive, |code, c| code == c), + SreOpcode::NOT_LITERAL => general_op_literal(drive, |code, c| code != c), + SreOpcode::LITERAL_IGNORE => general_op_literal(drive, |code, c| code == lower_ascii(c)), + SreOpcode::NOT_LITERAL_IGNORE => { + general_op_literal(drive, |code, c| code != lower_ascii(c)) + } + SreOpcode::LITERAL_UNI_IGNORE => { + general_op_literal(drive, |code, c| code == lower_unicode(c)) + } + SreOpcode::NOT_LITERAL_UNI_IGNORE => { + general_op_literal(drive, |code, c| code != lower_unicode(c)) + } + SreOpcode::LITERAL_LOC_IGNORE => general_op_literal(drive, char_loc_ignore), + SreOpcode::NOT_LITERAL_LOC_IGNORE => { + general_op_literal(drive, |code, c| !char_loc_ignore(code, c)) + } + SreOpcode::MARK => { + drive + .state + .set_mark(drive.peek_code(1) as usize, drive.ctx.string_position); + drive.skip_code(2); + } + SreOpcode::MAX_UNTIL => op_max_until(drive, stacks), + SreOpcode::MIN_UNTIL => op_min_until(drive, stacks), + SreOpcode::REPEAT => op_repeat(drive, stacks), + SreOpcode::REPEAT_ONE => op_repeat_one(drive, stacks), + SreOpcode::MIN_REPEAT_ONE => op_min_repeat_one(drive, stacks), + SreOpcode::GROUPREF => general_op_groupref(drive, |x| x), + SreOpcode::GROUPREF_IGNORE => general_op_groupref(drive, lower_ascii), + SreOpcode::GROUPREF_LOC_IGNORE => general_op_groupref(drive, lower_locate), + SreOpcode::GROUPREF_UNI_IGNORE => general_op_groupref(drive, lower_unicode), + SreOpcode::GROUPREF_EXISTS => { + let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); + match (group_start, group_end) { + (Some(start), Some(end)) if start <= end => { + drive.skip_code(3); + } + _ => drive.skip_code_from(2), + } + } + _ => unreachable!("unexpected opcode"), + } +} + +/* assert subpattern */ +/* */ +fn op_assert(drive: &mut StateContext) { + let back = drive.peek_code(2) as usize; + + if drive.ctx.string_position < back { + return drive.failure(); + } + + let offset = drive + .state + .string + .back_offset(drive.ctx.string_offset, back); + let position = drive.ctx.string_position - back; + + drive.state.string_position = position; + + let next_ctx = drive.next_ctx(3, |drive, _| { + if drive.popped_ctx().has_matched == Some(true) { + drive.ctx.handler = None; + drive.skip_code_from(1); + } else { + drive.failure(); + } + }); + next_ctx.string_position = position; + next_ctx.string_offset = offset; + next_ctx.toplevel = false; +} + +/* assert not subpattern */ +/* */ +fn op_assert_not(drive: &mut StateContext) { + let back = drive.peek_code(2) as usize; + + if drive.ctx.string_position < back { + return drive.skip_code_from(1); + } + + let offset = drive + .state + .string + .back_offset(drive.ctx.string_offset, back); + let position = drive.ctx.string_position - back; + + drive.state.string_position = position; + + let next_ctx = drive.next_ctx(3, |drive, _| { + if drive.popped_ctx().has_matched == Some(true) { + drive.failure(); + } else { + drive.ctx.handler = None; + drive.skip_code_from(1); + } + }); + next_ctx.string_position = position; + next_ctx.string_offset = offset; + next_ctx.toplevel = false; +} + +#[derive(Debug)] +struct BranchContext { + branch_offset: usize, +} + +// alternation +// <0=skip> code ... +fn op_branch(drive: &mut StateContext, stacks: &mut Stacks) { + drive.state.marks_push(); + stacks.branch.push(BranchContext { branch_offset: 1 }); + create_context(drive, stacks); + + fn create_context(drive: &mut StateContext, stacks: &mut Stacks) { + let branch_offset = stacks.branch_last().branch_offset; + let next_length = drive.peek_code(branch_offset) as usize; + if next_length == 0 { + drive.state.marks_pop_discard(); + stacks.branch.pop(); + return drive.failure(); + } + + drive.sync_string_position(); + + stacks.branch_last().branch_offset += next_length; + drive.next_ctx(branch_offset + 1, callback); + } + + fn callback(drive: &mut StateContext, stacks: &mut Stacks) { + if drive.popped_ctx().has_matched == Some(true) { + stacks.branch.pop(); + return drive.success(); + } + drive.state.marks_pop_keep(); + drive.ctx.handler = Some(create_context) + } +} + +#[derive(Debug, Copy, Clone)] +struct MinRepeatOneContext { + count: usize, + max_count: usize, +} + +/* <1=min> <2=max> item tail */ +fn op_min_repeat_one(drive: &mut StateContext, stacks: &mut Stacks) { + let min_count = drive.peek_code(2) as usize; + let max_count = drive.peek_code(3) as usize; + + if drive.remaining_chars() < min_count { + return drive.failure(); + } + + drive.sync_string_position(); + + let count = if min_count == 0 { + 0 + } else { + let count = count(drive, stacks, min_count); + if count < min_count { + return drive.failure(); + } + drive.skip_char(count); + count + }; + + let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); + if next_code == SreOpcode::SUCCESS as u32 && drive.can_success() { + // tail is empty. we're finished + drive.sync_string_position(); + return drive.success(); + } + + drive.state.marks_push(); + stacks + .min_repeat_one + .push(MinRepeatOneContext { count, max_count }); + create_context(drive, stacks); + + fn create_context(drive: &mut StateContext, stacks: &mut Stacks) { + let MinRepeatOneContext { count, max_count } = *stacks.min_repeat_one_last(); + + if max_count == MAXREPEAT || count <= max_count { + drive.sync_string_position(); + drive.next_ctx_from(1, callback); + } else { + drive.state.marks_pop_discard(); + stacks.min_repeat_one.pop(); + drive.failure(); + } + } + + fn callback(drive: &mut StateContext, stacks: &mut Stacks) { + if drive.popped_ctx().has_matched == Some(true) { + stacks.min_repeat_one.pop(); + return drive.success(); + } + + drive.sync_string_position(); + + if crate::engine::count(drive, stacks, 1) == 0 { + drive.state.marks_pop_discard(); + stacks.min_repeat_one.pop(); + return drive.failure(); + } + + drive.skip_char(1); + stacks.min_repeat_one_last().count += 1; + drive.state.marks_pop_keep(); + create_context(drive, stacks); + } +} + +#[derive(Debug, Copy, Clone)] +struct RepeatOneContext { + count: usize, + min_count: usize, + following_literal: Option, +} + +/* match repeated sequence (maximizing regexp) */ + +/* this operator only works if the repeated item is +exactly one character wide, and we're not already +collecting backtracking points. for other cases, +use the MAX_REPEAT operator */ + +/* <1=min> <2=max> item tail */ +fn op_repeat_one(drive: &mut StateContext, stacks: &mut Stacks) { + let min_count = drive.peek_code(2) as usize; + let max_count = drive.peek_code(3) as usize; + + if drive.remaining_chars() < min_count { + return drive.failure(); + } + + drive.sync_string_position(); + + let count = count(drive, stacks, max_count); + drive.skip_char(count); + if count < min_count { + return drive.failure(); + } + + let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); + if next_code == SreOpcode::SUCCESS as u32 && drive.can_success() { + // tail is empty. we're finished + drive.sync_string_position(); + return drive.success(); + } + + // Special case: Tail starts with a literal. Skip positions where + // the rest of the pattern cannot possibly match. + let following_literal = (next_code == SreOpcode::LITERAL as u32) + .then(|| drive.peek_code(drive.peek_code(1) as usize + 2)); + + drive.state.marks_push(); + stacks.repeat_one.push(RepeatOneContext { + count, + min_count, + following_literal, + }); + create_context(drive, stacks); + + fn create_context(drive: &mut StateContext, stacks: &mut Stacks) { + let RepeatOneContext { + mut count, + min_count, + following_literal, + } = *stacks.repeat_one_last(); + + if let Some(c) = following_literal { + while drive.at_end() || drive.peek_char() != c { + if count <= min_count { + drive.state.marks_pop_discard(); + stacks.repeat_one.pop(); + return drive.failure(); + } + drive.back_skip_char(1); + count -= 1; + } + } + stacks.repeat_one_last().count = count; + + drive.sync_string_position(); + + // General case: backtracking + drive.next_ctx_from(1, callback); + } + + fn callback(drive: &mut StateContext, stacks: &mut Stacks) { + if drive.popped_ctx().has_matched == Some(true) { + stacks.repeat_one.pop(); + return drive.success(); + } + + let RepeatOneContext { + count, + min_count, + following_literal: _, + } = stacks.repeat_one_last(); + + if count <= min_count { + drive.state.marks_pop_discard(); + stacks.repeat_one.pop(); + return drive.failure(); + } + + drive.back_skip_char(1); + *count -= 1; + + drive.state.marks_pop_keep(); + create_context(drive, stacks); + } +} + +#[derive(Debug, Clone, Copy)] +struct RepeatContext { + count: isize, + min_count: usize, + max_count: usize, + code_position: usize, + last_position: usize, + prev_id: usize, +} + +/* create repeat context. all the hard work is done +by the UNTIL operator (MAX_UNTIL, MIN_UNTIL) */ +/* <1=min> <2=max> item tail */ +fn op_repeat(drive: &mut StateContext, stacks: &mut Stacks) { + let repeat_ctx = RepeatContext { + count: -1, + min_count: drive.peek_code(2) as usize, + max_count: drive.peek_code(3) as usize, + code_position: drive.ctx.code_position, + last_position: std::usize::MAX, + prev_id: drive.ctx.repeat_ctx_id, + }; + + stacks.repeat.push(repeat_ctx); + + drive.sync_string_position(); + + let next_ctx = drive.next_ctx_from(1, |drive, stacks| { + drive.ctx.has_matched = drive.popped_ctx().has_matched; + stacks.repeat.pop(); + }); + next_ctx.repeat_ctx_id = stacks.repeat.len() - 1; +} + +#[derive(Debug, Clone, Copy)] +struct MinUntilContext { + count: isize, + save_repeat_ctx: Option, + save_last_position: usize, +} + +/* minimizing repeat */ +fn op_min_until(drive: &mut StateContext, stacks: &mut Stacks) { + let repeat_ctx = stacks.repeat.last_mut().unwrap(); + + drive.sync_string_position(); + + let count = repeat_ctx.count + 1; + + stacks.min_until.push(MinUntilContext { + count, + save_repeat_ctx: None, + save_last_position: repeat_ctx.last_position, + }); + + if (count as usize) < repeat_ctx.min_count { + // not enough matches + repeat_ctx.count = count; + drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { + if drive.popped_ctx().has_matched == Some(true) { + stacks.min_until.pop(); + return drive.success(); + } + + stacks.repeat_last().count = stacks.min_until_last().count - 1; + drive.sync_string_position(); + stacks.min_until.pop(); + drive.failure(); + }); + return; + } + + drive.state.marks_push(); + + // see if the tail matches + stacks.min_until_last().save_repeat_ctx = stacks.repeat.pop(); + + drive.next_ctx(1, |drive, stacks| { + let MinUntilContext { + count, + save_repeat_ctx, + save_last_position, + } = stacks.min_until_last(); + let count = *count; + + let mut repeat_ctx = save_repeat_ctx.take().unwrap(); + + if drive.popped_ctx().has_matched == Some(true) { + stacks.min_until.pop(); + // restore repeat before return + stacks.repeat.push(repeat_ctx); + return drive.success(); + } + + drive.sync_string_position(); + + drive.state.marks_pop(); + + // match more until tail matches + + if count as usize >= repeat_ctx.max_count && repeat_ctx.max_count != MAXREPEAT + || drive.state.string_position == repeat_ctx.last_position + { + stacks.min_until.pop(); + // restore repeat before return + stacks.repeat.push(repeat_ctx); + return drive.failure(); + } + + repeat_ctx.count = count; + /* zero-width match protection */ + *save_last_position = repeat_ctx.last_position; + repeat_ctx.last_position = drive.state.string_position; + + stacks.repeat.push(repeat_ctx); + + drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { + if drive.popped_ctx().has_matched == Some(true) { + stacks.min_until.pop(); + drive.success(); + } else { + stacks.repeat_last().count = stacks.min_until_last().count - 1; + drive.sync_string_position(); + stacks.min_until.pop(); + drive.failure(); + } + }); + }); +} + +#[derive(Debug, Clone, Copy)] +struct MaxUntilContext { + save_last_position: usize, +} + +/* maximizing repeat */ +fn op_max_until(drive: &mut StateContext, stacks: &mut Stacks) { + // let repeat_ctx = stacks.repeat.last_mut().unwrap(); + let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; + + drive.sync_string_position(); + + repeat_ctx.count += 1; + + // let count = repeat_ctx.count + 1; + + if (repeat_ctx.count as usize) < repeat_ctx.min_count { + // not enough matches + // repeat_ctx.count = count; + drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { + if drive.popped_ctx().has_matched == Some(true) { + // stacks.max_until.pop(); + drive.success(); + } else { + // let count = stacks.max_until_last().count; + // stacks.repeat_last().count -= 1; + stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; + drive.sync_string_position(); + // stacks.max_until.pop(); + drive.failure(); + } + }); + return; + } + + stacks.max_until.push(MaxUntilContext { + save_last_position: repeat_ctx.last_position, + }); + + if ((repeat_ctx.count as usize) < repeat_ctx.max_count || repeat_ctx.max_count == MAXREPEAT) + && drive.state.string_position != repeat_ctx.last_position + { + /* we may have enough matches, but if we can + match another item, do so */ + repeat_ctx.last_position = drive.state.string_position; + + drive.state.marks_push(); + + drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { + let save_last_position = stacks.max_until_last().save_last_position; + let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; + repeat_ctx.last_position = save_last_position; + if drive.popped_ctx().has_matched == Some(true) { + drive.state.marks_pop_discard(); + stacks.max_until.pop(); + return drive.success(); + } + drive.state.marks_pop(); + repeat_ctx.count -= 1; + drive.sync_string_position(); + + /* cannot match more repeated items here. make sure the + tail matches */ + let next_ctx = drive.next_ctx(1, tail_callback); + next_ctx.repeat_ctx_id = repeat_ctx.prev_id; + }); + return; + } + + /* cannot match more repeated items here. make sure the + tail matches */ + let next_ctx = drive.next_ctx(1, tail_callback); + next_ctx.repeat_ctx_id = repeat_ctx.prev_id; + + fn tail_callback(drive: &mut StateContext, stacks: &mut Stacks) { + stacks.max_until.pop(); + + if drive.popped_ctx().has_matched == Some(true) { + drive.success(); + } else { + drive.sync_string_position(); + drive.failure(); + } + } +} + +#[derive(Debug, Default)] +struct Stacks { + branch: Vec, + min_repeat_one: Vec, + repeat_one: Vec, + repeat: Vec, + min_until: Vec, + max_until: Vec, +} + +impl Stacks { + fn clear(&mut self) { + self.branch.clear(); + self.min_repeat_one.clear(); + self.repeat_one.clear(); + self.repeat.clear(); + self.min_until.clear(); + self.max_until.clear(); + } + + fn branch_last(&mut self) -> &mut BranchContext { + self.branch.last_mut().unwrap() + } + fn min_repeat_one_last(&mut self) -> &mut MinRepeatOneContext { + self.min_repeat_one.last_mut().unwrap() + } + fn repeat_one_last(&mut self) -> &mut RepeatOneContext { + self.repeat_one.last_mut().unwrap() + } + fn repeat_last(&mut self) -> &mut RepeatContext { + self.repeat.last_mut().unwrap() + } + fn min_until_last(&mut self) -> &mut MinUntilContext { + self.min_until.last_mut().unwrap() + } + fn max_until_last(&mut self) -> &mut MaxUntilContext { + self.max_until.last_mut().unwrap() + } +} + #[derive(Debug, Clone, Copy)] pub enum StrDrive<'a> { Str(&'a str), @@ -203,7 +857,7 @@ impl<'a> StrDrive<'a> { StrDrive::Str(s) => s .get(offset..) .and_then(|s| s.char_indices().nth(skip).map(|x| x.0 + offset)) - .unwrap_or_else(|| s.len()), + .unwrap_or(s.len()), StrDrive::Bytes(_) => offset + skip, } } @@ -264,31 +918,63 @@ impl<'a> StrDrive<'a> { } } -#[derive(Debug, Clone, Copy)] +type OpcodeHandler = fn(&mut StateContext, &mut Stacks); + +#[derive(Clone, Copy)] struct MatchContext { string_position: usize, string_offset: usize, code_position: usize, has_matched: Option, toplevel: bool, + handler: Option, + repeat_ctx_id: usize, } -trait MatchContextDrive { - fn ctx_mut(&mut self) -> &mut MatchContext; +impl std::fmt::Debug for MatchContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MatchContext") + .field("string_position", &self.string_position) + .field("string_offset", &self.string_offset) + .field("code_position", &self.code_position) + .field("has_matched", &self.has_matched) + .field("toplevel", &self.toplevel) + .field("handler", &self.handler.map(|x| x as usize)) + .finish() + } +} + +trait ContextDrive { fn ctx(&self) -> &MatchContext; + fn ctx_mut(&mut self) -> &mut MatchContext; fn state(&self) -> &State; - fn repeat_ctx(&self) -> &RepeatContext { - self.state().repeat_stack.last().unwrap() + + fn popped_ctx(&self) -> &MatchContext { + self.state().popped_context.as_ref().unwrap() } + fn pattern(&self) -> &[u32] { &self.state().pattern_codes[self.ctx().code_position..] } + fn peek_char(&self) -> u32 { self.state().string.peek(self.ctx().string_offset) } fn peek_code(&self, peek: usize) -> u32 { self.state().pattern_codes[self.ctx().code_position + peek] } + + fn back_peek_char(&self) -> u32 { + self.state().string.back_peek(self.ctx().string_offset) + } + fn back_skip_char(&mut self, skip_count: usize) { + self.ctx_mut().string_position -= skip_count; + self.ctx_mut().string_offset = self + .state() + .string + .back_offset(self.ctx().string_offset, skip_count); + } + fn skip_char(&mut self, skip_count: usize) { self.ctx_mut().string_offset = self .state() @@ -299,12 +985,17 @@ trait MatchContextDrive { fn skip_code(&mut self, skip_count: usize) { self.ctx_mut().code_position += skip_count; } + fn skip_code_from(&mut self, peek: usize) { + self.skip_code(self.peek_code(peek) as usize + 1); + } + fn remaining_chars(&self) -> usize { self.state().end - self.ctx().string_position } fn remaining_codes(&self) -> usize { self.state().pattern_codes.len() - self.ctx().code_position } + fn at_beginning(&self) -> bool { // self.ctx().string_position == self.state().start self.ctx().string_position == 0 @@ -331,16 +1022,7 @@ trait MatchContextDrive { let this = !self.at_end() && word_checker(self.peek_char()); this == that } - fn back_peek_char(&self) -> u32 { - self.state().string.back_peek(self.ctx().string_offset) - } - fn back_skip_char(&mut self, skip_count: usize) { - self.ctx_mut().string_position -= skip_count; - self.ctx_mut().string_offset = self - .state() - .string - .back_offset(self.ctx().string_offset, skip_count); - } + fn can_success(&self) -> bool { if !self.ctx().toplevel { return true; @@ -353,389 +1035,71 @@ trait MatchContextDrive { } true } -} -struct StackDrive<'a> { - state: State<'a>, - ctx_id: usize, -} -impl<'a> StackDrive<'a> { - fn id(&self) -> usize { - self.ctx_id - } - fn drive(ctx_id: usize, state: State<'a>) -> Self { - Self { state, ctx_id } - } - fn take(self) -> State<'a> { - self.state - } - fn push_new_context(&mut self, pattern_offset: usize) { - self.push_new_context_at(self.ctx().code_position + pattern_offset); - } - fn push_new_context_at(&mut self, code_position: usize) { - let mut child_ctx = MatchContext { ..*self.ctx() }; - child_ctx.code_position = code_position; - self.state.context_stack.push(child_ctx); - } - fn repeat_ctx_mut(&mut self) -> &mut RepeatContext { - self.state.repeat_stack.last_mut().unwrap() - } -} -impl MatchContextDrive for StackDrive<'_> { - fn ctx_mut(&mut self) -> &mut MatchContext { - &mut self.state.context_stack[self.ctx_id] + fn success(&mut self) { + self.ctx_mut().has_matched = Some(true); } - fn ctx(&self) -> &MatchContext { - &self.state.context_stack[self.ctx_id] - } - fn state(&self) -> &State { - &self.state + + fn failure(&mut self) { + self.ctx_mut().has_matched = Some(false); } } -struct WrapDrive<'a> { - stack_drive: &'a StackDrive<'a>, +struct StateContext<'a> { + state: State<'a>, ctx: MatchContext, + next_ctx: Option, } -impl<'a> WrapDrive<'a> { - fn drive(ctx: MatchContext, stack_drive: &'a StackDrive<'a>) -> Self { - Self { stack_drive, ctx } + +impl ContextDrive for StateContext<'_> { + fn ctx(&self) -> &MatchContext { + &self.ctx } -} -impl MatchContextDrive for WrapDrive<'_> { fn ctx_mut(&mut self) -> &mut MatchContext { &mut self.ctx } - fn ctx(&self) -> &MatchContext { - &self.ctx - } fn state(&self) -> &State { - self.stack_drive.state() + &self.state } } -trait OpcodeExecutor { - fn next(&mut self, drive: &mut StackDrive) -> Option<()>; -} +impl StateContext<'_> { + fn next_ctx_from(&mut self, peek: usize, handler: OpcodeHandler) -> &mut MatchContext { + self.next_ctx(self.peek_code(peek) as usize + 1, handler) + } + fn next_ctx(&mut self, offset: usize, handler: OpcodeHandler) -> &mut MatchContext { + self.next_ctx_at(self.ctx.code_position + offset, handler) + } + fn next_ctx_at(&mut self, code_position: usize, handler: OpcodeHandler) -> &mut MatchContext { + self.next_ctx = Some(MatchContext { + code_position, + has_matched: None, + handler: None, + ..self.ctx + }); + self.ctx.handler = Some(handler); + self.next_ctx.as_mut().unwrap() + } -struct OpTwice { - f1: Option, - f2: Option, -} -impl OpcodeExecutor for OpTwice -where - F1: FnOnce(&mut StackDrive) -> Option<()>, - F2: FnOnce(&mut StackDrive), -{ - fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - if let Some(f1) = self.f1.take() { - f1(drive) - } else if let Some(f2) = self.f2.take() { - f2(drive); - None - } else { - unreachable!() - } + fn sync_string_position(&mut self) { + self.state.string_position = self.ctx.string_position; } } -fn twice(f1: F1, f2: F2) -> Box> -where - F1: FnOnce(&mut StackDrive) -> Option<()>, - F2: FnOnce(&mut StackDrive), -{ - Box::new(OpTwice { - f1: Some(f1), - f2: Some(f2), - }) -} -struct OpcodeDispatcher { - executing_contexts: BTreeMap>, +struct StateRefContext<'a> { + entity: &'a StateContext<'a>, + ctx: MatchContext, } -impl OpcodeDispatcher { - fn new() -> Self { - Self { - executing_contexts: BTreeMap::new(), - } - } - fn clear(&mut self) { - self.executing_contexts.clear(); - } - // Returns True if the current context matches, False if it doesn't and - // None if matching is not finished, ie must be resumed after child - // contexts have been matched. - fn pymatch(&mut self, drive: &mut StackDrive) -> Option { - while drive.remaining_codes() > 0 && drive.ctx().has_matched.is_none() { - let code = drive.peek_code(0); - let opcode = SreOpcode::try_from(code).unwrap(); - if !self.dispatch(opcode, drive) { - return None; - } - } - match drive.ctx().has_matched { - Some(matched) => Some(matched), - None => { - drive.ctx_mut().has_matched = Some(false); - Some(false) - } - } - } - // Dispatches a context on a given opcode. Returns True if the context - // is done matching, False if it must be resumed when next encountered. - fn dispatch(&mut self, opcode: SreOpcode, drive: &mut StackDrive) -> bool { - let executor = self - .executing_contexts - .remove(&drive.id()) - .or_else(|| self.dispatch_table(opcode, drive)); - if let Some(mut executor) = executor { - if let Some(()) = executor.next(drive) { - self.executing_contexts.insert(drive.id(), executor); - return false; - } - } - true +impl ContextDrive for StateRefContext<'_> { + fn ctx(&self) -> &MatchContext { + &self.ctx } - - fn dispatch_table( - &mut self, - opcode: SreOpcode, - drive: &mut StackDrive, - ) -> Option> { - match opcode { - SreOpcode::FAILURE => { - drive.ctx_mut().has_matched = Some(false); - None - } - SreOpcode::SUCCESS => { - drive.ctx_mut().has_matched = Some(drive.can_success()); - if drive.ctx().has_matched == Some(true) { - drive.state.string_position = drive.ctx().string_position; - } - None - } - SreOpcode::ANY => { - if drive.at_end() || drive.at_linebreak() { - drive.ctx_mut().has_matched = Some(false); - } else { - drive.skip_code(1); - drive.skip_char(1); - } - None - } - SreOpcode::ANY_ALL => { - if drive.at_end() { - drive.ctx_mut().has_matched = Some(false); - } else { - drive.skip_code(1); - drive.skip_char(1); - } - None - } - /* assert subpattern */ - /* */ - SreOpcode::ASSERT => Some(twice( - |drive| { - let back = drive.peek_code(2) as usize; - let passed = drive.ctx().string_position; - if passed < back { - drive.ctx_mut().has_matched = Some(false); - return None; - } - let back_offset = drive - .state - .string - .back_offset(drive.ctx().string_offset, back); - - drive.state.string_position = drive.ctx().string_position - back; - - drive.push_new_context(3); - let child_ctx = drive.state.context_stack.last_mut().unwrap(); - child_ctx.toplevel = false; - child_ctx.string_position -= back; - child_ctx.string_offset = back_offset; - - Some(()) - }, - |drive| { - let child_ctx = drive.state.popped_context.unwrap(); - if child_ctx.has_matched == Some(true) { - drive.skip_code(drive.peek_code(1) as usize + 1); - } else { - drive.ctx_mut().has_matched = Some(false); - } - }, - )), - SreOpcode::ASSERT_NOT => Some(twice( - |drive| { - let back = drive.peek_code(2) as usize; - let passed = drive.ctx().string_position; - if passed < back { - drive.skip_code(drive.peek_code(1) as usize + 1); - return None; - } - let back_offset = drive - .state - .string - .back_offset(drive.ctx().string_offset, back); - - drive.state.string_position = drive.ctx().string_position - back; - - drive.push_new_context(3); - let child_ctx = drive.state.context_stack.last_mut().unwrap(); - child_ctx.toplevel = false; - child_ctx.string_position -= back; - child_ctx.string_offset = back_offset; - - Some(()) - }, - |drive| { - let child_ctx = drive.state.popped_context.unwrap(); - if child_ctx.has_matched == Some(true) { - drive.ctx_mut().has_matched = Some(false); - } else { - drive.skip_code(drive.peek_code(1) as usize + 1); - } - }, - )), - SreOpcode::AT => { - let atcode = SreAtCode::try_from(drive.peek_code(1)).unwrap(); - if !at(drive, atcode) { - drive.ctx_mut().has_matched = Some(false); - } else { - drive.skip_code(2); - } - None - } - SreOpcode::BRANCH => Some(Box::new(OpBranch::default())), - SreOpcode::CATEGORY => { - let catcode = SreCatCode::try_from(drive.peek_code(1)).unwrap(); - if drive.at_end() || !category(catcode, drive.peek_char()) { - drive.ctx_mut().has_matched = Some(false); - } else { - drive.skip_code(2); - drive.skip_char(1); - } - None - } - SreOpcode::IN => { - general_op_in(drive, |set, c| charset(set, c)); - None - } - SreOpcode::IN_IGNORE => { - general_op_in(drive, |set, c| charset(set, lower_ascii(c))); - None - } - SreOpcode::IN_UNI_IGNORE => { - general_op_in(drive, |set, c| charset(set, lower_unicode(c))); - None - } - SreOpcode::IN_LOC_IGNORE => { - general_op_in(drive, |set, c| charset_loc_ignore(set, c)); - None - } - SreOpcode::INFO | SreOpcode::JUMP => { - drive.skip_code(drive.peek_code(1) as usize + 1); - None - } - SreOpcode::LITERAL => { - general_op_literal(drive, |code, c| code == c); - None - } - SreOpcode::NOT_LITERAL => { - general_op_literal(drive, |code, c| code != c); - None - } - SreOpcode::LITERAL_IGNORE => { - general_op_literal(drive, |code, c| code == lower_ascii(c)); - None - } - SreOpcode::NOT_LITERAL_IGNORE => { - general_op_literal(drive, |code, c| code != lower_ascii(c)); - None - } - SreOpcode::LITERAL_UNI_IGNORE => { - general_op_literal(drive, |code, c| code == lower_unicode(c)); - None - } - SreOpcode::NOT_LITERAL_UNI_IGNORE => { - general_op_literal(drive, |code, c| code != lower_unicode(c)); - None - } - SreOpcode::LITERAL_LOC_IGNORE => { - general_op_literal(drive, char_loc_ignore); - None - } - SreOpcode::NOT_LITERAL_LOC_IGNORE => { - general_op_literal(drive, |code, c| !char_loc_ignore(code, c)); - None - } - SreOpcode::MARK => { - drive - .state - .set_mark(drive.peek_code(1) as usize, drive.ctx().string_position); - drive.skip_code(2); - None - } - SreOpcode::REPEAT => Some(twice( - // create repeat context. all the hard work is done by the UNTIL - // operator (MAX_UNTIL, MIN_UNTIL) - // <1=min> <2=max> item tail - |drive| { - let repeat = RepeatContext { - count: -1, - code_position: drive.ctx().code_position, - last_position: std::usize::MAX, - mincount: drive.peek_code(2) as usize, - maxcount: drive.peek_code(3) as usize, - }; - drive.state.repeat_stack.push(repeat); - drive.state.string_position = drive.ctx().string_position; - // execute UNTIL operator - drive.push_new_context(drive.peek_code(1) as usize + 1); - Some(()) - }, - |drive| { - drive.state.repeat_stack.pop(); - let child_ctx = drive.state.popped_context.unwrap(); - drive.ctx_mut().has_matched = child_ctx.has_matched; - }, - )), - SreOpcode::MAX_UNTIL => Some(Box::new(OpMaxUntil::default())), - SreOpcode::MIN_UNTIL => Some(Box::new(OpMinUntil::default())), - SreOpcode::REPEAT_ONE => Some(Box::new(OpRepeatOne::default())), - SreOpcode::MIN_REPEAT_ONE => Some(Box::new(OpMinRepeatOne::default())), - SreOpcode::GROUPREF => { - general_op_groupref(drive, |x| x); - None - } - SreOpcode::GROUPREF_IGNORE => { - general_op_groupref(drive, lower_ascii); - None - } - SreOpcode::GROUPREF_LOC_IGNORE => { - general_op_groupref(drive, lower_locate); - None - } - SreOpcode::GROUPREF_UNI_IGNORE => { - general_op_groupref(drive, lower_unicode); - None - } - SreOpcode::GROUPREF_EXISTS => { - let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); - match (group_start, group_end) { - (Some(start), Some(end)) if start <= end => { - drive.skip_code(3); - } - _ => drive.skip_code(drive.peek_code(2) as usize + 1), - } - None - } - _ => { - // TODO python expcetion? - unreachable!("unexpected opcode") - } - } + fn ctx_mut(&mut self) -> &mut MatchContext { + &mut self.ctx + } + fn state(&self) -> &State { + &self.entity.state } } @@ -752,60 +1116,63 @@ fn charset_loc_ignore(set: &[u32], c: u32) -> bool { up != lo && charset(set, up) } -fn general_op_groupref u32>(drive: &mut StackDrive, mut f: F) { +fn general_op_groupref u32>(drive: &mut StateContext, mut f: F) { let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); let (group_start, group_end) = match (group_start, group_end) { (Some(start), Some(end)) if start <= end => (start, end), _ => { - drive.ctx_mut().has_matched = Some(false); - return; + return drive.failure(); } }; - let mut wdrive = WrapDrive::drive(*drive.ctx(), &drive); - let mut gdrive = WrapDrive::drive( - MatchContext { + + let mut wdrive = StateRefContext { + entity: drive, + ctx: drive.ctx, + }; + let mut gdrive = StateRefContext { + entity: drive, + ctx: MatchContext { string_position: group_start, // TODO: cache the offset string_offset: drive.state.string.offset(0, group_start), - ..*drive.ctx() + ..drive.ctx }, - &drive, - ); + }; + for _ in group_start..group_end { if wdrive.at_end() || f(wdrive.peek_char()) != f(gdrive.peek_char()) { - drive.ctx_mut().has_matched = Some(false); - return; + return drive.failure(); } wdrive.skip_char(1); gdrive.skip_char(1); } - let position = wdrive.ctx().string_position; - let offset = wdrive.ctx().string_offset; + + let position = wdrive.ctx.string_position; + let offset = wdrive.ctx.string_offset; drive.skip_code(2); - drive.ctx_mut().string_position = position; - drive.ctx_mut().string_offset = offset; + drive.ctx.string_position = position; + drive.ctx.string_offset = offset; } -fn general_op_literal bool>(drive: &mut StackDrive, f: F) { +fn general_op_literal bool>(drive: &mut StateContext, f: F) { if drive.at_end() || !f(drive.peek_code(1), drive.peek_char()) { - drive.ctx_mut().has_matched = Some(false); + drive.failure(); } else { drive.skip_code(2); drive.skip_char(1); } } -fn general_op_in bool>(drive: &mut StackDrive, f: F) { - let skip = drive.peek_code(1) as usize; +fn general_op_in bool>(drive: &mut StateContext, f: F) { if drive.at_end() || !f(&drive.pattern()[2..], drive.peek_char()) { - drive.ctx_mut().has_matched = Some(false); + drive.failure(); } else { - drive.skip_code(skip + 1); + drive.skip_code_from(1); drive.skip_char(1); } } -fn at(drive: &StackDrive, atcode: SreAtCode) -> bool { +fn at(drive: &StateContext, atcode: SreAtCode) -> bool { match atcode { SreAtCode::BEGINNING | SreAtCode::BEGINNING_STRING => drive.at_beginning(), SreAtCode::BEGINNING_LINE => drive.at_beginning() || is_linebreak(drive.back_peek_char()), @@ -938,84 +1305,91 @@ fn charset(set: &[u32], ch: u32) -> bool { } /* General case */ -fn general_count(drive: &mut StackDrive, maxcount: usize) -> usize { +fn general_count(drive: &mut StateContext, stacks: &mut Stacks, max_count: usize) -> usize { let mut count = 0; - let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); + let max_count = std::cmp::min(max_count, drive.remaining_chars()); - let save_ctx = *drive.ctx(); + let save_ctx = drive.ctx; drive.skip_code(4); - let reset_position = drive.ctx().code_position; - - let mut dispatcher = OpcodeDispatcher::new(); - while count < maxcount { - drive.ctx_mut().code_position = reset_position; - dispatcher.dispatch(SreOpcode::try_from(drive.peek_code(0)).unwrap(), drive); - if drive.ctx().has_matched == Some(false) { + let reset_position = drive.ctx.code_position; + + while count < max_count { + drive.ctx.code_position = reset_position; + let code = drive.peek_code(0); + let code = SreOpcode::try_from(code).unwrap(); + dispatch(code, drive, stacks); + if drive.ctx.has_matched == Some(false) { break; } count += 1; } - *drive.ctx_mut() = save_ctx; + drive.ctx = save_ctx; count } -fn count(stack_drive: &mut StackDrive, maxcount: usize) -> usize { - let mut drive = WrapDrive::drive(*stack_drive.ctx(), stack_drive); - let maxcount = std::cmp::min(maxcount, drive.remaining_chars()); - let end = drive.ctx().string_position + maxcount; +fn count(drive: &mut StateContext, stacks: &mut Stacks, max_count: usize) -> usize { + let save_ctx = drive.ctx; + let max_count = std::cmp::min(max_count, drive.remaining_chars()); + let end = drive.ctx.string_position + max_count; let opcode = SreOpcode::try_from(drive.peek_code(0)).unwrap(); match opcode { SreOpcode::ANY => { - while !drive.ctx().string_position < end && !drive.at_linebreak() { + while !drive.ctx.string_position < end && !drive.at_linebreak() { drive.skip_char(1); } } SreOpcode::ANY_ALL => { - drive.skip_char(maxcount); + drive.skip_char(max_count); } SreOpcode::IN => { - while !drive.ctx().string_position < end + while !drive.ctx.string_position < end && charset(&drive.pattern()[2..], drive.peek_char()) { drive.skip_char(1); } } SreOpcode::LITERAL => { - general_count_literal(&mut drive, end, |code, c| code == c as u32); + general_count_literal(drive, end, |code, c| code == c as u32); } SreOpcode::NOT_LITERAL => { - general_count_literal(&mut drive, end, |code, c| code != c as u32); + general_count_literal(drive, end, |code, c| code != c as u32); } SreOpcode::LITERAL_IGNORE => { - general_count_literal(&mut drive, end, |code, c| code == lower_ascii(c) as u32); + general_count_literal(drive, end, |code, c| code == lower_ascii(c) as u32); } SreOpcode::NOT_LITERAL_IGNORE => { - general_count_literal(&mut drive, end, |code, c| code != lower_ascii(c) as u32); + general_count_literal(drive, end, |code, c| code != lower_ascii(c) as u32); } SreOpcode::LITERAL_LOC_IGNORE => { - general_count_literal(&mut drive, end, char_loc_ignore); + general_count_literal(drive, end, char_loc_ignore); } SreOpcode::NOT_LITERAL_LOC_IGNORE => { - general_count_literal(&mut drive, end, |code, c| !char_loc_ignore(code, c)); + general_count_literal(drive, end, |code, c| !char_loc_ignore(code, c)); } SreOpcode::LITERAL_UNI_IGNORE => { - general_count_literal(&mut drive, end, |code, c| code == lower_unicode(c) as u32); + general_count_literal(drive, end, |code, c| code == lower_unicode(c) as u32); } SreOpcode::NOT_LITERAL_UNI_IGNORE => { - general_count_literal(&mut drive, end, |code, c| code != lower_unicode(c) as u32); + general_count_literal(drive, end, |code, c| code != lower_unicode(c) as u32); } _ => { - return general_count(stack_drive, maxcount); + return general_count(drive, stacks, max_count); } } - drive.ctx().string_position - drive.state().string_position + let count = drive.ctx.string_position - drive.state.string_position; + drive.ctx = save_ctx; + count } -fn general_count_literal bool>(drive: &mut WrapDrive, end: usize, mut f: F) { +fn general_count_literal bool>( + drive: &mut StateContext, + end: usize, + mut f: F, +) { let ch = drive.peek_code(1); - while !drive.ctx().string_position < end && f(ch, drive.peek_char()) { + while !drive.ctx.string_position < end && f(ch, drive.peek_char()) { drive.skip_char(1); } } @@ -1065,7 +1439,9 @@ fn upper_locate(ch: u32) -> u32 { } fn is_uni_digit(ch: u32) -> bool { // TODO: check with cpython - char::try_from(ch).map(|x| x.is_digit(10)).unwrap_or(false) + char::try_from(ch) + .map(|x| x.is_ascii_digit()) + .unwrap_or(false) } fn is_uni_space(ch: u32) -> bool { // TODO: check with cpython @@ -1155,413 +1531,3 @@ fn utf8_back_peek_offset(bytes: &[u8], offset: usize) -> usize { } offset } - -#[derive(Debug, Copy, Clone)] -struct RepeatContext { - count: isize, - code_position: usize, - // zero-width match protection - last_position: usize, - mincount: usize, - maxcount: usize, -} - -#[derive(Default)] -struct OpMinRepeatOne { - jump_id: usize, - mincount: usize, - maxcount: usize, - count: usize, -} -impl OpcodeExecutor for OpMinRepeatOne { - /* <1=min> <2=max> item tail */ - fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - match self.jump_id { - 0 => { - self.mincount = drive.peek_code(2) as usize; - self.maxcount = drive.peek_code(3) as usize; - - if drive.remaining_chars() < self.mincount { - drive.ctx_mut().has_matched = Some(false); - return None; - } - - drive.state.string_position = drive.ctx().string_position; - - self.count = if self.mincount == 0 { - 0 - } else { - let count = count(drive, self.mincount); - if count < self.mincount { - drive.ctx_mut().has_matched = Some(false); - return None; - } - drive.skip_char(count); - count - }; - - let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); - if next_code == SreOpcode::SUCCESS as u32 && drive.can_success() { - // tail is empty. we're finished - drive.state.string_position = drive.ctx().string_position; - drive.ctx_mut().has_matched = Some(true); - return None; - } - - drive.state.marks_push(); - self.jump_id = 1; - self.next(drive) - } - 1 => { - if self.maxcount == MAXREPEAT || self.count <= self.maxcount { - drive.state.string_position = drive.ctx().string_position; - drive.push_new_context(drive.peek_code(1) as usize + 1); - self.jump_id = 2; - return Some(()); - } - - drive.state.marks_pop_discard(); - drive.ctx_mut().has_matched = Some(false); - None - } - 2 => { - let child_ctx = drive.state.popped_context.unwrap(); - if child_ctx.has_matched == Some(true) { - drive.ctx_mut().has_matched = Some(true); - return None; - } - drive.state.string_position = drive.ctx().string_position; - if count(drive, 1) == 0 { - drive.ctx_mut().has_matched = Some(false); - return None; - } - drive.skip_char(1); - self.count += 1; - drive.state.marks_pop_keep(); - self.jump_id = 1; - self.next(drive) - } - _ => unreachable!(), - } - } -} - -#[derive(Default)] -struct OpMaxUntil { - jump_id: usize, - count: isize, - save_last_position: usize, -} -impl OpcodeExecutor for OpMaxUntil { - fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - match self.jump_id { - 0 => { - let RepeatContext { - count, - code_position, - last_position, - mincount, - maxcount, - } = *drive.repeat_ctx(); - - drive.state.string_position = drive.ctx().string_position; - self.count = count + 1; - - if (self.count as usize) < mincount { - // not enough matches - drive.repeat_ctx_mut().count = self.count; - drive.push_new_context_at(code_position + 4); - self.jump_id = 1; - return Some(()); - } - - if ((self.count as usize) < maxcount || maxcount == MAXREPEAT) - && drive.state.string_position != last_position - { - // we may have enough matches, if we can match another item, do so - drive.repeat_ctx_mut().count = self.count; - drive.state.marks_push(); - self.save_last_position = last_position; - drive.repeat_ctx_mut().last_position = drive.state.string_position; - drive.push_new_context_at(code_position + 4); - self.jump_id = 2; - return Some(()); - } - - self.jump_id = 3; - self.next(drive) - } - 1 => { - let child_ctx = drive.state.popped_context.unwrap(); - drive.ctx_mut().has_matched = child_ctx.has_matched; - if drive.ctx().has_matched != Some(true) { - drive.repeat_ctx_mut().count = self.count - 1; - drive.state.string_position = drive.ctx().string_position; - } - None - } - 2 => { - drive.repeat_ctx_mut().last_position = self.save_last_position; - let child_ctx = drive.state.popped_context.unwrap(); - if child_ctx.has_matched == Some(true) { - drive.state.marks_pop_discard(); - drive.ctx_mut().has_matched = Some(true); - return None; - } - drive.state.marks_pop(); - drive.repeat_ctx_mut().count = self.count - 1; - drive.state.string_position = drive.ctx().string_position; - self.jump_id = 3; - self.next(drive) - } - 3 => { - // cannot match more repeated items here. make sure the tail matches - drive.push_new_context(1); - self.jump_id = 4; - Some(()) - } - 4 => { - let child_ctx = drive.state.popped_context.unwrap(); - drive.ctx_mut().has_matched = child_ctx.has_matched; - if drive.ctx().has_matched != Some(true) { - drive.state.string_position = drive.ctx().string_position; - } - None - } - _ => unreachable!(), - } - } -} - -#[derive(Default)] -struct OpMinUntil { - jump_id: usize, - count: isize, - save_repeat: Option, - save_last_position: usize, -} -impl OpcodeExecutor for OpMinUntil { - fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - match self.jump_id { - 0 => { - let RepeatContext { - count, - code_position, - last_position: _, - mincount, - maxcount: _, - } = *drive.repeat_ctx(); - drive.state.string_position = drive.ctx().string_position; - self.count = count + 1; - - if (self.count as usize) < mincount { - // not enough matches - drive.repeat_ctx_mut().count = self.count; - drive.push_new_context_at(code_position + 4); - self.jump_id = 1; - return Some(()); - } - - // see if the tail matches - drive.state.marks_push(); - self.save_repeat = drive.state.repeat_stack.pop(); - drive.push_new_context(1); - self.jump_id = 2; - Some(()) - } - 1 => { - let child_ctx = drive.state.popped_context.unwrap(); - drive.ctx_mut().has_matched = child_ctx.has_matched; - if drive.ctx().has_matched != Some(true) { - drive.repeat_ctx_mut().count = self.count - 1; - drive.repeat_ctx_mut().last_position = self.save_last_position; - drive.state.string_position = drive.ctx().string_position; - } - None - } - 2 => { - // restore repeat before return - drive.state.repeat_stack.push(self.save_repeat.unwrap()); - - let child_ctx = drive.state.popped_context.unwrap(); - if child_ctx.has_matched == Some(true) { - drive.ctx_mut().has_matched = Some(true); - return None; - } - drive.state.string_position = drive.ctx().string_position; - drive.state.marks_pop(); - - // match more until tail matches - let RepeatContext { - count: _, - code_position, - last_position, - mincount: _, - maxcount, - } = *drive.repeat_ctx(); - - if self.count as usize >= maxcount && maxcount != MAXREPEAT - || drive.state.string_position == last_position - { - drive.ctx_mut().has_matched = Some(false); - return None; - } - drive.repeat_ctx_mut().count = self.count; - - /* zero-width match protection */ - self.save_last_position = last_position; - drive.repeat_ctx_mut().last_position = drive.state.string_position; - - drive.push_new_context_at(code_position + 4); - self.jump_id = 1; - Some(()) - } - _ => unreachable!(), - } - } -} - -#[derive(Default)] -struct OpBranch { - jump_id: usize, - branch_offset: usize, -} -impl OpcodeExecutor for OpBranch { - // alternation - // <0=skip> code ... - fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - match self.jump_id { - 0 => { - drive.state.marks_push(); - // jump out the head - self.branch_offset = 1; - self.jump_id = 1; - self.next(drive) - } - 1 => { - let next_branch_length = drive.peek_code(self.branch_offset) as usize; - if next_branch_length == 0 { - drive.state.marks_pop_discard(); - drive.ctx_mut().has_matched = Some(false); - return None; - } - drive.state.string_position = drive.ctx().string_position; - drive.push_new_context(self.branch_offset + 1); - self.branch_offset += next_branch_length; - self.jump_id = 2; - Some(()) - } - 2 => { - let child_ctx = drive.state.popped_context.unwrap(); - if child_ctx.has_matched == Some(true) { - drive.ctx_mut().has_matched = Some(true); - return None; - } - drive.state.marks_pop_keep(); - self.jump_id = 1; - Some(()) - } - _ => unreachable!(), - } - } -} - -#[derive(Default)] -struct OpRepeatOne { - jump_id: usize, - mincount: usize, - maxcount: usize, - count: usize, - following_literal: Option, -} -impl OpcodeExecutor for OpRepeatOne { - /* match repeated sequence (maximizing regexp) */ - - /* this operator only works if the repeated item is - exactly one character wide, and we're not already - collecting backtracking points. for other cases, - use the MAX_REPEAT operator */ - - /* <1=min> <2=max> item tail */ - fn next(&mut self, drive: &mut StackDrive) -> Option<()> { - match self.jump_id { - 0 => { - self.mincount = drive.peek_code(2) as usize; - self.maxcount = drive.peek_code(3) as usize; - - if drive.remaining_chars() < self.mincount { - drive.ctx_mut().has_matched = Some(false); - return None; - } - - drive.state.string_position = drive.ctx().string_position; - - self.count = count(drive, self.maxcount); - drive.skip_char(self.count); - if self.count < self.mincount { - drive.ctx_mut().has_matched = Some(false); - return None; - } - - let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); - if next_code == SreOpcode::SUCCESS as u32 && drive.can_success() { - // tail is empty. we're finished - drive.state.string_position = drive.ctx().string_position; - drive.ctx_mut().has_matched = Some(true); - return None; - } - - drive.state.marks_push(); - - // Special case: Tail starts with a literal. Skip positions where - // the rest of the pattern cannot possibly match. - if next_code == SreOpcode::LITERAL as u32 { - self.following_literal = Some(drive.peek_code(drive.peek_code(1) as usize + 2)) - } - - self.jump_id = 1; - self.next(drive) - } - 1 => { - if let Some(c) = self.following_literal { - while drive.at_end() || drive.peek_char() != c { - if self.count <= self.mincount { - drive.state.marks_pop_discard(); - drive.ctx_mut().has_matched = Some(false); - return None; - } - drive.back_skip_char(1); - self.count -= 1; - } - } - - // General case: backtracking - drive.state.string_position = drive.ctx().string_position; - drive.push_new_context(drive.peek_code(1) as usize + 1); - self.jump_id = 2; - Some(()) - } - 2 => { - let child_ctx = drive.state.popped_context.unwrap(); - if child_ctx.has_matched == Some(true) { - drive.ctx_mut().has_matched = Some(true); - return None; - } - if self.count <= self.mincount { - drive.state.marks_pop_discard(); - drive.ctx_mut().has_matched = Some(false); - return None; - } - - drive.back_skip_char(1); - self.count -= 1; - - drive.state.marks_pop_keep(); - - self.jump_id = 1; - self.next(drive) - } - _ => unreachable!(), - } - } -} diff --git a/src/lib.rs b/src/lib.rs index 4a3ed1b754..c23e807501 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,7 @@ pub mod engine; pub const CODESIZE: usize = 4; #[cfg(target_pointer_width = "32")] -pub const MAXREPEAT: usize = usize::MAX; +pub const MAXREPEAT: usize = usize::MAX - 1; #[cfg(target_pointer_width = "64")] pub const MAXREPEAT: usize = u32::MAX as usize; diff --git a/tests/tests.rs b/tests/tests.rs index b430947a9b..e8ae487029 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -70,4 +70,15 @@ fn test_repeat_context_panic() { let mut state = p.state("axxzaz", 0..usize::MAX); state = state.pymatch(); assert!(state.marks == vec![Some(1), Some(3)]); -} \ No newline at end of file +} + +#[test] +fn test_double_max_until() { + // pattern p = re.compile(r'((1)?)*') + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { code: &[15, 4, 0, 0, 4294967295, 24, 18, 0, 4294967295, 18, 0, 24, 9, 0, 1, 18, 2, 17, 49, 18, 3, 19, 18, 1, 19, 1], flags: SreFlag::from_bits_truncate(32) }; + // END GENERATED + let mut state = p.state("1111", 0..usize::MAX); + state = state.pymatch(); + assert!(state.string_position == 4); +} From 4007f8276550efb6aa1ada429a3e89c042eae86c Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 27 Jul 2022 21:33:20 +0200 Subject: [PATCH 052/893] optimize max_until and min_until --- src/engine.rs | 96 +++++++++++++++------------------------------------ 1 file changed, 27 insertions(+), 69 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 81903ccfdd..223aa3425c 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -420,7 +420,7 @@ fn op_min_repeat_one(drive: &mut StateContext, stacks: &mut Stacks) { let count = if min_count == 0 { 0 } else { - let count = count(drive, stacks, min_count); + let count = _count(drive, stacks, min_count); if count < min_count { return drive.failure(); } @@ -462,7 +462,7 @@ fn op_min_repeat_one(drive: &mut StateContext, stacks: &mut Stacks) { drive.sync_string_position(); - if crate::engine::count(drive, stacks, 1) == 0 { + if _count(drive, stacks, 1) == 0 { drive.state.marks_pop_discard(); stacks.min_repeat_one.pop(); return drive.failure(); @@ -500,7 +500,7 @@ fn op_repeat_one(drive: &mut StateContext, stacks: &mut Stacks) { drive.sync_string_position(); - let count = count(drive, stacks, max_count); + let count = _count(drive, stacks, max_count); drive.skip_char(count); if count < min_count { return drive.failure(); @@ -614,9 +614,7 @@ fn op_repeat(drive: &mut StateContext, stacks: &mut Stacks) { #[derive(Debug, Clone, Copy)] struct MinUntilContext { - count: isize, - save_repeat_ctx: Option, - save_last_position: usize, + save_repeat_ctx_id: usize, } /* minimizing repeat */ @@ -625,50 +623,35 @@ fn op_min_until(drive: &mut StateContext, stacks: &mut Stacks) { drive.sync_string_position(); - let count = repeat_ctx.count + 1; - - stacks.min_until.push(MinUntilContext { - count, - save_repeat_ctx: None, - save_last_position: repeat_ctx.last_position, - }); + repeat_ctx.count += 1; - if (count as usize) < repeat_ctx.min_count { + if (repeat_ctx.count as usize) < repeat_ctx.min_count { // not enough matches - repeat_ctx.count = count; drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { if drive.popped_ctx().has_matched == Some(true) { - stacks.min_until.pop(); - return drive.success(); + drive.success(); + } else { + stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; + drive.sync_string_position(); + drive.failure(); } - - stacks.repeat_last().count = stacks.min_until_last().count - 1; - drive.sync_string_position(); - stacks.min_until.pop(); - drive.failure(); }); return; } drive.state.marks_push(); - // see if the tail matches - stacks.min_until_last().save_repeat_ctx = stacks.repeat.pop(); + stacks.min_until.push(MinUntilContext { + save_repeat_ctx_id: drive.ctx.repeat_ctx_id, + }); - drive.next_ctx(1, |drive, stacks| { - let MinUntilContext { - count, - save_repeat_ctx, - save_last_position, - } = stacks.min_until_last(); - let count = *count; + // see if the tail matches + let next_ctx = drive.next_ctx(1, |drive, stacks| { + drive.ctx.repeat_ctx_id = stacks.min_until.pop().unwrap().save_repeat_ctx_id; - let mut repeat_ctx = save_repeat_ctx.take().unwrap(); + let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; if drive.popped_ctx().has_matched == Some(true) { - stacks.min_until.pop(); - // restore repeat before return - stacks.repeat.push(repeat_ctx); return drive.success(); } @@ -678,34 +661,27 @@ fn op_min_until(drive: &mut StateContext, stacks: &mut Stacks) { // match more until tail matches - if count as usize >= repeat_ctx.max_count && repeat_ctx.max_count != MAXREPEAT + if repeat_ctx.count as usize >= repeat_ctx.max_count && repeat_ctx.max_count != MAXREPEAT || drive.state.string_position == repeat_ctx.last_position { - stacks.min_until.pop(); - // restore repeat before return - stacks.repeat.push(repeat_ctx); + repeat_ctx.count -= 1; return drive.failure(); } - repeat_ctx.count = count; /* zero-width match protection */ - *save_last_position = repeat_ctx.last_position; repeat_ctx.last_position = drive.state.string_position; - stacks.repeat.push(repeat_ctx); - drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { if drive.popped_ctx().has_matched == Some(true) { - stacks.min_until.pop(); drive.success(); } else { - stacks.repeat_last().count = stacks.min_until_last().count - 1; + stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; drive.sync_string_position(); - stacks.min_until.pop(); drive.failure(); } }); }); + next_ctx.repeat_ctx_id = repeat_ctx.prev_id; } #[derive(Debug, Clone, Copy)] @@ -715,28 +691,20 @@ struct MaxUntilContext { /* maximizing repeat */ fn op_max_until(drive: &mut StateContext, stacks: &mut Stacks) { - // let repeat_ctx = stacks.repeat.last_mut().unwrap(); let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; drive.sync_string_position(); repeat_ctx.count += 1; - // let count = repeat_ctx.count + 1; - if (repeat_ctx.count as usize) < repeat_ctx.min_count { // not enough matches - // repeat_ctx.count = count; drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { if drive.popped_ctx().has_matched == Some(true) { - // stacks.max_until.pop(); drive.success(); } else { - // let count = stacks.max_until_last().count; - // stacks.repeat_last().count -= 1; stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; drive.sync_string_position(); - // stacks.max_until.pop(); drive.failure(); } }); @@ -757,14 +725,15 @@ fn op_max_until(drive: &mut StateContext, stacks: &mut Stacks) { drive.state.marks_push(); drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { - let save_last_position = stacks.max_until_last().save_last_position; + let save_last_position = stacks.max_until.pop().unwrap().save_last_position; let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; repeat_ctx.last_position = save_last_position; + if drive.popped_ctx().has_matched == Some(true) { drive.state.marks_pop_discard(); - stacks.max_until.pop(); return drive.success(); } + drive.state.marks_pop(); repeat_ctx.count -= 1; drive.sync_string_position(); @@ -782,9 +751,7 @@ fn op_max_until(drive: &mut StateContext, stacks: &mut Stacks) { let next_ctx = drive.next_ctx(1, tail_callback); next_ctx.repeat_ctx_id = repeat_ctx.prev_id; - fn tail_callback(drive: &mut StateContext, stacks: &mut Stacks) { - stacks.max_until.pop(); - + fn tail_callback(drive: &mut StateContext, _stacks: &mut Stacks) { if drive.popped_ctx().has_matched == Some(true) { drive.success(); } else { @@ -823,15 +790,6 @@ impl Stacks { fn repeat_one_last(&mut self) -> &mut RepeatOneContext { self.repeat_one.last_mut().unwrap() } - fn repeat_last(&mut self) -> &mut RepeatContext { - self.repeat.last_mut().unwrap() - } - fn min_until_last(&mut self) -> &mut MinUntilContext { - self.min_until.last_mut().unwrap() - } - fn max_until_last(&mut self) -> &mut MaxUntilContext { - self.max_until.last_mut().unwrap() - } } #[derive(Debug, Clone, Copy)] @@ -1327,7 +1285,7 @@ fn general_count(drive: &mut StateContext, stacks: &mut Stacks, max_count: usize count } -fn count(drive: &mut StateContext, stacks: &mut Stacks, max_count: usize) -> usize { +fn _count(drive: &mut StateContext, stacks: &mut Stacks, max_count: usize) -> usize { let save_ctx = drive.ctx; let max_count = std::cmp::min(max_count, drive.remaining_chars()); let end = drive.ctx.string_position + max_count; From bf57f289bff1ec5633316a924e82fbb5f5ed6eb0 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 27 Jul 2022 21:33:43 +0200 Subject: [PATCH 053/893] update version to 0.2.1 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 6ba3996947..00123d92c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sre-engine" -version = "0.2.0" +version = "0.2.1" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" repository = "https://github.com/RustPython/sre-engine" From 9058f287881af7fc0f004759184a0ff5d0811967 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 28 Jul 2022 22:46:28 +0200 Subject: [PATCH 054/893] refactor trait StrDrive instead enum --- src/engine.rs | 2049 +++++++++++++++++++++++++------------------------ 1 file changed, 1055 insertions(+), 994 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 223aa3425c..8865eb6a39 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -9,8 +9,8 @@ const fn is_py_ascii_whitespace(b: u8) -> bool { } #[derive(Debug)] -pub struct State<'a> { - pub string: StrDrive<'a>, +pub struct State<'a, S: StrDrive> { + pub string: S, pub start: usize, pub end: usize, _flags: SreFlag, @@ -18,18 +18,25 @@ pub struct State<'a> { pub marks: Vec>, pub lastindex: isize, marks_stack: Vec<(Vec>, isize)>, - context_stack: Vec, - _stacks: Option>, + context_stack: Vec>, + // branch_stack: Vec, + // min_repeat_one_stack: Vec, + // repeat_one_stack: Vec, + // repeat_stack: Vec, + // min_until_stack: Vec, + // max_until_stack: Vec, + // _stacks: Option>, pub string_position: usize, - popped_context: Option, + popped_context: Option>, + next_context: Option>, pub has_matched: bool, pub match_all: bool, pub must_advance: bool, } -impl<'a> State<'a> { +impl<'a, S: StrDrive> State<'a, S> { pub fn new( - string: StrDrive<'a>, + string: S, start: usize, end: usize, flags: SreFlag, @@ -47,9 +54,16 @@ impl<'a> State<'a> { lastindex: -1, marks_stack: Vec::new(), context_stack: Vec::new(), - _stacks: Default::default(), + // branch_stack: Vec::new(), + // min_repeat_one_stack: Vec::new(), + // repeat_one_stack: Vec::new(), + // repeat_stack: Vec::new(), + // min_until_stack: Vec::new(), + // max_until_stack: Vec::new(), + // _stacks: Default::default(), string_position: start, popped_context: None, + next_context: None, has_matched: false, match_all: false, must_advance: false, @@ -61,11 +75,15 @@ impl<'a> State<'a> { self.marks.clear(); self.marks_stack.clear(); self.context_stack.clear(); - if let Some(x) = self._stacks.as_mut() { - x.clear() - }; + // self.branch_stack.clear(); + // self.min_repeat_one_stack.clear(); + // self.repeat_one_stack.clear(); + // self.repeat_stack.clear(); + // self.min_until_stack.clear(); + // self.max_until_stack.clear(); self.string_position = self.start; self.popped_context = None; + self.next_context = None; self.has_matched = false; } @@ -103,47 +121,46 @@ impl<'a> State<'a> { self.marks_stack.pop(); } - fn _match(mut self, stacks: &mut Stacks) -> Self { - while let Some(ctx) = self.context_stack.pop() { - let mut drive = StateContext { - state: self, - ctx, - next_ctx: None, - }; + fn _match(&mut self) { + while let Some(mut ctx) = self.context_stack.pop() { + // let mut drive = StateContext { + // state: self, + // ctx, + // next_ctx: None, + // }; + // let mut state = self; - if let Some(handler) = drive.ctx.handler { - handler(&mut drive, stacks); - } else if drive.remaining_codes() > 0 { - let code = drive.peek_code(0); + if let Some(handler) = ctx.handler { + handler(self, &mut ctx); + } else if ctx.remaining_codes(self) > 0 { + let code = ctx.peek_code(self, 0); let code = SreOpcode::try_from(code).unwrap(); - dispatch(code, &mut drive, stacks); + self.dispatch(code, &mut ctx); } else { - drive.failure(); + ctx.failure(); } - let StateContext { - mut state, - ctx, - next_ctx, - } = drive; + // let StateContext { + // mut state, + // ctx, + // next_ctx, + // } = drive; if ctx.has_matched.is_some() { - state.popped_context = Some(ctx); + self.popped_context = Some(ctx); } else { - state.context_stack.push(ctx); - if let Some(next_ctx) = next_ctx { - state.context_stack.push(next_ctx); + self.context_stack.push(ctx); + if let Some(next_ctx) = self.next_context.take() { + self.context_stack.push(next_ctx); } } - self = state + // self = state } - self.has_matched = self.popped_context.unwrap().has_matched == Some(true); - self + self.has_matched = self.popped_context.take().unwrap().has_matched == Some(true); + // self } pub fn pymatch(mut self) -> Self { - let mut stacks = self._stacks.take().unwrap_or_default(); - let ctx = MatchContext { string_position: self.start, string_offset: self.string.offset(0, self.start), @@ -155,13 +172,11 @@ impl<'a> State<'a> { }; self.context_stack.push(ctx); - self = self._match(&mut stacks); - self._stacks = Some(stacks); + self._match(); self } pub fn search(mut self) -> Self { - let mut stacks = self._stacks.take().unwrap_or_default(); // TODO: optimize by op info and skip prefix if self.start > self.end { @@ -180,14 +195,13 @@ impl<'a> State<'a> { repeat_ctx_id: usize::MAX, }; self.context_stack.push(ctx); - self = self._match(&mut stacks); + self._match(); self.must_advance = false; while !self.has_matched && self.start < self.end { self.start += 1; start_offset = self.string.offset(start_offset, 1); self.reset(); - stacks.clear(); let ctx = MatchContext { string_position: self.start, @@ -199,697 +213,730 @@ impl<'a> State<'a> { repeat_ctx_id: usize::MAX, }; self.context_stack.push(ctx); - self = self._match(&mut stacks); + self._match(); } - self._stacks = Some(stacks); self } -} -fn dispatch(opcode: SreOpcode, drive: &mut StateContext, stacks: &mut Stacks) { - match opcode { - SreOpcode::FAILURE => { - drive.failure(); - } - SreOpcode::SUCCESS => { - drive.ctx.has_matched = Some(drive.can_success()); - if drive.ctx.has_matched == Some(true) { - drive.state.string_position = drive.ctx.string_position; - } - } - SreOpcode::ANY => { - if drive.at_end() || drive.at_linebreak() { - drive.failure(); - } else { - drive.skip_code(1); - drive.skip_char(1); - } - } - SreOpcode::ANY_ALL => { - if drive.at_end() { - drive.failure(); - } else { - drive.skip_code(1); - drive.skip_char(1); - } - } - SreOpcode::ASSERT => op_assert(drive), - SreOpcode::ASSERT_NOT => op_assert_not(drive), - SreOpcode::AT => { - let atcode = SreAtCode::try_from(drive.peek_code(1)).unwrap(); - if at(drive, atcode) { - drive.skip_code(2); - } else { - drive.failure(); - } - } - SreOpcode::BRANCH => op_branch(drive, stacks), - SreOpcode::CATEGORY => { - let catcode = SreCatCode::try_from(drive.peek_code(1)).unwrap(); - if drive.at_end() || !category(catcode, drive.peek_char()) { - drive.failure(); - } else { - drive.skip_code(2); - drive.skip_char(1); - } - } - SreOpcode::IN => general_op_in(drive, charset), - SreOpcode::IN_IGNORE => general_op_in(drive, |set, c| charset(set, lower_ascii(c))), - SreOpcode::IN_UNI_IGNORE => general_op_in(drive, |set, c| charset(set, lower_unicode(c))), - SreOpcode::IN_LOC_IGNORE => general_op_in(drive, charset_loc_ignore), - SreOpcode::INFO | SreOpcode::JUMP => drive.skip_code_from(1), - SreOpcode::LITERAL => general_op_literal(drive, |code, c| code == c), - SreOpcode::NOT_LITERAL => general_op_literal(drive, |code, c| code != c), - SreOpcode::LITERAL_IGNORE => general_op_literal(drive, |code, c| code == lower_ascii(c)), - SreOpcode::NOT_LITERAL_IGNORE => { - general_op_literal(drive, |code, c| code != lower_ascii(c)) - } - SreOpcode::LITERAL_UNI_IGNORE => { - general_op_literal(drive, |code, c| code == lower_unicode(c)) - } - SreOpcode::NOT_LITERAL_UNI_IGNORE => { - general_op_literal(drive, |code, c| code != lower_unicode(c)) - } - SreOpcode::LITERAL_LOC_IGNORE => general_op_literal(drive, char_loc_ignore), - SreOpcode::NOT_LITERAL_LOC_IGNORE => { - general_op_literal(drive, |code, c| !char_loc_ignore(code, c)) - } - SreOpcode::MARK => { - drive - .state - .set_mark(drive.peek_code(1) as usize, drive.ctx.string_position); - drive.skip_code(2); - } - SreOpcode::MAX_UNTIL => op_max_until(drive, stacks), - SreOpcode::MIN_UNTIL => op_min_until(drive, stacks), - SreOpcode::REPEAT => op_repeat(drive, stacks), - SreOpcode::REPEAT_ONE => op_repeat_one(drive, stacks), - SreOpcode::MIN_REPEAT_ONE => op_min_repeat_one(drive, stacks), - SreOpcode::GROUPREF => general_op_groupref(drive, |x| x), - SreOpcode::GROUPREF_IGNORE => general_op_groupref(drive, lower_ascii), - SreOpcode::GROUPREF_LOC_IGNORE => general_op_groupref(drive, lower_locate), - SreOpcode::GROUPREF_UNI_IGNORE => general_op_groupref(drive, lower_unicode), - SreOpcode::GROUPREF_EXISTS => { - let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); - match (group_start, group_end) { - (Some(start), Some(end)) if start <= end => { - drive.skip_code(3); - } - _ => drive.skip_code_from(2), + fn dispatch(&mut self, opcode: SreOpcode, ctx: &mut MatchContext<'a, S>) { + match opcode { + SreOpcode::FAILURE => { + ctx.has_matched = Some(false); } + SreOpcode::SUCCESS => todo!(), + SreOpcode::ANY => todo!(), + SreOpcode::ANY_ALL => todo!(), + SreOpcode::ASSERT => todo!(), + SreOpcode::ASSERT_NOT => todo!(), + SreOpcode::AT => todo!(), + SreOpcode::BRANCH => todo!(), + SreOpcode::CALL => todo!(), + SreOpcode::CATEGORY => todo!(), + SreOpcode::CHARSET => todo!(), + SreOpcode::BIGCHARSET => todo!(), + SreOpcode::GROUPREF => todo!(), + SreOpcode::GROUPREF_EXISTS => todo!(), + SreOpcode::IN => todo!(), + SreOpcode::INFO => todo!(), + SreOpcode::JUMP => todo!(), + SreOpcode::LITERAL => todo!(), + SreOpcode::MARK => todo!(), + SreOpcode::MAX_UNTIL => todo!(), + SreOpcode::MIN_UNTIL => todo!(), + SreOpcode::NOT_LITERAL => todo!(), + SreOpcode::NEGATE => todo!(), + SreOpcode::RANGE => todo!(), + SreOpcode::REPEAT => todo!(), + SreOpcode::REPEAT_ONE => todo!(), + SreOpcode::SUBPATTERN => todo!(), + SreOpcode::MIN_REPEAT_ONE => todo!(), + SreOpcode::GROUPREF_IGNORE => todo!(), + SreOpcode::IN_IGNORE => todo!(), + SreOpcode::LITERAL_IGNORE => todo!(), + SreOpcode::NOT_LITERAL_IGNORE => todo!(), + SreOpcode::GROUPREF_LOC_IGNORE => todo!(), + SreOpcode::IN_LOC_IGNORE => todo!(), + SreOpcode::LITERAL_LOC_IGNORE => todo!(), + SreOpcode::NOT_LITERAL_LOC_IGNORE => todo!(), + SreOpcode::GROUPREF_UNI_IGNORE => todo!(), + SreOpcode::IN_UNI_IGNORE => todo!(), + SreOpcode::LITERAL_UNI_IGNORE => todo!(), + SreOpcode::NOT_LITERAL_UNI_IGNORE => todo!(), + SreOpcode::RANGE_UNI_IGNORE => todo!(), } - _ => unreachable!("unexpected opcode"), - } -} - -/* assert subpattern */ -/* */ -fn op_assert(drive: &mut StateContext) { - let back = drive.peek_code(2) as usize; - - if drive.ctx.string_position < back { - return drive.failure(); - } - - let offset = drive - .state - .string - .back_offset(drive.ctx.string_offset, back); - let position = drive.ctx.string_position - back; - - drive.state.string_position = position; - - let next_ctx = drive.next_ctx(3, |drive, _| { - if drive.popped_ctx().has_matched == Some(true) { - drive.ctx.handler = None; - drive.skip_code_from(1); - } else { - drive.failure(); - } - }); - next_ctx.string_position = position; - next_ctx.string_offset = offset; - next_ctx.toplevel = false; -} - -/* assert not subpattern */ -/* */ -fn op_assert_not(drive: &mut StateContext) { - let back = drive.peek_code(2) as usize; - - if drive.ctx.string_position < back { - return drive.skip_code_from(1); - } - - let offset = drive - .state - .string - .back_offset(drive.ctx.string_offset, back); - let position = drive.ctx.string_position - back; - - drive.state.string_position = position; - - let next_ctx = drive.next_ctx(3, |drive, _| { - if drive.popped_ctx().has_matched == Some(true) { - drive.failure(); - } else { - drive.ctx.handler = None; - drive.skip_code_from(1); - } - }); - next_ctx.string_position = position; - next_ctx.string_offset = offset; - next_ctx.toplevel = false; -} - -#[derive(Debug)] -struct BranchContext { - branch_offset: usize, -} - -// alternation -// <0=skip> code ... -fn op_branch(drive: &mut StateContext, stacks: &mut Stacks) { - drive.state.marks_push(); - stacks.branch.push(BranchContext { branch_offset: 1 }); - create_context(drive, stacks); - - fn create_context(drive: &mut StateContext, stacks: &mut Stacks) { - let branch_offset = stacks.branch_last().branch_offset; - let next_length = drive.peek_code(branch_offset) as usize; - if next_length == 0 { - drive.state.marks_pop_discard(); - stacks.branch.pop(); - return drive.failure(); - } - - drive.sync_string_position(); - - stacks.branch_last().branch_offset += next_length; - drive.next_ctx(branch_offset + 1, callback); - } - - fn callback(drive: &mut StateContext, stacks: &mut Stacks) { - if drive.popped_ctx().has_matched == Some(true) { - stacks.branch.pop(); - return drive.success(); - } - drive.state.marks_pop_keep(); - drive.ctx.handler = Some(create_context) - } -} - -#[derive(Debug, Copy, Clone)] -struct MinRepeatOneContext { - count: usize, - max_count: usize, -} - -/* <1=min> <2=max> item tail */ -fn op_min_repeat_one(drive: &mut StateContext, stacks: &mut Stacks) { - let min_count = drive.peek_code(2) as usize; - let max_count = drive.peek_code(3) as usize; - - if drive.remaining_chars() < min_count { - return drive.failure(); - } - - drive.sync_string_position(); - - let count = if min_count == 0 { - 0 - } else { - let count = _count(drive, stacks, min_count); - if count < min_count { - return drive.failure(); - } - drive.skip_char(count); - count - }; - - let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); - if next_code == SreOpcode::SUCCESS as u32 && drive.can_success() { - // tail is empty. we're finished - drive.sync_string_position(); - return drive.success(); - } - - drive.state.marks_push(); - stacks - .min_repeat_one - .push(MinRepeatOneContext { count, max_count }); - create_context(drive, stacks); - - fn create_context(drive: &mut StateContext, stacks: &mut Stacks) { - let MinRepeatOneContext { count, max_count } = *stacks.min_repeat_one_last(); - - if max_count == MAXREPEAT || count <= max_count { - drive.sync_string_position(); - drive.next_ctx_from(1, callback); - } else { - drive.state.marks_pop_discard(); - stacks.min_repeat_one.pop(); - drive.failure(); - } - } - - fn callback(drive: &mut StateContext, stacks: &mut Stacks) { - if drive.popped_ctx().has_matched == Some(true) { - stacks.min_repeat_one.pop(); - return drive.success(); - } - - drive.sync_string_position(); - - if _count(drive, stacks, 1) == 0 { - drive.state.marks_pop_discard(); - stacks.min_repeat_one.pop(); - return drive.failure(); - } - - drive.skip_char(1); - stacks.min_repeat_one_last().count += 1; - drive.state.marks_pop_keep(); - create_context(drive, stacks); } } -#[derive(Debug, Copy, Clone)] -struct RepeatOneContext { - count: usize, - min_count: usize, - following_literal: Option, +// fn dispatch(opcode: SreOpcode, drive: &mut StateContext, stacks: &mut Stacks) { +// match opcode { +// SreOpcode::FAILURE => { +// drive.failure(); +// } +// SreOpcode::SUCCESS => { +// drive.ctx.has_matched = Some(drive.can_success()); +// if drive.ctx.has_matched == Some(true) { +// drive.state.string_position = drive.ctx.string_position; +// } +// } +// SreOpcode::ANY => { +// if drive.at_end() || drive.at_linebreak() { +// drive.failure(); +// } else { +// drive.skip_code(1); +// drive.skip_char(1); +// } +// } +// SreOpcode::ANY_ALL => { +// if drive.at_end() { +// drive.failure(); +// } else { +// drive.skip_code(1); +// drive.skip_char(1); +// } +// } +// SreOpcode::ASSERT => op_assert(drive), +// SreOpcode::ASSERT_NOT => op_assert_not(drive), +// SreOpcode::AT => { +// let atcode = SreAtCode::try_from(drive.peek_code(1)).unwrap(); +// if at(drive, atcode) { +// drive.skip_code(2); +// } else { +// drive.failure(); +// } +// } +// SreOpcode::BRANCH => op_branch(drive, stacks), +// SreOpcode::CATEGORY => { +// let catcode = SreCatCode::try_from(drive.peek_code(1)).unwrap(); +// if drive.at_end() || !category(catcode, drive.peek_char()) { +// drive.failure(); +// } else { +// drive.skip_code(2); +// drive.skip_char(1); +// } +// } +// SreOpcode::IN => general_op_in(drive, charset), +// SreOpcode::IN_IGNORE => general_op_in(drive, |set, c| charset(set, lower_ascii(c))), +// SreOpcode::IN_UNI_IGNORE => general_op_in(drive, |set, c| charset(set, lower_unicode(c))), +// SreOpcode::IN_LOC_IGNORE => general_op_in(drive, charset_loc_ignore), +// SreOpcode::INFO | SreOpcode::JUMP => drive.skip_code_from(1), +// SreOpcode::LITERAL => general_op_literal(drive, |code, c| code == c), +// SreOpcode::NOT_LITERAL => general_op_literal(drive, |code, c| code != c), +// SreOpcode::LITERAL_IGNORE => general_op_literal(drive, |code, c| code == lower_ascii(c)), +// SreOpcode::NOT_LITERAL_IGNORE => { +// general_op_literal(drive, |code, c| code != lower_ascii(c)) +// } +// SreOpcode::LITERAL_UNI_IGNORE => { +// general_op_literal(drive, |code, c| code == lower_unicode(c)) +// } +// SreOpcode::NOT_LITERAL_UNI_IGNORE => { +// general_op_literal(drive, |code, c| code != lower_unicode(c)) +// } +// SreOpcode::LITERAL_LOC_IGNORE => general_op_literal(drive, char_loc_ignore), +// SreOpcode::NOT_LITERAL_LOC_IGNORE => { +// general_op_literal(drive, |code, c| !char_loc_ignore(code, c)) +// } +// SreOpcode::MARK => { +// drive +// .state +// .set_mark(drive.peek_code(1) as usize, drive.ctx.string_position); +// drive.skip_code(2); +// } +// SreOpcode::MAX_UNTIL => op_max_until(drive, stacks), +// SreOpcode::MIN_UNTIL => op_min_until(drive, stacks), +// SreOpcode::REPEAT => op_repeat(drive, stacks), +// SreOpcode::REPEAT_ONE => op_repeat_one(drive, stacks), +// SreOpcode::MIN_REPEAT_ONE => op_min_repeat_one(drive, stacks), +// SreOpcode::GROUPREF => general_op_groupref(drive, |x| x), +// SreOpcode::GROUPREF_IGNORE => general_op_groupref(drive, lower_ascii), +// SreOpcode::GROUPREF_LOC_IGNORE => general_op_groupref(drive, lower_locate), +// SreOpcode::GROUPREF_UNI_IGNORE => general_op_groupref(drive, lower_unicode), +// SreOpcode::GROUPREF_EXISTS => { +// let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); +// match (group_start, group_end) { +// (Some(start), Some(end)) if start <= end => { +// drive.skip_code(3); +// } +// _ => drive.skip_code_from(2), +// } +// } +// _ => unreachable!("unexpected opcode"), +// } +// } + +// /* assert subpattern */ +// /* */ +// fn op_assert(drive: &mut StateContext) { +// let back = drive.peek_code(2) as usize; + +// if drive.ctx.string_position < back { +// return drive.failure(); +// } + +// let offset = drive +// .state +// .string +// .back_offset(drive.ctx.string_offset, back); +// let position = drive.ctx.string_position - back; + +// drive.state.string_position = position; + +// let next_ctx = drive.next_ctx(3, |drive, _| { +// if drive.popped_ctx().has_matched == Some(true) { +// drive.ctx.handler = None; +// drive.skip_code_from(1); +// } else { +// drive.failure(); +// } +// }); +// next_ctx.string_position = position; +// next_ctx.string_offset = offset; +// next_ctx.toplevel = false; +// } + +// /* assert not subpattern */ +// /* */ +// fn op_assert_not(drive: &mut StateContext) { +// let back = drive.peek_code(2) as usize; + +// if drive.ctx.string_position < back { +// return drive.skip_code_from(1); +// } + +// let offset = drive +// .state +// .string +// .back_offset(drive.ctx.string_offset, back); +// let position = drive.ctx.string_position - back; + +// drive.state.string_position = position; + +// let next_ctx = drive.next_ctx(3, |drive, _| { +// if drive.popped_ctx().has_matched == Some(true) { +// drive.failure(); +// } else { +// drive.ctx.handler = None; +// drive.skip_code_from(1); +// } +// }); +// next_ctx.string_position = position; +// next_ctx.string_offset = offset; +// next_ctx.toplevel = false; +// } + +// #[derive(Debug)] +// struct BranchContext { +// branch_offset: usize, +// } + +// // alternation +// // <0=skip> code ... +// fn op_branch(drive: &mut StateContext, stacks: &mut Stacks) { +// drive.state.marks_push(); +// stacks.branch.push(BranchContext { branch_offset: 1 }); +// create_context(drive, stacks); + +// fn create_context(drive: &mut StateContext, stacks: &mut Stacks) { +// let branch_offset = stacks.branch_last().branch_offset; +// let next_length = drive.peek_code(branch_offset) as usize; +// if next_length == 0 { +// drive.state.marks_pop_discard(); +// stacks.branch.pop(); +// return drive.failure(); +// } + +// drive.sync_string_position(); + +// stacks.branch_last().branch_offset += next_length; +// drive.next_ctx(branch_offset + 1, callback); +// } + +// fn callback(drive: &mut StateContext, stacks: &mut Stacks) { +// if drive.popped_ctx().has_matched == Some(true) { +// stacks.branch.pop(); +// return drive.success(); +// } +// drive.state.marks_pop_keep(); +// drive.ctx.handler = Some(create_context) +// } +// } + +// #[derive(Debug, Copy, Clone)] +// struct MinRepeatOneContext { +// count: usize, +// max_count: usize, +// } + +// /* <1=min> <2=max> item tail */ +// fn op_min_repeat_one(drive: &mut StateContext, stacks: &mut Stacks) { +// let min_count = drive.peek_code(2) as usize; +// let max_count = drive.peek_code(3) as usize; + +// if drive.remaining_chars() < min_count { +// return drive.failure(); +// } + +// drive.sync_string_position(); + +// let count = if min_count == 0 { +// 0 +// } else { +// let count = _count(drive, stacks, min_count); +// if count < min_count { +// return drive.failure(); +// } +// drive.skip_char(count); +// count +// }; + +// let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); +// if next_code == SreOpcode::SUCCESS as u32 && drive.can_success() { +// // tail is empty. we're finished +// drive.sync_string_position(); +// return drive.success(); +// } + +// drive.state.marks_push(); +// stacks +// .min_repeat_one +// .push(MinRepeatOneContext { count, max_count }); +// create_context(drive, stacks); + +// fn create_context(drive: &mut StateContext, stacks: &mut Stacks) { +// let MinRepeatOneContext { count, max_count } = *stacks.min_repeat_one_last(); + +// if max_count == MAXREPEAT || count <= max_count { +// drive.sync_string_position(); +// drive.next_ctx_from(1, callback); +// } else { +// drive.state.marks_pop_discard(); +// stacks.min_repeat_one.pop(); +// drive.failure(); +// } +// } + +// fn callback(drive: &mut StateContext, stacks: &mut Stacks) { +// if drive.popped_ctx().has_matched == Some(true) { +// stacks.min_repeat_one.pop(); +// return drive.success(); +// } + +// drive.sync_string_position(); + +// if _count(drive, stacks, 1) == 0 { +// drive.state.marks_pop_discard(); +// stacks.min_repeat_one.pop(); +// return drive.failure(); +// } + +// drive.skip_char(1); +// stacks.min_repeat_one_last().count += 1; +// drive.state.marks_pop_keep(); +// create_context(drive, stacks); +// } +// } + +// #[derive(Debug, Copy, Clone)] +// struct RepeatOneContext { +// count: usize, +// min_count: usize, +// following_literal: Option, +// } + +// /* match repeated sequence (maximizing regexp) */ + +// /* this operator only works if the repeated item is +// exactly one character wide, and we're not already +// collecting backtracking points. for other cases, +// use the MAX_REPEAT operator */ + +// /* <1=min> <2=max> item tail */ +// fn op_repeat_one(drive: &mut StateContext, stacks: &mut Stacks) { +// let min_count = drive.peek_code(2) as usize; +// let max_count = drive.peek_code(3) as usize; + +// if drive.remaining_chars() < min_count { +// return drive.failure(); +// } + +// drive.sync_string_position(); + +// let count = _count(drive, stacks, max_count); +// drive.skip_char(count); +// if count < min_count { +// return drive.failure(); +// } + +// let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); +// if next_code == SreOpcode::SUCCESS as u32 && drive.can_success() { +// // tail is empty. we're finished +// drive.sync_string_position(); +// return drive.success(); +// } + +// // Special case: Tail starts with a literal. Skip positions where +// // the rest of the pattern cannot possibly match. +// let following_literal = (next_code == SreOpcode::LITERAL as u32) +// .then(|| drive.peek_code(drive.peek_code(1) as usize + 2)); + +// drive.state.marks_push(); +// stacks.repeat_one.push(RepeatOneContext { +// count, +// min_count, +// following_literal, +// }); +// create_context(drive, stacks); + +// fn create_context(drive: &mut StateContext, stacks: &mut Stacks) { +// let RepeatOneContext { +// mut count, +// min_count, +// following_literal, +// } = *stacks.repeat_one_last(); + +// if let Some(c) = following_literal { +// while drive.at_end() || drive.peek_char() != c { +// if count <= min_count { +// drive.state.marks_pop_discard(); +// stacks.repeat_one.pop(); +// return drive.failure(); +// } +// drive.back_skip_char(1); +// count -= 1; +// } +// } +// stacks.repeat_one_last().count = count; + +// drive.sync_string_position(); + +// // General case: backtracking +// drive.next_ctx_from(1, callback); +// } + +// fn callback(drive: &mut StateContext, stacks: &mut Stacks) { +// if drive.popped_ctx().has_matched == Some(true) { +// stacks.repeat_one.pop(); +// return drive.success(); +// } + +// let RepeatOneContext { +// count, +// min_count, +// following_literal: _, +// } = stacks.repeat_one_last(); + +// if count <= min_count { +// drive.state.marks_pop_discard(); +// stacks.repeat_one.pop(); +// return drive.failure(); +// } + +// drive.back_skip_char(1); +// *count -= 1; + +// drive.state.marks_pop_keep(); +// create_context(drive, stacks); +// } +// } + +// #[derive(Debug, Clone, Copy)] +// struct RepeatContext { +// count: isize, +// min_count: usize, +// max_count: usize, +// code_position: usize, +// last_position: usize, +// prev_id: usize, +// } + +// /* create repeat context. all the hard work is done +// by the UNTIL operator (MAX_UNTIL, MIN_UNTIL) */ +// /* <1=min> <2=max> item tail */ +// fn op_repeat(drive: &mut StateContext, stacks: &mut Stacks) { +// let repeat_ctx = RepeatContext { +// count: -1, +// min_count: drive.peek_code(2) as usize, +// max_count: drive.peek_code(3) as usize, +// code_position: drive.ctx.code_position, +// last_position: std::usize::MAX, +// prev_id: drive.ctx.repeat_ctx_id, +// }; + +// stacks.repeat.push(repeat_ctx); + +// drive.sync_string_position(); + +// let next_ctx = drive.next_ctx_from(1, |drive, stacks| { +// drive.ctx.has_matched = drive.popped_ctx().has_matched; +// stacks.repeat.pop(); +// }); +// next_ctx.repeat_ctx_id = stacks.repeat.len() - 1; +// } + +// #[derive(Debug, Clone, Copy)] +// struct MinUntilContext { +// save_repeat_ctx_id: usize, +// } + +// /* minimizing repeat */ +// fn op_min_until(drive: &mut StateContext, stacks: &mut Stacks) { +// let repeat_ctx = stacks.repeat.last_mut().unwrap(); + +// drive.sync_string_position(); + +// repeat_ctx.count += 1; + +// if (repeat_ctx.count as usize) < repeat_ctx.min_count { +// // not enough matches +// drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { +// if drive.popped_ctx().has_matched == Some(true) { +// drive.success(); +// } else { +// stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; +// drive.sync_string_position(); +// drive.failure(); +// } +// }); +// return; +// } + +// drive.state.marks_push(); + +// stacks.min_until.push(MinUntilContext { +// save_repeat_ctx_id: drive.ctx.repeat_ctx_id, +// }); + +// // see if the tail matches +// let next_ctx = drive.next_ctx(1, |drive, stacks| { +// drive.ctx.repeat_ctx_id = stacks.min_until.pop().unwrap().save_repeat_ctx_id; + +// let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; + +// if drive.popped_ctx().has_matched == Some(true) { +// return drive.success(); +// } + +// drive.sync_string_position(); + +// drive.state.marks_pop(); + +// // match more until tail matches + +// if repeat_ctx.count as usize >= repeat_ctx.max_count && repeat_ctx.max_count != MAXREPEAT +// || drive.state.string_position == repeat_ctx.last_position +// { +// repeat_ctx.count -= 1; +// return drive.failure(); +// } + +// /* zero-width match protection */ +// repeat_ctx.last_position = drive.state.string_position; + +// drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { +// if drive.popped_ctx().has_matched == Some(true) { +// drive.success(); +// } else { +// stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; +// drive.sync_string_position(); +// drive.failure(); +// } +// }); +// }); +// next_ctx.repeat_ctx_id = repeat_ctx.prev_id; +// } + +// #[derive(Debug, Clone, Copy)] +// struct MaxUntilContext { +// save_last_position: usize, +// } + +// /* maximizing repeat */ +// fn op_max_until(drive: &mut StateContext, stacks: &mut Stacks) { +// let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; + +// drive.sync_string_position(); + +// repeat_ctx.count += 1; + +// if (repeat_ctx.count as usize) < repeat_ctx.min_count { +// // not enough matches +// drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { +// if drive.popped_ctx().has_matched == Some(true) { +// drive.success(); +// } else { +// stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; +// drive.sync_string_position(); +// drive.failure(); +// } +// }); +// return; +// } + +// stacks.max_until.push(MaxUntilContext { +// save_last_position: repeat_ctx.last_position, +// }); + +// if ((repeat_ctx.count as usize) < repeat_ctx.max_count || repeat_ctx.max_count == MAXREPEAT) +// && drive.state.string_position != repeat_ctx.last_position +// { +// /* we may have enough matches, but if we can +// match another item, do so */ +// repeat_ctx.last_position = drive.state.string_position; + +// drive.state.marks_push(); + +// drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { +// let save_last_position = stacks.max_until.pop().unwrap().save_last_position; +// let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; +// repeat_ctx.last_position = save_last_position; + +// if drive.popped_ctx().has_matched == Some(true) { +// drive.state.marks_pop_discard(); +// return drive.success(); +// } + +// drive.state.marks_pop(); +// repeat_ctx.count -= 1; +// drive.sync_string_position(); + +// /* cannot match more repeated items here. make sure the +// tail matches */ +// let next_ctx = drive.next_ctx(1, tail_callback); +// next_ctx.repeat_ctx_id = repeat_ctx.prev_id; +// }); +// return; +// } + +// /* cannot match more repeated items here. make sure the +// tail matches */ +// let next_ctx = drive.next_ctx(1, tail_callback); +// next_ctx.repeat_ctx_id = repeat_ctx.prev_id; + +// fn tail_callback(drive: &mut StateContext, _stacks: &mut Stacks) { +// if drive.popped_ctx().has_matched == Some(true) { +// drive.success(); +// } else { +// drive.sync_string_position(); +// drive.failure(); +// } +// } +// } + +// #[derive(Debug, Default)] +// struct Stacks { +// } + +// impl Stacks { +// fn clear(&mut self) { +// self.branch.clear(); +// self.min_repeat_one.clear(); +// self.repeat_one.clear(); +// self.repeat.clear(); +// self.min_until.clear(); +// self.max_until.clear(); +// } + +// fn branch_last(&mut self) -> &mut BranchContext { +// self.branch.last_mut().unwrap() +// } +// fn min_repeat_one_last(&mut self) -> &mut MinRepeatOneContext { +// self.min_repeat_one.last_mut().unwrap() +// } +// fn repeat_one_last(&mut self) -> &mut RepeatOneContext { +// self.repeat_one.last_mut().unwrap() +// } +// } + +pub trait StrDrive { + fn offset(&self, offset: usize, skip: usize) -> usize; + fn count(&self) -> usize; + fn peek(&self, offset: usize) -> u32; + fn back_peek(&self, offset: usize) -> u32; + fn back_offset(&self, offset: usize, skip: usize) -> usize; } -/* match repeated sequence (maximizing regexp) */ - -/* this operator only works if the repeated item is -exactly one character wide, and we're not already -collecting backtracking points. for other cases, -use the MAX_REPEAT operator */ - -/* <1=min> <2=max> item tail */ -fn op_repeat_one(drive: &mut StateContext, stacks: &mut Stacks) { - let min_count = drive.peek_code(2) as usize; - let max_count = drive.peek_code(3) as usize; - - if drive.remaining_chars() < min_count { - return drive.failure(); - } - - drive.sync_string_position(); - let count = _count(drive, stacks, max_count); - drive.skip_char(count); - if count < min_count { - return drive.failure(); +impl<'a> StrDrive for &'a str { + fn offset(&self, offset: usize, skip: usize) -> usize { + self.get(offset..) + .and_then(|s| s.char_indices().nth(skip).map(|x| x.0 + offset)) + .unwrap_or(self.len()) } - let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); - if next_code == SreOpcode::SUCCESS as u32 && drive.can_success() { - // tail is empty. we're finished - drive.sync_string_position(); - return drive.success(); + fn count(&self) -> usize { + self.chars().count() } - // Special case: Tail starts with a literal. Skip positions where - // the rest of the pattern cannot possibly match. - let following_literal = (next_code == SreOpcode::LITERAL as u32) - .then(|| drive.peek_code(drive.peek_code(1) as usize + 2)); - - drive.state.marks_push(); - stacks.repeat_one.push(RepeatOneContext { - count, - min_count, - following_literal, - }); - create_context(drive, stacks); - - fn create_context(drive: &mut StateContext, stacks: &mut Stacks) { - let RepeatOneContext { - mut count, - min_count, - following_literal, - } = *stacks.repeat_one_last(); - - if let Some(c) = following_literal { - while drive.at_end() || drive.peek_char() != c { - if count <= min_count { - drive.state.marks_pop_discard(); - stacks.repeat_one.pop(); - return drive.failure(); - } - drive.back_skip_char(1); - count -= 1; - } - } - stacks.repeat_one_last().count = count; - - drive.sync_string_position(); - - // General case: backtracking - drive.next_ctx_from(1, callback); + fn peek(&self, offset: usize) -> u32 { + unsafe { self.get_unchecked(offset..) } + .chars() + .next() + .unwrap() as u32 } - fn callback(drive: &mut StateContext, stacks: &mut Stacks) { - if drive.popped_ctx().has_matched == Some(true) { - stacks.repeat_one.pop(); - return drive.success(); - } - - let RepeatOneContext { - count, - min_count, - following_literal: _, - } = stacks.repeat_one_last(); - - if count <= min_count { - drive.state.marks_pop_discard(); - stacks.repeat_one.pop(); - return drive.failure(); + fn back_peek(&self, offset: usize) -> u32 { + let bytes = self.as_bytes(); + let back_offset = utf8_back_peek_offset(bytes, offset); + match offset - back_offset { + 1 => u32::from_be_bytes([0, 0, 0, bytes[offset - 1]]), + 2 => u32::from_be_bytes([0, 0, bytes[offset - 2], bytes[offset - 1]]), + 3 => u32::from_be_bytes([0, bytes[offset - 3], bytes[offset - 2], bytes[offset - 1]]), + 4 => u32::from_be_bytes([ + bytes[offset - 4], + bytes[offset - 3], + bytes[offset - 2], + bytes[offset - 1], + ]), + _ => unreachable!(), } - - drive.back_skip_char(1); - *count -= 1; - - drive.state.marks_pop_keep(); - create_context(drive, stacks); } -} - -#[derive(Debug, Clone, Copy)] -struct RepeatContext { - count: isize, - min_count: usize, - max_count: usize, - code_position: usize, - last_position: usize, - prev_id: usize, -} - -/* create repeat context. all the hard work is done -by the UNTIL operator (MAX_UNTIL, MIN_UNTIL) */ -/* <1=min> <2=max> item tail */ -fn op_repeat(drive: &mut StateContext, stacks: &mut Stacks) { - let repeat_ctx = RepeatContext { - count: -1, - min_count: drive.peek_code(2) as usize, - max_count: drive.peek_code(3) as usize, - code_position: drive.ctx.code_position, - last_position: std::usize::MAX, - prev_id: drive.ctx.repeat_ctx_id, - }; - - stacks.repeat.push(repeat_ctx); - - drive.sync_string_position(); - - let next_ctx = drive.next_ctx_from(1, |drive, stacks| { - drive.ctx.has_matched = drive.popped_ctx().has_matched; - stacks.repeat.pop(); - }); - next_ctx.repeat_ctx_id = stacks.repeat.len() - 1; -} - -#[derive(Debug, Clone, Copy)] -struct MinUntilContext { - save_repeat_ctx_id: usize, -} - -/* minimizing repeat */ -fn op_min_until(drive: &mut StateContext, stacks: &mut Stacks) { - let repeat_ctx = stacks.repeat.last_mut().unwrap(); - - drive.sync_string_position(); - - repeat_ctx.count += 1; - - if (repeat_ctx.count as usize) < repeat_ctx.min_count { - // not enough matches - drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { - if drive.popped_ctx().has_matched == Some(true) { - drive.success(); - } else { - stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; - drive.sync_string_position(); - drive.failure(); - } - }); - return; - } - - drive.state.marks_push(); - - stacks.min_until.push(MinUntilContext { - save_repeat_ctx_id: drive.ctx.repeat_ctx_id, - }); - - // see if the tail matches - let next_ctx = drive.next_ctx(1, |drive, stacks| { - drive.ctx.repeat_ctx_id = stacks.min_until.pop().unwrap().save_repeat_ctx_id; - - let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; - - if drive.popped_ctx().has_matched == Some(true) { - return drive.success(); - } - - drive.sync_string_position(); - - drive.state.marks_pop(); - - // match more until tail matches - - if repeat_ctx.count as usize >= repeat_ctx.max_count && repeat_ctx.max_count != MAXREPEAT - || drive.state.string_position == repeat_ctx.last_position - { - repeat_ctx.count -= 1; - return drive.failure(); - } - - /* zero-width match protection */ - repeat_ctx.last_position = drive.state.string_position; - - drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { - if drive.popped_ctx().has_matched == Some(true) { - drive.success(); - } else { - stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; - drive.sync_string_position(); - drive.failure(); - } - }); - }); - next_ctx.repeat_ctx_id = repeat_ctx.prev_id; -} - -#[derive(Debug, Clone, Copy)] -struct MaxUntilContext { - save_last_position: usize, -} - -/* maximizing repeat */ -fn op_max_until(drive: &mut StateContext, stacks: &mut Stacks) { - let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; - - drive.sync_string_position(); - repeat_ctx.count += 1; - - if (repeat_ctx.count as usize) < repeat_ctx.min_count { - // not enough matches - drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { - if drive.popped_ctx().has_matched == Some(true) { - drive.success(); - } else { - stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; - drive.sync_string_position(); - drive.failure(); - } - }); - return; - } - - stacks.max_until.push(MaxUntilContext { - save_last_position: repeat_ctx.last_position, - }); - - if ((repeat_ctx.count as usize) < repeat_ctx.max_count || repeat_ctx.max_count == MAXREPEAT) - && drive.state.string_position != repeat_ctx.last_position - { - /* we may have enough matches, but if we can - match another item, do so */ - repeat_ctx.last_position = drive.state.string_position; - - drive.state.marks_push(); - - drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { - let save_last_position = stacks.max_until.pop().unwrap().save_last_position; - let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; - repeat_ctx.last_position = save_last_position; - - if drive.popped_ctx().has_matched == Some(true) { - drive.state.marks_pop_discard(); - return drive.success(); - } - - drive.state.marks_pop(); - repeat_ctx.count -= 1; - drive.sync_string_position(); - - /* cannot match more repeated items here. make sure the - tail matches */ - let next_ctx = drive.next_ctx(1, tail_callback); - next_ctx.repeat_ctx_id = repeat_ctx.prev_id; - }); - return; - } - - /* cannot match more repeated items here. make sure the - tail matches */ - let next_ctx = drive.next_ctx(1, tail_callback); - next_ctx.repeat_ctx_id = repeat_ctx.prev_id; - - fn tail_callback(drive: &mut StateContext, _stacks: &mut Stacks) { - if drive.popped_ctx().has_matched == Some(true) { - drive.success(); - } else { - drive.sync_string_position(); - drive.failure(); + fn back_offset(&self, offset: usize, skip: usize) -> usize { + let bytes = self.as_bytes(); + let mut back_offset = offset; + for _ in 0..skip { + back_offset = utf8_back_peek_offset(bytes, back_offset); } + back_offset } } -#[derive(Debug, Default)] -struct Stacks { - branch: Vec, - min_repeat_one: Vec, - repeat_one: Vec, - repeat: Vec, - min_until: Vec, - max_until: Vec, -} - -impl Stacks { - fn clear(&mut self) { - self.branch.clear(); - self.min_repeat_one.clear(); - self.repeat_one.clear(); - self.repeat.clear(); - self.min_until.clear(); - self.max_until.clear(); - } - - fn branch_last(&mut self) -> &mut BranchContext { - self.branch.last_mut().unwrap() - } - fn min_repeat_one_last(&mut self) -> &mut MinRepeatOneContext { - self.min_repeat_one.last_mut().unwrap() - } - fn repeat_one_last(&mut self) -> &mut RepeatOneContext { - self.repeat_one.last_mut().unwrap() - } -} - -#[derive(Debug, Clone, Copy)] -pub enum StrDrive<'a> { - Str(&'a str), - Bytes(&'a [u8]), -} - -impl<'a> From<&'a str> for StrDrive<'a> { - fn from(s: &'a str) -> Self { - Self::Str(s) - } -} -impl<'a> From<&'a [u8]> for StrDrive<'a> { - fn from(b: &'a [u8]) -> Self { - Self::Bytes(b) - } -} - -impl<'a> StrDrive<'a> { +impl<'a> StrDrive for &'a [u8] { fn offset(&self, offset: usize, skip: usize) -> usize { - match *self { - StrDrive::Str(s) => s - .get(offset..) - .and_then(|s| s.char_indices().nth(skip).map(|x| x.0 + offset)) - .unwrap_or(s.len()), - StrDrive::Bytes(_) => offset + skip, - } + offset + skip } - pub fn count(&self) -> usize { - match *self { - StrDrive::Str(s) => s.chars().count(), - StrDrive::Bytes(b) => b.len(), - } + fn count(&self) -> usize { + self.len() } fn peek(&self, offset: usize) -> u32 { - match *self { - StrDrive::Str(s) => unsafe { s.get_unchecked(offset..) }.chars().next().unwrap() as u32, - StrDrive::Bytes(b) => b[offset] as u32, - } + self[offset] as u32 } fn back_peek(&self, offset: usize) -> u32 { - match *self { - StrDrive::Str(s) => { - let bytes = s.as_bytes(); - let back_offset = utf8_back_peek_offset(bytes, offset); - match offset - back_offset { - 1 => u32::from_be_bytes([0, 0, 0, bytes[offset - 1]]), - 2 => u32::from_be_bytes([0, 0, bytes[offset - 2], bytes[offset - 1]]), - 3 => u32::from_be_bytes([ - 0, - bytes[offset - 3], - bytes[offset - 2], - bytes[offset - 1], - ]), - 4 => u32::from_be_bytes([ - bytes[offset - 4], - bytes[offset - 3], - bytes[offset - 2], - bytes[offset - 1], - ]), - _ => unreachable!(), - } - } - StrDrive::Bytes(b) => b[offset - 1] as u32, - } + self[offset - 1] as u32 } fn back_offset(&self, offset: usize, skip: usize) -> usize { - match *self { - StrDrive::Str(s) => { - let bytes = s.as_bytes(); - let mut back_offset = offset; - for _ in 0..skip { - back_offset = utf8_back_peek_offset(bytes, back_offset); - } - back_offset - } - StrDrive::Bytes(_) => offset - skip, - } + offset - skip } } -type OpcodeHandler = fn(&mut StateContext, &mut Stacks); +// type OpcodeHandler = for<'a>fn(&mut StateContext<'a, S>, &mut Stacks); #[derive(Clone, Copy)] -struct MatchContext { +struct MatchContext<'a, S: StrDrive> { string_position: usize, string_offset: usize, code_position: usize, has_matched: Option, toplevel: bool, - handler: Option, + handler: Option, &mut Self)>, repeat_ctx_id: usize, } -impl std::fmt::Debug for MatchContext { +impl<'a, S: StrDrive> std::fmt::Debug for MatchContext<'a, S> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("MatchContext") .field("string_position", &self.string_position) @@ -902,164 +949,178 @@ impl std::fmt::Debug for MatchContext { } } -trait ContextDrive { - fn ctx(&self) -> &MatchContext; - fn ctx_mut(&mut self) -> &mut MatchContext; - fn state(&self) -> &State; - - fn popped_ctx(&self) -> &MatchContext { - self.state().popped_context.as_ref().unwrap() +impl<'a, S: StrDrive> MatchContext<'a, S> { + fn remaining_codes(&self, state: &State<'a, S>) -> usize { + state.pattern_codes.len() - self.code_position } - - fn pattern(&self) -> &[u32] { - &self.state().pattern_codes[self.ctx().code_position..] - } - - fn peek_char(&self) -> u32 { - self.state().string.peek(self.ctx().string_offset) - } - fn peek_code(&self, peek: usize) -> u32 { - self.state().pattern_codes[self.ctx().code_position + peek] - } - - fn back_peek_char(&self) -> u32 { - self.state().string.back_peek(self.ctx().string_offset) - } - fn back_skip_char(&mut self, skip_count: usize) { - self.ctx_mut().string_position -= skip_count; - self.ctx_mut().string_offset = self - .state() - .string - .back_offset(self.ctx().string_offset, skip_count); - } - - fn skip_char(&mut self, skip_count: usize) { - self.ctx_mut().string_offset = self - .state() - .string - .offset(self.ctx().string_offset, skip_count); - self.ctx_mut().string_position += skip_count; - } - fn skip_code(&mut self, skip_count: usize) { - self.ctx_mut().code_position += skip_count; - } - fn skip_code_from(&mut self, peek: usize) { - self.skip_code(self.peek_code(peek) as usize + 1); - } - - fn remaining_chars(&self) -> usize { - self.state().end - self.ctx().string_position - } - fn remaining_codes(&self) -> usize { - self.state().pattern_codes.len() - self.ctx().code_position - } - - fn at_beginning(&self) -> bool { - // self.ctx().string_position == self.state().start - self.ctx().string_position == 0 - } - fn at_end(&self) -> bool { - self.ctx().string_position == self.state().end - } - fn at_linebreak(&self) -> bool { - !self.at_end() && is_linebreak(self.peek_char()) - } - fn at_boundary bool>(&self, mut word_checker: F) -> bool { - if self.at_beginning() && self.at_end() { - return false; - } - let that = !self.at_beginning() && word_checker(self.back_peek_char()); - let this = !self.at_end() && word_checker(self.peek_char()); - this != that - } - fn at_non_boundary bool>(&self, mut word_checker: F) -> bool { - if self.at_beginning() && self.at_end() { - return false; - } - let that = !self.at_beginning() && word_checker(self.back_peek_char()); - let this = !self.at_end() && word_checker(self.peek_char()); - this == that - } - - fn can_success(&self) -> bool { - if !self.ctx().toplevel { - return true; - } - if self.state().match_all && !self.at_end() { - return false; - } - if self.state().must_advance && self.ctx().string_position == self.state().start { - return false; - } - true - } - - fn success(&mut self) { - self.ctx_mut().has_matched = Some(true); + + fn peek_code(&self, state: &State<'a, S>, peek: usize) -> u32 { + state.pattern_codes[self.code_position + peek] } fn failure(&mut self) { - self.ctx_mut().has_matched = Some(false); + self.has_matched = Some(false); } } -struct StateContext<'a> { - state: State<'a>, - ctx: MatchContext, - next_ctx: Option, -} - -impl ContextDrive for StateContext<'_> { - fn ctx(&self) -> &MatchContext { - &self.ctx - } - fn ctx_mut(&mut self) -> &mut MatchContext { - &mut self.ctx - } - fn state(&self) -> &State { - &self.state - } -} - -impl StateContext<'_> { - fn next_ctx_from(&mut self, peek: usize, handler: OpcodeHandler) -> &mut MatchContext { - self.next_ctx(self.peek_code(peek) as usize + 1, handler) - } - fn next_ctx(&mut self, offset: usize, handler: OpcodeHandler) -> &mut MatchContext { - self.next_ctx_at(self.ctx.code_position + offset, handler) - } - fn next_ctx_at(&mut self, code_position: usize, handler: OpcodeHandler) -> &mut MatchContext { - self.next_ctx = Some(MatchContext { - code_position, - has_matched: None, - handler: None, - ..self.ctx - }); - self.ctx.handler = Some(handler); - self.next_ctx.as_mut().unwrap() - } - - fn sync_string_position(&mut self) { - self.state.string_position = self.ctx.string_position; - } -} - -struct StateRefContext<'a> { - entity: &'a StateContext<'a>, - ctx: MatchContext, -} - -impl ContextDrive for StateRefContext<'_> { - fn ctx(&self) -> &MatchContext { - &self.ctx - } - fn ctx_mut(&mut self) -> &mut MatchContext { - &mut self.ctx - } - fn state(&self) -> &State { - &self.entity.state - } -} +// trait ContextDrive<'a, T: StrDrive<'a>> { +// fn ctx(&self) -> &MatchContext; +// fn ctx_mut(&mut self) -> &mut MatchContext; +// fn state(&self) -> &State<'a, T>; + +// fn popped_ctx(&self) -> &MatchContext { +// self.state().popped_context.as_ref().unwrap() +// } + +// fn pattern(&self) -> &[u32] { +// &self.state().pattern_codes[self.ctx().code_position..] +// } + +// fn peek_char(&self) -> u32 { +// self.state().string.peek(self.ctx().string_offset) +// } +// fn peek_code(&self, peek: usize) -> u32 { +// self.state().pattern_codes[self.ctx().code_position + peek] +// } + +// fn back_peek_char(&self) -> u32 { +// self.state().string.back_peek(self.ctx().string_offset) +// } +// fn back_skip_char(&mut self, skip_count: usize) { +// self.ctx_mut().string_position -= skip_count; +// self.ctx_mut().string_offset = self +// .state() +// .string +// .back_offset(self.ctx().string_offset, skip_count); +// } + +// fn skip_char(&mut self, skip_count: usize) { +// self.ctx_mut().string_offset = self +// .state() +// .string +// .offset(self.ctx().string_offset, skip_count); +// self.ctx_mut().string_position += skip_count; +// } +// fn skip_code(&mut self, skip_count: usize) { +// self.ctx_mut().code_position += skip_count; +// } +// fn skip_code_from(&mut self, peek: usize) { +// self.skip_code(self.peek_code(peek) as usize + 1); +// } + +// fn remaining_chars(&self) -> usize { +// self.state().end - self.ctx().string_position +// } +// fn remaining_codes(&self) -> usize { +// self.state().pattern_codes.len() - self.ctx().code_position +// } + +// fn at_beginning(&self) -> bool { +// // self.ctx().string_position == self.state().start +// self.ctx().string_position == 0 +// } +// fn at_end(&self) -> bool { +// self.ctx().string_position == self.state().end +// } +// fn at_linebreak(&self) -> bool { +// !self.at_end() && is_linebreak(self.peek_char()) +// } +// fn at_boundary bool>(&self, mut word_checker: F) -> bool { +// if self.at_beginning() && self.at_end() { +// return false; +// } +// let that = !self.at_beginning() && word_checker(self.back_peek_char()); +// let this = !self.at_end() && word_checker(self.peek_char()); +// this != that +// } +// fn at_non_boundary bool>(&self, mut word_checker: F) -> bool { +// if self.at_beginning() && self.at_end() { +// return false; +// } +// let that = !self.at_beginning() && word_checker(self.back_peek_char()); +// let this = !self.at_end() && word_checker(self.peek_char()); +// this == that +// } + +// fn can_success(&self) -> bool { +// if !self.ctx().toplevel { +// return true; +// } +// if self.state().match_all && !self.at_end() { +// return false; +// } +// if self.state().must_advance && self.ctx().string_position == self.state().start { +// return false; +// } +// true +// } + +// fn success(&mut self) { +// self.ctx_mut().has_matched = Some(true); +// } + +// fn failure(&mut self) { +// self.ctx_mut().has_matched = Some(false); +// } +// } + +// struct StateContext<'a, S: StrDrive<'a>> { +// state: State<'a, S>, +// ctx: MatchContext, +// next_ctx: Option, +// } + +// impl<'a, S: StrDrive<'a>> ContextDrive<'a, S> for StateContext<'a, S> { +// fn ctx(&self) -> &MatchContext { +// &self.ctx +// } +// fn ctx_mut(&mut self) -> &mut MatchContext { +// &mut self.ctx +// } +// fn state(&self) -> &State<'a, S> { +// &self.state +// } +// } + +// impl StateContext<'_> { +// fn next_ctx_from(&mut self, peek: usize, handler: OpcodeHandler) -> &mut MatchContext { +// self.next_ctx(self.peek_code(peek) as usize + 1, handler) +// } +// fn next_ctx(&mut self, offset: usize, handler: OpcodeHandler) -> &mut MatchContext { +// self.next_ctx_at(self.ctx.code_position + offset, handler) +// } +// fn next_ctx_at(&mut self, code_position: usize, handler: OpcodeHandler) -> &mut MatchContext { +// self.next_ctx = Some(MatchContext { +// code_position, +// has_matched: None, +// handler: None, +// ..self.ctx +// }); +// self.ctx.handler = Some(handler); +// self.next_ctx.as_mut().unwrap() +// } + +// fn sync_string_position(&mut self) { +// self.state.string_position = self.ctx.string_position; +// } +// } + +// struct StateRefContext<'a> { +// entity: &'a StateContext<'a>, +// ctx: MatchContext, +// } + +// impl ContextDrive for StateRefContext<'_> { +// fn ctx(&self) -> &MatchContext { +// &self.ctx +// } +// fn ctx_mut(&mut self) -> &mut MatchContext { +// &mut self.ctx +// } +// fn state(&self) -> &State { +// &self.entity.state +// } +// } fn char_loc_ignore(code: u32, c: u32) -> bool { code == c || code == lower_locate(c) || code == upper_locate(c) @@ -1074,77 +1135,77 @@ fn charset_loc_ignore(set: &[u32], c: u32) -> bool { up != lo && charset(set, up) } -fn general_op_groupref u32>(drive: &mut StateContext, mut f: F) { - let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); - let (group_start, group_end) = match (group_start, group_end) { - (Some(start), Some(end)) if start <= end => (start, end), - _ => { - return drive.failure(); - } - }; - - let mut wdrive = StateRefContext { - entity: drive, - ctx: drive.ctx, - }; - let mut gdrive = StateRefContext { - entity: drive, - ctx: MatchContext { - string_position: group_start, - // TODO: cache the offset - string_offset: drive.state.string.offset(0, group_start), - ..drive.ctx - }, - }; - - for _ in group_start..group_end { - if wdrive.at_end() || f(wdrive.peek_char()) != f(gdrive.peek_char()) { - return drive.failure(); - } - wdrive.skip_char(1); - gdrive.skip_char(1); - } - - let position = wdrive.ctx.string_position; - let offset = wdrive.ctx.string_offset; - drive.skip_code(2); - drive.ctx.string_position = position; - drive.ctx.string_offset = offset; -} - -fn general_op_literal bool>(drive: &mut StateContext, f: F) { - if drive.at_end() || !f(drive.peek_code(1), drive.peek_char()) { - drive.failure(); - } else { - drive.skip_code(2); - drive.skip_char(1); - } -} - -fn general_op_in bool>(drive: &mut StateContext, f: F) { - if drive.at_end() || !f(&drive.pattern()[2..], drive.peek_char()) { - drive.failure(); - } else { - drive.skip_code_from(1); - drive.skip_char(1); - } -} - -fn at(drive: &StateContext, atcode: SreAtCode) -> bool { - match atcode { - SreAtCode::BEGINNING | SreAtCode::BEGINNING_STRING => drive.at_beginning(), - SreAtCode::BEGINNING_LINE => drive.at_beginning() || is_linebreak(drive.back_peek_char()), - SreAtCode::BOUNDARY => drive.at_boundary(is_word), - SreAtCode::NON_BOUNDARY => drive.at_non_boundary(is_word), - SreAtCode::END => (drive.remaining_chars() == 1 && drive.at_linebreak()) || drive.at_end(), - SreAtCode::END_LINE => drive.at_linebreak() || drive.at_end(), - SreAtCode::END_STRING => drive.at_end(), - SreAtCode::LOC_BOUNDARY => drive.at_boundary(is_loc_word), - SreAtCode::LOC_NON_BOUNDARY => drive.at_non_boundary(is_loc_word), - SreAtCode::UNI_BOUNDARY => drive.at_boundary(is_uni_word), - SreAtCode::UNI_NON_BOUNDARY => drive.at_non_boundary(is_uni_word), - } -} +// fn general_op_groupref u32>(drive: &mut StateContext, mut f: F) { +// let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); +// let (group_start, group_end) = match (group_start, group_end) { +// (Some(start), Some(end)) if start <= end => (start, end), +// _ => { +// return drive.failure(); +// } +// }; + +// let mut wdrive = StateRefContext { +// entity: drive, +// ctx: drive.ctx, +// }; +// let mut gdrive = StateRefContext { +// entity: drive, +// ctx: MatchContext { +// string_position: group_start, +// // TODO: cache the offset +// string_offset: drive.state.string.offset(0, group_start), +// ..drive.ctx +// }, +// }; + +// for _ in group_start..group_end { +// if wdrive.at_end() || f(wdrive.peek_char()) != f(gdrive.peek_char()) { +// return drive.failure(); +// } +// wdrive.skip_char(1); +// gdrive.skip_char(1); +// } + +// let position = wdrive.ctx.string_position; +// let offset = wdrive.ctx.string_offset; +// drive.skip_code(2); +// drive.ctx.string_position = position; +// drive.ctx.string_offset = offset; +// } + +// fn general_op_literal bool>(drive: &mut StateContext, f: F) { +// if drive.at_end() || !f(drive.peek_code(1), drive.peek_char()) { +// drive.failure(); +// } else { +// drive.skip_code(2); +// drive.skip_char(1); +// } +// } + +// fn general_op_in bool>(drive: &mut StateContext, f: F) { +// if drive.at_end() || !f(&drive.pattern()[2..], drive.peek_char()) { +// drive.failure(); +// } else { +// drive.skip_code_from(1); +// drive.skip_char(1); +// } +// } + +// fn at(drive: &StateContext, atcode: SreAtCode) -> bool { +// match atcode { +// SreAtCode::BEGINNING | SreAtCode::BEGINNING_STRING => drive.at_beginning(), +// SreAtCode::BEGINNING_LINE => drive.at_beginning() || is_linebreak(drive.back_peek_char()), +// SreAtCode::BOUNDARY => drive.at_boundary(is_word), +// SreAtCode::NON_BOUNDARY => drive.at_non_boundary(is_word), +// SreAtCode::END => (drive.remaining_chars() == 1 && drive.at_linebreak()) || drive.at_end(), +// SreAtCode::END_LINE => drive.at_linebreak() || drive.at_end(), +// SreAtCode::END_STRING => drive.at_end(), +// SreAtCode::LOC_BOUNDARY => drive.at_boundary(is_loc_word), +// SreAtCode::LOC_NON_BOUNDARY => drive.at_non_boundary(is_loc_word), +// SreAtCode::UNI_BOUNDARY => drive.at_boundary(is_uni_word), +// SreAtCode::UNI_NON_BOUNDARY => drive.at_non_boundary(is_uni_word), +// } +// } fn category(catcode: SreCatCode, c: u32) -> bool { match catcode { @@ -1262,95 +1323,95 @@ fn charset(set: &[u32], ch: u32) -> bool { false } -/* General case */ -fn general_count(drive: &mut StateContext, stacks: &mut Stacks, max_count: usize) -> usize { - let mut count = 0; - let max_count = std::cmp::min(max_count, drive.remaining_chars()); - - let save_ctx = drive.ctx; - drive.skip_code(4); - let reset_position = drive.ctx.code_position; - - while count < max_count { - drive.ctx.code_position = reset_position; - let code = drive.peek_code(0); - let code = SreOpcode::try_from(code).unwrap(); - dispatch(code, drive, stacks); - if drive.ctx.has_matched == Some(false) { - break; - } - count += 1; - } - drive.ctx = save_ctx; - count -} - -fn _count(drive: &mut StateContext, stacks: &mut Stacks, max_count: usize) -> usize { - let save_ctx = drive.ctx; - let max_count = std::cmp::min(max_count, drive.remaining_chars()); - let end = drive.ctx.string_position + max_count; - let opcode = SreOpcode::try_from(drive.peek_code(0)).unwrap(); - - match opcode { - SreOpcode::ANY => { - while !drive.ctx.string_position < end && !drive.at_linebreak() { - drive.skip_char(1); - } - } - SreOpcode::ANY_ALL => { - drive.skip_char(max_count); - } - SreOpcode::IN => { - while !drive.ctx.string_position < end - && charset(&drive.pattern()[2..], drive.peek_char()) - { - drive.skip_char(1); - } - } - SreOpcode::LITERAL => { - general_count_literal(drive, end, |code, c| code == c as u32); - } - SreOpcode::NOT_LITERAL => { - general_count_literal(drive, end, |code, c| code != c as u32); - } - SreOpcode::LITERAL_IGNORE => { - general_count_literal(drive, end, |code, c| code == lower_ascii(c) as u32); - } - SreOpcode::NOT_LITERAL_IGNORE => { - general_count_literal(drive, end, |code, c| code != lower_ascii(c) as u32); - } - SreOpcode::LITERAL_LOC_IGNORE => { - general_count_literal(drive, end, char_loc_ignore); - } - SreOpcode::NOT_LITERAL_LOC_IGNORE => { - general_count_literal(drive, end, |code, c| !char_loc_ignore(code, c)); - } - SreOpcode::LITERAL_UNI_IGNORE => { - general_count_literal(drive, end, |code, c| code == lower_unicode(c) as u32); - } - SreOpcode::NOT_LITERAL_UNI_IGNORE => { - general_count_literal(drive, end, |code, c| code != lower_unicode(c) as u32); - } - _ => { - return general_count(drive, stacks, max_count); - } - } - - let count = drive.ctx.string_position - drive.state.string_position; - drive.ctx = save_ctx; - count -} - -fn general_count_literal bool>( - drive: &mut StateContext, - end: usize, - mut f: F, -) { - let ch = drive.peek_code(1); - while !drive.ctx.string_position < end && f(ch, drive.peek_char()) { - drive.skip_char(1); - } -} +// /* General case */ +// fn general_count(drive: &mut StateContext, stacks: &mut Stacks, max_count: usize) -> usize { +// let mut count = 0; +// let max_count = std::cmp::min(max_count, drive.remaining_chars()); + +// let save_ctx = drive.ctx; +// drive.skip_code(4); +// let reset_position = drive.ctx.code_position; + +// while count < max_count { +// drive.ctx.code_position = reset_position; +// let code = drive.peek_code(0); +// let code = SreOpcode::try_from(code).unwrap(); +// dispatch(code, drive, stacks); +// if drive.ctx.has_matched == Some(false) { +// break; +// } +// count += 1; +// } +// drive.ctx = save_ctx; +// count +// } + +// fn _count(drive: &mut StateContext, stacks: &mut Stacks, max_count: usize) -> usize { +// let save_ctx = drive.ctx; +// let max_count = std::cmp::min(max_count, drive.remaining_chars()); +// let end = drive.ctx.string_position + max_count; +// let opcode = SreOpcode::try_from(drive.peek_code(0)).unwrap(); + +// match opcode { +// SreOpcode::ANY => { +// while !drive.ctx.string_position < end && !drive.at_linebreak() { +// drive.skip_char(1); +// } +// } +// SreOpcode::ANY_ALL => { +// drive.skip_char(max_count); +// } +// SreOpcode::IN => { +// while !drive.ctx.string_position < end +// && charset(&drive.pattern()[2..], drive.peek_char()) +// { +// drive.skip_char(1); +// } +// } +// SreOpcode::LITERAL => { +// general_count_literal(drive, end, |code, c| code == c as u32); +// } +// SreOpcode::NOT_LITERAL => { +// general_count_literal(drive, end, |code, c| code != c as u32); +// } +// SreOpcode::LITERAL_IGNORE => { +// general_count_literal(drive, end, |code, c| code == lower_ascii(c) as u32); +// } +// SreOpcode::NOT_LITERAL_IGNORE => { +// general_count_literal(drive, end, |code, c| code != lower_ascii(c) as u32); +// } +// SreOpcode::LITERAL_LOC_IGNORE => { +// general_count_literal(drive, end, char_loc_ignore); +// } +// SreOpcode::NOT_LITERAL_LOC_IGNORE => { +// general_count_literal(drive, end, |code, c| !char_loc_ignore(code, c)); +// } +// SreOpcode::LITERAL_UNI_IGNORE => { +// general_count_literal(drive, end, |code, c| code == lower_unicode(c) as u32); +// } +// SreOpcode::NOT_LITERAL_UNI_IGNORE => { +// general_count_literal(drive, end, |code, c| code != lower_unicode(c) as u32); +// } +// _ => { +// return general_count(drive, stacks, max_count); +// } +// } + +// let count = drive.ctx.string_position - drive.state.string_position; +// drive.ctx = save_ctx; +// count +// } + +// fn general_count_literal bool>( +// drive: &mut StateContext, +// end: usize, +// mut f: F, +// ) { +// let ch = drive.peek_code(1); +// while !drive.ctx.string_position < end && f(ch, drive.peek_char()) { +// drive.skip_char(1); +// } +// } fn is_word(ch: u32) -> bool { ch == '_' as u32 From 34bde45a2c0906a3a3e03dca79a4426b4cf57655 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 1 Aug 2022 22:19:49 +0200 Subject: [PATCH 055/893] pass compile --- benches/benches.rs | 22 +- src/engine.rs | 1802 ++++++++++++++++++++------------------------ tests/tests.rs | 20 +- 3 files changed, 821 insertions(+), 1023 deletions(-) diff --git a/benches/benches.rs b/benches/benches.rs index b86a592967..d000ceb62e 100644 --- a/benches/benches.rs +++ b/benches/benches.rs @@ -11,12 +11,12 @@ pub struct Pattern { } impl Pattern { - pub fn state<'a>( + pub fn state<'a, S: engine::StrDrive>( &self, - string: impl Into>, + string: S, range: std::ops::Range, - ) -> engine::State<'a> { - engine::State::new(string.into(), range.start, range.end, self.flags, self.code) + ) -> engine::State<'a, S> { + engine::State::new(string, range.start, range.end, self.flags, self.code) } } #[bench] @@ -84,28 +84,28 @@ fn benchmarks(b: &mut Bencher) { b.iter(move || { for (p, s) in &tests { let mut state = p.state(s.clone(), 0..usize::MAX); - state = state.search(); + state.search(); assert!(state.has_matched); state = p.state(s.clone(), 0..usize::MAX); - state = state.pymatch(); + state.pymatch(); assert!(state.has_matched); state = p.state(s.clone(), 0..usize::MAX); state.match_all = true; - state = state.pymatch(); + state.pymatch(); assert!(state.has_matched); let s2 = format!("{}{}{}", " ".repeat(10000), s, " ".repeat(10000)); state = p.state(s2.as_str(), 0..usize::MAX); - state = state.search(); + state.search(); assert!(state.has_matched); state = p.state(s2.as_str(), 10000..usize::MAX); - state = state.pymatch(); + state.pymatch(); assert!(state.has_matched); state = p.state(s2.as_str(), 10000..10000 + s.len()); - state = state.pymatch(); + state.pymatch(); assert!(state.has_matched); state = p.state(s2.as_str(), 10000..10000 + s.len()); state.match_all = true; - state = state.pymatch(); + state.pymatch(); assert!(state.has_matched); } }) diff --git a/src/engine.rs b/src/engine.rs index 8865eb6a39..b0717e1671 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -19,21 +19,45 @@ pub struct State<'a, S: StrDrive> { pub lastindex: isize, marks_stack: Vec<(Vec>, isize)>, context_stack: Vec>, - // branch_stack: Vec, - // min_repeat_one_stack: Vec, - // repeat_one_stack: Vec, - // repeat_stack: Vec, - // min_until_stack: Vec, - // max_until_stack: Vec, - // _stacks: Option>, + repeat_stack: Vec, pub string_position: usize, - popped_context: Option>, next_context: Option>, + popped_has_matched: bool, pub has_matched: bool, pub match_all: bool, pub must_advance: bool, } +macro_rules! next_ctx { + (offset $offset:expr, $state:expr, $ctx:expr, $handler:expr) => { + next_ctx!(position $ctx.code_position + $offset, $state, $ctx, $handler) + }; + (from $peek:expr, $state:expr, $ctx:expr, $handler:expr) => { + next_ctx!(position $ctx.peek_code($state, $peek) as usize + 1, $state, $ctx, $handler) + }; + (position $position:expr, $state:expr, $ctx:expr, $handler:expr) => { + {$state.next_context.insert(MatchContext { + code_position: $position, + has_matched: None, + handler: Some($handler), + ..*$ctx + })} + }; +} + +macro_rules! mark { + (push, $state:expr) => { + $state + .marks_stack + .push(($state.marks.clone(), $state.lastindex)) + }; + (pop, $state:expr) => { + let (marks, lastindex) = $state.marks_stack.pop().unwrap(); + $state.marks = marks; + $state.lastindex = lastindex; + }; +} + impl<'a, S: StrDrive> State<'a, S> { pub fn new( string: S, @@ -54,16 +78,10 @@ impl<'a, S: StrDrive> State<'a, S> { lastindex: -1, marks_stack: Vec::new(), context_stack: Vec::new(), - // branch_stack: Vec::new(), - // min_repeat_one_stack: Vec::new(), - // repeat_one_stack: Vec::new(), - // repeat_stack: Vec::new(), - // min_until_stack: Vec::new(), - // max_until_stack: Vec::new(), - // _stacks: Default::default(), + repeat_stack: Vec::new(), string_position: start, - popped_context: None, next_context: None, + popped_has_matched: false, has_matched: false, match_all: false, must_advance: false, @@ -75,15 +93,10 @@ impl<'a, S: StrDrive> State<'a, S> { self.marks.clear(); self.marks_stack.clear(); self.context_stack.clear(); - // self.branch_stack.clear(); - // self.min_repeat_one_stack.clear(); - // self.repeat_one_stack.clear(); - // self.repeat_stack.clear(); - // self.min_until_stack.clear(); - // self.max_until_stack.clear(); + self.repeat_stack.clear(); self.string_position = self.start; - self.popped_context = None; self.next_context = None; + self.popped_has_matched = false; self.has_matched = false; } @@ -104,14 +117,14 @@ impl<'a, S: StrDrive> State<'a, S> { (None, None) } } - fn marks_push(&mut self) { - self.marks_stack.push((self.marks.clone(), self.lastindex)); - } - fn marks_pop(&mut self) { - let (marks, lastindex) = self.marks_stack.pop().unwrap(); - self.marks = marks; - self.lastindex = lastindex; - } + // fn marks_push(&mut self) { + // self.marks_stack.push((self.marks.clone(), self.lastindex)); + // } + // fn marks_pop(&mut self) { + // let (marks, lastindex) = self.marks_stack.pop().unwrap(); + // self.marks = marks; + // self.lastindex = lastindex; + // } fn marks_pop_keep(&mut self) { let (marks, lastindex) = self.marks_stack.last().unwrap().clone(); self.marks = marks; @@ -121,46 +134,31 @@ impl<'a, S: StrDrive> State<'a, S> { self.marks_stack.pop(); } - fn _match(&mut self) { + fn _match(&mut self) { while let Some(mut ctx) = self.context_stack.pop() { - // let mut drive = StateContext { - // state: self, - // ctx, - // next_ctx: None, - // }; - // let mut state = self; - - if let Some(handler) = ctx.handler { + if let Some(handler) = ctx.handler.take() { handler(self, &mut ctx); } else if ctx.remaining_codes(self) > 0 { let code = ctx.peek_code(self, 0); let code = SreOpcode::try_from(code).unwrap(); - self.dispatch(code, &mut ctx); + dispatch(self, &mut ctx, code); } else { ctx.failure(); } - // let StateContext { - // mut state, - // ctx, - // next_ctx, - // } = drive; - - if ctx.has_matched.is_some() { - self.popped_context = Some(ctx); + if let Some(has_matched) = ctx.has_matched { + self.popped_has_matched = has_matched; } else { self.context_stack.push(ctx); if let Some(next_ctx) = self.next_context.take() { self.context_stack.push(next_ctx); } } - // self = state } - self.has_matched = self.popped_context.take().unwrap().has_matched == Some(true); - // self + self.has_matched = self.popped_has_matched; } - pub fn pymatch(mut self) -> Self { + pub fn pymatch(&mut self) { let ctx = MatchContext { string_position: self.start, string_offset: self.string.offset(0, self.start), @@ -169,18 +167,18 @@ impl<'a, S: StrDrive> State<'a, S> { toplevel: true, handler: None, repeat_ctx_id: usize::MAX, + count: -1, }; self.context_stack.push(ctx); self._match(); - self } - pub fn search(mut self) -> Self { + pub fn search(&mut self) { // TODO: optimize by op info and skip prefix if self.start > self.end { - return self; + return; } let mut start_offset = self.string.offset(0, self.start); @@ -193,6 +191,7 @@ impl<'a, S: StrDrive> State<'a, S> { toplevel: true, handler: None, repeat_ctx_id: usize::MAX, + count: -1, }; self.context_stack.push(ctx); self._match(); @@ -211,643 +210,503 @@ impl<'a, S: StrDrive> State<'a, S> { toplevel: false, handler: None, repeat_ctx_id: usize::MAX, + count: -1, }; self.context_stack.push(ctx); self._match(); } + } +} + +fn dispatch<'a, S: StrDrive>( + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, + opcode: SreOpcode, +) { + match opcode { + SreOpcode::FAILURE => { + ctx.failure(); + } + SreOpcode::SUCCESS => { + if ctx.can_success(state) { + state.string_position = ctx.string_position; + ctx.success(); + } else { + ctx.failure(); + } + } + SreOpcode::ANY => { + if ctx.at_end(state) || ctx.at_linebreak(state) { + ctx.failure(); + } else { + ctx.skip_code(1); + ctx.skip_char(state, 1); + } + } + SreOpcode::ANY_ALL => { + if ctx.at_end(state) { + ctx.failure(); + } else { + ctx.skip_code(1); + ctx.skip_char(state, 1); + } + } + /* assert subpattern */ + /* */ + SreOpcode::ASSERT => op_assert(state, ctx), + SreOpcode::ASSERT_NOT => op_assert_not(state, ctx), + SreOpcode::AT => { + let atcode = SreAtCode::try_from(ctx.peek_code(state, 1)).unwrap(); + if at(state, ctx, atcode) { + ctx.skip_code(2); + } else { + ctx.failure(); + } + } + SreOpcode::BRANCH => op_branch(state, ctx), + SreOpcode::CATEGORY => { + let catcode = SreCatCode::try_from(ctx.peek_code(state, 1)).unwrap(); + if ctx.at_end(state) || !category(catcode, ctx.peek_char(state)) { + ctx.failure(); + } else { + ctx.skip_code(2); + ctx.skip_char(state, 1); + } + } + SreOpcode::IN => general_op_in(state, ctx, charset), + SreOpcode::IN_IGNORE => general_op_in(state, ctx, |set, c| charset(set, lower_ascii(c))), + SreOpcode::IN_UNI_IGNORE => { + general_op_in(state, ctx, |set, c| charset(set, lower_unicode(c))) + } + SreOpcode::IN_LOC_IGNORE => general_op_in(state, ctx, charset_loc_ignore), + SreOpcode::INFO | SreOpcode::JUMP => ctx.skip_code_from(state, 1), + SreOpcode::LITERAL => general_op_literal(state, ctx, |code, c| code == c), + SreOpcode::NOT_LITERAL => general_op_literal(state, ctx, |code, c| code != c), + SreOpcode::LITERAL_IGNORE => { + general_op_literal(state, ctx, |code, c| code == lower_ascii(c)) + } + SreOpcode::NOT_LITERAL_IGNORE => { + general_op_literal(state, ctx, |code, c| code != lower_ascii(c)) + } + SreOpcode::LITERAL_UNI_IGNORE => { + general_op_literal(state, ctx, |code, c| code == lower_unicode(c)) + } + SreOpcode::NOT_LITERAL_UNI_IGNORE => { + general_op_literal(state, ctx, |code, c| code != lower_unicode(c)) + } + SreOpcode::LITERAL_LOC_IGNORE => general_op_literal(state, ctx, char_loc_ignore), + SreOpcode::NOT_LITERAL_LOC_IGNORE => { + general_op_literal(state, ctx, |code, c| !char_loc_ignore(code, c)) + } + SreOpcode::MARK => { + state.set_mark(ctx.peek_code(state, 1) as usize, ctx.string_position); + ctx.skip_code(2); + } + SreOpcode::MAX_UNTIL => op_max_until(state, ctx), + SreOpcode::MIN_UNTIL => op_min_until(state, ctx), + SreOpcode::REPEAT => op_repeat(state, ctx), + SreOpcode::REPEAT_ONE => op_repeat_one(state, ctx), + SreOpcode::MIN_REPEAT_ONE => op_min_repeat_one(state, ctx), + SreOpcode::GROUPREF => general_op_groupref(state, ctx, |x| x), + SreOpcode::GROUPREF_IGNORE => general_op_groupref(state, ctx, lower_ascii), + SreOpcode::GROUPREF_LOC_IGNORE => general_op_groupref(state, ctx, lower_locate), + SreOpcode::GROUPREF_UNI_IGNORE => general_op_groupref(state, ctx, lower_unicode), + SreOpcode::GROUPREF_EXISTS => { + let (group_start, group_end) = state.get_marks(ctx.peek_code(state, 1) as usize); + match (group_start, group_end) { + (Some(start), Some(end)) if start <= end => { + ctx.skip_code(3); + } + _ => ctx.skip_code_from(state, 2), + } + } + _ => unreachable!("unexpected opcode"), + } +} - self +/* assert subpattern */ +/* */ +fn op_assert<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + let back = ctx.peek_code(state, 2) as usize; + if ctx.string_position < back { + return ctx.failure(); } - fn dispatch(&mut self, opcode: SreOpcode, ctx: &mut MatchContext<'a, S>) { - match opcode { - SreOpcode::FAILURE => { - ctx.has_matched = Some(false); + // let next_ctx = state.next_ctx(ctx, 3, |state, ctx| { + let next_ctx = next_ctx!(offset 3, state, ctx, |state, ctx| { + if state.popped_has_matched { + ctx.skip_code_from(state, 1); + } else { + ctx.failure(); + } + }); + next_ctx.back_skip_char(&state.string, back); + state.string_position = next_ctx.string_position; + next_ctx.toplevel = false; +} + +/* assert not subpattern */ +/* */ +fn op_assert_not<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + let back = ctx.peek_code(state, 2) as usize; + + if ctx.string_position < back { + return ctx.skip_code_from(state, 1); + } + + let next_ctx = next_ctx!(offset 3, state, ctx, |state, ctx| { + if state.popped_has_matched { + ctx.failure(); + } else { + ctx.skip_code_from(state, 1); + } + }); + next_ctx.back_skip_char(&state.string, back); + state.string_position = next_ctx.string_position; + next_ctx.toplevel = false; +} + +// alternation +// <0=skip> code ... +fn op_branch<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + // state.marks_push(); + mark!(push, state); + + ctx.count = 1; + create_context(state, ctx); + + fn create_context<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + let branch_offset = ctx.count as usize; + let next_length = ctx.peek_code(state, branch_offset) as isize; + if next_length == 0 { + state.marks_pop_discard(); + return ctx.failure(); + } + + state.string_position = ctx.string_position; + + ctx.count += next_length; + next_ctx!(offset branch_offset + 1, state, ctx, callback); + } + + fn callback<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + if state.popped_has_matched { + return ctx.success(); + } + state.marks_pop_keep(); + create_context(state, ctx); + } +} + +/* <1=min> <2=max> item tail */ +fn op_min_repeat_one<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + let min_count = ctx.peek_code(state, 2) as usize; + // let max_count = ctx.peek_code(state, 3) as usize; + + if ctx.remaining_chars(state) < min_count { + return ctx.failure(); + } + + state.string_position = ctx.string_position; + + ctx.count = if min_count == 0 { + 0 + } else { + let count = _count(state, ctx, min_count); + if count < min_count { + return ctx.failure(); + } + ctx.skip_char(state, count); + count as isize + }; + + let next_code = ctx.peek_code(state, ctx.peek_code(state, 1) as usize + 1); + if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(state) { + // tail is empty. we're finished + state.string_position = ctx.string_position; + return ctx.success(); + } + + mark!(push, state); + create_context(state, ctx); + + fn create_context<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + let max_count = ctx.peek_code(state, 3) as usize; + + if max_count == MAXREPEAT || ctx.count as usize <= max_count { + state.string_position = ctx.string_position; + next_ctx!(from 1, state, ctx, callback); + } else { + state.marks_pop_discard(); + ctx.failure(); + } + } + + fn callback<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + if state.popped_has_matched { + return ctx.success(); + } + + state.string_position = ctx.string_position; + + if _count(state, ctx, 1) == 0 { + state.marks_pop_discard(); + return ctx.failure(); + } + + ctx.skip_char(state, 1); + ctx.count += 1; + state.marks_pop_keep(); + create_context(state, ctx); + } +} + +/* match repeated sequence (maximizing regexp) */ +/* this operator only works if the repeated item is +exactly one character wide, and we're not already +collecting backtracking points. for other cases, +use the MAX_REPEAT operator */ +/* <1=min> <2=max> item tail */ +fn op_repeat_one<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + let min_count = ctx.peek_code(state, 2) as usize; + let max_count = ctx.peek_code(state, 3) as usize; + + if ctx.remaining_chars(state) < min_count { + return ctx.failure(); + } + + state.string_position = ctx.string_position; + + let count = _count(state, ctx, max_count); + ctx.skip_char(state, count); + if count < min_count { + return ctx.failure(); + } + + let next_code = ctx.peek_code(state, ctx.peek_code(state, 1) as usize + 1); + if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(state) { + // tail is empty. we're finished + state.string_position = ctx.string_position; + return ctx.success(); + } + + mark!(push, state); + ctx.count = count as isize; + create_context(state, ctx); + + fn create_context<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + let min_count = ctx.peek_code(state, 2) as isize; + let next_code = ctx.peek_code(state, ctx.peek_code(state, 1) as usize + 1); + if next_code == SreOpcode::LITERAL as u32 { + // Special case: Tail starts with a literal. Skip positions where + // the rest of the pattern cannot possibly match. + let c = ctx.peek_code(state, ctx.peek_code(state, 1) as usize + 2); + while ctx.at_end(state) || ctx.peek_char(state) != c { + if ctx.count <= min_count { + state.marks_pop_discard(); + return ctx.failure(); + } + ctx.back_skip_char(&state.string, 1); + ctx.count -= 1; + } + } + + state.string_position = ctx.string_position; + + // General case: backtracking + next_ctx!(from 1, state, ctx, callback); + } + + fn callback<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + if state.popped_has_matched { + return ctx.success(); + } + + let min_count = ctx.peek_code(state, 2) as isize; + + if ctx.count <= min_count { + state.marks_pop_discard(); + return ctx.failure(); + } + + ctx.back_skip_char(&state.string, 1); + ctx.count -= 1; + + state.marks_pop_keep(); + create_context(state, ctx); + } +} + +#[derive(Debug, Clone, Copy)] +struct RepeatContext { + count: isize, + min_count: usize, + max_count: usize, + code_position: usize, + last_position: usize, + prev_id: usize, +} + +/* create repeat context. all the hard work is done +by the UNTIL operator (MAX_UNTIL, MIN_UNTIL) */ +/* <1=min> <2=max> item tail */ +fn op_repeat<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + let repeat_ctx = RepeatContext { + count: -1, + min_count: ctx.peek_code(state, 2) as usize, + max_count: ctx.peek_code(state, 3) as usize, + code_position: ctx.code_position, + last_position: std::usize::MAX, + prev_id: ctx.repeat_ctx_id, + }; + + state.repeat_stack.push(repeat_ctx); + + state.string_position = ctx.string_position; + + let next_ctx = next_ctx!(from 1, state, ctx, |state, ctx| { + ctx.has_matched = Some(state.popped_has_matched); + state.repeat_stack.pop(); + }); + next_ctx.repeat_ctx_id = state.repeat_stack.len() - 1; +} + +/* minimizing repeat */ +fn op_min_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + let repeat_ctx = state.repeat_stack.last_mut().unwrap(); + + state.string_position = ctx.string_position; + + repeat_ctx.count += 1; + + if (repeat_ctx.count as usize) < repeat_ctx.min_count { + // not enough matches + next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { + if state.popped_has_matched { + ctx.success(); + } else { + state.repeat_stack[ctx.repeat_ctx_id].count -= 1; + state.string_position = ctx.string_position; + ctx.failure(); + } + }); + return; + } + + mark!(push, state); + + ctx.count = ctx.repeat_ctx_id as isize; + + // see if the tail matches + let next_ctx = next_ctx!(offset 1, state, ctx, |state, ctx| { + if state.popped_has_matched { + return ctx.success(); + } + + ctx.repeat_ctx_id = ctx.count as usize; + + let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; + + state.string_position = ctx.string_position; + + mark!(pop, state); + + // match more until tail matches + + if repeat_ctx.count as usize >= repeat_ctx.max_count && repeat_ctx.max_count != MAXREPEAT + || state.string_position == repeat_ctx.last_position + { + repeat_ctx.count -= 1; + return ctx.failure(); + } + + /* zero-width match protection */ + repeat_ctx.last_position = state.string_position; + + next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { + if state.popped_has_matched { + ctx.success(); + } else { + state.repeat_stack[ctx.repeat_ctx_id].count -= 1; + state.string_position = ctx.string_position; + ctx.failure(); + } + }); + }); + next_ctx.repeat_ctx_id = repeat_ctx.prev_id; +} + +/* maximizing repeat */ +fn op_max_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; + + state.string_position = ctx.string_position; + + repeat_ctx.count += 1; + + if (repeat_ctx.count as usize) < repeat_ctx.min_count { + // not enough matches + next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { + if state.popped_has_matched { + ctx.success(); + } else { + state.repeat_stack[ctx.repeat_ctx_id].count -= 1; + state.string_position = ctx.string_position; + ctx.failure(); + } + }); + return; + } + + if ((repeat_ctx.count as usize) < repeat_ctx.max_count || repeat_ctx.max_count == MAXREPEAT) + && state.string_position != repeat_ctx.last_position + { + /* we may have enough matches, but if we can + match another item, do so */ + mark!(push, state); + + ctx.count = repeat_ctx.last_position as isize; + repeat_ctx.last_position = state.string_position; + + next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { + let save_last_position = ctx.count as usize; + let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; + repeat_ctx.last_position = save_last_position; + + if state.popped_has_matched { + state.marks_pop_discard(); + return ctx.success(); } - SreOpcode::SUCCESS => todo!(), - SreOpcode::ANY => todo!(), - SreOpcode::ANY_ALL => todo!(), - SreOpcode::ASSERT => todo!(), - SreOpcode::ASSERT_NOT => todo!(), - SreOpcode::AT => todo!(), - SreOpcode::BRANCH => todo!(), - SreOpcode::CALL => todo!(), - SreOpcode::CATEGORY => todo!(), - SreOpcode::CHARSET => todo!(), - SreOpcode::BIGCHARSET => todo!(), - SreOpcode::GROUPREF => todo!(), - SreOpcode::GROUPREF_EXISTS => todo!(), - SreOpcode::IN => todo!(), - SreOpcode::INFO => todo!(), - SreOpcode::JUMP => todo!(), - SreOpcode::LITERAL => todo!(), - SreOpcode::MARK => todo!(), - SreOpcode::MAX_UNTIL => todo!(), - SreOpcode::MIN_UNTIL => todo!(), - SreOpcode::NOT_LITERAL => todo!(), - SreOpcode::NEGATE => todo!(), - SreOpcode::RANGE => todo!(), - SreOpcode::REPEAT => todo!(), - SreOpcode::REPEAT_ONE => todo!(), - SreOpcode::SUBPATTERN => todo!(), - SreOpcode::MIN_REPEAT_ONE => todo!(), - SreOpcode::GROUPREF_IGNORE => todo!(), - SreOpcode::IN_IGNORE => todo!(), - SreOpcode::LITERAL_IGNORE => todo!(), - SreOpcode::NOT_LITERAL_IGNORE => todo!(), - SreOpcode::GROUPREF_LOC_IGNORE => todo!(), - SreOpcode::IN_LOC_IGNORE => todo!(), - SreOpcode::LITERAL_LOC_IGNORE => todo!(), - SreOpcode::NOT_LITERAL_LOC_IGNORE => todo!(), - SreOpcode::GROUPREF_UNI_IGNORE => todo!(), - SreOpcode::IN_UNI_IGNORE => todo!(), - SreOpcode::LITERAL_UNI_IGNORE => todo!(), - SreOpcode::NOT_LITERAL_UNI_IGNORE => todo!(), - SreOpcode::RANGE_UNI_IGNORE => todo!(), + + mark!(pop, state); + repeat_ctx.count -= 1; + + state.string_position = ctx.string_position; + + /* cannot match more repeated items here. make sure the + tail matches */ + let next_ctx = next_ctx!(offset 1, state, ctx, tail_callback); + next_ctx.repeat_ctx_id = repeat_ctx.prev_id; + }); + return; + } + + /* cannot match more repeated items here. make sure the + tail matches */ + let next_ctx = next_ctx!(offset 1, state, ctx, tail_callback); + next_ctx.repeat_ctx_id = repeat_ctx.prev_id; + + fn tail_callback<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + if state.popped_has_matched { + ctx.success(); + } else { + state.string_position = ctx.string_position; + ctx.failure(); } } } -// fn dispatch(opcode: SreOpcode, drive: &mut StateContext, stacks: &mut Stacks) { -// match opcode { -// SreOpcode::FAILURE => { -// drive.failure(); -// } -// SreOpcode::SUCCESS => { -// drive.ctx.has_matched = Some(drive.can_success()); -// if drive.ctx.has_matched == Some(true) { -// drive.state.string_position = drive.ctx.string_position; -// } -// } -// SreOpcode::ANY => { -// if drive.at_end() || drive.at_linebreak() { -// drive.failure(); -// } else { -// drive.skip_code(1); -// drive.skip_char(1); -// } -// } -// SreOpcode::ANY_ALL => { -// if drive.at_end() { -// drive.failure(); -// } else { -// drive.skip_code(1); -// drive.skip_char(1); -// } -// } -// SreOpcode::ASSERT => op_assert(drive), -// SreOpcode::ASSERT_NOT => op_assert_not(drive), -// SreOpcode::AT => { -// let atcode = SreAtCode::try_from(drive.peek_code(1)).unwrap(); -// if at(drive, atcode) { -// drive.skip_code(2); -// } else { -// drive.failure(); -// } -// } -// SreOpcode::BRANCH => op_branch(drive, stacks), -// SreOpcode::CATEGORY => { -// let catcode = SreCatCode::try_from(drive.peek_code(1)).unwrap(); -// if drive.at_end() || !category(catcode, drive.peek_char()) { -// drive.failure(); -// } else { -// drive.skip_code(2); -// drive.skip_char(1); -// } -// } -// SreOpcode::IN => general_op_in(drive, charset), -// SreOpcode::IN_IGNORE => general_op_in(drive, |set, c| charset(set, lower_ascii(c))), -// SreOpcode::IN_UNI_IGNORE => general_op_in(drive, |set, c| charset(set, lower_unicode(c))), -// SreOpcode::IN_LOC_IGNORE => general_op_in(drive, charset_loc_ignore), -// SreOpcode::INFO | SreOpcode::JUMP => drive.skip_code_from(1), -// SreOpcode::LITERAL => general_op_literal(drive, |code, c| code == c), -// SreOpcode::NOT_LITERAL => general_op_literal(drive, |code, c| code != c), -// SreOpcode::LITERAL_IGNORE => general_op_literal(drive, |code, c| code == lower_ascii(c)), -// SreOpcode::NOT_LITERAL_IGNORE => { -// general_op_literal(drive, |code, c| code != lower_ascii(c)) -// } -// SreOpcode::LITERAL_UNI_IGNORE => { -// general_op_literal(drive, |code, c| code == lower_unicode(c)) -// } -// SreOpcode::NOT_LITERAL_UNI_IGNORE => { -// general_op_literal(drive, |code, c| code != lower_unicode(c)) -// } -// SreOpcode::LITERAL_LOC_IGNORE => general_op_literal(drive, char_loc_ignore), -// SreOpcode::NOT_LITERAL_LOC_IGNORE => { -// general_op_literal(drive, |code, c| !char_loc_ignore(code, c)) -// } -// SreOpcode::MARK => { -// drive -// .state -// .set_mark(drive.peek_code(1) as usize, drive.ctx.string_position); -// drive.skip_code(2); -// } -// SreOpcode::MAX_UNTIL => op_max_until(drive, stacks), -// SreOpcode::MIN_UNTIL => op_min_until(drive, stacks), -// SreOpcode::REPEAT => op_repeat(drive, stacks), -// SreOpcode::REPEAT_ONE => op_repeat_one(drive, stacks), -// SreOpcode::MIN_REPEAT_ONE => op_min_repeat_one(drive, stacks), -// SreOpcode::GROUPREF => general_op_groupref(drive, |x| x), -// SreOpcode::GROUPREF_IGNORE => general_op_groupref(drive, lower_ascii), -// SreOpcode::GROUPREF_LOC_IGNORE => general_op_groupref(drive, lower_locate), -// SreOpcode::GROUPREF_UNI_IGNORE => general_op_groupref(drive, lower_unicode), -// SreOpcode::GROUPREF_EXISTS => { -// let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); -// match (group_start, group_end) { -// (Some(start), Some(end)) if start <= end => { -// drive.skip_code(3); -// } -// _ => drive.skip_code_from(2), -// } -// } -// _ => unreachable!("unexpected opcode"), -// } -// } - -// /* assert subpattern */ -// /* */ -// fn op_assert(drive: &mut StateContext) { -// let back = drive.peek_code(2) as usize; - -// if drive.ctx.string_position < back { -// return drive.failure(); -// } - -// let offset = drive -// .state -// .string -// .back_offset(drive.ctx.string_offset, back); -// let position = drive.ctx.string_position - back; - -// drive.state.string_position = position; - -// let next_ctx = drive.next_ctx(3, |drive, _| { -// if drive.popped_ctx().has_matched == Some(true) { -// drive.ctx.handler = None; -// drive.skip_code_from(1); -// } else { -// drive.failure(); -// } -// }); -// next_ctx.string_position = position; -// next_ctx.string_offset = offset; -// next_ctx.toplevel = false; -// } - -// /* assert not subpattern */ -// /* */ -// fn op_assert_not(drive: &mut StateContext) { -// let back = drive.peek_code(2) as usize; - -// if drive.ctx.string_position < back { -// return drive.skip_code_from(1); -// } - -// let offset = drive -// .state -// .string -// .back_offset(drive.ctx.string_offset, back); -// let position = drive.ctx.string_position - back; - -// drive.state.string_position = position; - -// let next_ctx = drive.next_ctx(3, |drive, _| { -// if drive.popped_ctx().has_matched == Some(true) { -// drive.failure(); -// } else { -// drive.ctx.handler = None; -// drive.skip_code_from(1); -// } -// }); -// next_ctx.string_position = position; -// next_ctx.string_offset = offset; -// next_ctx.toplevel = false; -// } - -// #[derive(Debug)] -// struct BranchContext { -// branch_offset: usize, -// } - -// // alternation -// // <0=skip> code ... -// fn op_branch(drive: &mut StateContext, stacks: &mut Stacks) { -// drive.state.marks_push(); -// stacks.branch.push(BranchContext { branch_offset: 1 }); -// create_context(drive, stacks); - -// fn create_context(drive: &mut StateContext, stacks: &mut Stacks) { -// let branch_offset = stacks.branch_last().branch_offset; -// let next_length = drive.peek_code(branch_offset) as usize; -// if next_length == 0 { -// drive.state.marks_pop_discard(); -// stacks.branch.pop(); -// return drive.failure(); -// } - -// drive.sync_string_position(); - -// stacks.branch_last().branch_offset += next_length; -// drive.next_ctx(branch_offset + 1, callback); -// } - -// fn callback(drive: &mut StateContext, stacks: &mut Stacks) { -// if drive.popped_ctx().has_matched == Some(true) { -// stacks.branch.pop(); -// return drive.success(); -// } -// drive.state.marks_pop_keep(); -// drive.ctx.handler = Some(create_context) -// } -// } - -// #[derive(Debug, Copy, Clone)] -// struct MinRepeatOneContext { -// count: usize, -// max_count: usize, -// } - -// /* <1=min> <2=max> item tail */ -// fn op_min_repeat_one(drive: &mut StateContext, stacks: &mut Stacks) { -// let min_count = drive.peek_code(2) as usize; -// let max_count = drive.peek_code(3) as usize; - -// if drive.remaining_chars() < min_count { -// return drive.failure(); -// } - -// drive.sync_string_position(); - -// let count = if min_count == 0 { -// 0 -// } else { -// let count = _count(drive, stacks, min_count); -// if count < min_count { -// return drive.failure(); -// } -// drive.skip_char(count); -// count -// }; - -// let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); -// if next_code == SreOpcode::SUCCESS as u32 && drive.can_success() { -// // tail is empty. we're finished -// drive.sync_string_position(); -// return drive.success(); -// } - -// drive.state.marks_push(); -// stacks -// .min_repeat_one -// .push(MinRepeatOneContext { count, max_count }); -// create_context(drive, stacks); - -// fn create_context(drive: &mut StateContext, stacks: &mut Stacks) { -// let MinRepeatOneContext { count, max_count } = *stacks.min_repeat_one_last(); - -// if max_count == MAXREPEAT || count <= max_count { -// drive.sync_string_position(); -// drive.next_ctx_from(1, callback); -// } else { -// drive.state.marks_pop_discard(); -// stacks.min_repeat_one.pop(); -// drive.failure(); -// } -// } - -// fn callback(drive: &mut StateContext, stacks: &mut Stacks) { -// if drive.popped_ctx().has_matched == Some(true) { -// stacks.min_repeat_one.pop(); -// return drive.success(); -// } - -// drive.sync_string_position(); - -// if _count(drive, stacks, 1) == 0 { -// drive.state.marks_pop_discard(); -// stacks.min_repeat_one.pop(); -// return drive.failure(); -// } - -// drive.skip_char(1); -// stacks.min_repeat_one_last().count += 1; -// drive.state.marks_pop_keep(); -// create_context(drive, stacks); -// } -// } - -// #[derive(Debug, Copy, Clone)] -// struct RepeatOneContext { -// count: usize, -// min_count: usize, -// following_literal: Option, -// } - -// /* match repeated sequence (maximizing regexp) */ - -// /* this operator only works if the repeated item is -// exactly one character wide, and we're not already -// collecting backtracking points. for other cases, -// use the MAX_REPEAT operator */ - -// /* <1=min> <2=max> item tail */ -// fn op_repeat_one(drive: &mut StateContext, stacks: &mut Stacks) { -// let min_count = drive.peek_code(2) as usize; -// let max_count = drive.peek_code(3) as usize; - -// if drive.remaining_chars() < min_count { -// return drive.failure(); -// } - -// drive.sync_string_position(); - -// let count = _count(drive, stacks, max_count); -// drive.skip_char(count); -// if count < min_count { -// return drive.failure(); -// } - -// let next_code = drive.peek_code(drive.peek_code(1) as usize + 1); -// if next_code == SreOpcode::SUCCESS as u32 && drive.can_success() { -// // tail is empty. we're finished -// drive.sync_string_position(); -// return drive.success(); -// } - -// // Special case: Tail starts with a literal. Skip positions where -// // the rest of the pattern cannot possibly match. -// let following_literal = (next_code == SreOpcode::LITERAL as u32) -// .then(|| drive.peek_code(drive.peek_code(1) as usize + 2)); - -// drive.state.marks_push(); -// stacks.repeat_one.push(RepeatOneContext { -// count, -// min_count, -// following_literal, -// }); -// create_context(drive, stacks); - -// fn create_context(drive: &mut StateContext, stacks: &mut Stacks) { -// let RepeatOneContext { -// mut count, -// min_count, -// following_literal, -// } = *stacks.repeat_one_last(); - -// if let Some(c) = following_literal { -// while drive.at_end() || drive.peek_char() != c { -// if count <= min_count { -// drive.state.marks_pop_discard(); -// stacks.repeat_one.pop(); -// return drive.failure(); -// } -// drive.back_skip_char(1); -// count -= 1; -// } -// } -// stacks.repeat_one_last().count = count; - -// drive.sync_string_position(); - -// // General case: backtracking -// drive.next_ctx_from(1, callback); -// } - -// fn callback(drive: &mut StateContext, stacks: &mut Stacks) { -// if drive.popped_ctx().has_matched == Some(true) { -// stacks.repeat_one.pop(); -// return drive.success(); -// } - -// let RepeatOneContext { -// count, -// min_count, -// following_literal: _, -// } = stacks.repeat_one_last(); - -// if count <= min_count { -// drive.state.marks_pop_discard(); -// stacks.repeat_one.pop(); -// return drive.failure(); -// } - -// drive.back_skip_char(1); -// *count -= 1; - -// drive.state.marks_pop_keep(); -// create_context(drive, stacks); -// } -// } - -// #[derive(Debug, Clone, Copy)] -// struct RepeatContext { -// count: isize, -// min_count: usize, -// max_count: usize, -// code_position: usize, -// last_position: usize, -// prev_id: usize, -// } - -// /* create repeat context. all the hard work is done -// by the UNTIL operator (MAX_UNTIL, MIN_UNTIL) */ -// /* <1=min> <2=max> item tail */ -// fn op_repeat(drive: &mut StateContext, stacks: &mut Stacks) { -// let repeat_ctx = RepeatContext { -// count: -1, -// min_count: drive.peek_code(2) as usize, -// max_count: drive.peek_code(3) as usize, -// code_position: drive.ctx.code_position, -// last_position: std::usize::MAX, -// prev_id: drive.ctx.repeat_ctx_id, -// }; - -// stacks.repeat.push(repeat_ctx); - -// drive.sync_string_position(); - -// let next_ctx = drive.next_ctx_from(1, |drive, stacks| { -// drive.ctx.has_matched = drive.popped_ctx().has_matched; -// stacks.repeat.pop(); -// }); -// next_ctx.repeat_ctx_id = stacks.repeat.len() - 1; -// } - -// #[derive(Debug, Clone, Copy)] -// struct MinUntilContext { -// save_repeat_ctx_id: usize, -// } - -// /* minimizing repeat */ -// fn op_min_until(drive: &mut StateContext, stacks: &mut Stacks) { -// let repeat_ctx = stacks.repeat.last_mut().unwrap(); - -// drive.sync_string_position(); - -// repeat_ctx.count += 1; - -// if (repeat_ctx.count as usize) < repeat_ctx.min_count { -// // not enough matches -// drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { -// if drive.popped_ctx().has_matched == Some(true) { -// drive.success(); -// } else { -// stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; -// drive.sync_string_position(); -// drive.failure(); -// } -// }); -// return; -// } - -// drive.state.marks_push(); - -// stacks.min_until.push(MinUntilContext { -// save_repeat_ctx_id: drive.ctx.repeat_ctx_id, -// }); - -// // see if the tail matches -// let next_ctx = drive.next_ctx(1, |drive, stacks| { -// drive.ctx.repeat_ctx_id = stacks.min_until.pop().unwrap().save_repeat_ctx_id; - -// let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; - -// if drive.popped_ctx().has_matched == Some(true) { -// return drive.success(); -// } - -// drive.sync_string_position(); - -// drive.state.marks_pop(); - -// // match more until tail matches - -// if repeat_ctx.count as usize >= repeat_ctx.max_count && repeat_ctx.max_count != MAXREPEAT -// || drive.state.string_position == repeat_ctx.last_position -// { -// repeat_ctx.count -= 1; -// return drive.failure(); -// } - -// /* zero-width match protection */ -// repeat_ctx.last_position = drive.state.string_position; - -// drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { -// if drive.popped_ctx().has_matched == Some(true) { -// drive.success(); -// } else { -// stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; -// drive.sync_string_position(); -// drive.failure(); -// } -// }); -// }); -// next_ctx.repeat_ctx_id = repeat_ctx.prev_id; -// } - -// #[derive(Debug, Clone, Copy)] -// struct MaxUntilContext { -// save_last_position: usize, -// } - -// /* maximizing repeat */ -// fn op_max_until(drive: &mut StateContext, stacks: &mut Stacks) { -// let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; - -// drive.sync_string_position(); - -// repeat_ctx.count += 1; - -// if (repeat_ctx.count as usize) < repeat_ctx.min_count { -// // not enough matches -// drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { -// if drive.popped_ctx().has_matched == Some(true) { -// drive.success(); -// } else { -// stacks.repeat[drive.ctx.repeat_ctx_id].count -= 1; -// drive.sync_string_position(); -// drive.failure(); -// } -// }); -// return; -// } - -// stacks.max_until.push(MaxUntilContext { -// save_last_position: repeat_ctx.last_position, -// }); - -// if ((repeat_ctx.count as usize) < repeat_ctx.max_count || repeat_ctx.max_count == MAXREPEAT) -// && drive.state.string_position != repeat_ctx.last_position -// { -// /* we may have enough matches, but if we can -// match another item, do so */ -// repeat_ctx.last_position = drive.state.string_position; - -// drive.state.marks_push(); - -// drive.next_ctx_at(repeat_ctx.code_position + 4, |drive, stacks| { -// let save_last_position = stacks.max_until.pop().unwrap().save_last_position; -// let repeat_ctx = &mut stacks.repeat[drive.ctx.repeat_ctx_id]; -// repeat_ctx.last_position = save_last_position; - -// if drive.popped_ctx().has_matched == Some(true) { -// drive.state.marks_pop_discard(); -// return drive.success(); -// } - -// drive.state.marks_pop(); -// repeat_ctx.count -= 1; -// drive.sync_string_position(); - -// /* cannot match more repeated items here. make sure the -// tail matches */ -// let next_ctx = drive.next_ctx(1, tail_callback); -// next_ctx.repeat_ctx_id = repeat_ctx.prev_id; -// }); -// return; -// } - -// /* cannot match more repeated items here. make sure the -// tail matches */ -// let next_ctx = drive.next_ctx(1, tail_callback); -// next_ctx.repeat_ctx_id = repeat_ctx.prev_id; - -// fn tail_callback(drive: &mut StateContext, _stacks: &mut Stacks) { -// if drive.popped_ctx().has_matched == Some(true) { -// drive.success(); -// } else { -// drive.sync_string_position(); -// drive.failure(); -// } -// } -// } - -// #[derive(Debug, Default)] -// struct Stacks { -// } - -// impl Stacks { -// fn clear(&mut self) { -// self.branch.clear(); -// self.min_repeat_one.clear(); -// self.repeat_one.clear(); -// self.repeat.clear(); -// self.min_until.clear(); -// self.max_until.clear(); -// } - -// fn branch_last(&mut self) -> &mut BranchContext { -// self.branch.last_mut().unwrap() -// } -// fn min_repeat_one_last(&mut self) -> &mut MinRepeatOneContext { -// self.min_repeat_one.last_mut().unwrap() -// } -// fn repeat_one_last(&mut self) -> &mut RepeatOneContext { -// self.repeat_one.last_mut().unwrap() -// } -// } - -pub trait StrDrive { +pub trait StrDrive: Copy { fn offset(&self, offset: usize, skip: usize) -> usize; fn count(&self) -> usize; fn peek(&self, offset: usize) -> u32; @@ -855,7 +714,6 @@ pub trait StrDrive { fn back_offset(&self, offset: usize, skip: usize) -> usize; } - impl<'a> StrDrive for &'a str { fn offset(&self, offset: usize, skip: usize) -> usize { self.get(offset..) @@ -934,6 +792,7 @@ struct MatchContext<'a, S: StrDrive> { toplevel: bool, handler: Option, &mut Self)>, repeat_ctx_id: usize, + count: isize, } impl<'a, S: StrDrive> std::fmt::Debug for MatchContext<'a, S> { @@ -945,182 +804,188 @@ impl<'a, S: StrDrive> std::fmt::Debug for MatchContext<'a, S> { .field("has_matched", &self.has_matched) .field("toplevel", &self.toplevel) .field("handler", &self.handler.map(|x| x as usize)) + .field("count", &self.count) .finish() } } impl<'a, S: StrDrive> MatchContext<'a, S> { + fn pattern(&self, state: &State<'a, S>) -> &[u32] { + &state.pattern_codes[self.code_position..] + } + fn remaining_codes(&self, state: &State<'a, S>) -> usize { state.pattern_codes.len() - self.code_position } - + + fn remaining_chars(&self, state: &State<'a, S>) -> usize { + state.end - self.string_position + } + + fn peek_char(&self, state: &State<'a, S>) -> u32 { + state.string.peek(self.string_offset) + } + + fn skip_char(&mut self, state: &State<'a, S>, skip: usize) { + self.string_position += skip; + self.string_offset = state.string.offset(self.string_offset, skip); + } + + fn back_peek_char(&self, state: &State<'a, S>) -> u32 { + state.string.back_peek(self.string_offset) + } + + fn back_skip_char(&mut self, string: &S, skip: usize) { + self.string_position -= skip; + self.string_offset = string.back_offset(self.string_offset, skip); + } + fn peek_code(&self, state: &State<'a, S>, peek: usize) -> u32 { state.pattern_codes[self.code_position + peek] } + fn skip_code(&mut self, skip: usize) { + self.code_position += skip; + } + + fn skip_code_from(&mut self, state: &State<'a, S>, peek: usize) { + self.skip_code(self.peek_code(state, peek) as usize + 1); + } + + fn at_beginning(&self) -> bool { + // self.ctx().string_position == self.state().start + self.string_position == 0 + } + + fn at_end(&self, state: &State<'a, S>) -> bool { + self.string_position == state.end + } + + fn at_linebreak(&self, state: &State<'a, S>) -> bool { + !self.at_end(state) && is_linebreak(self.peek_char(state)) + } + + fn at_boundary bool>( + &self, + state: &State<'a, S>, + mut word_checker: F, + ) -> bool { + if self.at_beginning() && self.at_end(state) { + return false; + } + let that = !self.at_beginning() && word_checker(self.back_peek_char(state)); + let this = !self.at_end(state) && word_checker(self.peek_char(state)); + this != that + } + + fn at_non_boundary bool>( + &self, + state: &State<'a, S>, + mut word_checker: F, + ) -> bool { + if self.at_beginning() && self.at_end(state) { + return false; + } + let that = !self.at_beginning() && word_checker(self.back_peek_char(state)); + let this = !self.at_end(state) && word_checker(self.peek_char(state)); + this == that + } + + fn can_success(&self, state: &State<'a, S>) -> bool { + if !self.toplevel { + return true; + } + if state.match_all && !self.at_end(state) { + return false; + } + if state.must_advance && self.string_position == state.start { + return false; + } + true + } + + fn success(&mut self) { + self.has_matched = Some(true); + } + fn failure(&mut self) { self.has_matched = Some(false); } } -// trait ContextDrive<'a, T: StrDrive<'a>> { -// fn ctx(&self) -> &MatchContext; -// fn ctx_mut(&mut self) -> &mut MatchContext; -// fn state(&self) -> &State<'a, T>; - -// fn popped_ctx(&self) -> &MatchContext { -// self.state().popped_context.as_ref().unwrap() -// } - -// fn pattern(&self) -> &[u32] { -// &self.state().pattern_codes[self.ctx().code_position..] -// } - -// fn peek_char(&self) -> u32 { -// self.state().string.peek(self.ctx().string_offset) -// } -// fn peek_code(&self, peek: usize) -> u32 { -// self.state().pattern_codes[self.ctx().code_position + peek] -// } - -// fn back_peek_char(&self) -> u32 { -// self.state().string.back_peek(self.ctx().string_offset) -// } -// fn back_skip_char(&mut self, skip_count: usize) { -// self.ctx_mut().string_position -= skip_count; -// self.ctx_mut().string_offset = self -// .state() -// .string -// .back_offset(self.ctx().string_offset, skip_count); -// } - -// fn skip_char(&mut self, skip_count: usize) { -// self.ctx_mut().string_offset = self -// .state() -// .string -// .offset(self.ctx().string_offset, skip_count); -// self.ctx_mut().string_position += skip_count; -// } -// fn skip_code(&mut self, skip_count: usize) { -// self.ctx_mut().code_position += skip_count; -// } -// fn skip_code_from(&mut self, peek: usize) { -// self.skip_code(self.peek_code(peek) as usize + 1); -// } - -// fn remaining_chars(&self) -> usize { -// self.state().end - self.ctx().string_position -// } -// fn remaining_codes(&self) -> usize { -// self.state().pattern_codes.len() - self.ctx().code_position -// } - -// fn at_beginning(&self) -> bool { -// // self.ctx().string_position == self.state().start -// self.ctx().string_position == 0 -// } -// fn at_end(&self) -> bool { -// self.ctx().string_position == self.state().end -// } -// fn at_linebreak(&self) -> bool { -// !self.at_end() && is_linebreak(self.peek_char()) -// } -// fn at_boundary bool>(&self, mut word_checker: F) -> bool { -// if self.at_beginning() && self.at_end() { -// return false; -// } -// let that = !self.at_beginning() && word_checker(self.back_peek_char()); -// let this = !self.at_end() && word_checker(self.peek_char()); -// this != that -// } -// fn at_non_boundary bool>(&self, mut word_checker: F) -> bool { -// if self.at_beginning() && self.at_end() { -// return false; -// } -// let that = !self.at_beginning() && word_checker(self.back_peek_char()); -// let this = !self.at_end() && word_checker(self.peek_char()); -// this == that -// } - -// fn can_success(&self) -> bool { -// if !self.ctx().toplevel { -// return true; -// } -// if self.state().match_all && !self.at_end() { -// return false; -// } -// if self.state().must_advance && self.ctx().string_position == self.state().start { -// return false; -// } -// true -// } - -// fn success(&mut self) { -// self.ctx_mut().has_matched = Some(true); -// } - -// fn failure(&mut self) { -// self.ctx_mut().has_matched = Some(false); -// } -// } - -// struct StateContext<'a, S: StrDrive<'a>> { -// state: State<'a, S>, -// ctx: MatchContext, -// next_ctx: Option, -// } - -// impl<'a, S: StrDrive<'a>> ContextDrive<'a, S> for StateContext<'a, S> { -// fn ctx(&self) -> &MatchContext { -// &self.ctx -// } -// fn ctx_mut(&mut self) -> &mut MatchContext { -// &mut self.ctx -// } -// fn state(&self) -> &State<'a, S> { -// &self.state -// } -// } - -// impl StateContext<'_> { -// fn next_ctx_from(&mut self, peek: usize, handler: OpcodeHandler) -> &mut MatchContext { -// self.next_ctx(self.peek_code(peek) as usize + 1, handler) -// } -// fn next_ctx(&mut self, offset: usize, handler: OpcodeHandler) -> &mut MatchContext { -// self.next_ctx_at(self.ctx.code_position + offset, handler) -// } -// fn next_ctx_at(&mut self, code_position: usize, handler: OpcodeHandler) -> &mut MatchContext { -// self.next_ctx = Some(MatchContext { -// code_position, -// has_matched: None, -// handler: None, -// ..self.ctx -// }); -// self.ctx.handler = Some(handler); -// self.next_ctx.as_mut().unwrap() -// } - -// fn sync_string_position(&mut self) { -// self.state.string_position = self.ctx.string_position; -// } -// } - -// struct StateRefContext<'a> { -// entity: &'a StateContext<'a>, -// ctx: MatchContext, -// } - -// impl ContextDrive for StateRefContext<'_> { -// fn ctx(&self) -> &MatchContext { -// &self.ctx -// } -// fn ctx_mut(&mut self) -> &mut MatchContext { -// &mut self.ctx -// } -// fn state(&self) -> &State { -// &self.entity.state -// } -// } +fn at<'a, S: StrDrive>(state: &State<'a, S>, ctx: &MatchContext<'a, S>, atcode: SreAtCode) -> bool { + match atcode { + SreAtCode::BEGINNING | SreAtCode::BEGINNING_STRING => ctx.at_beginning(), + SreAtCode::BEGINNING_LINE => ctx.at_beginning() || is_linebreak(ctx.back_peek_char(state)), + SreAtCode::BOUNDARY => ctx.at_boundary(state, is_word), + SreAtCode::NON_BOUNDARY => ctx.at_non_boundary(state, is_word), + SreAtCode::END => { + (ctx.remaining_chars(state) == 1 && ctx.at_linebreak(state)) || ctx.at_end(state) + } + SreAtCode::END_LINE => ctx.at_linebreak(state) || ctx.at_end(state), + SreAtCode::END_STRING => ctx.at_end(state), + SreAtCode::LOC_BOUNDARY => ctx.at_boundary(state, is_loc_word), + SreAtCode::LOC_NON_BOUNDARY => ctx.at_non_boundary(state, is_loc_word), + SreAtCode::UNI_BOUNDARY => ctx.at_boundary(state, is_uni_word), + SreAtCode::UNI_NON_BOUNDARY => ctx.at_non_boundary(state, is_uni_word), + } +} + +fn general_op_literal<'a, S: StrDrive, F: FnOnce(u32, u32) -> bool>( + state: &State<'a, S>, + ctx: &mut MatchContext<'a, S>, + f: F, +) { + if ctx.at_end(state) || !f(ctx.peek_code(state, 1), ctx.peek_char(state)) { + ctx.failure(); + } else { + ctx.skip_code(2); + ctx.skip_char(state, 1); + } +} + +fn general_op_in<'a, S: StrDrive, F: FnOnce(&[u32], u32) -> bool>( + state: &State<'a, S>, + ctx: &mut MatchContext<'a, S>, + f: F, +) { + if ctx.at_end(state) || !f(&ctx.pattern(state)[2..], ctx.peek_char(state)) { + ctx.failure(); + } else { + ctx.skip_code_from(state, 1); + ctx.skip_char(state, 1); + } +} + +fn general_op_groupref<'a, S: StrDrive, F: FnMut(u32) -> u32>( + state: &State<'a, S>, + ctx: &mut MatchContext<'a, S>, + mut f: F, +) { + let (group_start, group_end) = state.get_marks(ctx.peek_code(state, 1) as usize); + let (group_start, group_end) = match (group_start, group_end) { + (Some(start), Some(end)) if start <= end => (start, end), + _ => { + return ctx.failure(); + } + }; + + let mut gctx = MatchContext { + string_position: group_start, + string_offset: state.string.offset(0, group_start), + ..*ctx + }; + + for _ in group_start..group_end { + if ctx.at_end(state) || f(ctx.peek_char(state)) != f(gctx.peek_char(state)) { + return ctx.failure(); + } + ctx.skip_char(state, 1); + gctx.skip_char(state, 1); + } + + ctx.skip_code(2); +} fn char_loc_ignore(code: u32, c: u32) -> bool { code == c || code == lower_locate(c) || code == upper_locate(c) @@ -1135,78 +1000,6 @@ fn charset_loc_ignore(set: &[u32], c: u32) -> bool { up != lo && charset(set, up) } -// fn general_op_groupref u32>(drive: &mut StateContext, mut f: F) { -// let (group_start, group_end) = drive.state.get_marks(drive.peek_code(1) as usize); -// let (group_start, group_end) = match (group_start, group_end) { -// (Some(start), Some(end)) if start <= end => (start, end), -// _ => { -// return drive.failure(); -// } -// }; - -// let mut wdrive = StateRefContext { -// entity: drive, -// ctx: drive.ctx, -// }; -// let mut gdrive = StateRefContext { -// entity: drive, -// ctx: MatchContext { -// string_position: group_start, -// // TODO: cache the offset -// string_offset: drive.state.string.offset(0, group_start), -// ..drive.ctx -// }, -// }; - -// for _ in group_start..group_end { -// if wdrive.at_end() || f(wdrive.peek_char()) != f(gdrive.peek_char()) { -// return drive.failure(); -// } -// wdrive.skip_char(1); -// gdrive.skip_char(1); -// } - -// let position = wdrive.ctx.string_position; -// let offset = wdrive.ctx.string_offset; -// drive.skip_code(2); -// drive.ctx.string_position = position; -// drive.ctx.string_offset = offset; -// } - -// fn general_op_literal bool>(drive: &mut StateContext, f: F) { -// if drive.at_end() || !f(drive.peek_code(1), drive.peek_char()) { -// drive.failure(); -// } else { -// drive.skip_code(2); -// drive.skip_char(1); -// } -// } - -// fn general_op_in bool>(drive: &mut StateContext, f: F) { -// if drive.at_end() || !f(&drive.pattern()[2..], drive.peek_char()) { -// drive.failure(); -// } else { -// drive.skip_code_from(1); -// drive.skip_char(1); -// } -// } - -// fn at(drive: &StateContext, atcode: SreAtCode) -> bool { -// match atcode { -// SreAtCode::BEGINNING | SreAtCode::BEGINNING_STRING => drive.at_beginning(), -// SreAtCode::BEGINNING_LINE => drive.at_beginning() || is_linebreak(drive.back_peek_char()), -// SreAtCode::BOUNDARY => drive.at_boundary(is_word), -// SreAtCode::NON_BOUNDARY => drive.at_non_boundary(is_word), -// SreAtCode::END => (drive.remaining_chars() == 1 && drive.at_linebreak()) || drive.at_end(), -// SreAtCode::END_LINE => drive.at_linebreak() || drive.at_end(), -// SreAtCode::END_STRING => drive.at_end(), -// SreAtCode::LOC_BOUNDARY => drive.at_boundary(is_loc_word), -// SreAtCode::LOC_NON_BOUNDARY => drive.at_non_boundary(is_loc_word), -// SreAtCode::UNI_BOUNDARY => drive.at_boundary(is_uni_word), -// SreAtCode::UNI_NON_BOUNDARY => drive.at_non_boundary(is_uni_word), -// } -// } - fn category(catcode: SreCatCode, c: u32) -> bool { match catcode { SreCatCode::DIGIT => is_digit(c), @@ -1323,95 +1116,100 @@ fn charset(set: &[u32], ch: u32) -> bool { false } -// /* General case */ -// fn general_count(drive: &mut StateContext, stacks: &mut Stacks, max_count: usize) -> usize { -// let mut count = 0; -// let max_count = std::cmp::min(max_count, drive.remaining_chars()); - -// let save_ctx = drive.ctx; -// drive.skip_code(4); -// let reset_position = drive.ctx.code_position; - -// while count < max_count { -// drive.ctx.code_position = reset_position; -// let code = drive.peek_code(0); -// let code = SreOpcode::try_from(code).unwrap(); -// dispatch(code, drive, stacks); -// if drive.ctx.has_matched == Some(false) { -// break; -// } -// count += 1; -// } -// drive.ctx = save_ctx; -// count -// } - -// fn _count(drive: &mut StateContext, stacks: &mut Stacks, max_count: usize) -> usize { -// let save_ctx = drive.ctx; -// let max_count = std::cmp::min(max_count, drive.remaining_chars()); -// let end = drive.ctx.string_position + max_count; -// let opcode = SreOpcode::try_from(drive.peek_code(0)).unwrap(); - -// match opcode { -// SreOpcode::ANY => { -// while !drive.ctx.string_position < end && !drive.at_linebreak() { -// drive.skip_char(1); -// } -// } -// SreOpcode::ANY_ALL => { -// drive.skip_char(max_count); -// } -// SreOpcode::IN => { -// while !drive.ctx.string_position < end -// && charset(&drive.pattern()[2..], drive.peek_char()) -// { -// drive.skip_char(1); -// } -// } -// SreOpcode::LITERAL => { -// general_count_literal(drive, end, |code, c| code == c as u32); -// } -// SreOpcode::NOT_LITERAL => { -// general_count_literal(drive, end, |code, c| code != c as u32); -// } -// SreOpcode::LITERAL_IGNORE => { -// general_count_literal(drive, end, |code, c| code == lower_ascii(c) as u32); -// } -// SreOpcode::NOT_LITERAL_IGNORE => { -// general_count_literal(drive, end, |code, c| code != lower_ascii(c) as u32); -// } -// SreOpcode::LITERAL_LOC_IGNORE => { -// general_count_literal(drive, end, char_loc_ignore); -// } -// SreOpcode::NOT_LITERAL_LOC_IGNORE => { -// general_count_literal(drive, end, |code, c| !char_loc_ignore(code, c)); -// } -// SreOpcode::LITERAL_UNI_IGNORE => { -// general_count_literal(drive, end, |code, c| code == lower_unicode(c) as u32); -// } -// SreOpcode::NOT_LITERAL_UNI_IGNORE => { -// general_count_literal(drive, end, |code, c| code != lower_unicode(c) as u32); -// } -// _ => { -// return general_count(drive, stacks, max_count); -// } -// } - -// let count = drive.ctx.string_position - drive.state.string_position; -// drive.ctx = save_ctx; -// count -// } - -// fn general_count_literal bool>( -// drive: &mut StateContext, -// end: usize, -// mut f: F, -// ) { -// let ch = drive.peek_code(1); -// while !drive.ctx.string_position < end && f(ch, drive.peek_char()) { -// drive.skip_char(1); -// } -// } +fn _count<'a, S: StrDrive>( + state: &mut State<'a, S>, + ctx: &MatchContext<'a, S>, + max_count: usize, +) -> usize { + let mut ctx = *ctx; + let max_count = std::cmp::min(max_count, ctx.remaining_chars(state)); + let end = ctx.string_position + max_count; + let opcode = SreOpcode::try_from(ctx.peek_code(state, 0)).unwrap(); + + match opcode { + SreOpcode::ANY => { + while !ctx.string_position < end && !ctx.at_linebreak(state) { + ctx.skip_char(state, 1); + } + } + SreOpcode::ANY_ALL => { + ctx.skip_char(state, max_count); + } + SreOpcode::IN => { + while !ctx.string_position < end + && charset(&ctx.pattern(state)[2..], ctx.peek_char(state)) + { + ctx.skip_char(state, 1); + } + } + SreOpcode::LITERAL => { + general_count_literal(state, &mut ctx, end, |code, c| code == c as u32); + } + SreOpcode::NOT_LITERAL => { + general_count_literal(state, &mut ctx, end, |code, c| code != c as u32); + } + SreOpcode::LITERAL_IGNORE => { + general_count_literal(state, &mut ctx, end, |code, c| { + code == lower_ascii(c) as u32 + }); + } + SreOpcode::NOT_LITERAL_IGNORE => { + general_count_literal(state, &mut ctx, end, |code, c| { + code != lower_ascii(c) as u32 + }); + } + SreOpcode::LITERAL_LOC_IGNORE => { + general_count_literal(state, &mut ctx, end, char_loc_ignore); + } + SreOpcode::NOT_LITERAL_LOC_IGNORE => { + general_count_literal(state, &mut ctx, end, |code, c| !char_loc_ignore(code, c)); + } + SreOpcode::LITERAL_UNI_IGNORE => { + general_count_literal(state, &mut ctx, end, |code, c| { + code == lower_unicode(c) as u32 + }); + } + SreOpcode::NOT_LITERAL_UNI_IGNORE => { + general_count_literal(state, &mut ctx, end, |code, c| { + code != lower_unicode(c) as u32 + }); + } + _ => { + /* General case */ + let mut count = 0; + + ctx.skip_code(4); + let reset_position = ctx.code_position; + + while count < max_count { + ctx.code_position = reset_position; + let code = ctx.peek_code(state, 0); + let code = SreOpcode::try_from(code).unwrap(); + dispatch(state, &mut ctx, code); + if ctx.has_matched == Some(false) { + break; + } + count += 1; + } + return count; + } + } + + // TODO: return offset + ctx.string_position - state.string_position +} + +fn general_count_literal<'a, S: StrDrive, F: FnMut(u32, u32) -> bool>( + state: &State<'a, S>, + ctx: &mut MatchContext<'a, S>, + end: usize, + mut f: F, +) { + let ch = ctx.peek_code(state, 1); + while !ctx.string_position < end && f(ch, ctx.peek_char(state)) { + ctx.skip_char(state, 1); + } +} fn is_word(ch: u32) -> bool { ch == '_' as u32 diff --git a/tests/tests.rs b/tests/tests.rs index e8ae487029..cc5c4d1f38 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -7,12 +7,12 @@ struct Pattern { } impl Pattern { - fn state<'a>( + fn state<'a, S: engine::StrDrive>( &self, - string: impl Into>, + string: S, range: std::ops::Range, - ) -> engine::State<'a> { - engine::State::new(string.into(), range.start, range.end, self.flags, self.code) + ) -> engine::State<'a, S> { + engine::State::new(string, range.start, range.end, self.flags, self.code) } } @@ -23,7 +23,7 @@ fn test_2427() { #[rustfmt::skip] let lookbehind = Pattern { code: &[15, 4, 0, 1, 1, 5, 5, 1, 17, 46, 1, 17, 120, 6, 10, 1], flags: SreFlag::from_bits_truncate(32) }; // END GENERATED let mut state = lookbehind.state("x", 0..usize::MAX); - state = state.pymatch(); + state.pymatch(); assert!(state.has_matched); } @@ -34,7 +34,7 @@ fn test_assert() { #[rustfmt::skip] let positive_lookbehind = Pattern { code: &[15, 4, 0, 3, 3, 4, 9, 3, 17, 97, 17, 98, 17, 99, 1, 17, 100, 17, 101, 17, 102, 1], flags: SreFlag::from_bits_truncate(32) }; // END GENERATED let mut state = positive_lookbehind.state("abcdef", 0..usize::MAX); - state = state.search(); + state.search(); assert!(state.has_matched); } @@ -45,7 +45,7 @@ fn test_string_boundaries() { #[rustfmt::skip] let big_b = Pattern { code: &[15, 4, 0, 0, 0, 6, 11, 1], flags: SreFlag::from_bits_truncate(32) }; // END GENERATED let mut state = big_b.state("", 0..usize::MAX); - state = state.search(); + state.search(); assert!(!state.has_matched); } @@ -57,7 +57,7 @@ fn test_zerowidth() { // END GENERATED let mut state = p.state("a:", 0..usize::MAX); state.must_advance = true; - state = state.search(); + state.search(); assert!(state.string_position == 1); } @@ -68,7 +68,7 @@ fn test_repeat_context_panic() { #[rustfmt::skip] let p = Pattern { code: &[15, 4, 0, 0, 4294967295, 24, 25, 0, 4294967295, 27, 6, 0, 4294967295, 17, 97, 1, 24, 11, 0, 1, 18, 0, 17, 120, 17, 120, 18, 1, 20, 17, 122, 19, 1], flags: SreFlag::from_bits_truncate(32) }; // END GENERATED let mut state = p.state("axxzaz", 0..usize::MAX); - state = state.pymatch(); + state.pymatch(); assert!(state.marks == vec![Some(1), Some(3)]); } @@ -79,6 +79,6 @@ fn test_double_max_until() { #[rustfmt::skip] let p = Pattern { code: &[15, 4, 0, 0, 4294967295, 24, 18, 0, 4294967295, 18, 0, 24, 9, 0, 1, 18, 2, 17, 49, 18, 3, 19, 18, 1, 19, 1], flags: SreFlag::from_bits_truncate(32) }; // END GENERATED let mut state = p.state("1111", 0..usize::MAX); - state = state.pymatch(); + state.pymatch(); assert!(state.string_position == 4); } From 982d8f53f2eea5eb80d37a0d22c807aebd694445 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 2 Aug 2022 21:10:31 +0200 Subject: [PATCH 056/893] fix next_ctx --- src/engine.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index b0717e1671..64043cee53 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -35,14 +35,16 @@ macro_rules! next_ctx { (from $peek:expr, $state:expr, $ctx:expr, $handler:expr) => { next_ctx!(position $ctx.peek_code($state, $peek) as usize + 1, $state, $ctx, $handler) }; - (position $position:expr, $state:expr, $ctx:expr, $handler:expr) => { - {$state.next_context.insert(MatchContext { + (position $position:expr, $state:expr, $ctx:expr, $handler:expr) => {{ + $ctx.handler = Some($handler); + $state.next_context.insert(MatchContext { code_position: $position, has_matched: None, - handler: Some($handler), + handler: None, + count: -1, ..*$ctx - })} - }; + }) + }}; } macro_rules! mark { @@ -678,7 +680,7 @@ fn op_max_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex return ctx.success(); } - mark!(pop, state); + mark!(pop, state); repeat_ctx.count -= 1; state.string_position = ctx.string_position; From ccae898885496a1390a10fa6fca4cf394445dd35 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 2 Aug 2022 21:24:11 +0200 Subject: [PATCH 057/893] fix next_ctx bug --- src/engine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/engine.rs b/src/engine.rs index 64043cee53..4acc172bc6 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -33,7 +33,7 @@ macro_rules! next_ctx { next_ctx!(position $ctx.code_position + $offset, $state, $ctx, $handler) }; (from $peek:expr, $state:expr, $ctx:expr, $handler:expr) => { - next_ctx!(position $ctx.peek_code($state, $peek) as usize + 1, $state, $ctx, $handler) + next_ctx!(offset $ctx.peek_code($state, $peek) as usize + 1, $state, $ctx, $handler) }; (position $position:expr, $state:expr, $ctx:expr, $handler:expr) => {{ $ctx.handler = Some($handler); From f3b30443aab82095ed2ce786482309e659f9c107 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 5 Aug 2022 21:05:10 +0200 Subject: [PATCH 058/893] fix lifetime --- src/engine.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/engine.rs b/src/engine.rs index 4acc172bc6..810e011c95 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -716,7 +716,7 @@ pub trait StrDrive: Copy { fn back_offset(&self, offset: usize, skip: usize) -> usize; } -impl<'a> StrDrive for &'a str { +impl StrDrive for &str { fn offset(&self, offset: usize, skip: usize) -> usize { self.get(offset..) .and_then(|s| s.char_indices().nth(skip).map(|x| x.0 + offset)) From ca20b5951d7092cf0ec25198f01d3620a69e61e2 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 5 Aug 2022 21:08:14 +0200 Subject: [PATCH 059/893] update version to 0.3.0 --- Cargo.toml | 2 +- src/engine.rs | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 00123d92c5..98b632a4dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sre-engine" -version = "0.2.1" +version = "0.3.0" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" repository = "https://github.com/RustPython/sre-engine" diff --git a/src/engine.rs b/src/engine.rs index 810e011c95..1302667d1d 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -253,8 +253,6 @@ fn dispatch<'a, S: StrDrive>( ctx.skip_char(state, 1); } } - /* assert subpattern */ - /* */ SreOpcode::ASSERT => op_assert(state, ctx), SreOpcode::ASSERT_NOT => op_assert_not(state, ctx), SreOpcode::AT => { @@ -334,7 +332,6 @@ fn op_assert<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<' return ctx.failure(); } - // let next_ctx = state.next_ctx(ctx, 3, |state, ctx| { let next_ctx = next_ctx!(offset 3, state, ctx, |state, ctx| { if state.popped_has_matched { ctx.skip_code_from(state, 1); @@ -371,7 +368,6 @@ fn op_assert_not<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchConte // alternation // <0=skip> code ... fn op_branch<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { - // state.marks_push(); mark!(push, state); ctx.count = 1; @@ -403,7 +399,6 @@ fn op_branch<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<' /* <1=min> <2=max> item tail */ fn op_min_repeat_one<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { let min_count = ctx.peek_code(state, 2) as usize; - // let max_count = ctx.peek_code(state, 3) as usize; if ctx.remaining_chars(state) < min_count { return ctx.failure(); From a48f5b07c5671b690e262b935b81f0a7a832b566 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 5 Aug 2022 21:49:46 +0200 Subject: [PATCH 060/893] impl op_info --- src/engine.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/engine.rs b/src/engine.rs index 1302667d1d..fcade829f2 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -279,7 +279,15 @@ fn dispatch<'a, S: StrDrive>( general_op_in(state, ctx, |set, c| charset(set, lower_unicode(c))) } SreOpcode::IN_LOC_IGNORE => general_op_in(state, ctx, charset_loc_ignore), - SreOpcode::INFO | SreOpcode::JUMP => ctx.skip_code_from(state, 1), + SreOpcode::INFO => { + let min = ctx.peek_code(state, 3) as usize; + if ctx.remaining_chars(state) < min { + ctx.failure(); + } else { + ctx.skip_code_from(state, 1); + } + } + SreOpcode::JUMP => ctx.skip_code_from(state, 1), SreOpcode::LITERAL => general_op_literal(state, ctx, |code, c| code == c), SreOpcode::NOT_LITERAL => general_op_literal(state, ctx, |code, c| code != c), SreOpcode::LITERAL_IGNORE => { From 8b1fcea7ec27aa22f698d485ee203dbe2a552334 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 7 Aug 2022 08:10:52 +0200 Subject: [PATCH 061/893] update version to 0.3.1 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 98b632a4dc..4b403f2861 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sre-engine" -version = "0.3.0" +version = "0.3.1" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" repository = "https://github.com/RustPython/sre-engine" From c31462d51b7d3adbf8f121403c4e8b305a9dab6f Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 9 Aug 2022 16:29:51 +0200 Subject: [PATCH 062/893] refactor split State with Request --- src/engine.rs | 669 +++++++++++++++++++++++++++++++------------------- 1 file changed, 420 insertions(+), 249 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index fcade829f2..5e1f1457ec 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,6 +1,6 @@ // good luck to those that follow; here be dragons -use super::constants::{SreAtCode, SreCatCode, SreFlag, SreOpcode}; +use super::constants::{SreAtCode, SreCatCode, SreInfo, SreOpcode}; use super::MAXREPEAT; use std::convert::TryFrom; @@ -8,26 +8,37 @@ const fn is_py_ascii_whitespace(b: u8) -> bool { matches!(b, b'\t' | b'\n' | b'\x0C' | b'\r' | b' ' | b'\x0B') } -#[derive(Debug)] -pub struct State<'a, S: StrDrive> { +pub struct Request<'a, S: StrDrive> { pub string: S, pub start: usize, pub end: usize, - _flags: SreFlag, - pattern_codes: &'a [u32], - pub marks: Vec>, - pub lastindex: isize, - marks_stack: Vec<(Vec>, isize)>, - context_stack: Vec>, - repeat_stack: Vec, - pub string_position: usize, - next_context: Option>, - popped_has_matched: bool, - pub has_matched: bool, + pub pattern_codes: &'a [u32], pub match_all: bool, pub must_advance: bool, } +impl<'a, S: StrDrive> Request<'a, S> { + pub fn new( + string: S, + start: usize, + end: usize, + pattern_codes: &'a [u32], + match_all: bool, + ) -> Self { + let end = std::cmp::min(end, string.count()); + let start = std::cmp::min(start, end); + + Self { + string, + start, + end, + pattern_codes, + match_all, + must_advance: false, + } + } +} + macro_rules! next_ctx { (offset $offset:expr, $state:expr, $ctx:expr, $handler:expr) => { next_ctx!(position $ctx.code_position + $offset, $state, $ctx, $handler) @@ -60,43 +71,41 @@ macro_rules! mark { }; } +#[derive(Debug)] +pub struct State<'a, S: StrDrive> { + pub marks: Vec>, + pub lastindex: isize, + marks_stack: Vec<(Vec>, isize)>, + context_stack: Vec>, + repeat_stack: Vec, + pub string_position: usize, + next_context: Option>, + popped_has_matched: bool, + has_matched: bool, +} + impl<'a, S: StrDrive> State<'a, S> { - pub fn new( - string: S, - start: usize, - end: usize, - flags: SreFlag, - pattern_codes: &'a [u32], - ) -> Self { - let end = std::cmp::min(end, string.count()); - let start = std::cmp::min(start, end); + pub fn new(string_position: usize) -> Self { Self { - string, - start, - end, - _flags: flags, - pattern_codes, marks: Vec::new(), lastindex: -1, marks_stack: Vec::new(), context_stack: Vec::new(), repeat_stack: Vec::new(), - string_position: start, + string_position, next_context: None, popped_has_matched: false, has_matched: false, - match_all: false, - must_advance: false, } } - pub fn reset(&mut self) { + pub fn reset(&mut self, string_position: usize) { self.lastindex = -1; self.marks.clear(); self.marks_stack.clear(); self.context_stack.clear(); self.repeat_stack.clear(); - self.string_position = self.start; + self.string_position = string_position; self.next_context = None; self.popped_has_matched = false; self.has_matched = false; @@ -136,14 +145,14 @@ impl<'a, S: StrDrive> State<'a, S> { self.marks_stack.pop(); } - fn _match(&mut self) { + fn _match(&mut self, req: &mut Request<'a, S>) { while let Some(mut ctx) = self.context_stack.pop() { if let Some(handler) = ctx.handler.take() { - handler(self, &mut ctx); - } else if ctx.remaining_codes(self) > 0 { - let code = ctx.peek_code(self, 0); + handler(req, self, &mut ctx); + } else if ctx.remaining_codes(req) > 0 { + let code = ctx.peek_code(req, 0); let code = SreOpcode::try_from(code).unwrap(); - dispatch(self, &mut ctx, code); + dispatch(req, self, &mut ctx, code); } else { ctx.failure(); } @@ -160,10 +169,10 @@ impl<'a, S: StrDrive> State<'a, S> { self.has_matched = self.popped_has_matched; } - pub fn pymatch(&mut self) { + pub fn pymatch(&mut self, req: &mut Request<'a, S>) { let ctx = MatchContext { - string_position: self.start, - string_offset: self.string.offset(0, self.start), + string_position: req.start, + string_offset: req.string.offset(0, req.start), code_position: 0, has_matched: None, toplevel: true, @@ -173,20 +182,22 @@ impl<'a, S: StrDrive> State<'a, S> { }; self.context_stack.push(ctx); - self._match(); + self._match(req); } - pub fn search(&mut self) { + pub fn search(&mut self, req: &mut Request<'a, S>) { // TODO: optimize by op info and skip prefix - - if self.start > self.end { + if req.start > req.end { return; } - let mut start_offset = self.string.offset(0, self.start); + // let start = self.start; + // let end = self.end; - let ctx = MatchContext { - string_position: self.start, + let mut start_offset = req.string.offset(0, req.start); + + let mut ctx = MatchContext { + string_position: req.start, string_offset: start_offset, code_position: 0, has_matched: None, @@ -195,17 +206,26 @@ impl<'a, S: StrDrive> State<'a, S> { repeat_ctx_id: usize::MAX, count: -1, }; + + // if ctx.peek_code(self, 0) == SreOpcode::INFO as u32 { + // search_op_info(self, &mut ctx); + // if let Some(has_matched) = ctx.has_matched { + // self.has_matched = has_matched; + // return; + // } + // } + self.context_stack.push(ctx); - self._match(); + self._match(req); - self.must_advance = false; - while !self.has_matched && self.start < self.end { - self.start += 1; - start_offset = self.string.offset(start_offset, 1); - self.reset(); + req.must_advance = false; + while !self.has_matched && req.start < req.end { + req.start += 1; + start_offset = req.string.offset(start_offset, 1); + self.reset(req.start); let ctx = MatchContext { - string_position: self.start, + string_position: req.start, string_offset: start_offset, code_position: 0, has_matched: None, @@ -215,12 +235,13 @@ impl<'a, S: StrDrive> State<'a, S> { count: -1, }; self.context_stack.push(ctx); - self._match(); + self._match(req); } } } fn dispatch<'a, S: StrDrive>( + req: &Request<'a, S>, state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>, opcode: SreOpcode, @@ -230,7 +251,7 @@ fn dispatch<'a, S: StrDrive>( ctx.failure(); } SreOpcode::SUCCESS => { - if ctx.can_success(state) { + if ctx.can_success(req) { state.string_position = ctx.string_position; ctx.success(); } else { @@ -238,152 +259,224 @@ fn dispatch<'a, S: StrDrive>( } } SreOpcode::ANY => { - if ctx.at_end(state) || ctx.at_linebreak(state) { + if ctx.at_end(req) || ctx.at_linebreak(req) { ctx.failure(); } else { ctx.skip_code(1); - ctx.skip_char(state, 1); + ctx.skip_char(req, 1); } } SreOpcode::ANY_ALL => { - if ctx.at_end(state) { + if ctx.at_end(req) { ctx.failure(); } else { ctx.skip_code(1); - ctx.skip_char(state, 1); + ctx.skip_char(req, 1); } } - SreOpcode::ASSERT => op_assert(state, ctx), - SreOpcode::ASSERT_NOT => op_assert_not(state, ctx), + SreOpcode::ASSERT => op_assert(req, state, ctx), + SreOpcode::ASSERT_NOT => op_assert_not(req, state, ctx), SreOpcode::AT => { - let atcode = SreAtCode::try_from(ctx.peek_code(state, 1)).unwrap(); - if at(state, ctx, atcode) { + let atcode = SreAtCode::try_from(ctx.peek_code(req, 1)).unwrap(); + if at(req, ctx, atcode) { ctx.skip_code(2); } else { ctx.failure(); } } - SreOpcode::BRANCH => op_branch(state, ctx), + SreOpcode::BRANCH => op_branch(req, state, ctx), SreOpcode::CATEGORY => { - let catcode = SreCatCode::try_from(ctx.peek_code(state, 1)).unwrap(); - if ctx.at_end(state) || !category(catcode, ctx.peek_char(state)) { + let catcode = SreCatCode::try_from(ctx.peek_code(req, 1)).unwrap(); + if ctx.at_end(req) || !category(catcode, ctx.peek_char(req)) { ctx.failure(); } else { ctx.skip_code(2); - ctx.skip_char(state, 1); + ctx.skip_char(req, 1); } } - SreOpcode::IN => general_op_in(state, ctx, charset), - SreOpcode::IN_IGNORE => general_op_in(state, ctx, |set, c| charset(set, lower_ascii(c))), + SreOpcode::IN => general_op_in(req, ctx, charset), + SreOpcode::IN_IGNORE => general_op_in(req, ctx, |set, c| charset(set, lower_ascii(c))), SreOpcode::IN_UNI_IGNORE => { - general_op_in(state, ctx, |set, c| charset(set, lower_unicode(c))) + general_op_in(req, ctx, |set, c| charset(set, lower_unicode(c))) } - SreOpcode::IN_LOC_IGNORE => general_op_in(state, ctx, charset_loc_ignore), + SreOpcode::IN_LOC_IGNORE => general_op_in(req, ctx, charset_loc_ignore), SreOpcode::INFO => { - let min = ctx.peek_code(state, 3) as usize; - if ctx.remaining_chars(state) < min { + let min = ctx.peek_code(req, 3) as usize; + if ctx.remaining_chars(req) < min { ctx.failure(); } else { - ctx.skip_code_from(state, 1); + ctx.skip_code_from(req, 1); } } - SreOpcode::JUMP => ctx.skip_code_from(state, 1), - SreOpcode::LITERAL => general_op_literal(state, ctx, |code, c| code == c), - SreOpcode::NOT_LITERAL => general_op_literal(state, ctx, |code, c| code != c), - SreOpcode::LITERAL_IGNORE => { - general_op_literal(state, ctx, |code, c| code == lower_ascii(c)) - } + SreOpcode::JUMP => ctx.skip_code_from(req, 1), + SreOpcode::LITERAL => general_op_literal(req, ctx, |code, c| code == c), + SreOpcode::NOT_LITERAL => general_op_literal(req, ctx, |code, c| code != c), + SreOpcode::LITERAL_IGNORE => general_op_literal(req, ctx, |code, c| code == lower_ascii(c)), SreOpcode::NOT_LITERAL_IGNORE => { - general_op_literal(state, ctx, |code, c| code != lower_ascii(c)) + general_op_literal(req, ctx, |code, c| code != lower_ascii(c)) } SreOpcode::LITERAL_UNI_IGNORE => { - general_op_literal(state, ctx, |code, c| code == lower_unicode(c)) + general_op_literal(req, ctx, |code, c| code == lower_unicode(c)) } SreOpcode::NOT_LITERAL_UNI_IGNORE => { - general_op_literal(state, ctx, |code, c| code != lower_unicode(c)) + general_op_literal(req, ctx, |code, c| code != lower_unicode(c)) } - SreOpcode::LITERAL_LOC_IGNORE => general_op_literal(state, ctx, char_loc_ignore), + SreOpcode::LITERAL_LOC_IGNORE => general_op_literal(req, ctx, char_loc_ignore), SreOpcode::NOT_LITERAL_LOC_IGNORE => { - general_op_literal(state, ctx, |code, c| !char_loc_ignore(code, c)) + general_op_literal(req, ctx, |code, c| !char_loc_ignore(code, c)) } SreOpcode::MARK => { - state.set_mark(ctx.peek_code(state, 1) as usize, ctx.string_position); + state.set_mark(ctx.peek_code(req, 1) as usize, ctx.string_position); ctx.skip_code(2); } SreOpcode::MAX_UNTIL => op_max_until(state, ctx), SreOpcode::MIN_UNTIL => op_min_until(state, ctx), - SreOpcode::REPEAT => op_repeat(state, ctx), - SreOpcode::REPEAT_ONE => op_repeat_one(state, ctx), - SreOpcode::MIN_REPEAT_ONE => op_min_repeat_one(state, ctx), - SreOpcode::GROUPREF => general_op_groupref(state, ctx, |x| x), - SreOpcode::GROUPREF_IGNORE => general_op_groupref(state, ctx, lower_ascii), - SreOpcode::GROUPREF_LOC_IGNORE => general_op_groupref(state, ctx, lower_locate), - SreOpcode::GROUPREF_UNI_IGNORE => general_op_groupref(state, ctx, lower_unicode), + SreOpcode::REPEAT => op_repeat(req, state, ctx), + SreOpcode::REPEAT_ONE => op_repeat_one(req, state, ctx), + SreOpcode::MIN_REPEAT_ONE => op_min_repeat_one(req, state, ctx), + SreOpcode::GROUPREF => general_op_groupref(req, state, ctx, |x| x), + SreOpcode::GROUPREF_IGNORE => general_op_groupref(req, state, ctx, lower_ascii), + SreOpcode::GROUPREF_LOC_IGNORE => general_op_groupref(req, state, ctx, lower_locate), + SreOpcode::GROUPREF_UNI_IGNORE => general_op_groupref(req, state, ctx, lower_unicode), SreOpcode::GROUPREF_EXISTS => { - let (group_start, group_end) = state.get_marks(ctx.peek_code(state, 1) as usize); + let (group_start, group_end) = state.get_marks(ctx.peek_code(req, 1) as usize); match (group_start, group_end) { (Some(start), Some(end)) if start <= end => { ctx.skip_code(3); } - _ => ctx.skip_code_from(state, 2), + _ => ctx.skip_code_from(req, 2), } } _ => unreachable!("unexpected opcode"), } } +/* optimization info block */ +/* <1=skip> <2=flags> <3=min> <4=max> <5=prefix info> */ +// fn search_op_info<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { +// let min = ctx.peek_code(state, 3) as usize; + +// if ctx.remaining_chars(state) < min { +// return ctx.failure(); +// } + +// if min > 1 { +// /* adjust end point (but make sure we leave at least one +// character in there, so literal search will work) */ +// // no overflow can happen as remaining chars >= min +// state.end -= min - 1; + +// // adjust ctx position +// if state.end < ctx.string_position { +// ctx.string_position = state.end; +// ctx.string_offset = state.string.offset(0, ctx.string_position); +// } +// } + +// let flags = SreInfo::from_bits_truncate(ctx.peek_code(state, 2)); + +// if flags.contains(SreInfo::PREFIX) { +// /* pattern starts with a known prefix */ +// /* */ +// let len = ctx.peek_code(state, 5) as usize; +// let skip = ctx.peek_code(state, 6) as usize; +// let prefix = &ctx.pattern(state)[7..]; +// let overlap = &prefix[len - 1..]; + +// ctx.skip_code_from(state, 1); + +// if len == 1 { +// // pattern starts with a literal character +// let c = prefix[0]; +// let end = state.end; + +// while (!ctx.at_end(state)) { +// // find the next matched literal +// while (ctx.peek_char(state) != c) { +// ctx.skip_char(state, 1); +// if (ctx.at_end(state)) { +// return ctx.failure(); +// } +// } + +// // literal only +// if flags.contains(SreInfo::LITERAL) { +// return ctx.success(); +// } +// } +// } +// } +// } + /* assert subpattern */ /* */ -fn op_assert<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { - let back = ctx.peek_code(state, 2) as usize; +fn op_assert<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, +) { + let back = ctx.peek_code(req, 2) as usize; if ctx.string_position < back { return ctx.failure(); } - let next_ctx = next_ctx!(offset 3, state, ctx, |state, ctx| { + // let next_ctx = next_ctx!(offset 3, state, ctx, |req, state, ctx| { + let next_ctx = ctx.next_offset(3, state, |req, state, ctx| { if state.popped_has_matched { - ctx.skip_code_from(state, 1); + ctx.skip_code_from(req, 1); } else { ctx.failure(); } }); - next_ctx.back_skip_char(&state.string, back); - state.string_position = next_ctx.string_position; next_ctx.toplevel = false; + next_ctx.back_skip_char(req, back); + state.string_position = next_ctx.string_position; } /* assert not subpattern */ /* */ -fn op_assert_not<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { - let back = ctx.peek_code(state, 2) as usize; +fn op_assert_not<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, +) { + let back = ctx.peek_code(req, 2) as usize; if ctx.string_position < back { - return ctx.skip_code_from(state, 1); + return ctx.skip_code_from(req, 1); } - let next_ctx = next_ctx!(offset 3, state, ctx, |state, ctx| { + let next_ctx = next_ctx!(offset 3, state, ctx, |req, state, ctx| { if state.popped_has_matched { ctx.failure(); } else { - ctx.skip_code_from(state, 1); + ctx.skip_code_from(req, 1); } }); - next_ctx.back_skip_char(&state.string, back); - state.string_position = next_ctx.string_position; next_ctx.toplevel = false; + next_ctx.back_skip_char(req, back); + state.string_position = next_ctx.string_position; } // alternation // <0=skip> code ... -fn op_branch<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { +fn op_branch<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, +) { mark!(push, state); ctx.count = 1; - create_context(state, ctx); + create_context(req, state, ctx); - fn create_context<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + fn create_context<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, + ) { let branch_offset = ctx.count as usize; - let next_length = ctx.peek_code(state, branch_offset) as isize; + let next_length = ctx.peek_code(req, branch_offset) as isize; if next_length == 0 { state.marks_pop_discard(); return ctx.failure(); @@ -395,20 +488,28 @@ fn op_branch<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<' next_ctx!(offset branch_offset + 1, state, ctx, callback); } - fn callback<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + fn callback<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, + ) { if state.popped_has_matched { return ctx.success(); } state.marks_pop_keep(); - create_context(state, ctx); + create_context(req, state, ctx); } } /* <1=min> <2=max> item tail */ -fn op_min_repeat_one<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { - let min_count = ctx.peek_code(state, 2) as usize; +fn op_min_repeat_one<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, +) { + let min_count = ctx.peek_code(req, 2) as usize; - if ctx.remaining_chars(state) < min_count { + if ctx.remaining_chars(req) < min_count { return ctx.failure(); } @@ -417,52 +518,61 @@ fn op_min_repeat_one<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchC ctx.count = if min_count == 0 { 0 } else { - let count = _count(state, ctx, min_count); + let count = _count(req, state, ctx, min_count); if count < min_count { return ctx.failure(); } - ctx.skip_char(state, count); + ctx.skip_char(req, count); count as isize }; - let next_code = ctx.peek_code(state, ctx.peek_code(state, 1) as usize + 1); - if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(state) { + let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); + if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(req) { // tail is empty. we're finished state.string_position = ctx.string_position; return ctx.success(); } mark!(push, state); - create_context(state, ctx); + create_context(req, state, ctx); - fn create_context<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { - let max_count = ctx.peek_code(state, 3) as usize; + fn create_context<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, + ) { + let max_count = ctx.peek_code(req, 3) as usize; if max_count == MAXREPEAT || ctx.count as usize <= max_count { state.string_position = ctx.string_position; - next_ctx!(from 1, state, ctx, callback); + // next_ctx!(from 1, state, ctx, callback); + ctx.next_from(1, req, state, callback); } else { state.marks_pop_discard(); ctx.failure(); } } - fn callback<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + fn callback<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, + ) { if state.popped_has_matched { return ctx.success(); } state.string_position = ctx.string_position; - if _count(state, ctx, 1) == 0 { + if _count(req, state, ctx, 1) == 0 { state.marks_pop_discard(); return ctx.failure(); } - ctx.skip_char(state, 1); + ctx.skip_char(req, 1); ctx.count += 1; state.marks_pop_keep(); - create_context(state, ctx); + create_context(req, state, ctx); } } @@ -472,24 +582,28 @@ exactly one character wide, and we're not already collecting backtracking points. for other cases, use the MAX_REPEAT operator */ /* <1=min> <2=max> item tail */ -fn op_repeat_one<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { - let min_count = ctx.peek_code(state, 2) as usize; - let max_count = ctx.peek_code(state, 3) as usize; +fn op_repeat_one<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, +) { + let min_count = ctx.peek_code(req, 2) as usize; + let max_count = ctx.peek_code(req, 3) as usize; - if ctx.remaining_chars(state) < min_count { + if ctx.remaining_chars(req) < min_count { return ctx.failure(); } state.string_position = ctx.string_position; - let count = _count(state, ctx, max_count); - ctx.skip_char(state, count); + let count = _count(req, state, ctx, max_count); + ctx.skip_char(req, count); if count < min_count { return ctx.failure(); } - let next_code = ctx.peek_code(state, ctx.peek_code(state, 1) as usize + 1); - if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(state) { + let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); + if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(req) { // tail is empty. we're finished state.string_position = ctx.string_position; return ctx.success(); @@ -497,21 +611,25 @@ fn op_repeat_one<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchConte mark!(push, state); ctx.count = count as isize; - create_context(state, ctx); - - fn create_context<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { - let min_count = ctx.peek_code(state, 2) as isize; - let next_code = ctx.peek_code(state, ctx.peek_code(state, 1) as usize + 1); + create_context(req, state, ctx); + + fn create_context<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, + ) { + let min_count = ctx.peek_code(req, 2) as isize; + let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); if next_code == SreOpcode::LITERAL as u32 { // Special case: Tail starts with a literal. Skip positions where // the rest of the pattern cannot possibly match. - let c = ctx.peek_code(state, ctx.peek_code(state, 1) as usize + 2); - while ctx.at_end(state) || ctx.peek_char(state) != c { + let c = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 2); + while ctx.at_end(req) || ctx.peek_char(req) != c { if ctx.count <= min_count { state.marks_pop_discard(); return ctx.failure(); } - ctx.back_skip_char(&state.string, 1); + ctx.back_skip_char(req, 1); ctx.count -= 1; } } @@ -519,26 +637,31 @@ fn op_repeat_one<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchConte state.string_position = ctx.string_position; // General case: backtracking - next_ctx!(from 1, state, ctx, callback); + // next_ctx!(from 1, state, ctx, callback); + ctx.next_from(1, req, state, callback); } - fn callback<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + fn callback<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, + ) { if state.popped_has_matched { return ctx.success(); } - let min_count = ctx.peek_code(state, 2) as isize; + let min_count = ctx.peek_code(req, 2) as isize; if ctx.count <= min_count { state.marks_pop_discard(); return ctx.failure(); } - ctx.back_skip_char(&state.string, 1); + ctx.back_skip_char(req, 1); ctx.count -= 1; state.marks_pop_keep(); - create_context(state, ctx); + create_context(req, state, ctx); } } @@ -555,11 +678,15 @@ struct RepeatContext { /* create repeat context. all the hard work is done by the UNTIL operator (MAX_UNTIL, MIN_UNTIL) */ /* <1=min> <2=max> item tail */ -fn op_repeat<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { +fn op_repeat<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, +) { let repeat_ctx = RepeatContext { count: -1, - min_count: ctx.peek_code(state, 2) as usize, - max_count: ctx.peek_code(state, 3) as usize, + min_count: ctx.peek_code(req, 2) as usize, + max_count: ctx.peek_code(req, 3) as usize, code_position: ctx.code_position, last_position: std::usize::MAX, prev_id: ctx.repeat_ctx_id, @@ -569,11 +696,14 @@ fn op_repeat<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<' state.string_position = ctx.string_position; - let next_ctx = next_ctx!(from 1, state, ctx, |state, ctx| { + let repeat_ctx_id = state.repeat_stack.len(); + + // let next_ctx = next_ctx!(from 1, state, ctx, |state, ctx| { + let next_ctx = ctx.next_from(1, req, state, |req, state, ctx| { ctx.has_matched = Some(state.popped_has_matched); state.repeat_stack.pop(); }); - next_ctx.repeat_ctx_id = state.repeat_stack.len() - 1; + next_ctx.repeat_ctx_id = repeat_ctx_id; } /* minimizing repeat */ @@ -586,7 +716,8 @@ fn op_min_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex if (repeat_ctx.count as usize) < repeat_ctx.min_count { // not enough matches - next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { + // next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { + ctx.next_at(repeat_ctx.code_position + 4, state, |req, state, ctx| { if state.popped_has_matched { ctx.success(); } else { @@ -602,8 +733,11 @@ fn op_min_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex ctx.count = ctx.repeat_ctx_id as isize; + let repeat_ctx_prev_id = repeat_ctx.prev_id; + // see if the tail matches - let next_ctx = next_ctx!(offset 1, state, ctx, |state, ctx| { + // let next_ctx = next_ctx!(offset 1, state, ctx, |state, ctx| { + let next_ctx = ctx.next_offset(1, state, |req, state, ctx| { if state.popped_has_matched { return ctx.success(); } @@ -628,7 +762,8 @@ fn op_min_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex /* zero-width match protection */ repeat_ctx.last_position = state.string_position; - next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { + // next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { + ctx.next_at(repeat_ctx.code_position + 4, state, |req, state, ctx| { if state.popped_has_matched { ctx.success(); } else { @@ -638,7 +773,7 @@ fn op_min_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex } }); }); - next_ctx.repeat_ctx_id = repeat_ctx.prev_id; + next_ctx.repeat_ctx_id = repeat_ctx_prev_id; } /* maximizing repeat */ @@ -651,7 +786,8 @@ fn op_max_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex if (repeat_ctx.count as usize) < repeat_ctx.min_count { // not enough matches - next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { + // next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { + ctx.next_at(repeat_ctx.code_position + 4, state, |req, state, ctx| { if state.popped_has_matched { ctx.success(); } else { @@ -673,7 +809,7 @@ fn op_max_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex ctx.count = repeat_ctx.last_position as isize; repeat_ctx.last_position = state.string_position; - next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { + ctx.next_at(repeat_ctx.code_position + 4, state, |req, state, ctx| { let save_last_position = ctx.count as usize; let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; repeat_ctx.last_position = save_last_position; @@ -701,7 +837,11 @@ fn op_max_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex let next_ctx = next_ctx!(offset 1, state, ctx, tail_callback); next_ctx.repeat_ctx_id = repeat_ctx.prev_id; - fn tail_callback<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { + fn tail_callback<'a, S: StrDrive>( + req: &Request<'a, S>, + state: &mut State<'a, S>, + ctx: &mut MatchContext<'a, S>, + ) { if state.popped_has_matched { ctx.success(); } else { @@ -786,8 +926,6 @@ impl<'a> StrDrive for &'a [u8] { } } -// type OpcodeHandler = for<'a>fn(&mut StateContext<'a, S>, &mut Stacks); - #[derive(Clone, Copy)] struct MatchContext<'a, S: StrDrive> { string_position: usize, @@ -795,7 +933,7 @@ struct MatchContext<'a, S: StrDrive> { code_position: usize, has_matched: Option, toplevel: bool, - handler: Option, &mut Self)>, + handler: Option, &mut State<'a, S>, &mut Self)>, repeat_ctx_id: usize, count: isize, } @@ -809,52 +947,53 @@ impl<'a, S: StrDrive> std::fmt::Debug for MatchContext<'a, S> { .field("has_matched", &self.has_matched) .field("toplevel", &self.toplevel) .field("handler", &self.handler.map(|x| x as usize)) + .field("repeat_ctx_id", &self.repeat_ctx_id) .field("count", &self.count) .finish() } } impl<'a, S: StrDrive> MatchContext<'a, S> { - fn pattern(&self, state: &State<'a, S>) -> &[u32] { - &state.pattern_codes[self.code_position..] + fn pattern(&self, req: &Request<'a, S>) -> &'a [u32] { + &req.pattern_codes[self.code_position..] } - fn remaining_codes(&self, state: &State<'a, S>) -> usize { - state.pattern_codes.len() - self.code_position + fn remaining_codes(&self, req: &Request<'a, S>) -> usize { + req.pattern_codes.len() - self.code_position } - fn remaining_chars(&self, state: &State<'a, S>) -> usize { - state.end - self.string_position + fn remaining_chars(&self, req: &Request<'a, S>) -> usize { + req.end - self.string_position } - fn peek_char(&self, state: &State<'a, S>) -> u32 { - state.string.peek(self.string_offset) + fn peek_char(&self, req: &Request<'a, S>) -> u32 { + req.string.peek(self.string_offset) } - fn skip_char(&mut self, state: &State<'a, S>, skip: usize) { + fn skip_char(&mut self, req: &Request<'a, S>, skip: usize) { self.string_position += skip; - self.string_offset = state.string.offset(self.string_offset, skip); + self.string_offset = req.string.offset(self.string_offset, skip); } - fn back_peek_char(&self, state: &State<'a, S>) -> u32 { - state.string.back_peek(self.string_offset) + fn back_peek_char(&self, req: &Request<'a, S>) -> u32 { + req.string.back_peek(self.string_offset) } - fn back_skip_char(&mut self, string: &S, skip: usize) { + fn back_skip_char(&mut self, req: &Request<'a, S>, skip: usize) { self.string_position -= skip; - self.string_offset = string.back_offset(self.string_offset, skip); + self.string_offset = req.string.back_offset(self.string_offset, skip); } - fn peek_code(&self, state: &State<'a, S>, peek: usize) -> u32 { - state.pattern_codes[self.code_position + peek] + fn peek_code(&self, req: &Request<'a, S>, peek: usize) -> u32 { + req.pattern_codes[self.code_position + peek] } fn skip_code(&mut self, skip: usize) { self.code_position += skip; } - fn skip_code_from(&mut self, state: &State<'a, S>, peek: usize) { - self.skip_code(self.peek_code(state, peek) as usize + 1); + fn skip_code_from(&mut self, req: &Request<'a, S>, peek: usize) { + self.skip_code(self.peek_code(req, peek) as usize + 1); } fn at_beginning(&self) -> bool { @@ -862,48 +1001,48 @@ impl<'a, S: StrDrive> MatchContext<'a, S> { self.string_position == 0 } - fn at_end(&self, state: &State<'a, S>) -> bool { - self.string_position == state.end + fn at_end(&self, req: &Request<'a, S>) -> bool { + self.string_position == req.end } - fn at_linebreak(&self, state: &State<'a, S>) -> bool { - !self.at_end(state) && is_linebreak(self.peek_char(state)) + fn at_linebreak(&self, req: &Request<'a, S>) -> bool { + !self.at_end(req) && is_linebreak(self.peek_char(req)) } fn at_boundary bool>( &self, - state: &State<'a, S>, + req: &Request<'a, S>, mut word_checker: F, ) -> bool { - if self.at_beginning() && self.at_end(state) { + if self.at_beginning() && self.at_end(req) { return false; } - let that = !self.at_beginning() && word_checker(self.back_peek_char(state)); - let this = !self.at_end(state) && word_checker(self.peek_char(state)); + let that = !self.at_beginning() && word_checker(self.back_peek_char(req)); + let this = !self.at_end(req) && word_checker(self.peek_char(req)); this != that } fn at_non_boundary bool>( &self, - state: &State<'a, S>, + req: &Request<'a, S>, mut word_checker: F, ) -> bool { - if self.at_beginning() && self.at_end(state) { + if self.at_beginning() && self.at_end(req) { return false; } - let that = !self.at_beginning() && word_checker(self.back_peek_char(state)); - let this = !self.at_end(state) && word_checker(self.peek_char(state)); + let that = !self.at_beginning() && word_checker(self.back_peek_char(req)); + let this = !self.at_end(req) && word_checker(self.peek_char(req)); this == that } - fn can_success(&self, state: &State<'a, S>) -> bool { + fn can_success(&self, req: &Request<'a, S>) -> bool { if !self.toplevel { return true; } - if state.match_all && !self.at_end(state) { + if req.match_all && !self.at_end(req) { return false; } - if state.must_advance && self.string_position == state.start { + if req.must_advance && self.string_position == req.start { return false; } true @@ -916,58 +1055,94 @@ impl<'a, S: StrDrive> MatchContext<'a, S> { fn failure(&mut self) { self.has_matched = Some(false); } + + fn next_from<'b>( + &mut self, + peek: usize, + req: &Request<'a, S>, + state: &'b mut State<'a, S>, + f: fn(&Request<'a, S>, &mut State<'a, S>, &mut Self), + ) -> &'b mut Self { + self.next_offset(self.peek_code(req, peek) as usize + 1, state, f) + } + + fn next_offset<'b>( + &mut self, + offset: usize, + state: &'b mut State<'a, S>, + f: fn(&Request<'a, S>, &mut State<'a, S>, &mut Self), + ) -> &'b mut Self { + self.next_at(self.code_position + offset, state, f) + } + + fn next_at<'b>( + &mut self, + code_position: usize, + state: &'b mut State<'a, S>, + f: fn(&Request<'a, S>, &mut State<'a, S>, &mut Self), + ) -> &'b mut Self { + self.handler = Some(f); + state.next_context.insert(MatchContext { + code_position, + has_matched: None, + handler: None, + count: -1, + ..*self + }) + } } -fn at<'a, S: StrDrive>(state: &State<'a, S>, ctx: &MatchContext<'a, S>, atcode: SreAtCode) -> bool { +fn at<'a, S: StrDrive>(req: &Request<'a, S>, ctx: &MatchContext<'a, S>, atcode: SreAtCode) -> bool { match atcode { SreAtCode::BEGINNING | SreAtCode::BEGINNING_STRING => ctx.at_beginning(), - SreAtCode::BEGINNING_LINE => ctx.at_beginning() || is_linebreak(ctx.back_peek_char(state)), - SreAtCode::BOUNDARY => ctx.at_boundary(state, is_word), - SreAtCode::NON_BOUNDARY => ctx.at_non_boundary(state, is_word), + SreAtCode::BEGINNING_LINE => ctx.at_beginning() || is_linebreak(ctx.back_peek_char(req)), + SreAtCode::BOUNDARY => ctx.at_boundary(req, is_word), + SreAtCode::NON_BOUNDARY => ctx.at_non_boundary(req, is_word), SreAtCode::END => { - (ctx.remaining_chars(state) == 1 && ctx.at_linebreak(state)) || ctx.at_end(state) + (ctx.remaining_chars(req) == 1 && ctx.at_linebreak(req)) || ctx.at_end(req) } - SreAtCode::END_LINE => ctx.at_linebreak(state) || ctx.at_end(state), - SreAtCode::END_STRING => ctx.at_end(state), - SreAtCode::LOC_BOUNDARY => ctx.at_boundary(state, is_loc_word), - SreAtCode::LOC_NON_BOUNDARY => ctx.at_non_boundary(state, is_loc_word), - SreAtCode::UNI_BOUNDARY => ctx.at_boundary(state, is_uni_word), - SreAtCode::UNI_NON_BOUNDARY => ctx.at_non_boundary(state, is_uni_word), + SreAtCode::END_LINE => ctx.at_linebreak(req) || ctx.at_end(req), + SreAtCode::END_STRING => ctx.at_end(req), + SreAtCode::LOC_BOUNDARY => ctx.at_boundary(req, is_loc_word), + SreAtCode::LOC_NON_BOUNDARY => ctx.at_non_boundary(req, is_loc_word), + SreAtCode::UNI_BOUNDARY => ctx.at_boundary(req, is_uni_word), + SreAtCode::UNI_NON_BOUNDARY => ctx.at_non_boundary(req, is_uni_word), } } fn general_op_literal<'a, S: StrDrive, F: FnOnce(u32, u32) -> bool>( - state: &State<'a, S>, + req: &Request<'a, S>, ctx: &mut MatchContext<'a, S>, f: F, ) { - if ctx.at_end(state) || !f(ctx.peek_code(state, 1), ctx.peek_char(state)) { + if ctx.at_end(req) || !f(ctx.peek_code(req, 1), ctx.peek_char(req)) { ctx.failure(); } else { ctx.skip_code(2); - ctx.skip_char(state, 1); + ctx.skip_char(req, 1); } } fn general_op_in<'a, S: StrDrive, F: FnOnce(&[u32], u32) -> bool>( - state: &State<'a, S>, + req: &Request<'a, S>, ctx: &mut MatchContext<'a, S>, f: F, ) { - if ctx.at_end(state) || !f(&ctx.pattern(state)[2..], ctx.peek_char(state)) { + if ctx.at_end(req) || !f(&ctx.pattern(req)[2..], ctx.peek_char(req)) { ctx.failure(); } else { - ctx.skip_code_from(state, 1); - ctx.skip_char(state, 1); + ctx.skip_code_from(req, 1); + ctx.skip_char(req, 1); } } fn general_op_groupref<'a, S: StrDrive, F: FnMut(u32) -> u32>( + req: &Request<'a, S>, state: &State<'a, S>, ctx: &mut MatchContext<'a, S>, mut f: F, ) { - let (group_start, group_end) = state.get_marks(ctx.peek_code(state, 1) as usize); + let (group_start, group_end) = state.get_marks(ctx.peek_code(req, 1) as usize); let (group_start, group_end) = match (group_start, group_end) { (Some(start), Some(end)) if start <= end => (start, end), _ => { @@ -977,16 +1152,16 @@ fn general_op_groupref<'a, S: StrDrive, F: FnMut(u32) -> u32>( let mut gctx = MatchContext { string_position: group_start, - string_offset: state.string.offset(0, group_start), + string_offset: req.string.offset(0, group_start), ..*ctx }; for _ in group_start..group_end { - if ctx.at_end(state) || f(ctx.peek_char(state)) != f(gctx.peek_char(state)) { + if ctx.at_end(req) || f(ctx.peek_char(req)) != f(gctx.peek_char(req)) { return ctx.failure(); } - ctx.skip_char(state, 1); - gctx.skip_char(state, 1); + ctx.skip_char(req, 1); + gctx.skip_char(req, 1); } ctx.skip_code(2); @@ -1122,60 +1297,56 @@ fn charset(set: &[u32], ch: u32) -> bool { } fn _count<'a, S: StrDrive>( + req: &Request<'a, S>, state: &mut State<'a, S>, ctx: &MatchContext<'a, S>, max_count: usize, ) -> usize { let mut ctx = *ctx; - let max_count = std::cmp::min(max_count, ctx.remaining_chars(state)); + let max_count = std::cmp::min(max_count, ctx.remaining_chars(req)); let end = ctx.string_position + max_count; - let opcode = SreOpcode::try_from(ctx.peek_code(state, 0)).unwrap(); + let opcode = SreOpcode::try_from(ctx.peek_code(req, 0)).unwrap(); match opcode { SreOpcode::ANY => { - while !ctx.string_position < end && !ctx.at_linebreak(state) { - ctx.skip_char(state, 1); + while !ctx.string_position < end && !ctx.at_linebreak(req) { + ctx.skip_char(req, 1); } } SreOpcode::ANY_ALL => { - ctx.skip_char(state, max_count); + ctx.skip_char(req, max_count); } SreOpcode::IN => { - while !ctx.string_position < end - && charset(&ctx.pattern(state)[2..], ctx.peek_char(state)) + while !ctx.string_position < end && charset(&ctx.pattern(req)[2..], ctx.peek_char(req)) { - ctx.skip_char(state, 1); + ctx.skip_char(req, 1); } } SreOpcode::LITERAL => { - general_count_literal(state, &mut ctx, end, |code, c| code == c as u32); + general_count_literal(req, &mut ctx, end, |code, c| code == c as u32); } SreOpcode::NOT_LITERAL => { - general_count_literal(state, &mut ctx, end, |code, c| code != c as u32); + general_count_literal(req, &mut ctx, end, |code, c| code != c as u32); } SreOpcode::LITERAL_IGNORE => { - general_count_literal(state, &mut ctx, end, |code, c| { - code == lower_ascii(c) as u32 - }); + general_count_literal(req, &mut ctx, end, |code, c| code == lower_ascii(c) as u32); } SreOpcode::NOT_LITERAL_IGNORE => { - general_count_literal(state, &mut ctx, end, |code, c| { - code != lower_ascii(c) as u32 - }); + general_count_literal(req, &mut ctx, end, |code, c| code != lower_ascii(c) as u32); } SreOpcode::LITERAL_LOC_IGNORE => { - general_count_literal(state, &mut ctx, end, char_loc_ignore); + general_count_literal(req, &mut ctx, end, char_loc_ignore); } SreOpcode::NOT_LITERAL_LOC_IGNORE => { - general_count_literal(state, &mut ctx, end, |code, c| !char_loc_ignore(code, c)); + general_count_literal(req, &mut ctx, end, |code, c| !char_loc_ignore(code, c)); } SreOpcode::LITERAL_UNI_IGNORE => { - general_count_literal(state, &mut ctx, end, |code, c| { + general_count_literal(req, &mut ctx, end, |code, c| { code == lower_unicode(c) as u32 }); } SreOpcode::NOT_LITERAL_UNI_IGNORE => { - general_count_literal(state, &mut ctx, end, |code, c| { + general_count_literal(req, &mut ctx, end, |code, c| { code != lower_unicode(c) as u32 }); } @@ -1188,9 +1359,9 @@ fn _count<'a, S: StrDrive>( while count < max_count { ctx.code_position = reset_position; - let code = ctx.peek_code(state, 0); + let code = ctx.peek_code(req, 0); let code = SreOpcode::try_from(code).unwrap(); - dispatch(state, &mut ctx, code); + dispatch(req, state, &mut ctx, code); if ctx.has_matched == Some(false) { break; } @@ -1205,14 +1376,14 @@ fn _count<'a, S: StrDrive>( } fn general_count_literal<'a, S: StrDrive, F: FnMut(u32, u32) -> bool>( - state: &State<'a, S>, + req: &Request<'a, S>, ctx: &mut MatchContext<'a, S>, end: usize, mut f: F, ) { - let ch = ctx.peek_code(state, 1); - while !ctx.string_position < end && f(ch, ctx.peek_char(state)) { - ctx.skip_char(state, 1); + let ch = ctx.peek_code(req, 1); + while !ctx.string_position < end && f(ch, ctx.peek_char(req)) { + ctx.skip_char(req, 1); } } From c15387e97289386bbf0891a7ed18367220b59a15 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 9 Aug 2022 16:44:47 +0200 Subject: [PATCH 063/893] refactor tests --- generate_tests.py | 2 +- src/engine.rs | 4 ++-- tests/tests.rs | 47 +++++++++++++++++++++++------------------------ 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/generate_tests.py b/generate_tests.py index b432720cd1..8adf043f29 100644 --- a/generate_tests.py +++ b/generate_tests.py @@ -33,7 +33,7 @@ def compile(cls, pattern, flags=0): def replace_compiled(m): line, indent, varname, pattern = m.groups() pattern = eval(pattern, {"re": CompiledPattern}) - pattern = f"Pattern {{ code: &{json.dumps(pattern.code)}, flags: SreFlag::from_bits_truncate({int(pattern.flags)}) }}" + pattern = f"Pattern {{ code: &{json.dumps(pattern.code)} }}" return f'''{line} {indent}// START GENERATED by generate_tests.py {indent}#[rustfmt::skip] let {varname} = {pattern}; diff --git a/src/engine.rs b/src/engine.rs index 5e1f1457ec..b9487a2fd2 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -81,7 +81,7 @@ pub struct State<'a, S: StrDrive> { pub string_position: usize, next_context: Option>, popped_has_matched: bool, - has_matched: bool, + pub has_matched: bool, } impl<'a, S: StrDrive> State<'a, S> { @@ -696,7 +696,7 @@ fn op_repeat<'a, S: StrDrive>( state.string_position = ctx.string_position; - let repeat_ctx_id = state.repeat_stack.len(); + let repeat_ctx_id = state.repeat_stack.len() - 1; // let next_ctx = next_ctx!(from 1, state, ctx, |state, ctx| { let next_ctx = ctx.next_from(1, req, state, |req, state, ctx| { diff --git a/tests/tests.rs b/tests/tests.rs index cc5c4d1f38..b4ad09f7be 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,18 +1,17 @@ -use sre_engine::constants::SreFlag; use sre_engine::engine; struct Pattern { code: &'static [u32], - flags: SreFlag, } impl Pattern { fn state<'a, S: engine::StrDrive>( &self, string: S, - range: std::ops::Range, - ) -> engine::State<'a, S> { - engine::State::new(string, range.start, range.end, self.flags, self.code) + ) -> (engine::Request<'a, S>, engine::State<'a, S>) { + let req = engine::Request::new(string, 0, usize::MAX, self.code, false); + let state = engine::State::new(0); + (req, state) } } @@ -20,10 +19,10 @@ impl Pattern { fn test_2427() { // pattern lookbehind = re.compile(r'(? Date: Tue, 9 Aug 2022 16:54:41 +0200 Subject: [PATCH 064/893] refactor benches --- benches/benches.rs | 95 +++++++++++++++++++++++++--------------------- 1 file changed, 52 insertions(+), 43 deletions(-) diff --git a/benches/benches.rs b/benches/benches.rs index d000ceb62e..e24ea2f972 100644 --- a/benches/benches.rs +++ b/benches/benches.rs @@ -3,68 +3,77 @@ extern crate test; use test::Bencher; -use sre_engine::constants::SreFlag; use sre_engine::engine; -pub struct Pattern { - pub code: &'static [u32], - pub flags: SreFlag, + +struct Pattern { + code: &'static [u32], } impl Pattern { - pub fn state<'a, S: engine::StrDrive>( + fn state<'a, S: engine::StrDrive>( + &self, + string: S, + ) -> (engine::Request<'a, S>, engine::State<'a, S>) { + self.state_range(string, 0..usize::MAX) + } + + fn state_range<'a, S: engine::StrDrive>( &self, string: S, range: std::ops::Range, - ) -> engine::State<'a, S> { - engine::State::new(string, range.start, range.end, self.flags, self.code) + ) -> (engine::Request<'a, S>, engine::State<'a, S>) { + let req = engine::Request::new(string, range.start, range.end, self.code, false); + let state = engine::State::new(0); + (req, state) } } + #[bench] fn benchmarks(b: &mut Bencher) { // # test common prefix // pattern p1 = re.compile('Python|Perl') # , 'Perl'), # Alternation // START GENERATED by generate_tests.py - #[rustfmt::skip] let p1 = Pattern { code: &[15, 8, 1, 4, 6, 1, 1, 80, 0, 17, 80, 7, 13, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 11, 9, 17, 101, 17, 114, 17, 108, 16, 2, 0, 1], flags: SreFlag::from_bits_truncate(32) }; + #[rustfmt::skip] let p1 = Pattern { code: &[15, 8, 1, 4, 6, 1, 1, 80, 0, 17, 80, 7, 13, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 11, 9, 17, 101, 17, 114, 17, 108, 16, 2, 0, 1] }; // END GENERATED // pattern p2 = re.compile('(Python|Perl)') #, 'Perl'), # Grouped alternation // START GENERATED by generate_tests.py - #[rustfmt::skip] let p2 = Pattern { code: &[15, 8, 1, 4, 6, 1, 0, 80, 0, 18, 0, 17, 80, 7, 13, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 11, 9, 17, 101, 17, 114, 17, 108, 16, 2, 0, 18, 1, 1], flags: SreFlag::from_bits_truncate(32) }; + #[rustfmt::skip] let p2 = Pattern { code: &[15, 8, 1, 4, 6, 1, 0, 80, 0, 18, 0, 17, 80, 7, 13, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 11, 9, 17, 101, 17, 114, 17, 108, 16, 2, 0, 18, 1, 1] }; // END GENERATED - // pattern pn = re.compile('Python|Perl|Tcl') #, 'Perl'), # Alternation + // pattern p3 = re.compile('Python|Perl|Tcl') #, 'Perl'), # Alternation // START GENERATED by generate_tests.py - #[rustfmt::skip] let p3 = Pattern { code: &[15, 9, 4, 3, 6, 17, 80, 17, 84, 0, 7, 15, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 22, 11, 17, 80, 17, 101, 17, 114, 17, 108, 16, 11, 9, 17, 84, 17, 99, 17, 108, 16, 2, 0, 1], flags: SreFlag::from_bits_truncate(32) }; + #[rustfmt::skip] let p3 = Pattern { code: &[15, 9, 4, 3, 6, 17, 80, 17, 84, 0, 7, 15, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 22, 11, 17, 80, 17, 101, 17, 114, 17, 108, 16, 11, 9, 17, 84, 17, 99, 17, 108, 16, 2, 0, 1] }; // END GENERATED - // pattern pn = re.compile('(Python|Perl|Tcl)') #, 'Perl'), # Grouped alternation + // pattern p4 = re.compile('(Python|Perl|Tcl)') #, 'Perl'), # Grouped alternation // START GENERATED by generate_tests.py - #[rustfmt::skip] let p4 = Pattern { code: &[15, 9, 4, 3, 6, 17, 80, 17, 84, 0, 18, 0, 7, 15, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 22, 11, 17, 80, 17, 101, 17, 114, 17, 108, 16, 11, 9, 17, 84, 17, 99, 17, 108, 16, 2, 0, 18, 1, 1], flags: SreFlag::from_bits_truncate(32) }; + #[rustfmt::skip] let p4 = Pattern { code: &[15, 9, 4, 3, 6, 17, 80, 17, 84, 0, 18, 0, 7, 15, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 22, 11, 17, 80, 17, 101, 17, 114, 17, 108, 16, 11, 9, 17, 84, 17, 99, 17, 108, 16, 2, 0, 18, 1, 1] }; // END GENERATED - // pattern pn = re.compile('(Python)\\1') #, 'PythonPython'), # Backreference + // pattern p5 = re.compile('(Python)\\1') #, 'PythonPython'), # Backreference // START GENERATED by generate_tests.py - #[rustfmt::skip] let p5 = Pattern { code: &[15, 18, 1, 12, 12, 6, 0, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 18, 0, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 18, 1, 12, 0, 1], flags: SreFlag::from_bits_truncate(32) }; + #[rustfmt::skip] let p5 = Pattern { code: &[15, 18, 1, 12, 12, 6, 0, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 18, 0, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 18, 1, 12, 0, 1] }; // END GENERATED - // pattern pn = re.compile('([0a-z][a-z0-9]*,)+') #, 'a5,b7,c9,'), # Disable the fastmap optimization + // pattern p6 = re.compile('([0a-z][a-z0-9]*,)+') #, 'a5,b7,c9,'), # Disable the fastmap optimization // START GENERATED by generate_tests.py - #[rustfmt::skip] let p6 = Pattern { code: &[15, 4, 0, 2, 4294967295, 24, 31, 1, 4294967295, 18, 0, 14, 7, 17, 48, 23, 97, 122, 0, 25, 13, 0, 4294967295, 14, 8, 23, 97, 122, 23, 48, 57, 0, 1, 17, 44, 18, 1, 19, 1], flags: SreFlag::from_bits_truncate(32) }; + #[rustfmt::skip] let p6 = Pattern { code: &[15, 4, 0, 2, 4294967295, 24, 31, 1, 4294967295, 18, 0, 14, 7, 17, 48, 23, 97, 122, 0, 25, 13, 0, 4294967295, 14, 8, 23, 97, 122, 23, 48, 57, 0, 1, 17, 44, 18, 1, 19, 1] }; // END GENERATED - // pattern pn = re.compile('([a-z][a-z0-9]*,)+') #, 'a5,b7,c9,'), # A few sets + // pattern p7 = re.compile('([a-z][a-z0-9]*,)+') #, 'a5,b7,c9,'), # A few sets // START GENERATED by generate_tests.py - #[rustfmt::skip] let p7 = Pattern { code: &[15, 4, 0, 2, 4294967295, 24, 29, 1, 4294967295, 18, 0, 14, 5, 23, 97, 122, 0, 25, 13, 0, 4294967295, 14, 8, 23, 97, 122, 23, 48, 57, 0, 1, 17, 44, 18, 1, 19, 1], flags: SreFlag::from_bits_truncate(32) }; + #[rustfmt::skip] let p7 = Pattern { code: &[15, 4, 0, 2, 4294967295, 24, 29, 1, 4294967295, 18, 0, 14, 5, 23, 97, 122, 0, 25, 13, 0, 4294967295, 14, 8, 23, 97, 122, 23, 48, 57, 0, 1, 17, 44, 18, 1, 19, 1] }; // END GENERATED - // pattern pn = re.compile('Python') #, 'Python'), # Simple text literal + // pattern p8 = re.compile('Python') #, 'Python'), # Simple text literal // START GENERATED by generate_tests.py - #[rustfmt::skip] let p8 = Pattern { code: &[15, 18, 3, 6, 6, 6, 6, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 1], flags: SreFlag::from_bits_truncate(32) }; + #[rustfmt::skip] let p8 = Pattern { code: &[15, 18, 3, 6, 6, 6, 6, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 1] }; // END GENERATED - // pattern pn = re.compile('.*Python') #, 'Python'), # Bad text literal + // pattern p9 = re.compile('.*Python') #, 'Python'), # Bad text literal // START GENERATED by generate_tests.py - #[rustfmt::skip] let p9 = Pattern { code: &[15, 4, 0, 6, 4294967295, 25, 5, 0, 4294967295, 2, 1, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 1], flags: SreFlag::from_bits_truncate(32) }; + #[rustfmt::skip] let p9 = Pattern { code: &[15, 4, 0, 6, 4294967295, 25, 5, 0, 4294967295, 2, 1, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 1] }; // END GENERATED - // pattern pn = re.compile('.*Python.*') #, 'Python'), # Worse text literal + // pattern p10 = re.compile('.*Python.*') #, 'Python'), # Worse text literal // START GENERATED by generate_tests.py - #[rustfmt::skip] let p10 = Pattern { code: &[15, 4, 0, 6, 4294967295, 25, 5, 0, 4294967295, 2, 1, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 25, 5, 0, 4294967295, 2, 1, 1], flags: SreFlag::from_bits_truncate(32) }; + #[rustfmt::skip] let p10 = Pattern { code: &[15, 4, 0, 6, 4294967295, 25, 5, 0, 4294967295, 2, 1, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 25, 5, 0, 4294967295, 2, 1, 1] }; // END GENERATED - // pattern pn = re.compile('.*(Python)') #, 'Python'), # Bad text literal with grouping + // pattern p11 = re.compile('.*(Python)') #, 'Python'), # Bad text literal with grouping // START GENERATED by generate_tests.py - #[rustfmt::skip] let p11 = Pattern { code: &[15, 4, 0, 6, 4294967295, 25, 5, 0, 4294967295, 2, 1, 18, 0, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 18, 1, 1], flags: SreFlag::from_bits_truncate(32) }; + #[rustfmt::skip] let p11 = Pattern { code: &[15, 4, 0, 6, 4294967295, 25, 5, 0, 4294967295, 2, 1, 18, 0, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 18, 1, 1] }; // END GENERATED let tests = [ @@ -83,29 +92,29 @@ fn benchmarks(b: &mut Bencher) { b.iter(move || { for (p, s) in &tests { - let mut state = p.state(s.clone(), 0..usize::MAX); - state.search(); + let (mut req, mut state) = p.state(s.clone()); + state.search(&mut req); assert!(state.has_matched); - state = p.state(s.clone(), 0..usize::MAX); - state.pymatch(); + let (mut req, mut state) = p.state(s.clone()); + state.pymatch(&mut req); assert!(state.has_matched); - state = p.state(s.clone(), 0..usize::MAX); - state.match_all = true; - state.pymatch(); + let (mut req, mut state) = p.state(s.clone()); + req.match_all = true; + state.pymatch(&mut req); assert!(state.has_matched); let s2 = format!("{}{}{}", " ".repeat(10000), s, " ".repeat(10000)); - state = p.state(s2.as_str(), 0..usize::MAX); - state.search(); + let (mut req, mut state) = p.state_range(s2.as_str(), 0..usize::MAX); + state.search(&mut req); assert!(state.has_matched); - state = p.state(s2.as_str(), 10000..usize::MAX); - state.pymatch(); + let (mut req, mut state) = p.state_range(s2.as_str(), 10000..usize::MAX); + state.pymatch(&mut req); assert!(state.has_matched); - state = p.state(s2.as_str(), 10000..10000 + s.len()); - state.pymatch(); + let (mut req, mut state) = p.state_range(s2.as_str(), 10000..10000 + s.len()); + state.pymatch(&mut req); assert!(state.has_matched); - state = p.state(s2.as_str(), 10000..10000 + s.len()); - state.match_all = true; - state.pymatch(); + let (mut req, mut state) = p.state_range(s2.as_str(), 10000..10000 + s.len()); + req.match_all = true; + state.pymatch(&mut req); assert!(state.has_matched); } }) From de8973d77a40303693e8e15da70fe63f6f974546 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 9 Aug 2022 17:34:03 +0200 Subject: [PATCH 065/893] simplify lifetime --- benches/benches.rs | 6 +- src/engine.rs | 265 ++++++++++++++++++--------------------------- tests/tests.rs | 4 +- 3 files changed, 110 insertions(+), 165 deletions(-) diff --git a/benches/benches.rs b/benches/benches.rs index e24ea2f972..8e0e87935a 100644 --- a/benches/benches.rs +++ b/benches/benches.rs @@ -13,7 +13,7 @@ impl Pattern { fn state<'a, S: engine::StrDrive>( &self, string: S, - ) -> (engine::Request<'a, S>, engine::State<'a, S>) { + ) -> (engine::Request<'a, S>, engine::State) { self.state_range(string, 0..usize::MAX) } @@ -21,9 +21,9 @@ impl Pattern { &self, string: S, range: std::ops::Range, - ) -> (engine::Request<'a, S>, engine::State<'a, S>) { + ) -> (engine::Request<'a, S>, engine::State) { let req = engine::Request::new(string, range.start, range.end, self.code, false); - let state = engine::State::new(0); + let state = engine::State::new(); (req, state) } } diff --git a/src/engine.rs b/src/engine.rs index b9487a2fd2..ace75d1a36 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -39,25 +39,6 @@ impl<'a, S: StrDrive> Request<'a, S> { } } -macro_rules! next_ctx { - (offset $offset:expr, $state:expr, $ctx:expr, $handler:expr) => { - next_ctx!(position $ctx.code_position + $offset, $state, $ctx, $handler) - }; - (from $peek:expr, $state:expr, $ctx:expr, $handler:expr) => { - next_ctx!(offset $ctx.peek_code($state, $peek) as usize + 1, $state, $ctx, $handler) - }; - (position $position:expr, $state:expr, $ctx:expr, $handler:expr) => {{ - $ctx.handler = Some($handler); - $state.next_context.insert(MatchContext { - code_position: $position, - has_matched: None, - handler: None, - count: -1, - ..*$ctx - }) - }}; -} - macro_rules! mark { (push, $state:expr) => { $state @@ -72,27 +53,27 @@ macro_rules! mark { } #[derive(Debug)] -pub struct State<'a, S: StrDrive> { +pub struct State { pub marks: Vec>, pub lastindex: isize, marks_stack: Vec<(Vec>, isize)>, - context_stack: Vec>, + context_stack: Vec>, repeat_stack: Vec, pub string_position: usize, - next_context: Option>, + next_context: Option>, popped_has_matched: bool, pub has_matched: bool, } -impl<'a, S: StrDrive> State<'a, S> { - pub fn new(string_position: usize) -> Self { +impl State { + pub fn new() -> Self { Self { marks: Vec::new(), lastindex: -1, marks_stack: Vec::new(), context_stack: Vec::new(), repeat_stack: Vec::new(), - string_position, + string_position: 0, next_context: None, popped_has_matched: false, has_matched: false, @@ -145,7 +126,7 @@ impl<'a, S: StrDrive> State<'a, S> { self.marks_stack.pop(); } - fn _match(&mut self, req: &mut Request<'a, S>) { + fn _match(&mut self, req: &mut Request) { while let Some(mut ctx) = self.context_stack.pop() { if let Some(handler) = ctx.handler.take() { handler(req, self, &mut ctx); @@ -169,7 +150,9 @@ impl<'a, S: StrDrive> State<'a, S> { self.has_matched = self.popped_has_matched; } - pub fn pymatch(&mut self, req: &mut Request<'a, S>) { + pub fn pymatch(&mut self, req: &mut Request) { + self.string_position = req.start; + let ctx = MatchContext { string_position: req.start, string_offset: req.string.offset(0, req.start), @@ -185,7 +168,9 @@ impl<'a, S: StrDrive> State<'a, S> { self._match(req); } - pub fn search(&mut self, req: &mut Request<'a, S>) { + pub fn search(&mut self, req: &mut Request) { + self.string_position = req.start; + // TODO: optimize by op info and skip prefix if req.start > req.end { return; @@ -196,7 +181,7 @@ impl<'a, S: StrDrive> State<'a, S> { let mut start_offset = req.string.offset(0, req.start); - let mut ctx = MatchContext { + let ctx = MatchContext { string_position: req.start, string_offset: start_offset, code_position: 0, @@ -240,10 +225,10 @@ impl<'a, S: StrDrive> State<'a, S> { } } -fn dispatch<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, +fn dispatch( + req: &Request, + state: &mut State, + ctx: &mut MatchContext, opcode: SreOpcode, ) { match opcode { @@ -410,11 +395,7 @@ fn dispatch<'a, S: StrDrive>( /* assert subpattern */ /* */ -fn op_assert<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, -) { +fn op_assert(req: &Request, state: &mut State, ctx: &mut MatchContext) { let back = ctx.peek_code(req, 2) as usize; if ctx.string_position < back { return ctx.failure(); @@ -435,18 +416,14 @@ fn op_assert<'a, S: StrDrive>( /* assert not subpattern */ /* */ -fn op_assert_not<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, -) { +fn op_assert_not(req: &Request, state: &mut State, ctx: &mut MatchContext) { let back = ctx.peek_code(req, 2) as usize; if ctx.string_position < back { return ctx.skip_code_from(req, 1); } - let next_ctx = next_ctx!(offset 3, state, ctx, |req, state, ctx| { + let next_ctx = ctx.next_offset(3, state, |req, state, ctx| { if state.popped_has_matched { ctx.failure(); } else { @@ -460,20 +437,16 @@ fn op_assert_not<'a, S: StrDrive>( // alternation // <0=skip> code ... -fn op_branch<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, -) { +fn op_branch(req: &Request, state: &mut State, ctx: &mut MatchContext) { mark!(push, state); ctx.count = 1; create_context(req, state, ctx); - fn create_context<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, + fn create_context( + req: &Request, + state: &mut State, + ctx: &mut MatchContext, ) { let branch_offset = ctx.count as usize; let next_length = ctx.peek_code(req, branch_offset) as isize; @@ -485,14 +458,10 @@ fn op_branch<'a, S: StrDrive>( state.string_position = ctx.string_position; ctx.count += next_length; - next_ctx!(offset branch_offset + 1, state, ctx, callback); + ctx.next_offset(branch_offset + 1, state, callback); } - fn callback<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, - ) { + fn callback(req: &Request, state: &mut State, ctx: &mut MatchContext) { if state.popped_has_matched { return ctx.success(); } @@ -502,10 +471,10 @@ fn op_branch<'a, S: StrDrive>( } /* <1=min> <2=max> item tail */ -fn op_min_repeat_one<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, +fn op_min_repeat_one( + req: &Request, + state: &mut State, + ctx: &mut MatchContext, ) { let min_count = ctx.peek_code(req, 2) as usize; @@ -536,10 +505,10 @@ fn op_min_repeat_one<'a, S: StrDrive>( mark!(push, state); create_context(req, state, ctx); - fn create_context<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, + fn create_context( + req: &Request, + state: &mut State, + ctx: &mut MatchContext, ) { let max_count = ctx.peek_code(req, 3) as usize; @@ -553,11 +522,7 @@ fn op_min_repeat_one<'a, S: StrDrive>( } } - fn callback<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, - ) { + fn callback(req: &Request, state: &mut State, ctx: &mut MatchContext) { if state.popped_has_matched { return ctx.success(); } @@ -582,11 +547,7 @@ exactly one character wide, and we're not already collecting backtracking points. for other cases, use the MAX_REPEAT operator */ /* <1=min> <2=max> item tail */ -fn op_repeat_one<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, -) { +fn op_repeat_one(req: &Request, state: &mut State, ctx: &mut MatchContext) { let min_count = ctx.peek_code(req, 2) as usize; let max_count = ctx.peek_code(req, 3) as usize; @@ -613,10 +574,10 @@ fn op_repeat_one<'a, S: StrDrive>( ctx.count = count as isize; create_context(req, state, ctx); - fn create_context<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, + fn create_context( + req: &Request, + state: &mut State, + ctx: &mut MatchContext, ) { let min_count = ctx.peek_code(req, 2) as isize; let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); @@ -641,11 +602,7 @@ fn op_repeat_one<'a, S: StrDrive>( ctx.next_from(1, req, state, callback); } - fn callback<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, - ) { + fn callback(req: &Request, state: &mut State, ctx: &mut MatchContext) { if state.popped_has_matched { return ctx.success(); } @@ -678,11 +635,7 @@ struct RepeatContext { /* create repeat context. all the hard work is done by the UNTIL operator (MAX_UNTIL, MIN_UNTIL) */ /* <1=min> <2=max> item tail */ -fn op_repeat<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, -) { +fn op_repeat(req: &Request, state: &mut State, ctx: &mut MatchContext) { let repeat_ctx = RepeatContext { count: -1, min_count: ctx.peek_code(req, 2) as usize, @@ -698,8 +651,7 @@ fn op_repeat<'a, S: StrDrive>( let repeat_ctx_id = state.repeat_stack.len() - 1; - // let next_ctx = next_ctx!(from 1, state, ctx, |state, ctx| { - let next_ctx = ctx.next_from(1, req, state, |req, state, ctx| { + let next_ctx = ctx.next_from(1, req, state, |_, state, ctx| { ctx.has_matched = Some(state.popped_has_matched); state.repeat_stack.pop(); }); @@ -707,7 +659,7 @@ fn op_repeat<'a, S: StrDrive>( } /* minimizing repeat */ -fn op_min_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { +fn op_min_until(state: &mut State, ctx: &mut MatchContext) { let repeat_ctx = state.repeat_stack.last_mut().unwrap(); state.string_position = ctx.string_position; @@ -716,8 +668,7 @@ fn op_min_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex if (repeat_ctx.count as usize) < repeat_ctx.min_count { // not enough matches - // next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { - ctx.next_at(repeat_ctx.code_position + 4, state, |req, state, ctx| { + ctx.next_at(repeat_ctx.code_position + 4, state, |_, state, ctx| { if state.popped_has_matched { ctx.success(); } else { @@ -736,8 +687,7 @@ fn op_min_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex let repeat_ctx_prev_id = repeat_ctx.prev_id; // see if the tail matches - // let next_ctx = next_ctx!(offset 1, state, ctx, |state, ctx| { - let next_ctx = ctx.next_offset(1, state, |req, state, ctx| { + let next_ctx = ctx.next_offset(1, state, |_, state, ctx| { if state.popped_has_matched { return ctx.success(); } @@ -762,8 +712,7 @@ fn op_min_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex /* zero-width match protection */ repeat_ctx.last_position = state.string_position; - // next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { - ctx.next_at(repeat_ctx.code_position + 4, state, |req, state, ctx| { + ctx.next_at(repeat_ctx.code_position + 4, state, |_, state, ctx| { if state.popped_has_matched { ctx.success(); } else { @@ -777,7 +726,7 @@ fn op_min_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex } /* maximizing repeat */ -fn op_max_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { +fn op_max_until(state: &mut State, ctx: &mut MatchContext) { let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; state.string_position = ctx.string_position; @@ -786,8 +735,7 @@ fn op_max_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex if (repeat_ctx.count as usize) < repeat_ctx.min_count { // not enough matches - // next_ctx!(position repeat_ctx.code_position + 4, state, ctx, |state, ctx| { - ctx.next_at(repeat_ctx.code_position + 4, state, |req, state, ctx| { + ctx.next_at(repeat_ctx.code_position + 4, state, |_, state, ctx| { if state.popped_has_matched { ctx.success(); } else { @@ -809,7 +757,7 @@ fn op_max_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex ctx.count = repeat_ctx.last_position as isize; repeat_ctx.last_position = state.string_position; - ctx.next_at(repeat_ctx.code_position + 4, state, |req, state, ctx| { + ctx.next_at(repeat_ctx.code_position + 4, state, |_, state, ctx| { let save_last_position = ctx.count as usize; let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; repeat_ctx.last_position = save_last_position; @@ -826,22 +774,21 @@ fn op_max_until<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContex /* cannot match more repeated items here. make sure the tail matches */ - let next_ctx = next_ctx!(offset 1, state, ctx, tail_callback); - next_ctx.repeat_ctx_id = repeat_ctx.prev_id; + let repeat_ctx_prev_id = repeat_ctx.prev_id; + let next_ctx = ctx.next_offset(1, state, tail_callback); + next_ctx.repeat_ctx_id = repeat_ctx_prev_id; }); return; } /* cannot match more repeated items here. make sure the tail matches */ - let next_ctx = next_ctx!(offset 1, state, ctx, tail_callback); - next_ctx.repeat_ctx_id = repeat_ctx.prev_id; + // let next_ctx = next_ctx!(offset 1, state, ctx, tail_callback); + let repeat_ctx_prev_id = repeat_ctx.prev_id; + let next_ctx = ctx.next_offset(1, state, tail_callback); + next_ctx.repeat_ctx_id = repeat_ctx_prev_id; - fn tail_callback<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &mut MatchContext<'a, S>, - ) { + fn tail_callback(_: &Request, state: &mut State, ctx: &mut MatchContext) { if state.popped_has_matched { ctx.success(); } else { @@ -926,19 +873,21 @@ impl<'a> StrDrive for &'a [u8] { } } +type OpFunc = for<'a> fn(&Request<'a, S>, &mut State, &mut MatchContext); + #[derive(Clone, Copy)] -struct MatchContext<'a, S: StrDrive> { +struct MatchContext { string_position: usize, string_offset: usize, code_position: usize, has_matched: Option, toplevel: bool, - handler: Option, &mut State<'a, S>, &mut Self)>, + handler: Option>, repeat_ctx_id: usize, count: isize, } -impl<'a, S: StrDrive> std::fmt::Debug for MatchContext<'a, S> { +impl<'a, S: StrDrive> std::fmt::Debug for MatchContext { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("MatchContext") .field("string_position", &self.string_position) @@ -953,38 +902,38 @@ impl<'a, S: StrDrive> std::fmt::Debug for MatchContext<'a, S> { } } -impl<'a, S: StrDrive> MatchContext<'a, S> { - fn pattern(&self, req: &Request<'a, S>) -> &'a [u32] { +impl MatchContext { + fn pattern<'a>(&self, req: &Request<'a, S>) -> &'a [u32] { &req.pattern_codes[self.code_position..] } - fn remaining_codes(&self, req: &Request<'a, S>) -> usize { + fn remaining_codes(&self, req: &Request) -> usize { req.pattern_codes.len() - self.code_position } - fn remaining_chars(&self, req: &Request<'a, S>) -> usize { + fn remaining_chars(&self, req: &Request) -> usize { req.end - self.string_position } - fn peek_char(&self, req: &Request<'a, S>) -> u32 { + fn peek_char(&self, req: &Request) -> u32 { req.string.peek(self.string_offset) } - fn skip_char(&mut self, req: &Request<'a, S>, skip: usize) { + fn skip_char(&mut self, req: &Request, skip: usize) { self.string_position += skip; self.string_offset = req.string.offset(self.string_offset, skip); } - fn back_peek_char(&self, req: &Request<'a, S>) -> u32 { + fn back_peek_char(&self, req: &Request) -> u32 { req.string.back_peek(self.string_offset) } - fn back_skip_char(&mut self, req: &Request<'a, S>, skip: usize) { + fn back_skip_char(&mut self, req: &Request, skip: usize) { self.string_position -= skip; self.string_offset = req.string.back_offset(self.string_offset, skip); } - fn peek_code(&self, req: &Request<'a, S>, peek: usize) -> u32 { + fn peek_code(&self, req: &Request, peek: usize) -> u32 { req.pattern_codes[self.code_position + peek] } @@ -992,7 +941,7 @@ impl<'a, S: StrDrive> MatchContext<'a, S> { self.code_position += skip; } - fn skip_code_from(&mut self, req: &Request<'a, S>, peek: usize) { + fn skip_code_from(&mut self, req: &Request, peek: usize) { self.skip_code(self.peek_code(req, peek) as usize + 1); } @@ -1001,19 +950,15 @@ impl<'a, S: StrDrive> MatchContext<'a, S> { self.string_position == 0 } - fn at_end(&self, req: &Request<'a, S>) -> bool { + fn at_end(&self, req: &Request) -> bool { self.string_position == req.end } - fn at_linebreak(&self, req: &Request<'a, S>) -> bool { + fn at_linebreak(&self, req: &Request) -> bool { !self.at_end(req) && is_linebreak(self.peek_char(req)) } - fn at_boundary bool>( - &self, - req: &Request<'a, S>, - mut word_checker: F, - ) -> bool { + fn at_boundary bool>(&self, req: &Request, mut word_checker: F) -> bool { if self.at_beginning() && self.at_end(req) { return false; } @@ -1024,7 +969,7 @@ impl<'a, S: StrDrive> MatchContext<'a, S> { fn at_non_boundary bool>( &self, - req: &Request<'a, S>, + req: &Request, mut word_checker: F, ) -> bool { if self.at_beginning() && self.at_end(req) { @@ -1035,7 +980,7 @@ impl<'a, S: StrDrive> MatchContext<'a, S> { this == that } - fn can_success(&self, req: &Request<'a, S>) -> bool { + fn can_success(&self, req: &Request) -> bool { if !self.toplevel { return true; } @@ -1059,9 +1004,9 @@ impl<'a, S: StrDrive> MatchContext<'a, S> { fn next_from<'b>( &mut self, peek: usize, - req: &Request<'a, S>, - state: &'b mut State<'a, S>, - f: fn(&Request<'a, S>, &mut State<'a, S>, &mut Self), + req: &Request, + state: &'b mut State, + f: OpFunc, ) -> &'b mut Self { self.next_offset(self.peek_code(req, peek) as usize + 1, state, f) } @@ -1069,8 +1014,8 @@ impl<'a, S: StrDrive> MatchContext<'a, S> { fn next_offset<'b>( &mut self, offset: usize, - state: &'b mut State<'a, S>, - f: fn(&Request<'a, S>, &mut State<'a, S>, &mut Self), + state: &'b mut State, + f: OpFunc, ) -> &'b mut Self { self.next_at(self.code_position + offset, state, f) } @@ -1078,8 +1023,8 @@ impl<'a, S: StrDrive> MatchContext<'a, S> { fn next_at<'b>( &mut self, code_position: usize, - state: &'b mut State<'a, S>, - f: fn(&Request<'a, S>, &mut State<'a, S>, &mut Self), + state: &'b mut State, + f: OpFunc, ) -> &'b mut Self { self.handler = Some(f); state.next_context.insert(MatchContext { @@ -1092,7 +1037,7 @@ impl<'a, S: StrDrive> MatchContext<'a, S> { } } -fn at<'a, S: StrDrive>(req: &Request<'a, S>, ctx: &MatchContext<'a, S>, atcode: SreAtCode) -> bool { +fn at(req: &Request, ctx: &MatchContext, atcode: SreAtCode) -> bool { match atcode { SreAtCode::BEGINNING | SreAtCode::BEGINNING_STRING => ctx.at_beginning(), SreAtCode::BEGINNING_LINE => ctx.at_beginning() || is_linebreak(ctx.back_peek_char(req)), @@ -1110,9 +1055,9 @@ fn at<'a, S: StrDrive>(req: &Request<'a, S>, ctx: &MatchContext<'a, S>, atcode: } } -fn general_op_literal<'a, S: StrDrive, F: FnOnce(u32, u32) -> bool>( - req: &Request<'a, S>, - ctx: &mut MatchContext<'a, S>, +fn general_op_literal bool>( + req: &Request, + ctx: &mut MatchContext, f: F, ) { if ctx.at_end(req) || !f(ctx.peek_code(req, 1), ctx.peek_char(req)) { @@ -1123,9 +1068,9 @@ fn general_op_literal<'a, S: StrDrive, F: FnOnce(u32, u32) -> bool>( } } -fn general_op_in<'a, S: StrDrive, F: FnOnce(&[u32], u32) -> bool>( - req: &Request<'a, S>, - ctx: &mut MatchContext<'a, S>, +fn general_op_in bool>( + req: &Request, + ctx: &mut MatchContext, f: F, ) { if ctx.at_end(req) || !f(&ctx.pattern(req)[2..], ctx.peek_char(req)) { @@ -1136,10 +1081,10 @@ fn general_op_in<'a, S: StrDrive, F: FnOnce(&[u32], u32) -> bool>( } } -fn general_op_groupref<'a, S: StrDrive, F: FnMut(u32) -> u32>( - req: &Request<'a, S>, - state: &State<'a, S>, - ctx: &mut MatchContext<'a, S>, +fn general_op_groupref u32>( + req: &Request, + state: &State, + ctx: &mut MatchContext, mut f: F, ) { let (group_start, group_end) = state.get_marks(ctx.peek_code(req, 1) as usize); @@ -1296,10 +1241,10 @@ fn charset(set: &[u32], ch: u32) -> bool { false } -fn _count<'a, S: StrDrive>( - req: &Request<'a, S>, - state: &mut State<'a, S>, - ctx: &MatchContext<'a, S>, +fn _count( + req: &Request, + state: &mut State, + ctx: &MatchContext, max_count: usize, ) -> usize { let mut ctx = *ctx; @@ -1375,9 +1320,9 @@ fn _count<'a, S: StrDrive>( ctx.string_position - state.string_position } -fn general_count_literal<'a, S: StrDrive, F: FnMut(u32, u32) -> bool>( - req: &Request<'a, S>, - ctx: &mut MatchContext<'a, S>, +fn general_count_literal bool>( + req: &Request, + ctx: &mut MatchContext, end: usize, mut f: F, ) { diff --git a/tests/tests.rs b/tests/tests.rs index b4ad09f7be..ead111c74a 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -8,9 +8,9 @@ impl Pattern { fn state<'a, S: engine::StrDrive>( &self, string: S, - ) -> (engine::Request<'a, S>, engine::State<'a, S>) { + ) -> (engine::Request<'a, S>, engine::State) { let req = engine::Request::new(string, 0, usize::MAX, self.code, false); - let state = engine::State::new(0); + let state = engine::State::new(); (req, state) } } From c494feb7f776e8e15185710f72d41b603db995d8 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 9 Aug 2022 21:34:06 +0200 Subject: [PATCH 066/893] refactor split Marks --- Cargo.toml | 1 + benches/benches.rs | 2 +- src/engine.rs | 242 ++++++++++++++++++++++++++++++--------------- tests/tests.rs | 5 +- 4 files changed, 166 insertions(+), 84 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4b403f2861..8993c1e71d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,4 @@ include = ["LICENSE", "src/**/*.rs"] [dependencies] num_enum = "0.5" bitflags = "1.2" +optional = "0.5" diff --git a/benches/benches.rs b/benches/benches.rs index 8e0e87935a..f19b92d64b 100644 --- a/benches/benches.rs +++ b/benches/benches.rs @@ -23,7 +23,7 @@ impl Pattern { range: std::ops::Range, ) -> (engine::Request<'a, S>, engine::State) { let req = engine::Request::new(string, range.start, range.end, self.code, false); - let state = engine::State::new(); + let state = engine::State::default(); (req, state) } } diff --git a/src/engine.rs b/src/engine.rs index ace75d1a36..087d64e8cd 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -2,7 +2,9 @@ use super::constants::{SreAtCode, SreCatCode, SreInfo, SreOpcode}; use super::MAXREPEAT; +use optional::Optioned; use std::convert::TryFrom; +use std::ops::Deref; const fn is_py_ascii_whitespace(b: u8) -> bool { matches!(b, b'\t' | b'\n' | b'\x0C' | b'\r' | b' ' | b'\x0B') @@ -39,24 +41,98 @@ impl<'a, S: StrDrive> Request<'a, S> { } } -macro_rules! mark { - (push, $state:expr) => { - $state - .marks_stack - .push(($state.marks.clone(), $state.lastindex)) - }; - (pop, $state:expr) => { - let (marks, lastindex) = $state.marks_stack.pop().unwrap(); - $state.marks = marks; - $state.lastindex = lastindex; - }; +// macro_rules! mark { +// (push, $state:expr) => { +// $state +// .marks_stack +// .push(($state.marks.clone(), $state.lastindex)) +// }; +// (pop, $state:expr) => { +// let (marks, lastindex) = $state.marks_stack.pop().unwrap(); +// $state.marks = marks; +// $state.lastindex = lastindex; +// }; +// } + +#[derive(Debug)] +pub struct Marks { + last_index: isize, + marks: Vec>, + marks_stack: Vec<(Vec>, isize)>, +} + +impl Default for Marks { + fn default() -> Self { + Self { + last_index: -1, + marks: Vec::new(), + marks_stack: Vec::new(), + } + } +} + +impl Deref for Marks { + type Target = Vec>; + + fn deref(&self) -> &Self::Target { + &self.marks + } +} + +impl Marks { + pub fn get(&self, group_index: usize) -> (Optioned, Optioned) { + let marks_index = 2 * group_index; + if marks_index + 1 < self.marks.len() { + (self.marks[marks_index], self.marks[marks_index + 1]) + } else { + (Optioned::none(), Optioned::none()) + } + } + + pub fn last_index(&self) -> isize { + self.last_index + } + + fn set(&mut self, mark_nr: usize, position: usize) { + if mark_nr & 1 != 0 { + self.last_index = mark_nr as isize / 2 + 1; + } + if mark_nr >= self.marks.len() { + self.marks.resize(mark_nr + 1, Optioned::none()); + } + self.marks[mark_nr] = Optioned::some(position); + } + + fn push(&mut self) { + self.marks_stack.push((self.marks.clone(), self.last_index)); + } + + fn pop(&mut self) { + let (marks, last_index) = self.marks_stack.pop().unwrap(); + self.marks = marks; + self.last_index = last_index; + } + + fn pop_keep(&mut self) { + let (marks, last_index) = self.marks_stack.last().unwrap().clone(); + self.marks = marks; + self.last_index = last_index; + } + + fn pop_discard(&mut self) { + self.marks_stack.pop(); + } + + fn clear(&mut self) { + self.last_index = -1; + self.marks.clear(); + self.marks_stack.clear(); + } } #[derive(Debug)] pub struct State { - pub marks: Vec>, - pub lastindex: isize, - marks_stack: Vec<(Vec>, isize)>, + pub marks: Marks, context_stack: Vec>, repeat_stack: Vec, pub string_position: usize, @@ -65,25 +141,23 @@ pub struct State { pub has_matched: bool, } -impl State { - pub fn new() -> Self { +impl Default for State { + fn default() -> Self { Self { - marks: Vec::new(), - lastindex: -1, - marks_stack: Vec::new(), - context_stack: Vec::new(), - repeat_stack: Vec::new(), - string_position: 0, - next_context: None, - popped_has_matched: false, - has_matched: false, + marks: Default::default(), + context_stack: Default::default(), + repeat_stack: Default::default(), + string_position: Default::default(), + next_context: Default::default(), + popped_has_matched: Default::default(), + has_matched: Default::default(), } } +} +impl State { pub fn reset(&mut self, string_position: usize) { - self.lastindex = -1; self.marks.clear(); - self.marks_stack.clear(); self.context_stack.clear(); self.repeat_stack.clear(); self.string_position = string_position; @@ -92,23 +166,23 @@ impl State { self.has_matched = false; } - fn set_mark(&mut self, mark_nr: usize, position: usize) { - if mark_nr & 1 != 0 { - self.lastindex = mark_nr as isize / 2 + 1; - } - if mark_nr >= self.marks.len() { - self.marks.resize(mark_nr + 1, None); - } - self.marks[mark_nr] = Some(position); - } - fn get_marks(&self, group_index: usize) -> (Option, Option) { - let marks_index = 2 * group_index; - if marks_index + 1 < self.marks.len() { - (self.marks[marks_index], self.marks[marks_index + 1]) - } else { - (None, None) - } - } + // fn set_mark(&mut self, mark_nr: usize, position: usize) { + // if mark_nr & 1 != 0 { + // self.lastindex = mark_nr as isize / 2 + 1; + // } + // if mark_nr >= self.marks.len() { + // self.marks.resize(mark_nr + 1, None); + // } + // self.marks[mark_nr] = Some(position); + // } + // fn get_marks(&self, group_index: usize) -> (Option, Option) { + // let marks_index = 2 * group_index; + // if marks_index + 1 < self.marks.len() { + // (self.marks[marks_index], self.marks[marks_index + 1]) + // } else { + // (None, None) + // } + // } // fn marks_push(&mut self) { // self.marks_stack.push((self.marks.clone(), self.lastindex)); // } @@ -117,14 +191,14 @@ impl State { // self.marks = marks; // self.lastindex = lastindex; // } - fn marks_pop_keep(&mut self) { - let (marks, lastindex) = self.marks_stack.last().unwrap().clone(); - self.marks = marks; - self.lastindex = lastindex; - } - fn marks_pop_discard(&mut self) { - self.marks_stack.pop(); - } + // fn marks_pop_keep(&mut self) { + // let (marks, lastindex) = self.marks_stack.last().unwrap().clone(); + // self.marks = marks; + // self.lastindex = lastindex; + // } + // fn marks_pop_discard(&mut self) { + // self.marks_stack.pop(); + // } fn _match(&mut self, req: &mut Request) { while let Some(mut ctx) = self.context_stack.pop() { @@ -311,7 +385,9 @@ fn dispatch( general_op_literal(req, ctx, |code, c| !char_loc_ignore(code, c)) } SreOpcode::MARK => { - state.set_mark(ctx.peek_code(req, 1) as usize, ctx.string_position); + state + .marks + .set(ctx.peek_code(req, 1) as usize, ctx.string_position); ctx.skip_code(2); } SreOpcode::MAX_UNTIL => op_max_until(state, ctx), @@ -324,12 +400,14 @@ fn dispatch( SreOpcode::GROUPREF_LOC_IGNORE => general_op_groupref(req, state, ctx, lower_locate), SreOpcode::GROUPREF_UNI_IGNORE => general_op_groupref(req, state, ctx, lower_unicode), SreOpcode::GROUPREF_EXISTS => { - let (group_start, group_end) = state.get_marks(ctx.peek_code(req, 1) as usize); - match (group_start, group_end) { - (Some(start), Some(end)) if start <= end => { - ctx.skip_code(3); - } - _ => ctx.skip_code_from(req, 2), + let (group_start, group_end) = state.marks.get(ctx.peek_code(req, 1) as usize); + if group_start.is_some() + && group_end.is_some() + && group_start.unpack() <= group_end.unpack() + { + ctx.skip_code(3); + } else { + ctx.skip_code_from(req, 2) } } _ => unreachable!("unexpected opcode"), @@ -438,7 +516,7 @@ fn op_assert_not(req: &Request, state: &mut State, ctx: &mut // alternation // <0=skip> code ... fn op_branch(req: &Request, state: &mut State, ctx: &mut MatchContext) { - mark!(push, state); + state.marks.push(); ctx.count = 1; create_context(req, state, ctx); @@ -451,7 +529,7 @@ fn op_branch(req: &Request, state: &mut State, ctx: &mut Matc let branch_offset = ctx.count as usize; let next_length = ctx.peek_code(req, branch_offset) as isize; if next_length == 0 { - state.marks_pop_discard(); + state.marks.pop_discard(); return ctx.failure(); } @@ -465,7 +543,7 @@ fn op_branch(req: &Request, state: &mut State, ctx: &mut Matc if state.popped_has_matched { return ctx.success(); } - state.marks_pop_keep(); + state.marks.pop_keep(); create_context(req, state, ctx); } } @@ -502,7 +580,7 @@ fn op_min_repeat_one( return ctx.success(); } - mark!(push, state); + state.marks.push(); create_context(req, state, ctx); fn create_context( @@ -517,7 +595,7 @@ fn op_min_repeat_one( // next_ctx!(from 1, state, ctx, callback); ctx.next_from(1, req, state, callback); } else { - state.marks_pop_discard(); + state.marks.pop_discard(); ctx.failure(); } } @@ -530,13 +608,13 @@ fn op_min_repeat_one( state.string_position = ctx.string_position; if _count(req, state, ctx, 1) == 0 { - state.marks_pop_discard(); + state.marks.pop_discard(); return ctx.failure(); } ctx.skip_char(req, 1); ctx.count += 1; - state.marks_pop_keep(); + state.marks.pop_keep(); create_context(req, state, ctx); } } @@ -570,7 +648,7 @@ fn op_repeat_one(req: &Request, state: &mut State, ctx: &mut return ctx.success(); } - mark!(push, state); + state.marks.push(); ctx.count = count as isize; create_context(req, state, ctx); @@ -587,7 +665,7 @@ fn op_repeat_one(req: &Request, state: &mut State, ctx: &mut let c = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 2); while ctx.at_end(req) || ctx.peek_char(req) != c { if ctx.count <= min_count { - state.marks_pop_discard(); + state.marks.pop_discard(); return ctx.failure(); } ctx.back_skip_char(req, 1); @@ -610,14 +688,14 @@ fn op_repeat_one(req: &Request, state: &mut State, ctx: &mut let min_count = ctx.peek_code(req, 2) as isize; if ctx.count <= min_count { - state.marks_pop_discard(); + state.marks.pop_discard(); return ctx.failure(); } ctx.back_skip_char(req, 1); ctx.count -= 1; - state.marks_pop_keep(); + state.marks.pop_keep(); create_context(req, state, ctx); } } @@ -680,7 +758,7 @@ fn op_min_until(state: &mut State, ctx: &mut MatchContext) { return; } - mark!(push, state); + state.marks.push(); ctx.count = ctx.repeat_ctx_id as isize; @@ -698,7 +776,7 @@ fn op_min_until(state: &mut State, ctx: &mut MatchContext) { state.string_position = ctx.string_position; - mark!(pop, state); + state.marks.pop(); // match more until tail matches @@ -752,7 +830,7 @@ fn op_max_until(state: &mut State, ctx: &mut MatchContext) { { /* we may have enough matches, but if we can match another item, do so */ - mark!(push, state); + state.marks.push(); ctx.count = repeat_ctx.last_position as isize; repeat_ctx.last_position = state.string_position; @@ -763,11 +841,11 @@ fn op_max_until(state: &mut State, ctx: &mut MatchContext) { repeat_ctx.last_position = save_last_position; if state.popped_has_matched { - state.marks_pop_discard(); + state.marks.pop_discard(); return ctx.success(); } - mark!(pop, state); + state.marks.pop(); repeat_ctx.count -= 1; state.string_position = ctx.string_position; @@ -1087,12 +1165,14 @@ fn general_op_groupref u32>( ctx: &mut MatchContext, mut f: F, ) { - let (group_start, group_end) = state.get_marks(ctx.peek_code(req, 1) as usize); - let (group_start, group_end) = match (group_start, group_end) { - (Some(start), Some(end)) if start <= end => (start, end), - _ => { - return ctx.failure(); - } + let (group_start, group_end) = state.marks.get(ctx.peek_code(req, 1) as usize); + let (group_start, group_end) = if group_start.is_some() + && group_end.is_some() + && group_start.unpack() <= group_end.unpack() + { + (group_start.unpack(), group_end.unpack()) + } else { + return ctx.failure(); }; let mut gctx = MatchContext { diff --git a/tests/tests.rs b/tests/tests.rs index ead111c74a..cb11db3483 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -10,7 +10,7 @@ impl Pattern { string: S, ) -> (engine::Request<'a, S>, engine::State) { let req = engine::Request::new(string, 0, usize::MAX, self.code, false); - let state = engine::State::new(); + let state = engine::State::default(); (req, state) } } @@ -62,13 +62,14 @@ fn test_zerowidth() { #[test] fn test_repeat_context_panic() { + use optional::Optioned; // pattern p = re.compile(r'(?:a*?(xx)??z)*') // START GENERATED by generate_tests.py #[rustfmt::skip] let p = Pattern { code: &[15, 4, 0, 0, 4294967295, 24, 25, 0, 4294967295, 27, 6, 0, 4294967295, 17, 97, 1, 24, 11, 0, 1, 18, 0, 17, 120, 17, 120, 18, 1, 20, 17, 122, 19, 1] }; // END GENERATED let (mut req, mut state) = p.state("axxzaz"); state.pymatch(&mut req); - assert!(state.marks == vec![Some(1), Some(3)]); + assert!(*state.marks == vec![Optioned::some(1), Optioned::some(3)]); } #[test] From 18258000cde2c848198b745e92f74a89d48c7fe8 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 9 Aug 2022 22:17:01 +0200 Subject: [PATCH 067/893] clearup --- src/engine.rs | 78 +++++++++++---------------------------------------- 1 file changed, 16 insertions(+), 62 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 087d64e8cd..652ca04c27 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,6 +1,6 @@ // good luck to those that follow; here be dragons -use super::constants::{SreAtCode, SreCatCode, SreInfo, SreOpcode}; +use super::constants::{SreAtCode, SreCatCode, SreOpcode}; use super::MAXREPEAT; use optional::Optioned; use std::convert::TryFrom; @@ -41,19 +41,6 @@ impl<'a, S: StrDrive> Request<'a, S> { } } -// macro_rules! mark { -// (push, $state:expr) => { -// $state -// .marks_stack -// .push(($state.marks.clone(), $state.lastindex)) -// }; -// (pop, $state:expr) => { -// let (marks, lastindex) = $state.marks_stack.pop().unwrap(); -// $state.marks = marks; -// $state.lastindex = lastindex; -// }; -// } - #[derive(Debug)] pub struct Marks { last_index: isize, @@ -135,6 +122,7 @@ pub struct State { pub marks: Marks, context_stack: Vec>, repeat_stack: Vec, + pub start: usize, pub string_position: usize, next_context: Option>, popped_has_matched: bool, @@ -144,62 +132,30 @@ pub struct State { impl Default for State { fn default() -> Self { Self { - marks: Default::default(), - context_stack: Default::default(), - repeat_stack: Default::default(), - string_position: Default::default(), - next_context: Default::default(), - popped_has_matched: Default::default(), - has_matched: Default::default(), + marks: Marks::default(), + context_stack: Vec::new(), + repeat_stack: Vec::new(), + start: 0, + string_position: 0, + next_context: None, + popped_has_matched: false, + has_matched: false, } } } impl State { - pub fn reset(&mut self, string_position: usize) { + pub fn reset(&mut self, start: usize) { self.marks.clear(); self.context_stack.clear(); self.repeat_stack.clear(); - self.string_position = string_position; + self.start = start; + self.string_position = start; self.next_context = None; self.popped_has_matched = false; self.has_matched = false; } - // fn set_mark(&mut self, mark_nr: usize, position: usize) { - // if mark_nr & 1 != 0 { - // self.lastindex = mark_nr as isize / 2 + 1; - // } - // if mark_nr >= self.marks.len() { - // self.marks.resize(mark_nr + 1, None); - // } - // self.marks[mark_nr] = Some(position); - // } - // fn get_marks(&self, group_index: usize) -> (Option, Option) { - // let marks_index = 2 * group_index; - // if marks_index + 1 < self.marks.len() { - // (self.marks[marks_index], self.marks[marks_index + 1]) - // } else { - // (None, None) - // } - // } - // fn marks_push(&mut self) { - // self.marks_stack.push((self.marks.clone(), self.lastindex)); - // } - // fn marks_pop(&mut self) { - // let (marks, lastindex) = self.marks_stack.pop().unwrap(); - // self.marks = marks; - // self.lastindex = lastindex; - // } - // fn marks_pop_keep(&mut self) { - // let (marks, lastindex) = self.marks_stack.last().unwrap().clone(); - // self.marks = marks; - // self.lastindex = lastindex; - // } - // fn marks_pop_discard(&mut self) { - // self.marks_stack.pop(); - // } - fn _match(&mut self, req: &mut Request) { while let Some(mut ctx) = self.context_stack.pop() { if let Some(handler) = ctx.handler.take() { @@ -225,6 +181,7 @@ impl State { } pub fn pymatch(&mut self, req: &mut Request) { + self.start = req.start; self.string_position = req.start; let ctx = MatchContext { @@ -243,6 +200,7 @@ impl State { } pub fn search(&mut self, req: &mut Request) { + self.start = req.start; self.string_position = req.start; // TODO: optimize by op info and skip prefix @@ -479,7 +437,6 @@ fn op_assert(req: &Request, state: &mut State, ctx: &mut Matc return ctx.failure(); } - // let next_ctx = next_ctx!(offset 3, state, ctx, |req, state, ctx| { let next_ctx = ctx.next_offset(3, state, |req, state, ctx| { if state.popped_has_matched { ctx.skip_code_from(req, 1); @@ -592,7 +549,6 @@ fn op_min_repeat_one( if max_count == MAXREPEAT || ctx.count as usize <= max_count { state.string_position = ctx.string_position; - // next_ctx!(from 1, state, ctx, callback); ctx.next_from(1, req, state, callback); } else { state.marks.pop_discard(); @@ -676,7 +632,6 @@ fn op_repeat_one(req: &Request, state: &mut State, ctx: &mut state.string_position = ctx.string_position; // General case: backtracking - // next_ctx!(from 1, state, ctx, callback); ctx.next_from(1, req, state, callback); } @@ -861,7 +816,6 @@ fn op_max_until(state: &mut State, ctx: &mut MatchContext) { /* cannot match more repeated items here. make sure the tail matches */ - // let next_ctx = next_ctx!(offset 1, state, ctx, tail_callback); let repeat_ctx_prev_id = repeat_ctx.prev_id; let next_ctx = ctx.next_offset(1, state, tail_callback); next_ctx.repeat_ctx_id = repeat_ctx_prev_id; @@ -965,7 +919,7 @@ struct MatchContext { count: isize, } -impl<'a, S: StrDrive> std::fmt::Debug for MatchContext { +impl std::fmt::Debug for MatchContext { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("MatchContext") .field("string_position", &self.string_position) From e42df1d8597bf964fb7f70b606df9e0be696c624 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Tue, 9 Aug 2022 22:18:01 +0200 Subject: [PATCH 068/893] update version to 0.4.0 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 8993c1e71d..373166d6db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sre-engine" -version = "0.3.1" +version = "0.4.0" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" repository = "https://github.com/RustPython/sre-engine" From c4f10edc95ab7dc1f20bed2c9b4bddfc520fa660 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 14 Aug 2022 20:46:20 +0200 Subject: [PATCH 069/893] impl opinfo single literal --- benches/benches.rs | 24 +++---- src/engine.rs | 173 +++++++++++++++++++++++---------------------- tests/tests.rs | 51 +++++++++---- 3 files changed, 139 insertions(+), 109 deletions(-) diff --git a/benches/benches.rs b/benches/benches.rs index f19b92d64b..604cf91f42 100644 --- a/benches/benches.rs +++ b/benches/benches.rs @@ -92,29 +92,29 @@ fn benchmarks(b: &mut Bencher) { b.iter(move || { for (p, s) in &tests { - let (mut req, mut state) = p.state(s.clone()); - state.search(&mut req); + let (req, mut state) = p.state(s.clone()); + state.search(req); assert!(state.has_matched); - let (mut req, mut state) = p.state(s.clone()); - state.pymatch(&mut req); + let (req, mut state) = p.state(s.clone()); + state.pymatch(req); assert!(state.has_matched); let (mut req, mut state) = p.state(s.clone()); req.match_all = true; - state.pymatch(&mut req); + state.pymatch(req); assert!(state.has_matched); let s2 = format!("{}{}{}", " ".repeat(10000), s, " ".repeat(10000)); - let (mut req, mut state) = p.state_range(s2.as_str(), 0..usize::MAX); - state.search(&mut req); + let (req, mut state) = p.state_range(s2.as_str(), 0..usize::MAX); + state.search(req); assert!(state.has_matched); - let (mut req, mut state) = p.state_range(s2.as_str(), 10000..usize::MAX); - state.pymatch(&mut req); + let (req, mut state) = p.state_range(s2.as_str(), 10000..usize::MAX); + state.pymatch(req); assert!(state.has_matched); - let (mut req, mut state) = p.state_range(s2.as_str(), 10000..10000 + s.len()); - state.pymatch(&mut req); + let (req, mut state) = p.state_range(s2.as_str(), 10000..10000 + s.len()); + state.pymatch(req); assert!(state.has_matched); let (mut req, mut state) = p.state_range(s2.as_str(), 10000..10000 + s.len()); req.match_all = true; - state.pymatch(&mut req); + state.pymatch(req); assert!(state.has_matched); } }) diff --git a/src/engine.rs b/src/engine.rs index 652ca04c27..0d645e046d 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,5 +1,7 @@ // good luck to those that follow; here be dragons +use crate::constants::SreInfo; + use super::constants::{SreAtCode, SreCatCode, SreOpcode}; use super::MAXREPEAT; use optional::Optioned; @@ -10,6 +12,7 @@ const fn is_py_ascii_whitespace(b: u8) -> bool { matches!(b, b'\t' | b'\n' | b'\x0C' | b'\r' | b' ' | b'\x0B') } +#[derive(Debug, Clone, Copy)] pub struct Request<'a, S: StrDrive> { pub string: S, pub start: usize, @@ -180,7 +183,7 @@ impl State { self.has_matched = self.popped_has_matched; } - pub fn pymatch(&mut self, req: &mut Request) { + pub fn pymatch(&mut self, mut req: Request) { self.start = req.start; self.string_position = req.start; @@ -196,10 +199,10 @@ impl State { }; self.context_stack.push(ctx); - self._match(req); + self._match(&mut req); } - pub fn search(&mut self, req: &mut Request) { + pub fn search(&mut self, mut req: Request) { self.start = req.start; self.string_position = req.start; @@ -208,12 +211,11 @@ impl State { return; } - // let start = self.start; - // let end = self.end; + let mut end = req.end; let mut start_offset = req.string.offset(0, req.start); - let ctx = MatchContext { + let mut ctx = MatchContext { string_position: req.start, string_offset: start_offset, code_position: 0, @@ -224,35 +226,97 @@ impl State { count: -1, }; - // if ctx.peek_code(self, 0) == SreOpcode::INFO as u32 { - // search_op_info(self, &mut ctx); - // if let Some(has_matched) = ctx.has_matched { - // self.has_matched = has_matched; - // return; - // } - // } + if ctx.peek_code(&req, 0) == SreOpcode::INFO as u32 { + /* optimization info block */ + /* <1=skip> <2=flags> <3=min> <4=max> <5=prefix info> */ + let req = &mut req; + let min = ctx.peek_code(req, 3) as usize; + + if ctx.remaining_chars(req) < min { + return; + } + + if min > 1 { + /* adjust end point (but make sure we leave at least one + character in there, so literal search will work) */ + // no overflow can happen as remaining chars >= min + end -= min - 1; + + // adjust ctx position + if end < ctx.string_position { + ctx.string_position = end; + ctx.string_offset = req.string.offset(0, ctx.string_position); + } + } + + let flags = SreInfo::from_bits_truncate(ctx.peek_code(req, 2)); + + if flags.contains(SreInfo::PREFIX) { + /* pattern starts with a known prefix */ + /* */ + let len = ctx.peek_code(req, 5) as usize; + let skip = ctx.peek_code(req, 6) as usize; + let prefix = &ctx.pattern(req)[7..]; + let overlap = &prefix[len - 1..]; + + if len == 1 { + // pattern starts with a literal character + ctx.skip_code_from(req, 1); + let c = prefix[0]; + req.must_advance = false; + + while !ctx.at_end(req) { + // find the next matched literal + while ctx.peek_char(req) != c { + ctx.skip_char(req, 1); + if ctx.at_end(req) { + return; + } + } + + req.start = ctx.string_position; + self.reset(req.start); + // self.start = ctx.string_position; + self.string_position += skip; + + // literal only + if flags.contains(SreInfo::LITERAL) { + self.has_matched = true; + return; + } + + let mut next_ctx = ctx; + next_ctx.skip_char(req, skip); + next_ctx.skip_code(2 * skip); + + self.context_stack.push(next_ctx); + self._match(req); + + if self.has_matched { + return; + } + + ctx.skip_char(req, 1); + } + return; + } + } + } self.context_stack.push(ctx); - self._match(req); + self._match(&mut req); req.must_advance = false; - while !self.has_matched && req.start < req.end { + ctx.toplevel = false; + while !self.has_matched && req.start < end { req.start += 1; start_offset = req.string.offset(start_offset, 1); self.reset(req.start); + ctx.string_position = req.start; + ctx.string_offset = start_offset; - let ctx = MatchContext { - string_position: req.start, - string_offset: start_offset, - code_position: 0, - has_matched: None, - toplevel: false, - handler: None, - repeat_ctx_id: usize::MAX, - count: -1, - }; self.context_stack.push(ctx); - self._match(req); + self._match(&mut req); } } } @@ -372,63 +436,6 @@ fn dispatch( } } -/* optimization info block */ -/* <1=skip> <2=flags> <3=min> <4=max> <5=prefix info> */ -// fn search_op_info<'a, S: StrDrive>(state: &mut State<'a, S>, ctx: &mut MatchContext<'a, S>) { -// let min = ctx.peek_code(state, 3) as usize; - -// if ctx.remaining_chars(state) < min { -// return ctx.failure(); -// } - -// if min > 1 { -// /* adjust end point (but make sure we leave at least one -// character in there, so literal search will work) */ -// // no overflow can happen as remaining chars >= min -// state.end -= min - 1; - -// // adjust ctx position -// if state.end < ctx.string_position { -// ctx.string_position = state.end; -// ctx.string_offset = state.string.offset(0, ctx.string_position); -// } -// } - -// let flags = SreInfo::from_bits_truncate(ctx.peek_code(state, 2)); - -// if flags.contains(SreInfo::PREFIX) { -// /* pattern starts with a known prefix */ -// /* */ -// let len = ctx.peek_code(state, 5) as usize; -// let skip = ctx.peek_code(state, 6) as usize; -// let prefix = &ctx.pattern(state)[7..]; -// let overlap = &prefix[len - 1..]; - -// ctx.skip_code_from(state, 1); - -// if len == 1 { -// // pattern starts with a literal character -// let c = prefix[0]; -// let end = state.end; - -// while (!ctx.at_end(state)) { -// // find the next matched literal -// while (ctx.peek_char(state) != c) { -// ctx.skip_char(state, 1); -// if (ctx.at_end(state)) { -// return ctx.failure(); -// } -// } - -// // literal only -// if flags.contains(SreInfo::LITERAL) { -// return ctx.success(); -// } -// } -// } -// } -// } - /* assert subpattern */ /* */ fn op_assert(req: &Request, state: &mut State, ctx: &mut MatchContext) { diff --git a/tests/tests.rs b/tests/tests.rs index cb11db3483..31f032f0a4 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -21,8 +21,8 @@ fn test_2427() { // START GENERATED by generate_tests.py #[rustfmt::skip] let lookbehind = Pattern { code: &[15, 4, 0, 1, 1, 5, 5, 1, 17, 46, 1, 17, 120, 6, 10, 1] }; // END GENERATED - let (mut req, mut state) = lookbehind.state("x"); - state.pymatch(&mut req); + let (req, mut state) = lookbehind.state("x"); + state.pymatch(req); assert!(state.has_matched); } @@ -32,8 +32,8 @@ fn test_assert() { // START GENERATED by generate_tests.py #[rustfmt::skip] let positive_lookbehind = Pattern { code: &[15, 4, 0, 3, 3, 4, 9, 3, 17, 97, 17, 98, 17, 99, 1, 17, 100, 17, 101, 17, 102, 1] }; // END GENERATED - let (mut req, mut state) = positive_lookbehind.state("abcdef"); - state.search(&mut req); + let (req, mut state) = positive_lookbehind.state("abcdef"); + state.search(req); assert!(state.has_matched); } @@ -43,8 +43,8 @@ fn test_string_boundaries() { // START GENERATED by generate_tests.py #[rustfmt::skip] let big_b = Pattern { code: &[15, 4, 0, 0, 0, 6, 11, 1] }; // END GENERATED - let (mut req, mut state) = big_b.state(""); - state.search(&mut req); + let (req, mut state) = big_b.state(""); + state.search(req); assert!(!state.has_matched); } @@ -56,8 +56,8 @@ fn test_zerowidth() { // END GENERATED let (mut req, mut state) = p.state("a:"); req.must_advance = true; - state.search(&mut req); - assert!(state.string_position == 1); + state.search(req); + assert_eq!(state.string_position, 1); } #[test] @@ -67,9 +67,9 @@ fn test_repeat_context_panic() { // START GENERATED by generate_tests.py #[rustfmt::skip] let p = Pattern { code: &[15, 4, 0, 0, 4294967295, 24, 25, 0, 4294967295, 27, 6, 0, 4294967295, 17, 97, 1, 24, 11, 0, 1, 18, 0, 17, 120, 17, 120, 18, 1, 20, 17, 122, 19, 1] }; // END GENERATED - let (mut req, mut state) = p.state("axxzaz"); - state.pymatch(&mut req); - assert!(*state.marks == vec![Optioned::some(1), Optioned::some(3)]); + let (req, mut state) = p.state("axxzaz"); + state.pymatch(req); + assert_eq!(*state.marks, vec![Optioned::some(1), Optioned::some(3)]); } #[test] @@ -78,7 +78,30 @@ fn test_double_max_until() { // START GENERATED by generate_tests.py #[rustfmt::skip] let p = Pattern { code: &[15, 4, 0, 0, 4294967295, 24, 18, 0, 4294967295, 18, 0, 24, 9, 0, 1, 18, 2, 17, 49, 18, 3, 19, 18, 1, 19, 1] }; // END GENERATED - let (mut req, mut state) = p.state("1111"); - state.pymatch(&mut req); - assert!(state.string_position == 4); + let (req, mut state) = p.state("1111"); + state.pymatch(req); + assert_eq!(state.string_position, 4); +} + +#[test] +fn test_info_single() { + // pattern p = re.compile(r'aa*') + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { code: &[15, 8, 1, 1, 4294967295, 1, 1, 97, 0, 17, 97, 25, 6, 0, 4294967295, 17, 97, 1, 1] }; + // END GENERATED + let (req, mut state) = p.state("baaaa"); + state.search(req); + assert_eq!(state.start, 1); + assert_eq!(state.string_position, 5); +} + +#[test] +fn test_info_single2() { + // pattern p = re.compile(r'Python|Perl') + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { code: &[15, 8, 1, 4, 6, 1, 1, 80, 0, 17, 80, 7, 13, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 11, 9, 17, 101, 17, 114, 17, 108, 16, 2, 0, 1] }; + // END GENERATED + let (req, mut state) = p.state("Perl"); + state.search(req); + assert!(state.has_matched); } From 236631141fa3d2681e9a53eb3550c96219cb0cf7 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 15 Aug 2022 19:55:48 +0200 Subject: [PATCH 070/893] impl opinfo literal --- src/engine.rs | 81 +++++++++++++++++++++++++++++++++++++++++++++----- tests/tests.rs | 22 ++++++++++++++ 2 files changed, 96 insertions(+), 7 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 0d645e046d..53181c5043 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -256,13 +256,17 @@ impl State { /* */ let len = ctx.peek_code(req, 5) as usize; let skip = ctx.peek_code(req, 6) as usize; - let prefix = &ctx.pattern(req)[7..]; - let overlap = &prefix[len - 1..]; + let prefix = &ctx.pattern(req)[7..7 + len]; + let overlap = &ctx.pattern(req)[7 + len - 1..7 + len * 2]; if len == 1 { // pattern starts with a literal character - ctx.skip_code_from(req, 1); let c = prefix[0]; + + // code_position ready for tail match + ctx.skip_code_from(req, 1); + ctx.skip_code(2 * skip); + req.must_advance = false; while !ctx.at_end(req) { @@ -275,9 +279,8 @@ impl State { } req.start = ctx.string_position; - self.reset(req.start); - // self.start = ctx.string_position; - self.string_position += skip; + self.start = ctx.string_position; + self.string_position = ctx.string_position + skip; // literal only if flags.contains(SreInfo::LITERAL) { @@ -287,7 +290,6 @@ impl State { let mut next_ctx = ctx; next_ctx.skip_char(req, skip); - next_ctx.skip_code(2 * skip); self.context_stack.push(next_ctx); self._match(req); @@ -297,6 +299,71 @@ impl State { } ctx.skip_char(req, 1); + self.marks.clear(); + } + return; + } else if len > 1 { + // code_position ready for tail match + ctx.skip_code_from(req, 1); + ctx.skip_code(2 * skip); + + req.must_advance = false; + + while !ctx.at_end(req) { + let c = prefix[0]; + while ctx.peek_char(req) != c { + ctx.skip_char(req, 1); + if ctx.at_end(req) { + return; + } + } + ctx.skip_char(req, 1); + if ctx.at_end(req) { + return; + } + + let mut i = 1; + loop { + if ctx.peek_char(req) == prefix[i] { + i += 1; + if i != len { + ctx.skip_char(req, 1); + if ctx.at_end(req) { + return; + } + continue; + } + + req.start = ctx.string_position - (len - 1); + self.start = req.start; + self.string_position = self.start + skip; + + if flags.contains(SreInfo::LITERAL) { + self.has_matched = true; + return; + } + + let mut next_ctx = ctx; + // next_ctx.skip_char(req, 1); + next_ctx.string_position = self.string_position; + next_ctx.string_offset = req.string.offset(0, self.string_position); + self.context_stack.push(next_ctx); + self._match(req); + if self.has_matched { + return; + } + + ctx.skip_char(req, 1); + if ctx.at_end(req) { + return; + } + self.marks.clear(); + } + i = overlap[i] as usize; + if i == 0 { + break; + } + } } return; } diff --git a/tests/tests.rs b/tests/tests.rs index 31f032f0a4..21bc89d40c 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -105,3 +105,25 @@ fn test_info_single2() { state.search(req); assert!(state.has_matched); } + +#[test] +fn test_info_literal() { + // pattern p = re.compile(r'ababc+') + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { code: &[15, 14, 1, 5, 4294967295, 4, 4, 97, 98, 97, 98, 0, 0, 1, 2, 17, 97, 17, 98, 17, 97, 17, 98, 25, 6, 1, 4294967295, 17, 99, 1, 1] }; + // END GENERATED + let (req, mut state) = p.state("!ababc"); + state.search(req); + assert!(state.has_matched); +} + +#[test] +fn test_info_literal2() { + // pattern p = re.compile(r'(python)\1') + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { code: &[15, 18, 1, 12, 12, 6, 0, 112, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 18, 0, 17, 112, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 18, 1, 12, 0, 1] }; + // END GENERATED + let (req, mut state) = p.state("pythonpython"); + state.search(req); + assert!(state.has_matched); +} \ No newline at end of file From 646c8ac6578977b847e8f871e1f83fb422b3e39a Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 15 Aug 2022 20:09:51 +0200 Subject: [PATCH 071/893] impl opinfo charset --- src/engine.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/engine.rs b/src/engine.rs index 53181c5043..fe7ee438e9 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -367,6 +367,31 @@ impl State { } return; } + } else if flags.contains(SreInfo::CHARSET) { + let set = &ctx.pattern(req)[5..]; + ctx.skip_code_from(req, 1); + req.must_advance = false; + loop { + while !ctx.at_end(req) && !charset(set, ctx.peek_char(req)) { + ctx.skip_char(req, 1); + } + if ctx.at_end(req) { + return; + } + req.start = ctx.string_position; + self.start = ctx.string_position; + self.string_position = ctx.string_position; + + self.context_stack.push(ctx); + self._match(req); + + if self.has_matched { + return; + } + + ctx.skip_char(req, 1); + self.marks.clear(); + } } } From 7e7b9734810947155876cc52028d873ba953b5f4 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 15 Aug 2022 20:36:32 +0200 Subject: [PATCH 072/893] clearup --- src/engine.rs | 301 +++++++++++++++++++++++++++----------------------- 1 file changed, 163 insertions(+), 138 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index fe7ee438e9..7a644f0671 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -252,147 +252,16 @@ impl State { let flags = SreInfo::from_bits_truncate(ctx.peek_code(req, 2)); if flags.contains(SreInfo::PREFIX) { - /* pattern starts with a known prefix */ - /* */ - let len = ctx.peek_code(req, 5) as usize; - let skip = ctx.peek_code(req, 6) as usize; - let prefix = &ctx.pattern(req)[7..7 + len]; - let overlap = &ctx.pattern(req)[7 + len - 1..7 + len * 2]; - - if len == 1 { - // pattern starts with a literal character - let c = prefix[0]; - - // code_position ready for tail match - ctx.skip_code_from(req, 1); - ctx.skip_code(2 * skip); - - req.must_advance = false; - - while !ctx.at_end(req) { - // find the next matched literal - while ctx.peek_char(req) != c { - ctx.skip_char(req, 1); - if ctx.at_end(req) { - return; - } - } - - req.start = ctx.string_position; - self.start = ctx.string_position; - self.string_position = ctx.string_position + skip; - - // literal only - if flags.contains(SreInfo::LITERAL) { - self.has_matched = true; - return; - } - - let mut next_ctx = ctx; - next_ctx.skip_char(req, skip); - - self.context_stack.push(next_ctx); - self._match(req); - - if self.has_matched { - return; - } - - ctx.skip_char(req, 1); - self.marks.clear(); - } - return; - } else if len > 1 { - // code_position ready for tail match - ctx.skip_code_from(req, 1); - ctx.skip_code(2 * skip); - - req.must_advance = false; - - while !ctx.at_end(req) { - let c = prefix[0]; - while ctx.peek_char(req) != c { - ctx.skip_char(req, 1); - if ctx.at_end(req) { - return; - } - } - ctx.skip_char(req, 1); - if ctx.at_end(req) { - return; - } - - let mut i = 1; - loop { - if ctx.peek_char(req) == prefix[i] { - i += 1; - if i != len { - ctx.skip_char(req, 1); - if ctx.at_end(req) { - return; - } - continue; - } - - req.start = ctx.string_position - (len - 1); - self.start = req.start; - self.string_position = self.start + skip; - - if flags.contains(SreInfo::LITERAL) { - self.has_matched = true; - return; - } - - let mut next_ctx = ctx; - // next_ctx.skip_char(req, 1); - next_ctx.string_position = self.string_position; - next_ctx.string_offset = req.string.offset(0, self.string_position); - self.context_stack.push(next_ctx); - self._match(req); - if self.has_matched { - return; - } - - ctx.skip_char(req, 1); - if ctx.at_end(req) { - return; - } - self.marks.clear(); - } - i = overlap[i] as usize; - if i == 0 { - break; - } - } - } - return; + if flags.contains(SreInfo::LITERAL) { + search_info_literal::(req, self, ctx); + } else { + search_info_literal::(req, self, ctx); } + return; } else if flags.contains(SreInfo::CHARSET) { - let set = &ctx.pattern(req)[5..]; - ctx.skip_code_from(req, 1); - req.must_advance = false; - loop { - while !ctx.at_end(req) && !charset(set, ctx.peek_char(req)) { - ctx.skip_char(req, 1); - } - if ctx.at_end(req) { - return; - } - req.start = ctx.string_position; - self.start = ctx.string_position; - self.string_position = ctx.string_position; - - self.context_stack.push(ctx); - self._match(req); - - if self.has_matched { - return; - } - - ctx.skip_char(req, 1); - self.marks.clear(); - } + return search_info_charset(req, self, ctx); } + // fallback to general search } self.context_stack.push(ctx); @@ -528,6 +397,162 @@ fn dispatch( } } +fn search_info_literal( + req: &mut Request, + state: &mut State, + mut ctx: MatchContext, +) { + /* pattern starts with a known prefix */ + /* */ + let len = ctx.peek_code(req, 5) as usize; + let skip = ctx.peek_code(req, 6) as usize; + let prefix = &ctx.pattern(req)[7..7 + len]; + let overlap = &ctx.pattern(req)[7 + len - 1..7 + len * 2]; + + // code_position ready for tail match + ctx.skip_code_from(req, 1); + ctx.skip_code(2 * skip); + + req.must_advance = false; + + if len == 1 { + // pattern starts with a literal character + let c = prefix[0]; + + while !ctx.at_end(req) { + // find the next matched literal + while ctx.peek_char(req) != c { + ctx.skip_char(req, 1); + if ctx.at_end(req) { + return; + } + } + + req.start = ctx.string_position; + state.start = ctx.string_position; + state.string_position = ctx.string_position + skip; + + // literal only + if LITERAL { + state.has_matched = true; + return; + } + + let mut next_ctx = ctx; + next_ctx.skip_char(req, skip); + + state.context_stack.push(next_ctx); + state._match(req); + + if state.has_matched { + return; + } + + ctx.skip_char(req, 1); + state.marks.clear(); + } + } else { + while !ctx.at_end(req) { + let c = prefix[0]; + while ctx.peek_char(req) != c { + ctx.skip_char(req, 1); + if ctx.at_end(req) { + return; + } + } + ctx.skip_char(req, 1); + if ctx.at_end(req) { + return; + } + + let mut i = 1; + loop { + if ctx.peek_char(req) == prefix[i] { + i += 1; + if i != len { + ctx.skip_char(req, 1); + if ctx.at_end(req) { + return; + } + continue; + } + + req.start = ctx.string_position - (len - 1); + state.start = req.start; + state.string_position = state.start + skip; + + // literal only + if LITERAL { + state.has_matched = true; + return; + } + + let mut next_ctx = ctx; + if skip != 0 { + next_ctx.skip_char(req, 1); + } else { + next_ctx.string_position = state.string_position; + next_ctx.string_offset = req.string.offset(0, state.string_position); + } + + state.context_stack.push(next_ctx); + state._match(req); + + if state.has_matched { + return; + } + + ctx.skip_char(req, 1); + if ctx.at_end(req) { + return; + } + state.marks.clear(); + } + + i = overlap[i] as usize; + if i == 0 { + break; + } + } + } + } +} + +fn search_info_charset( + req: &mut Request, + state: &mut State, + mut ctx: MatchContext, +) { + let set = &ctx.pattern(req)[5..]; + + ctx.skip_code_from(req, 1); + + req.must_advance = false; + + loop { + while !ctx.at_end(req) && !charset(set, ctx.peek_char(req)) { + ctx.skip_char(req, 1); + } + if ctx.at_end(req) { + return; + } + + req.start = ctx.string_position; + state.start = ctx.string_position; + state.string_position = ctx.string_position; + + state.context_stack.push(ctx); + state._match(req); + + if state.has_matched { + return; + } + + ctx.skip_char(req, 1); + state.marks.clear(); + } +} + /* assert subpattern */ /* */ fn op_assert(req: &Request, state: &mut State, ctx: &mut MatchContext) { From 26a78dbaa4e4e78f1e5d8dcea9c32054ae07fdd3 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 15 Aug 2022 20:36:46 +0200 Subject: [PATCH 073/893] update to 0.4.1 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 373166d6db..53524f1446 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sre-engine" -version = "0.4.0" +version = "0.4.1" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" repository = "https://github.com/RustPython/sre-engine" From 4e6b27144a407bb4daf341048310cad26570a7b3 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 15 Aug 2022 21:30:40 +0200 Subject: [PATCH 074/893] introduce SearchIter --- src/engine.rs | 26 ++++++++++++++++++++++++++ tests/tests.rs | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/engine.rs b/src/engine.rs index 7a644f0671..7334516c2f 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -282,6 +282,32 @@ impl State { } } +pub struct SearchIter<'a, S: StrDrive> { + pub req: Request<'a, S>, + pub state: State, +} + +impl<'a, S: StrDrive> Iterator for SearchIter<'a, S> { + type Item = (); + + fn next(&mut self) -> Option { + if self.req.start > self.req.end { + return None; + } + + self.state.reset(self.req.start); + self.state.search(self.req); + if !self.state.has_matched { + return None; + } + + self.req.must_advance = self.state.string_position == self.state.start; + self.req.start = self.state.string_position; + + Some(()) + } +} + fn dispatch( req: &Request, state: &mut State, diff --git a/tests/tests.rs b/tests/tests.rs index 21bc89d40c..5212226f4e 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -126,4 +126,4 @@ fn test_info_literal2() { let (req, mut state) = p.state("pythonpython"); state.search(req); assert!(state.has_matched); -} \ No newline at end of file +} From 6363940d6cdfa6b7165ee3fe6bd167ae07d8dee8 Mon Sep 17 00:00:00 2001 From: Zanie Date: Tue, 11 Jul 2023 15:27:16 -0500 Subject: [PATCH 075/893] Delete stale ASDL update script --- scripts/update_asdl.sh | 7 ------- 1 file changed, 7 deletions(-) delete mode 100755 scripts/update_asdl.sh diff --git a/scripts/update_asdl.sh b/scripts/update_asdl.sh deleted file mode 100755 index 0f12735477..0000000000 --- a/scripts/update_asdl.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash -set -e - -cd "$(dirname "$(dirname "$0")")" - -python compiler/ast/asdl_rs.py -D compiler/ast/src/ast_gen.rs -M vm/src/stdlib/ast/gen.rs compiler/ast/Python.asdl -rustfmt compiler/ast/src/ast_gen.rs vm/src/stdlib/ast/gen.rs From b088787f7b3aa358860f08c7c3a17765a14e4262 Mon Sep 17 00:00:00 2001 From: Zanie Date: Tue, 11 Jul 2023 15:44:22 -0500 Subject: [PATCH 076/893] Remove commented use of ASDL update script from CI workflow --- .github/workflows/ci.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c71f464cbb..559756971b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -323,8 +323,6 @@ jobs: - name: check wasm code with prettier # prettier doesn't handle ignore files very well: https://github.com/prettier/prettier/issues/8506 run: cd wasm && git ls-files -z | xargs -0 prettier --check -u - # - name: Check update_asdl.sh consistency - # run: bash scripts/update_asdl.sh && git diff --exit-code miri: if: ${{ !contains(github.event.pull_request.labels.*.name, 'skip:ci') }} From c28cb3941fde35f9931b995aa5a9a54c8bfe7a7c Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 19 Jul 2023 23:43:26 +0900 Subject: [PATCH 077/893] prettier *.js + fix miri build (#5028) * prettier *.js * bump up proc-macro2 --- Cargo.lock | 4 ++-- derive-impl/Cargo.toml | 2 +- wasm/README.md | 2 +- wasm/demo/webpack.config.js | 8 ++++---- wasm/example/index.html | 2 +- wasm/notebook/src/index.js | 12 ++++++------ wasm/notebook/webpack.config.js | 4 ++-- 7 files changed, 17 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f1b8b18104..f71ff7c8de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1714,9 +1714,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.51" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" dependencies = [ "unicode-ident", ] diff --git a/derive-impl/Cargo.toml b/derive-impl/Cargo.toml index e50d92d869..1c6b214cc6 100644 --- a/derive-impl/Cargo.toml +++ b/derive-impl/Cargo.toml @@ -13,7 +13,7 @@ once_cell = { workspace = true } syn = { workspace = true, features = ["full", "extra-traits"] } maplit = "1.0.2" -proc-macro2 = "1.0.37" +proc-macro2 = "1.0.60" quote = "1.0.18" syn-ext = { version = "0.4.0", features = ["full"] } textwrap = { version = "0.15.0", default-features = false } diff --git a/wasm/README.md b/wasm/README.md index de71e66045..e8f8799f39 100644 --- a/wasm/README.md +++ b/wasm/README.md @@ -51,7 +51,7 @@ print(js_vars['a'] * 9) vars: { a: 9, }, - } + }, ); ``` diff --git a/wasm/demo/webpack.config.js b/wasm/demo/webpack.config.js index c431056a75..e20b477241 100644 --- a/wasm/demo/webpack.config.js +++ b/wasm/demo/webpack.config.js @@ -18,7 +18,7 @@ module.exports = (env = {}) => { alias: { rustpython: path.resolve( __dirname, - env.rustpythonPkg || '../lib/pkg' + env.rustpythonPkg || '../lib/pkg', ), }, }, @@ -39,11 +39,11 @@ module.exports = (env = {}) => { snippets: fs .readdirSync(path.join(__dirname, 'snippets')) .map((filename) => - path.basename(filename, path.extname(filename)) + path.basename(filename, path.extname(filename)), ), defaultSnippetName: 'fibonacci', defaultSnippet: fs.readFileSync( - path.join(__dirname, 'snippets/fibonacci.py') + path.join(__dirname, 'snippets/fibonacci.py'), ), }, }), @@ -56,7 +56,7 @@ module.exports = (env = {}) => { config.plugins.push( new WasmPackPlugin({ crateDirectory: path.join(__dirname, '../lib'), - }) + }), ); } return config; diff --git a/wasm/example/index.html b/wasm/example/index.html index 86a99c42f1..4c274469f3 100644 --- a/wasm/example/index.html +++ b/wasm/example/index.html @@ -1,4 +1,4 @@ - + diff --git a/wasm/notebook/src/index.js b/wasm/notebook/src/index.js index 799c49ae72..422bc4d0d6 100644 --- a/wasm/notebook/src/index.js +++ b/wasm/notebook/src/index.js @@ -68,7 +68,7 @@ const secondaryEditor = CodeMirror( { lineNumbers: true, lineWrapping: true, - } + }, ); const buffers = {}; @@ -88,7 +88,7 @@ openBuffer( '# python code or code blocks that start with %%py, %%md %%math.', 'notebook', buffersDropDown, - buffersList + buffersList, ); openBuffer( @@ -97,7 +97,7 @@ openBuffer( '# Python code', 'python', buffersDropDown, - buffersList + buffersList, ); openBuffer( @@ -106,7 +106,7 @@ openBuffer( '// Javascript code goes here', 'javascript', buffersDropDown, - buffersList + buffersList, ); openBuffer( @@ -115,7 +115,7 @@ openBuffer( '/* CSS goes here */', 'css', buffersDropDown, - buffersList + buffersList, ); // select main buffer by default and set the main tab to active @@ -286,7 +286,7 @@ CodeMirror.on(buffersDropDown, 'change', function () { selectBuffer( secondaryEditor, buffers, - buffersDropDown.options[buffersDropDown.selectedIndex].value + buffersDropDown.options[buffersDropDown.selectedIndex].value, ); }); diff --git a/wasm/notebook/webpack.config.js b/wasm/notebook/webpack.config.js index e6e1aae0d1..9fda3cf4aa 100644 --- a/wasm/notebook/webpack.config.js +++ b/wasm/notebook/webpack.config.js @@ -18,7 +18,7 @@ module.exports = (env = {}) => { alias: { rustpython: path.resolve( __dirname, - env.rustpythonPkg || '../lib/pkg' + env.rustpythonPkg || '../lib/pkg', ), }, }, @@ -64,7 +64,7 @@ module.exports = (env = {}) => { new WasmPackPlugin({ crateDirectory: path.join(__dirname, '../lib'), forceMode: 'release', - }) + }), ); } return config; From bdb0c8f64557e0822f0bcfd63defbad54625c17a Mon Sep 17 00:00:00 2001 From: Zanie Blue Date: Wed, 19 Jul 2023 23:54:07 -0500 Subject: [PATCH 078/893] Add parser support for PEP 695 (#5026) * Add generated content for PEP 695 ASDL * Bump RustPython/Parser to https://github.com/RustPython/Parser/commit/704eb40108239a8faf9bd1d4217e8dad0ac7edb3 * Add stubs for type aliases and parameters --- Cargo.lock | 14 +- Cargo.toml | 10 +- compiler/codegen/src/compile.rs | 1 + compiler/codegen/src/symboltable.rs | 2 + vm/src/stdlib/ast/gen.rs | 279 ++++++++++++++++++++++++++++ 5 files changed, 294 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f71ff7c8de..f79b0748a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1953,7 +1953,7 @@ dependencies = [ [[package]] name = "ruff_source_location" version = "0.0.0" -source = "git+https://github.com/RustPython/Parser.git?rev=69d27d924c877b6f2fa5dc75c9589ab505d5b3f1#69d27d924c877b6f2fa5dc75c9589ab505d5b3f1" +source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" dependencies = [ "memchr", "once_cell", @@ -1963,7 +1963,7 @@ dependencies = [ [[package]] name = "ruff_text_size" version = "0.0.0" -source = "git+https://github.com/RustPython/Parser.git?rev=69d27d924c877b6f2fa5dc75c9589ab505d5b3f1#69d27d924c877b6f2fa5dc75c9589ab505d5b3f1" +source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" [[package]] name = "rustc-hash" @@ -2021,7 +2021,7 @@ dependencies = [ [[package]] name = "rustpython-ast" version = "0.2.0" -source = "git+https://github.com/RustPython/Parser.git?rev=69d27d924c877b6f2fa5dc75c9589ab505d5b3f1#69d27d924c877b6f2fa5dc75c9589ab505d5b3f1" +source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" dependencies = [ "is-macro", "malachite-bigint", @@ -2133,7 +2133,7 @@ dependencies = [ [[package]] name = "rustpython-format" version = "0.2.0" -source = "git+https://github.com/RustPython/Parser.git?rev=69d27d924c877b6f2fa5dc75c9589ab505d5b3f1#69d27d924c877b6f2fa5dc75c9589ab505d5b3f1" +source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" dependencies = [ "bitflags 2.3.1", "itertools 0.10.5", @@ -2160,7 +2160,7 @@ dependencies = [ [[package]] name = "rustpython-literal" version = "0.2.0" -source = "git+https://github.com/RustPython/Parser.git?rev=69d27d924c877b6f2fa5dc75c9589ab505d5b3f1#69d27d924c877b6f2fa5dc75c9589ab505d5b3f1" +source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" dependencies = [ "hexf-parse", "is-macro", @@ -2172,7 +2172,7 @@ dependencies = [ [[package]] name = "rustpython-parser" version = "0.2.0" -source = "git+https://github.com/RustPython/Parser.git?rev=69d27d924c877b6f2fa5dc75c9589ab505d5b3f1#69d27d924c877b6f2fa5dc75c9589ab505d5b3f1" +source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" dependencies = [ "anyhow", "is-macro", @@ -2195,7 +2195,7 @@ dependencies = [ [[package]] name = "rustpython-parser-core" version = "0.2.0" -source = "git+https://github.com/RustPython/Parser.git?rev=69d27d924c877b6f2fa5dc75c9589ab505d5b3f1#69d27d924c877b6f2fa5dc75c9589ab505d5b3f1" +source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" dependencies = [ "is-macro", "memchr", diff --git a/Cargo.toml b/Cargo.toml index 5548879580..8fd494a67f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,11 +29,11 @@ rustpython-pylib = { path = "pylib" } rustpython-stdlib = { path = "stdlib" } rustpython-doc = { git = "https://github.com/RustPython/__doc__", branch = "main" } -rustpython-literal = { git = "https://github.com/RustPython/Parser.git", rev = "69d27d924c877b6f2fa5dc75c9589ab505d5b3f1" } -rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", rev = "69d27d924c877b6f2fa5dc75c9589ab505d5b3f1" } -rustpython-parser = { git = "https://github.com/RustPython/Parser.git", rev = "69d27d924c877b6f2fa5dc75c9589ab505d5b3f1" } -rustpython-ast = { git = "https://github.com/RustPython/Parser.git", rev = "69d27d924c877b6f2fa5dc75c9589ab505d5b3f1" } -rustpython-format = { git = "https://github.com/RustPython/Parser.git", rev = "69d27d924c877b6f2fa5dc75c9589ab505d5b3f1" } +rustpython-literal = { git = "https://github.com/RustPython/Parser.git", rev = "704eb40108239a8faf9bd1d4217e8dad0ac7edb3" } +rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", rev = "704eb40108239a8faf9bd1d4217e8dad0ac7edb3" } +rustpython-parser = { git = "https://github.com/RustPython/Parser.git", rev = "704eb40108239a8faf9bd1d4217e8dad0ac7edb3" } +rustpython-ast = { git = "https://github.com/RustPython/Parser.git", rev = "704eb40108239a8faf9bd1d4217e8dad0ac7edb3" } +rustpython-format = { git = "https://github.com/RustPython/Parser.git", rev = "704eb40108239a8faf9bd1d4217e8dad0ac7edb3" } # rustpython-literal = { path = "../RustPython-parser/literal" } # rustpython-parser-core = { path = "../RustPython-parser/core" } # rustpython-parser = { path = "../RustPython-parser/parser" } diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index cc4d82ec8e..359ecd8600 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -888,6 +888,7 @@ impl Compiler { Stmt::Pass(_) => { // No need to emit any code here :) } + Stmt::TypeAlias(_) => {} } Ok(()) } diff --git a/compiler/codegen/src/symboltable.rs b/compiler/codegen/src/symboltable.rs index a868a6ffb5..e09f0c4aab 100644 --- a/compiler/codegen/src/symboltable.rs +++ b/compiler/codegen/src/symboltable.rs @@ -686,6 +686,7 @@ impl SymbolTableBuilder { bases, keywords, decorator_list, + type_params: _, range, }) => { self.enter_scope(name.as_str(), SymbolTableType::Class, range.start.row.get()); @@ -863,6 +864,7 @@ impl SymbolTableBuilder { self.scan_expression(expression, ExpressionContext::Load)?; } } + Stmt::TypeAlias(StmtTypeAlias { .. }) => {} } Ok(()) } diff --git a/vm/src/stdlib/ast/gen.rs b/vm/src/stdlib/ast/gen.rs index d053e41e39..d3969b9024 100644 --- a/vm/src/stdlib/ast/gen.rs +++ b/vm/src/stdlib/ast/gen.rs @@ -89,6 +89,7 @@ impl NodeStmtFunctionDef { ctx.new_str(ascii!("decorator_list")).into(), ctx.new_str(ascii!("returns")).into(), ctx.new_str(ascii!("type_comment")).into(), + ctx.new_str(ascii!("type_params")).into(), ]) .into(), ); @@ -119,6 +120,7 @@ impl NodeStmtAsyncFunctionDef { ctx.new_str(ascii!("decorator_list")).into(), ctx.new_str(ascii!("returns")).into(), ctx.new_str(ascii!("type_comment")).into(), + ctx.new_str(ascii!("type_params")).into(), ]) .into(), ); @@ -148,6 +150,7 @@ impl NodeStmtClassDef { ctx.new_str(ascii!("keywords")).into(), ctx.new_str(ascii!("body")).into(), ctx.new_str(ascii!("decorator_list")).into(), + ctx.new_str(ascii!("type_params")).into(), ]) .into(), ); @@ -236,6 +239,33 @@ impl NodeStmtAssign { ); } } +#[pyclass(module = "_ast", name = "TypeAlias", base = "NodeStmt")] +struct NodeStmtTypeAlias; +#[pyclass(flags(HAS_DICT, BASETYPE))] +impl NodeStmtTypeAlias { + #[extend_class] + fn extend_class_with_fields(ctx: &Context, class: &'static Py) { + class.set_attr( + identifier!(ctx, _fields), + ctx.new_tuple(vec![ + ctx.new_str(ascii!("name")).into(), + ctx.new_str(ascii!("type_params")).into(), + ctx.new_str(ascii!("value")).into(), + ]) + .into(), + ); + class.set_attr( + identifier!(ctx, _attributes), + ctx.new_list(vec![ + ctx.new_str(ascii!("lineno")).into(), + ctx.new_str(ascii!("col_offset")).into(), + ctx.new_str(ascii!("end_lineno")).into(), + ctx.new_str(ascii!("end_col_offset")).into(), + ]) + .into(), + ); + } +} #[pyclass(module = "_ast", name = "AugAssign", base = "NodeStmt")] struct NodeStmtAugAssign; #[pyclass(flags(HAS_DICT, BASETYPE))] @@ -2214,6 +2244,82 @@ impl NodeTypeIgnoreTypeIgnore { class.set_attr(identifier!(ctx, _attributes), ctx.new_list(vec![]).into()); } } +#[pyclass(module = "_ast", name = "type_param", base = "NodeAst")] +struct NodeTypeParam; +#[pyclass(flags(HAS_DICT, BASETYPE))] +impl NodeTypeParam {} +#[pyclass(module = "_ast", name = "TypeVar", base = "NodeTypeParam")] +struct NodeTypeParamTypeVar; +#[pyclass(flags(HAS_DICT, BASETYPE))] +impl NodeTypeParamTypeVar { + #[extend_class] + fn extend_class_with_fields(ctx: &Context, class: &'static Py) { + class.set_attr( + identifier!(ctx, _fields), + ctx.new_tuple(vec![ + ctx.new_str(ascii!("name")).into(), + ctx.new_str(ascii!("bound")).into(), + ]) + .into(), + ); + class.set_attr( + identifier!(ctx, _attributes), + ctx.new_list(vec![ + ctx.new_str(ascii!("lineno")).into(), + ctx.new_str(ascii!("col_offset")).into(), + ctx.new_str(ascii!("end_lineno")).into(), + ctx.new_str(ascii!("end_col_offset")).into(), + ]) + .into(), + ); + } +} +#[pyclass(module = "_ast", name = "ParamSpec", base = "NodeTypeParam")] +struct NodeTypeParamParamSpec; +#[pyclass(flags(HAS_DICT, BASETYPE))] +impl NodeTypeParamParamSpec { + #[extend_class] + fn extend_class_with_fields(ctx: &Context, class: &'static Py) { + class.set_attr( + identifier!(ctx, _fields), + ctx.new_tuple(vec![ctx.new_str(ascii!("name")).into()]) + .into(), + ); + class.set_attr( + identifier!(ctx, _attributes), + ctx.new_list(vec![ + ctx.new_str(ascii!("lineno")).into(), + ctx.new_str(ascii!("col_offset")).into(), + ctx.new_str(ascii!("end_lineno")).into(), + ctx.new_str(ascii!("end_col_offset")).into(), + ]) + .into(), + ); + } +} +#[pyclass(module = "_ast", name = "TypeVarTuple", base = "NodeTypeParam")] +struct NodeTypeParamTypeVarTuple; +#[pyclass(flags(HAS_DICT, BASETYPE))] +impl NodeTypeParamTypeVarTuple { + #[extend_class] + fn extend_class_with_fields(ctx: &Context, class: &'static Py) { + class.set_attr( + identifier!(ctx, _fields), + ctx.new_tuple(vec![ctx.new_str(ascii!("name")).into()]) + .into(), + ); + class.set_attr( + identifier!(ctx, _attributes), + ctx.new_list(vec![ + ctx.new_str(ascii!("lineno")).into(), + ctx.new_str(ascii!("col_offset")).into(), + ctx.new_str(ascii!("end_lineno")).into(), + ctx.new_str(ascii!("end_col_offset")).into(), + ]) + .into(), + ); + } +} // sum impl Node for ast::located::Mod { @@ -2364,6 +2470,7 @@ impl Node for ast::located::Stmt { ast::located::Stmt::Return(cons) => cons.ast_to_object(vm), ast::located::Stmt::Delete(cons) => cons.ast_to_object(vm), ast::located::Stmt::Assign(cons) => cons.ast_to_object(vm), + ast::located::Stmt::TypeAlias(cons) => cons.ast_to_object(vm), ast::located::Stmt::AugAssign(cons) => cons.ast_to_object(vm), ast::located::Stmt::AnnAssign(cons) => cons.ast_to_object(vm), ast::located::Stmt::For(cons) => cons.ast_to_object(vm), @@ -2405,6 +2512,10 @@ impl Node for ast::located::Stmt { ast::located::Stmt::Delete(ast::located::StmtDelete::ast_from_object(_vm, _object)?) } else if _cls.is(NodeStmtAssign::static_type()) { ast::located::Stmt::Assign(ast::located::StmtAssign::ast_from_object(_vm, _object)?) + } else if _cls.is(NodeStmtTypeAlias::static_type()) { + ast::located::Stmt::TypeAlias(ast::located::StmtTypeAlias::ast_from_object( + _vm, _object, + )?) } else if _cls.is(NodeStmtAugAssign::static_type()) { ast::located::Stmt::AugAssign(ast::located::StmtAugAssign::ast_from_object( _vm, _object, @@ -2473,6 +2584,7 @@ impl Node for ast::located::StmtFunctionDef { decorator_list, returns, type_comment, + type_params, range: _range, } = self; let node = NodeAst @@ -2488,6 +2600,8 @@ impl Node for ast::located::StmtFunctionDef { .unwrap(); dict.set_item("type_comment", type_comment.ast_to_object(_vm), _vm) .unwrap(); + dict.set_item("type_params", type_params.ast_to_object(_vm), _vm) + .unwrap(); node_add_location(&dict, _range, _vm); node.into() } @@ -2515,6 +2629,10 @@ impl Node for ast::located::StmtFunctionDef { type_comment: get_node_field_opt(_vm, &_object, "type_comment")? .map(|obj| Node::ast_from_object(_vm, obj)) .transpose()?, + type_params: Node::ast_from_object( + _vm, + get_node_field(_vm, &_object, "type_params", "FunctionDef")?, + )?, range: range_from_object(_vm, _object, "FunctionDef")?, }) } @@ -2529,6 +2647,7 @@ impl Node for ast::located::StmtAsyncFunctionDef { decorator_list, returns, type_comment, + type_params, range: _range, } = self; let node = NodeAst @@ -2544,6 +2663,8 @@ impl Node for ast::located::StmtAsyncFunctionDef { .unwrap(); dict.set_item("type_comment", type_comment.ast_to_object(_vm), _vm) .unwrap(); + dict.set_item("type_params", type_params.ast_to_object(_vm), _vm) + .unwrap(); node_add_location(&dict, _range, _vm); node.into() } @@ -2571,6 +2692,10 @@ impl Node for ast::located::StmtAsyncFunctionDef { type_comment: get_node_field_opt(_vm, &_object, "type_comment")? .map(|obj| Node::ast_from_object(_vm, obj)) .transpose()?, + type_params: Node::ast_from_object( + _vm, + get_node_field(_vm, &_object, "type_params", "AsyncFunctionDef")?, + )?, range: range_from_object(_vm, _object, "AsyncFunctionDef")?, }) } @@ -2584,6 +2709,7 @@ impl Node for ast::located::StmtClassDef { keywords, body, decorator_list, + type_params, range: _range, } = self; let node = NodeAst @@ -2598,6 +2724,8 @@ impl Node for ast::located::StmtClassDef { dict.set_item("body", body.ast_to_object(_vm), _vm).unwrap(); dict.set_item("decorator_list", decorator_list.ast_to_object(_vm), _vm) .unwrap(); + dict.set_item("type_params", type_params.ast_to_object(_vm), _vm) + .unwrap(); node_add_location(&dict, _range, _vm); node.into() } @@ -2614,6 +2742,10 @@ impl Node for ast::located::StmtClassDef { _vm, get_node_field(_vm, &_object, "decorator_list", "ClassDef")?, )?, + type_params: Node::ast_from_object( + _vm, + get_node_field(_vm, &_object, "type_params", "ClassDef")?, + )?, range: range_from_object(_vm, _object, "ClassDef")?, }) } @@ -2706,6 +2838,42 @@ impl Node for ast::located::StmtAssign { } } // constructor +impl Node for ast::located::StmtTypeAlias { + fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef { + let ast::located::StmtTypeAlias { + name, + type_params, + value, + range: _range, + } = self; + let node = NodeAst + .into_ref_with_type(_vm, NodeStmtTypeAlias::static_type().to_owned()) + .unwrap(); + let dict = node.as_object().dict().unwrap(); + dict.set_item("name", name.ast_to_object(_vm), _vm).unwrap(); + dict.set_item("type_params", type_params.ast_to_object(_vm), _vm) + .unwrap(); + dict.set_item("value", value.ast_to_object(_vm), _vm) + .unwrap(); + node_add_location(&dict, _range, _vm); + node.into() + } + fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult { + Ok(ast::located::StmtTypeAlias { + name: Node::ast_from_object(_vm, get_node_field(_vm, &_object, "name", "TypeAlias")?)?, + type_params: Node::ast_from_object( + _vm, + get_node_field(_vm, &_object, "type_params", "TypeAlias")?, + )?, + value: Node::ast_from_object( + _vm, + get_node_field(_vm, &_object, "value", "TypeAlias")?, + )?, + range: range_from_object(_vm, _object, "TypeAlias")?, + }) + } +} +// constructor impl Node for ast::located::StmtAugAssign { fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef { let ast::located::StmtAugAssign { @@ -5092,6 +5260,112 @@ impl Node for ast::located::TypeIgnoreTypeIgnore { }) } } +// sum +impl Node for ast::located::TypeParam { + fn ast_to_object(self, vm: &VirtualMachine) -> PyObjectRef { + match self { + ast::located::TypeParam::TypeVar(cons) => cons.ast_to_object(vm), + ast::located::TypeParam::ParamSpec(cons) => cons.ast_to_object(vm), + ast::located::TypeParam::TypeVarTuple(cons) => cons.ast_to_object(vm), + } + } + fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult { + let _cls = _object.class(); + Ok(if _cls.is(NodeTypeParamTypeVar::static_type()) { + ast::located::TypeParam::TypeVar(ast::located::TypeParamTypeVar::ast_from_object( + _vm, _object, + )?) + } else if _cls.is(NodeTypeParamParamSpec::static_type()) { + ast::located::TypeParam::ParamSpec(ast::located::TypeParamParamSpec::ast_from_object( + _vm, _object, + )?) + } else if _cls.is(NodeTypeParamTypeVarTuple::static_type()) { + ast::located::TypeParam::TypeVarTuple( + ast::located::TypeParamTypeVarTuple::ast_from_object(_vm, _object)?, + ) + } else { + return Err(_vm.new_type_error(format!( + "expected some sort of type_param, but got {}", + _object.repr(_vm)? + ))); + }) + } +} +// constructor +impl Node for ast::located::TypeParamTypeVar { + fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef { + let ast::located::TypeParamTypeVar { + name, + bound, + range: _range, + } = self; + let node = NodeAst + .into_ref_with_type(_vm, NodeTypeParamTypeVar::static_type().to_owned()) + .unwrap(); + let dict = node.as_object().dict().unwrap(); + dict.set_item("name", name.ast_to_object(_vm), _vm).unwrap(); + dict.set_item("bound", bound.ast_to_object(_vm), _vm) + .unwrap(); + node_add_location(&dict, _range, _vm); + node.into() + } + fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult { + Ok(ast::located::TypeParamTypeVar { + name: Node::ast_from_object(_vm, get_node_field(_vm, &_object, "name", "TypeVar")?)?, + bound: get_node_field_opt(_vm, &_object, "bound")? + .map(|obj| Node::ast_from_object(_vm, obj)) + .transpose()?, + range: range_from_object(_vm, _object, "TypeVar")?, + }) + } +} +// constructor +impl Node for ast::located::TypeParamParamSpec { + fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef { + let ast::located::TypeParamParamSpec { + name, + range: _range, + } = self; + let node = NodeAst + .into_ref_with_type(_vm, NodeTypeParamParamSpec::static_type().to_owned()) + .unwrap(); + let dict = node.as_object().dict().unwrap(); + dict.set_item("name", name.ast_to_object(_vm), _vm).unwrap(); + node_add_location(&dict, _range, _vm); + node.into() + } + fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult { + Ok(ast::located::TypeParamParamSpec { + name: Node::ast_from_object(_vm, get_node_field(_vm, &_object, "name", "ParamSpec")?)?, + range: range_from_object(_vm, _object, "ParamSpec")?, + }) + } +} +// constructor +impl Node for ast::located::TypeParamTypeVarTuple { + fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef { + let ast::located::TypeParamTypeVarTuple { + name, + range: _range, + } = self; + let node = NodeAst + .into_ref_with_type(_vm, NodeTypeParamTypeVarTuple::static_type().to_owned()) + .unwrap(); + let dict = node.as_object().dict().unwrap(); + dict.set_item("name", name.ast_to_object(_vm), _vm).unwrap(); + node_add_location(&dict, _range, _vm); + node.into() + } + fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult { + Ok(ast::located::TypeParamTypeVarTuple { + name: Node::ast_from_object( + _vm, + get_node_field(_vm, &_object, "name", "TypeVarTuple")?, + )?, + range: range_from_object(_vm, _object, "TypeVarTuple")?, + }) + } +} pub fn extend_module_nodes(vm: &VirtualMachine, module: &Py) { extend_module!(vm, module, { @@ -5107,6 +5381,7 @@ pub fn extend_module_nodes(vm: &VirtualMachine, module: &Py) { "Return" => NodeStmtReturn::make_class(&vm.ctx), "Delete" => NodeStmtDelete::make_class(&vm.ctx), "Assign" => NodeStmtAssign::make_class(&vm.ctx), + "TypeAlias" => NodeStmtTypeAlias::make_class(&vm.ctx), "AugAssign" => NodeStmtAugAssign::make_class(&vm.ctx), "AnnAssign" => NodeStmtAnnAssign::make_class(&vm.ctx), "For" => NodeStmtFor::make_class(&vm.ctx), @@ -5213,5 +5488,9 @@ pub fn extend_module_nodes(vm: &VirtualMachine, module: &Py) { "MatchOr" => NodePatternMatchOr::make_class(&vm.ctx), "type_ignore" => NodeTypeIgnore::make_class(&vm.ctx), "TypeIgnore" => NodeTypeIgnoreTypeIgnore::make_class(&vm.ctx), + "type_param" => NodeTypeParam::make_class(&vm.ctx), + "TypeVar" => NodeTypeParamTypeVar::make_class(&vm.ctx), + "ParamSpec" => NodeTypeParamParamSpec::make_class(&vm.ctx), + "TypeVarTuple" => NodeTypeParamTypeVarTuple::make_class(&vm.ctx), }) } From aea4d509c91d89ea432e833222801013fca86436 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Sat, 26 Aug 2023 23:21:32 +0900 Subject: [PATCH 079/893] Fix win_lib_path to actually work --- pylib/build.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pylib/build.rs b/pylib/build.rs index 1ea80e25a5..52a2e3a632 100644 --- a/pylib/build.rs +++ b/pylib/build.rs @@ -9,7 +9,7 @@ fn main() { if cfg!(windows) { if let Ok(real_path) = std::fs::read_to_string("Lib") { - println!("rustc-env:win_lib_path={real_path:?}"); + println!("cargo:rustc-env=win_lib_path={real_path:?}"); } } } From d4d362a9ca0ce6b1cdfa4b631cdc5d427a99005b Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Sat, 26 Aug 2023 17:22:17 +0900 Subject: [PATCH 080/893] Update to Rust 1.72.0 --- compiler/codegen/src/compile.rs | 2 +- compiler/core/src/marshal.rs | 4 +- derive-impl/src/pyclass.rs | 5 +- derive-impl/src/util.rs | 22 ++++++--- src/interpreter.rs | 2 +- stdlib/src/array.rs | 16 +++---- stdlib/src/socket.rs | 4 +- stdlib/src/sqlite.rs | 84 +++++++++++++++++++++++++-------- vm/src/builtins/dict.rs | 1 + vm/src/builtins/range.rs | 2 +- vm/src/function/mod.rs | 2 +- vm/src/function/protocol.rs | 4 +- vm/src/object/core.rs | 8 ++-- vm/src/protocol/callable.rs | 7 ++- vm/src/stdlib/posix.rs | 2 +- 15 files changed, 114 insertions(+), 51 deletions(-) diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index 359ecd8600..217c2dc02b 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -976,7 +976,7 @@ impl Compiler { .chain(&args.posonlyargs) .chain(&args.args) .map(|arg| arg.as_arg()) - .chain(kw_without_defaults.into_iter()) + .chain(kw_without_defaults) .chain(kw_with_defaults.into_iter().map(|(arg, _)| arg)); for name in args_iter { self.varname(name.arg.as_str())?; diff --git a/compiler/core/src/marshal.rs b/compiler/core/src/marshal.rs index a5d3089ff6..9f3bb8e22e 100644 --- a/compiler/core/src/marshal.rs +++ b/compiler/core/src/marshal.rs @@ -484,7 +484,9 @@ impl Write for Vec { } pub(crate) fn write_len(buf: &mut W, len: usize) { - let Ok(len) = len.try_into() else { panic!("too long to serialize") }; + let Ok(len) = len.try_into() else { + panic!("too long to serialize") + }; buf.write_u32(len); } diff --git a/derive-impl/src/pyclass.rs b/derive-impl/src/pyclass.rs index 3f78811073..55705c742d 100644 --- a/derive-impl/src/pyclass.rs +++ b/derive-impl/src/pyclass.rs @@ -1327,7 +1327,10 @@ impl ItemMeta for SlotItemMeta { Some(HashMap::default()) }; let (Some(meta_map), None) = (meta_map, nested.next()) else { - bail_span!(meta_ident, "#[pyslot] must be of the form #[pyslot] or #[pyslot(slot_name)]") + bail_span!( + meta_ident, + "#[pyslot] must be of the form #[pyslot] or #[pyslot(slot_name)]" + ) }; Ok(Self::from_inner(ItemMetaInner { item_ident, diff --git a/derive-impl/src/util.rs b/derive-impl/src/util.rs index c7b98ab2ce..9f827f9e9a 100644 --- a/derive-impl/src/util.rs +++ b/derive-impl/src/util.rs @@ -167,9 +167,16 @@ impl ItemMetaInner { pub fn _optional_str(&self, key: &str) -> Result> { let value = if let Some((_, meta)) = self.meta_map.get(key) { let Meta::NameValue(syn::MetaNameValue { - lit: syn::Lit::Str(lit), .. - }) = meta else { - bail_span!(meta, "#[{}({} = ...)] must exist as a string", self.meta_name(), key) + lit: syn::Lit::Str(lit), + .. + }) = meta + else { + bail_span!( + meta, + "#[{}({} = ...)] must exist as a string", + self.meta_name(), + key + ) }; Some(lit.value()) } else { @@ -203,7 +210,10 @@ impl ItemMetaInner { key: &str, ) -> Result>> { let value = if let Some((_, meta)) = self.meta_map.get(key) { - let Meta::List(syn::MetaList { path: _, nested, .. }) = meta else { + let Meta::List(syn::MetaList { + path: _, nested, .. + }) = meta + else { bail_span!(meta, "#[{}({}(...))] must be a list", self.meta_name(), key) }; Some(nested.into_iter()) @@ -445,11 +455,11 @@ impl ExceptionItemMeta { return Ok({ let type_name = inner.item_name(); let Some(py_name) = type_name.as_str().strip_prefix("Py") else { - bail_span!( + bail_span!( inner.item_ident, "#[pyexception] expects its underlying type to be named `Py` prefixed" ) - }; + }; py_name.to_string() }) } diff --git a/src/interpreter.rs b/src/interpreter.rs index 89b0bc6e00..39a9662cd3 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -66,6 +66,6 @@ pub fn init_stdlib(vm: &mut VirtualMachine) { .push(rustpython_pylib::LIB_PATH.to_owned()) } - settings.path_list.extend(path_list.into_iter()); + settings.path_list.extend(path_list); } } diff --git a/stdlib/src/array.rs b/stdlib/src/array.rs index b0739e768f..a0beb5d1ee 100644 --- a/stdlib/src/array.rs +++ b/stdlib/src/array.rs @@ -1484,14 +1484,14 @@ mod array { _ => false, }; match code { - 0 | 1 => Ok(Self::Int8 { signed }), - 2 | 3 | 4 | 5 => Ok(Self::Int16 { signed, big_endian }), - 6 | 7 | 8 | 9 => Ok(Self::Int32 { signed, big_endian }), - 10 | 11 | 12 | 13 => Ok(Self::Int64 { signed, big_endian }), - 14 | 15 => Ok(Self::Ieee754Float { big_endian }), - 16 | 17 => Ok(Self::Ieee754Double { big_endian }), - 18 | 19 => Ok(Self::Utf16 { big_endian }), - 20 | 21 => Ok(Self::Utf32 { big_endian }), + 0..=1 => Ok(Self::Int8 { signed }), + 2..=5 => Ok(Self::Int16 { signed, big_endian }), + 6..=9 => Ok(Self::Int32 { signed, big_endian }), + 10..=13 => Ok(Self::Int64 { signed, big_endian }), + 14..=15 => Ok(Self::Ieee754Float { big_endian }), + 16..=17 => Ok(Self::Ieee754Double { big_endian }), + 18..=19 => Ok(Self::Utf16 { big_endian }), + 20..=21 => Ok(Self::Utf32 { big_endian }), _ => Err(code), } } diff --git a/stdlib/src/socket.rs b/stdlib/src/socket.rs index 1543d56e55..b417a1739c 100644 --- a/stdlib/src/socket.rs +++ b/stdlib/src/socket.rs @@ -935,7 +935,7 @@ mod _socket { )) })?; match tuple.len() { - 2 | 3 | 4 => {} + 2..=4 => {} _ => return Err(vm.new_type_error( "AF_INET6 address must be a tuple (host, port[, flowinfo[, scopeid]])" .to_owned(), @@ -1914,7 +1914,7 @@ mod _socket { vm: &VirtualMachine, ) -> Result<(String, String), IoOrPyException> { match address.len() { - 2 | 3 | 4 => {} + 2..=4 => {} _ => { return Err(vm .new_type_error("illegal sockaddr argument".to_owned()) diff --git a/stdlib/src/sqlite.rs b/stdlib/src/sqlite.rs index cb8c5f89cb..ecc5fea71c 100644 --- a/stdlib/src/sqlite.rs +++ b/stdlib/src/sqlite.rs @@ -99,8 +99,11 @@ mod _sqlite { )* fn setup_module_exceptions(module: &PyObject, vm: &VirtualMachine) { $( - let exception = [<$x:snake:upper>].get_or_init( - || vm.ctx.new_exception_type("_sqlite3", stringify!($x), Some(vec![$base(vm).to_owned()]))); + #[allow(clippy::redundant_closure_call)] + let exception = [<$x:snake:upper>].get_or_init(|| { + let base = $base(vm); + vm.ctx.new_exception_type("_sqlite3", stringify!($x), Some(vec![base.to_owned()])) + }); module.set_attr(stringify!($x), exception.clone().into_object(), vm).unwrap(); )* } @@ -455,7 +458,9 @@ mod _sqlite { let context = SqliteContext::from(context); let (_, vm) = (*context.user_data::()).retrieve(); let instance = context.aggregate_context::<*const PyObject>(); - let Some(instance) = (*instance).as_ref() else { return; }; + let Some(instance) = (*instance).as_ref() else { + return; + }; Self::callback_result_from_method(context, instance, "finalize", vm); } @@ -895,7 +900,7 @@ mod _sqlite { let cursor = cursor.downcast::().map_err(|x| { vm.new_type_error(format!("factory must return a cursor, not {}", x.class())) })?; - unsafe { cursor.row_factory.swap(zelf.row_factory.to_owned()) }; + let _ = unsafe { cursor.row_factory.swap(zelf.row_factory.to_owned()) }; cursor } else { let row_factory = zelf.row_factory.to_owned(); @@ -1077,7 +1082,17 @@ mod _sqlite { }; let db = self.db_lock(vm)?; let Some(data) = CallbackData::new(args.func, vm) else { - return db.create_function(name.as_ptr(), args.narg, flags, null_mut(), None, None, None, None, vm); + return db.create_function( + name.as_ptr(), + args.narg, + flags, + null_mut(), + None, + None, + None, + None, + vm, + ); }; db.create_function( @@ -1098,7 +1113,17 @@ mod _sqlite { let name = args.name.to_cstring(vm)?; let db = self.db_lock(vm)?; let Some(data) = CallbackData::new(args.aggregate_class, vm) else { - return db.create_function(name.as_ptr(), args.narg, SQLITE_UTF8, null_mut(), None, None, None, None, vm); + return db.create_function( + name.as_ptr(), + args.narg, + SQLITE_UTF8, + null_mut(), + None, + None, + None, + None, + vm, + ); }; db.create_function( @@ -1125,7 +1150,14 @@ mod _sqlite { let db = self.db_lock(vm)?; let Some(data) = CallbackData::new(callable.clone(), vm) else { unsafe { - sqlite3_create_collation_v2(db.db, name.as_ptr(), SQLITE_UTF8, null_mut(), None, None); + sqlite3_create_collation_v2( + db.db, + name.as_ptr(), + SQLITE_UTF8, + null_mut(), + None, + None, + ); } return Ok(()); }; @@ -1149,7 +1181,7 @@ mod _sqlite { // TODO: replace with Result.inspect_err when stable if let Err(exc) = db.check(ret, vm) { // create_collation do not call destructor if error occur - unsafe { Box::from_raw(data) }; + let _ = unsafe { Box::from_raw(data) }; Err(exc) } else { Ok(()) @@ -1168,7 +1200,18 @@ mod _sqlite { let db = self.db_lock(vm)?; let Some(data) = CallbackData::new(aggregate_class, vm) else { unsafe { - sqlite3_create_window_function(db.db, name.as_ptr(), narg, SQLITE_UTF8, null_mut(), None, None, None, None, None) + sqlite3_create_window_function( + db.db, + name.as_ptr(), + narg, + SQLITE_UTF8, + null_mut(), + None, + None, + None, + None, + None, + ) }; return Ok(()); }; @@ -1214,10 +1257,8 @@ mod _sqlite { #[pymethod] fn set_trace_callback(&self, callable: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { let db = self.db_lock(vm)?; - let Some(data )= CallbackData::new(callable, vm) else { - unsafe { - sqlite3_trace_v2(db.db, SQLITE_TRACE_STMT as u32, None, null_mut()) - }; + let Some(data) = CallbackData::new(callable, vm) else { + unsafe { sqlite3_trace_v2(db.db, SQLITE_TRACE_STMT as u32, None, null_mut()) }; return Ok(()); }; @@ -1241,7 +1282,7 @@ mod _sqlite { vm: &VirtualMachine, ) -> PyResult<()> { let db = self.db_lock(vm)?; - let Some(data )= CallbackData::new(callable, vm) else { + let Some(data) = CallbackData::new(callable, vm) else { unsafe { sqlite3_progress_handler(db.db, n, None, null_mut()) }; return Ok(()); }; @@ -1310,7 +1351,7 @@ mod _sqlite { if let Some(val) = &val { begin_statement_ptr_from_isolation_level(val, vm)?; } - unsafe { self.isolation_level.swap(val) }; + let _ = unsafe { self.isolation_level.swap(val) }; Ok(()) } @@ -1320,7 +1361,7 @@ mod _sqlite { } #[pygetset(setter)] fn set_text_factory(&self, val: PyObjectRef) { - unsafe { self.text_factory.swap(val) }; + let _ = unsafe { self.text_factory.swap(val) }; } #[pygetset] @@ -1329,7 +1370,7 @@ mod _sqlite { } #[pygetset(setter)] fn set_row_factory(&self, val: Option) { - unsafe { self.row_factory.swap(val) }; + let _ = unsafe { self.row_factory.swap(val) }; } fn check_thread(&self, vm: &VirtualMachine) -> PyResult<()> { @@ -1841,7 +1882,9 @@ mod _sqlite { } else if let Some(name) = needle.payload::() { for (obj, i) in self.description.iter().zip(0..) { let obj = &obj.payload::().unwrap().as_slice()[0]; - let Some(obj) = obj.payload::() else { break; }; + let Some(obj) = obj.payload::() else { + break; + }; let a_iter = name.as_str().chars().flat_map(|x| x.to_uppercase()); let b_iter = obj.as_str().chars().flat_map(|x| x.to_uppercase()); @@ -2153,7 +2196,10 @@ mod _sqlite { if let Some(index) = needle.try_index_opt(vm) { let Some(value) = value.payload::() else { - return Err(vm.new_type_error(format!("'{}' object cannot be interpreted as an integer", value.class()))); + return Err(vm.new_type_error(format!( + "'{}' object cannot be interpreted as an integer", + value.class() + ))); }; let value = value.try_to_primitive::(vm)?; let blob_len = inner.blob.bytes(); diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index a164fc11b6..1a323b4c47 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -784,6 +784,7 @@ macro_rules! dict_view { &self.dict } fn item(vm: &VirtualMachine, key: PyObjectRef, value: PyObjectRef) -> PyObjectRef { + #[allow(clippy::redundant_closure_call)] $result_fn(vm, key, value) } fn reversed(&self) -> Self::ReverseIter { diff --git a/vm/src/builtins/range.rs b/vm/src/builtins/range.rs index 1543e2e564..348924a7a9 100644 --- a/vm/src/builtins/range.rs +++ b/vm/src/builtins/range.rs @@ -287,7 +287,7 @@ impl PyRange { #[pymethod(magic)] fn reduce(&self, vm: &VirtualMachine) -> (PyTypeRef, PyTupleRef) { - let range_parameters: Vec = vec![&self.start, &self.stop, &self.step] + let range_parameters: Vec = [&self.start, &self.stop, &self.step] .iter() .map(|x| x.as_object().to_owned()) .collect(); diff --git a/vm/src/function/mod.rs b/vm/src/function/mod.rs index 409e4f1dbd..1c3babd808 100644 --- a/vm/src/function/mod.rs +++ b/vm/src/function/mod.rs @@ -15,7 +15,6 @@ pub use argument::{ }; pub use arithmetic::{PyArithmeticValue, PyComparisonValue}; pub use buffer::{ArgAsciiBuffer, ArgBytesLike, ArgMemoryBuffer, ArgStrOrBytesLike}; -pub(self) use builtin::{BorrowedParam, OwnedParam, RefParam}; pub use builtin::{IntoPyNativeFn, PyNativeFn}; pub use either::Either; pub use fspath::FsPath; @@ -26,6 +25,7 @@ pub use number::{ArgIndex, ArgIntoBool, ArgIntoComplex, ArgIntoFloat, ArgPrimiti pub use protocol::{ArgCallable, ArgIterable, ArgMapping, ArgSequence}; use crate::{builtins::PyStr, convert::TryFromBorrowedObject, PyObject, PyResult, VirtualMachine}; +use builtin::{BorrowedParam, OwnedParam, RefParam}; #[derive(Clone, Copy, PartialEq, Eq)] pub enum ArgByteOrder { diff --git a/vm/src/function/protocol.rs b/vm/src/function/protocol.rs index a70dd9e530..14aa5ad841 100644 --- a/vm/src/function/protocol.rs +++ b/vm/src/function/protocol.rs @@ -58,7 +58,9 @@ impl From for PyObjectRef { impl TryFromObject for ArgCallable { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { let Some(callable) = obj.to_callable() else { - return Err(vm.new_type_error(format!("'{}' object is not callable", obj.class().name()))); + return Err( + vm.new_type_error(format!("'{}' object is not callable", obj.class().name())) + ); }; let call = callable.call; Ok(ArgCallable { obj, call }) diff --git a/vm/src/object/core.rs b/vm/src/object/core.rs index c5e42229d3..bda8aaba2f 100644 --- a/vm/src/object/core.rs +++ b/vm/src/object/core.rs @@ -1112,6 +1112,8 @@ macro_rules! partially_init { Uninit { $($uninit_field:ident),*$(,)? }$(,)? ) => {{ // check all the fields are there but *don't* actually run it + + #[allow(clippy::diverging_sub_expression)] // FIXME: better way than using `if false`? if false { #[allow(invalid_value, dead_code, unreachable_code)] let _ = {$ty { @@ -1185,10 +1187,8 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) { Uninit { typ }, ))); - let object_type_ptr = - object_type_ptr as *mut MaybeUninit> as *mut PyInner; - let type_type_ptr = - type_type_ptr as *mut MaybeUninit> as *mut PyInner; + let object_type_ptr = object_type_ptr as *mut PyInner; + let type_type_ptr = type_type_ptr as *mut PyInner; unsafe { (*type_type_ptr).ref_count.inc(); diff --git a/vm/src/protocol/callable.rs b/vm/src/protocol/callable.rs index e8c8b45c7c..8a04e2021e 100644 --- a/vm/src/protocol/callable.rs +++ b/vm/src/protocol/callable.rs @@ -26,10 +26,9 @@ impl PyObject { pub fn call_with_args(&self, args: FuncArgs, vm: &VirtualMachine) -> PyResult { vm_trace!("Invoke: {:?} {:?}", callable, args); let Some(callable) = self.to_callable() else { - return Err(vm.new_type_error(format!( - "'{}' object is not callable", - self.class().name() - ))); + return Err( + vm.new_type_error(format!("'{}' object is not callable", self.class().name())) + ); }; callable.invoke(args, vm) } diff --git a/vm/src/stdlib/posix.rs b/vm/src/stdlib/posix.rs index e9d705b4e6..1adf0006ef 100644 --- a/vm/src/stdlib/posix.rs +++ b/vm/src/stdlib/posix.rs @@ -1343,7 +1343,7 @@ pub mod module { .to_vec(); keys.into_iter() - .zip(values.into_iter()) + .zip(values) .map(|(k, v)| { let k = OsPath::try_from_object(vm, k)?.into_bytes(); let v = OsPath::try_from_object(vm, v)?.into_bytes(); From 03f954408a72a7948c4ee1bcb2b27e258a20a913 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Sun, 27 Aug 2023 00:32:17 +0900 Subject: [PATCH 081/893] Fix windows SSL bug --- stdlib/src/ssl.rs | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index a6ba30375a..746bc53911 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -636,11 +636,19 @@ mod _ssl { ); } + #[cold] + fn invalid_cadata(vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_type_error( + "cadata should be an ASCII string or a bytes-like object".to_owned(), + ) + } + + // validate cadata type and load cadata if let Some(cadata) = args.cadata { let certs = match cadata { Either::A(s) => { if !s.is_ascii() { - return Err(vm.new_type_error("Must be an ascii string".to_owned())); + return Err(invalid_cadata(vm)); } X509::stack_from_pem(s.as_str().as_bytes()) } @@ -1191,6 +1199,7 @@ mod _ssl { vm.new_exception_msg(cls, msg.to_owned()) } + // SSL_FILETYPE_ASN1 part of _add_ca_certs in CPython fn x509_stack_from_der(der: &[u8]) -> Result, ErrorStack> { unsafe { openssl::init(); @@ -1198,20 +1207,25 @@ mod _ssl { let mut certs = vec![]; loop { - let r = sys::d2i_X509_bio(bio.as_ptr(), std::ptr::null_mut()); - if r.is_null() { - let err = sys::ERR_peek_last_error(); - if sys::ERR_GET_LIB(err) == sys::ERR_LIB_ASN1 - && sys::ERR_GET_REASON(err) == sys::ASN1_R_HEADER_TOO_LONG - { - sys::ERR_clear_error(); - break; - } - - return Err(ErrorStack::get()); - } else { - certs.push(X509::from_ptr(r)); + let cert = sys::d2i_X509_bio(bio.as_ptr(), std::ptr::null_mut()); + if cert.is_null() { + break; } + certs.push(X509::from_ptr(cert)); + } + + let err = sys::ERR_peek_last_error(); + + if certs.is_empty() { + // let msg = if filetype == sys::SSL_FILETYPE_PEM { + // "no start line: cadata does not contain a certificate" + // } else { + // "not enough data: cadata does not contain a certificate" + // }; + return Err(ErrorStack::get()); + } + if err != 0 { + return Err(ErrorStack::get()); } Ok(certs) From 287f89aa04aff0666c6a0e1eee1f745ced74def3 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Sun, 27 Aug 2023 02:11:01 +0900 Subject: [PATCH 082/893] Add with/without ssl targets to debugger configuration --- .vscode/launch.json | 18 ++++++++++++------ .vscode/tasks.json | 18 +++++++++++++++++- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 6a632e0f10..fa6f96c5fd 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -8,12 +8,6 @@ "type": "lldb", "request": "launch", "name": "Debug executable 'rustpython'", - "cargo": { - "args": [ - "build", - "--package=rustpython" - ], - }, "preLaunchTask": "Build RustPython Debug", "program": "target/debug/rustpython", "args": [], @@ -22,6 +16,18 @@ }, "cwd": "${workspaceFolder}" }, + { + "type": "lldb", + "request": "launch", + "name": "Debug executable 'rustpython' without SSL", + "preLaunchTask": "Build RustPython Debug without SSL", + "program": "target/debug/rustpython", + "args": [], + "env": { + "RUST_BACKTRACE": "1" + }, + "cwd": "${workspaceFolder}" + }, { "type": "lldb", "request": "launch", diff --git a/.vscode/tasks.json b/.vscode/tasks.json index 415356ac87..18a3d6010d 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -1,12 +1,28 @@ { "version": "2.0.0", "tasks": [ + { + "label": "Build RustPython Debug without SSL", + "type": "shell", + "command": "cargo", + "args": [ + "build", + ], + "problemMatcher": [ + "$rustc", + ], + "group": { + "kind": "build", + "isDefault": true, + }, + }, { "label": "Build RustPython Debug", "type": "shell", "command": "cargo", "args": [ "build", + "--features=ssl" ], "problemMatcher": [ "$rustc", @@ -15,6 +31,6 @@ "kind": "build", "isDefault": true, }, - } + }, ], } \ No newline at end of file From b1238ab4eb9352c3452a64563af930b4bc731a45 Mon Sep 17 00:00:00 2001 From: Dan Nasman Date: Fri, 11 Aug 2023 20:54:39 +0200 Subject: [PATCH 083/893] change import_encodings error message --- vm/src/vm/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index 05ad7ca5a0..5d40c05e07 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -224,7 +224,7 @@ impl VirtualMachine { fn import_encodings(&mut self) -> PyResult<()> { self.import("encodings", None, 0).map_err(|import_err| { let err = self.new_runtime_error( - "Could not import encodings. Is your RUSTPYTHONPATH set? If you don't have \ + "Could not import encodings. Is your RUSTPYTHONPATH set? You can also try adding your path to Setting struct's field path_list. If you don't have \ access to a consistent external environment (e.g. if you're embedding \ rustpython in another application), try enabling the freeze-stdlib feature" .to_owned(), From ba8d7b541fd58b15900b127c816355efb9ec0243 Mon Sep 17 00:00:00 2001 From: Dan Nasman Date: Fri, 11 Aug 2023 21:24:44 +0200 Subject: [PATCH 084/893] change conditional expression --- vm/src/vm/mod.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index 5d40c05e07..a72dfbea8b 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -223,12 +223,18 @@ impl VirtualMachine { #[cfg(feature = "encodings")] fn import_encodings(&mut self) -> PyResult<()> { self.import("encodings", None, 0).map_err(|import_err| { - let err = self.new_runtime_error( - "Could not import encodings. Is your RUSTPYTHONPATH set? You can also try adding your path to Setting struct's field path_list. If you don't have \ + let msg = if !self.state.settings.path_list.iter().any(|s| s == "PYTHONPATH" || s == "RUSTPYTHONPATH"){ + "Could not import encodings. Is your RUSTPYTHONPATH or PYTHONPATH set? If you don't have \ access to a consistent external environment (e.g. if you're embedding \ rustpython in another application), try enabling the freeze-stdlib feature" - .to_owned(), - ); + .to_owned() + } else { + "Could not import encodings. Try adding your path to Setting struct's path_list field. If you don't have \ + access to a consistent external environment (e.g. if you're embedding \ + rustpython in another application), try enabling the freeze-stdlib feature".to_owned() + }; + + let err = self.new_runtime_error(msg); err.set_cause(Some(import_err)); err })?; From 5f6059ef736fd40878dd0ec321ca768ffb0e7461 Mon Sep 17 00:00:00 2001 From: Dan Nasman Date: Fri, 11 Aug 2023 21:25:32 +0200 Subject: [PATCH 085/893] fix formatting --- vm/src/vm/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index a72dfbea8b..b01e28821e 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -231,7 +231,8 @@ impl VirtualMachine { } else { "Could not import encodings. Try adding your path to Setting struct's path_list field. If you don't have \ access to a consistent external environment (e.g. if you're embedding \ - rustpython in another application), try enabling the freeze-stdlib feature".to_owned() + rustpython in another application), try enabling the freeze-stdlib feature" + .to_owned() }; let err = self.new_runtime_error(msg); From 0fbe57f96b12f711e003b80ea54eeba291fe3d5d Mon Sep 17 00:00:00 2001 From: zhangbl Date: Mon, 7 Aug 2023 10:19:55 +0800 Subject: [PATCH 086/893] Update test_code.py code.py from CPython v3.11. --- Lib/code.py | 5 ++-- Lib/test/test_code.py | 53 ++++++++++++++++++++++++++++++------------- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/Lib/code.py b/Lib/code.py index 23295f4cf5..76000f8c8b 100644 --- a/Lib/code.py +++ b/Lib/code.py @@ -7,7 +7,6 @@ import sys import traceback -import argparse from codeop import CommandCompiler, compile_command __all__ = ["InteractiveInterpreter", "InteractiveConsole", "interact", @@ -41,7 +40,7 @@ def runsource(self, source, filename="", symbol="single"): Arguments are as for compile_command(). - One several things can happen: + One of several things can happen: 1) The input is incorrect; compile_command() raised an exception (SyntaxError or OverflowError). A syntax traceback @@ -303,6 +302,8 @@ def interact(banner=None, readfunc=None, local=None, exitmsg=None): if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() parser.add_argument('-q', action='store_true', help="don't print version and copyright messages") diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py index 65630a2715..2661cbaa1f 100644 --- a/Lib/test/test_code.py +++ b/Lib/test/test_code.py @@ -139,13 +139,14 @@ import unittest import textwrap import weakref +import dis try: import ctypes except ImportError: ctypes = None from test.support import (cpython_only, - check_impl_detail, + check_impl_detail, requires_debug_ranges, gc_collect) from test.support.script_helper import assert_python_ok from test.support import threading_helper @@ -165,9 +166,8 @@ def consts(t): def dump(co): """Print out a text representation of a code object.""" for attr in ["name", "argcount", "posonlyargcount", - "kwonlyargcount", "names", "varnames",]: - # TODO: RUSTPYTHON - # "cellvars","freevars", "nlocals", "flags"]: + "kwonlyargcount", "names", "varnames", + "cellvars", "freevars", "nlocals", "flags"]: print("%s: %s" % (attr, getattr(co, "co_" + attr))) print("consts:", tuple(consts(co.co_consts))) @@ -357,18 +357,6 @@ def func(): new_code = code = func.__code__.replace(co_linetable=b'') self.assertEqual(list(new_code.co_lines()), []) - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_invalid_bytecode(self): - def foo(): pass - foo.__code__ = co = foo.__code__.replace(co_code=b'\xee\x00d\x00S\x00') - - with self.assertRaises(SystemError) as se: - foo() - self.assertEqual( - f"{co.co_filename}:{co.co_firstlineno}: unknown opcode 238", - str(se.exception)) - # TODO: RUSTPYTHON @unittest.expectedFailure # @requires_debug_ranges() @@ -717,6 +705,38 @@ def test_lines(self): self.check_lines(misshappen) self.check_lines(bug93662) + @cpython_only + def test_code_new_empty(self): + # If this test fails, it means that the construction of PyCode_NewEmpty + # needs to be modified! Please update this test *and* PyCode_NewEmpty, + # so that they both stay in sync. + def f(): + pass + PY_CODE_LOCATION_INFO_NO_COLUMNS = 13 + f.__code__ = f.__code__.replace( + co_firstlineno=42, + co_code=bytes( + [ + dis.opmap["RESUME"], 0, + dis.opmap["LOAD_ASSERTION_ERROR"], 0, + dis.opmap["RAISE_VARARGS"], 1, + ] + ), + co_linetable=bytes( + [ + (1 << 7) + | (PY_CODE_LOCATION_INFO_NO_COLUMNS << 3) + | (3 - 1), + 0, + ] + ), + ) + self.assertRaises(AssertionError, f) + self.assertEqual( + list(f.__code__.co_positions()), + 3 * [(42, 42, None, None)], + ) + if check_impl_detail(cpython=True) and ctypes is not None: py = ctypes.pythonapi @@ -811,6 +831,7 @@ def run(self): tt.join() self.assertEqual(LAST_FREED, 500) + def load_tests(loader, tests, pattern): tests.addTest(doctest.DocTestSuite()) return tests From 94029386ae85f31a7ebd29395b28af686d9785b0 Mon Sep 17 00:00:00 2001 From: LucaSforza Date: Sun, 23 Jul 2023 16:52:34 +0200 Subject: [PATCH 087/893] added __reduce__ method for itertools.permutations --- vm/src/stdlib/itertools.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 91dbd7489d..526ab61af3 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1670,7 +1670,15 @@ mod decl { } #[pyclass(with(IterNext, Iterable, Constructor))] - impl PyItertoolsPermutations {} + impl PyItertoolsPermutations { + #[pymethod(magic)] + fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyRef { + vm.new_tuple(( + zelf.class().to_owned(), + vm.new_tuple((zelf.pool.clone(), vm.ctx.new_int(zelf.r.load()))), + )) + } + } impl SelfIter for PyItertoolsPermutations {} impl IterNext for PyItertoolsPermutations { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { From 9c2355117c8a9e9881dc5b02f865f4a8cf931233 Mon Sep 17 00:00:00 2001 From: Dominic Elm Date: Wed, 16 Aug 2023 09:27:52 +0200 Subject: [PATCH 088/893] Add ability to initialize cwd from PWD when targeting WASI --- src/lib.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 4c83a0cd3c..d2d942a70e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,6 +63,13 @@ pub use settings::{opts_with_clap, RunMode}; pub fn run(init: impl FnOnce(&mut VirtualMachine) + 'static) -> ExitCode { env_logger::init(); + #[cfg(target_os = "wasi")] + { + if let Ok(pwd) = env::var("PWD") { + let _ = env::set_current_dir(pwd); + }; + } + let (settings, run_mode) = opts_with_clap(); // Be quiet if "quiet" arg is set OR stdin is not connected to a terminal From 3131d56298b056f98aa0c8688e0ccd89a012d3c0 Mon Sep 17 00:00:00 2001 From: Dominic Elm Date: Wed, 23 Aug 2023 11:28:17 +0200 Subject: [PATCH 089/893] fixup: add note --- src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib.rs b/src/lib.rs index d2d942a70e..c9ff153547 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,6 +63,7 @@ pub use settings::{opts_with_clap, RunMode}; pub fn run(init: impl FnOnce(&mut VirtualMachine) + 'static) -> ExitCode { env_logger::init(); + // NOTE: This is not a WASI convention. But it will be convenient since POSIX shell always defines it. #[cfg(target_os = "wasi")] { if let Ok(pwd) = env::var("PWD") { From 9417eec81e6d1656c35e33fdcf841e731ef5edc8 Mon Sep 17 00:00:00 2001 From: Reid00 <38450639+Reid00@users.noreply.github.com> Date: Tue, 29 Aug 2023 23:09:21 +0800 Subject: [PATCH 090/893] feat: update test_hashlib fro CPython3.11 (#5053) Co-authored-by: zhangbl --- Lib/hashlib.py | 4 ++-- Lib/test/test_hashlib.py | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/Lib/hashlib.py b/Lib/hashlib.py index 52a4f8a77b..e5f81d754a 100644 --- a/Lib/hashlib.py +++ b/Lib/hashlib.py @@ -114,8 +114,8 @@ def __get_builtin_constructor(name): cache['shake_128'] = _sha3.shake_128 cache['shake_256'] = _sha3.shake_256 except ImportError: - pass # no extension module, this hash is unsupported.''' - + pass # no extension module, this hash is unsupported. + constructor = cache.get(name) if constructor is not None: return constructor diff --git a/Lib/test/test_hashlib.py b/Lib/test/test_hashlib.py index 6f12aaa936..c03208427a 100644 --- a/Lib/test/test_hashlib.py +++ b/Lib/test/test_hashlib.py @@ -400,9 +400,6 @@ def check_file_digest(self, name, data, hexdigest): for digest in digests: buf = io.BytesIO(data) buf.seek(0) - '''self.assertEqual( - dir(hashlib), None - )''' self.assertEqual( hashlib.file_digest(buf, digest).hexdigest(), hexdigest ) From e2f7d5b2f97f6c95487ea6e95332c19cfadc330f Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 30 Aug 2023 18:25:27 +0900 Subject: [PATCH 091/893] fix wasm prettier (#5055) --- wasm/demo/src/index.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wasm/demo/src/index.js b/wasm/demo/src/index.js index cc604b1d43..aa1619eb2f 100644 --- a/wasm/demo/src/index.js +++ b/wasm/demo/src/index.js @@ -78,9 +78,9 @@ function updateSnippet() { // the require here creates a webpack context; it's fine to use it // dynamically. // https://webpack.js.org/guides/dependency-management/ - const { - default: snippet, - } = require(`raw-loader!../snippets/${selected}.py`); + const { default: snippet } = require( + `raw-loader!../snippets/${selected}.py`, + ); editor.setValue(snippet); runCodeFromTextarea(); From d4be55c2ea02b67d92c04b3e9f682d465acd59b3 Mon Sep 17 00:00:00 2001 From: Junho Lee <46811505+naonus@users.noreply.github.com> Date: Wed, 30 Aug 2023 19:32:27 +0900 Subject: [PATCH 092/893] Add command line parameter -P (#4611) * Add command line parameter -P * Modify the value of safe_path to be set --------- Co-authored-by: Jeong YunWon --- Lib/test/test_support.py | 2 -- extra_tests/snippets/stdlib_sys.py | 23 ++++++++++++++++++++++- src/settings.rs | 11 +++++++++++ vm/src/stdlib/sys.rs | 2 +- vm/src/vm/setting.rs | 4 ++++ 5 files changed, 38 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 8158bee302..6ad272697b 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -513,8 +513,6 @@ def check_options(self, args, func, expected=None): self.assertEqual(proc.stdout.rstrip(), repr(expected)) self.assertEqual(proc.returncode, 0) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_args_from_interpreter_flags(self): # Test test.support.args_from_interpreter_flags() for opts in ( diff --git a/extra_tests/snippets/stdlib_sys.py b/extra_tests/snippets/stdlib_sys.py index cb92fc6416..5d8859ac8e 100644 --- a/extra_tests/snippets/stdlib_sys.py +++ b/extra_tests/snippets/stdlib_sys.py @@ -1,4 +1,6 @@ import sys +import os +import subprocess from testutils import assert_raises @@ -105,4 +107,23 @@ def recursive_call(n): sys.set_int_max_str_digits(1) sys.set_int_max_str_digits(1000) -assert sys.get_int_max_str_digits() == 1000 \ No newline at end of file +assert sys.get_int_max_str_digits() == 1000 + +# Test the PYTHONSAFEPATH environment variable +code = "import sys; print(sys.flags.safe_path)" +env = dict(os.environ) +env.pop('PYTHONSAFEPATH', None) +args = (sys.executable, '-P', '-c', code) + +proc = subprocess.run( + args, stdout=subprocess.PIPE, + universal_newlines=True, env=env) +assert proc.stdout.rstrip() == 'True', proc +assert proc.returncode == 0, proc + +env['PYTHONSAFEPATH'] = '1' +proc = subprocess.run( + args, stdout=subprocess.PIPE, + universal_newlines=True, env=env) +assert proc.stdout.rstrip() == 'True' +assert proc.returncode == 0, proc diff --git a/src/settings.rs b/src/settings.rs index 6a47233353..494d594de4 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -100,6 +100,11 @@ fn parse_arguments<'a>(app: App<'a, '_>) -> ArgMatches<'a> { .short("B") .help("don't write .pyc files on import"), ) + .arg( + Arg::with_name("safe-path") + .short("P") + .help("don’t prepend a potentially unsafe path to sys.path"), + ) .arg( Arg::with_name("ignore-environment") .short("E") @@ -237,6 +242,12 @@ fn settings_from(matches: &ArgMatches) -> (Settings, RunMode) { }; } + if matches.is_present("safe-path") + || (!ignore_environment && env::var_os("PYTHONSAFEPATH").is_some()) + { + settings.safe_path = true; + } + settings.check_hash_based_pycs = matches .value_of("check-hash-based-pycs") .unwrap_or("default") diff --git a/vm/src/stdlib/sys.rs b/vm/src/stdlib/sys.rs index 3a3824b310..47f9376f17 100644 --- a/vm/src/stdlib/sys.rs +++ b/vm/src/stdlib/sys.rs @@ -754,7 +754,7 @@ mod sys { dev_mode: settings.dev_mode, utf8_mode: settings.utf8_mode, int_max_str_digits: settings.int_max_str_digits, - safe_path: false, + safe_path: settings.safe_path, warn_default_encoding: settings.warn_default_encoding as u8, } } diff --git a/vm/src/vm/setting.rs b/vm/src/vm/setting.rs index 508c50da78..c4d1f85b2a 100644 --- a/vm/src/vm/setting.rs +++ b/vm/src/vm/setting.rs @@ -37,6 +37,9 @@ pub struct Settings { /// -B pub dont_write_bytecode: bool, + /// -P + pub safe_path: bool, + /// -b pub bytes_warning: u64, @@ -108,6 +111,7 @@ impl Default for Settings { verbose: 0, quiet: false, dont_write_bytecode: false, + safe_path: false, bytes_warning: 0, xopts: vec![], isolated: false, From aee68d20bb03d3c371ec1460f0fabbcf1188bb1d Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 30 Aug 2023 19:50:20 +0900 Subject: [PATCH 093/893] Fix `freeze-stdlib` + `Interpreter::without_stdlib` (#5051) * Fix pylib invalidation config * Fix Interpreter::without_stdlib with frozen-stdlib feature --- pylib/build.rs | 13 ++++++++----- vm/src/vm/mod.rs | 16 ++++++++++------ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/pylib/build.rs b/pylib/build.rs index 52a2e3a632..1aca8b3318 100644 --- a/pylib/build.rs +++ b/pylib/build.rs @@ -1,11 +1,14 @@ fn main() { - process_python_libs("../Lib/python_builtins/*"); + process_python_libs("../vm/Lib/python_builtins/*"); #[cfg(not(feature = "stdlib"))] - process_python_libs("../Lib/core_modules/*"); - - #[cfg(feature = "stdlib")] - process_python_libs("../../Lib/**/*"); + process_python_libs("../vm/Lib/core_modules/*"); + #[cfg(feature = "freeze-stdlib")] + if cfg!(windows) { + process_python_libs("../Lib/**/*"); + } else { + process_python_libs("./Lib/**/*"); + } if cfg!(windows) { if let Ok(real_path) = std::fs::read_to_string("Lib") { diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index b01e28821e..fa2dfb9073 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -244,11 +244,13 @@ impl VirtualMachine { fn import_utf8_encodings(&mut self) -> PyResult<()> { import::import_frozen(self, "codecs")?; - let encoding_module_name = if cfg!(feature = "freeze-stdlib") { - "encodings.utf_8" - } else { - "encodings_utf_8" - }; + // FIXME: See corresponding part of `core_frozen_inits` + // let encoding_module_name = if cfg!(feature = "freeze-stdlib") { + // "encodings.utf_8" + // } else { + // "encodings_utf_8" + // }; + let encoding_module_name = "encodings_utf_8"; let encoding_module = import::import_frozen(self, encoding_module_name)?; let getregentry = encoding_module.get_attr("getregentry", self)?; let codec_info = getregentry.call((), self)?; @@ -875,7 +877,9 @@ fn core_frozen_inits() -> impl Iterator { // core stdlib Python modules that the vm calls into, but are still used in Python // application code, e.g. copyreg - #[cfg(not(feature = "freeze-stdlib"))] + // FIXME: Initializing core_modules here results duplicated frozen module generation for core_modules. + // We need a way to initialize this modules for both `Interpreter::without_stdlib()` and `InterpreterConfig::new().init_stdlib().interpreter()` + // #[cfg(not(feature = "freeze-stdlib"))] ext_modules!( iter, dir = "./Lib/core_modules", From 4ba2892168f7491e182c299e3a4c81619d000809 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 30 Aug 2023 20:00:07 +0900 Subject: [PATCH 094/893] Better tips for Interpreter & InterpreterConfig (#5047) Co-authored-by: fanninpm --- src/interpreter.rs | 34 ++++++++++++++++++++++++++++++++++ vm/src/vm/interpreter.rs | 5 ++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/interpreter.rs b/src/interpreter.rs index 39a9662cd3..d2ec119b4e 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -2,6 +2,40 @@ use rustpython_vm::{Interpreter, Settings, VirtualMachine}; pub type InitHook = Box; +/// The convenient way to create [rustpython_vm::Interpreter] with stdlib and other stuffs. +/// +/// Basic usage: +/// ``` +/// let interpreter = rustpython::InterpreterConfig::new() +/// .init_stdlib() +/// .interpreter(); +/// ``` +/// +/// To override [rustpython_vm::Settings]: +/// ``` +/// use rustpython_vm::Settings; +/// // Override your settings here. +/// let mut settings = Settings::default(); +/// settings.debug = true; +/// // You may want to add paths to `rustpython_vm::Settings::path_list` to allow import python libraries. +/// settings.path_list.push("".to_owned()); // add current working directory +/// let interpreter = rustpython::InterpreterConfig::new() +/// .settings(settings) +/// .interpreter(); +/// ``` +/// +/// To add native modules: +/// ```compile_fail +/// let interpreter = rustpython::InterpreterConfig::new() +/// .init_stdlib() +/// .init_hook(Box::new(|vm| { +/// vm.add_native_module( +/// "your_module_name".to_owned(), +/// Box::new(your_module::make_module), +/// ); +/// })) +/// .interpreter(); +/// ``` #[derive(Default)] pub struct InterpreterConfig { settings: Option, diff --git a/vm/src/vm/interpreter.rs b/vm/src/vm/interpreter.rs index a27c1f5dd8..9fcd2ea1d2 100644 --- a/vm/src/vm/interpreter.rs +++ b/vm/src/vm/interpreter.rs @@ -28,7 +28,10 @@ pub struct Interpreter { } impl Interpreter { - /// To create with stdlib, use `with_init` + /// This is a bare unit to build up an interpreter without the standard library. + /// To create an interpreter with the standard library with the `rustpython` crate, use `rustpython::InterpreterConfig`. + /// To create an interpreter without the `rustpython` crate, but only with `rustpython-vm`, + /// try to build one from the source code of `InterpreterConfig`. It will not be a one-liner but it also will not be too hard. pub fn without_stdlib(settings: Settings) -> Self { Self::with_init(settings, |_| {}) } From 64c66e00d68c70fed261f8b46720438efda746b5 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 30 Aug 2023 21:45:36 +0900 Subject: [PATCH 095/893] Fix `encodings` related error messages to be less confusing (#5049) * Move cwd setup to interpreter code * rework import encodings failure message * try import_encodings only when `path_list` is set * std::mem::take instead of drain(..).collect() * Add empty path_list warnings to import_encodings * Prepend current working directory when !safe_path Co-authored-by: fanninpm --- src/interpreter.rs | 3 +-- src/lib.rs | 11 +++++++++- vm/src/vm/mod.rs | 51 +++++++++++++++++++++++++++++++------------- vm/src/vm/setting.rs | 7 +++--- 4 files changed, 50 insertions(+), 22 deletions(-) diff --git a/src/interpreter.rs b/src/interpreter.rs index d2ec119b4e..b84f167ae4 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -84,8 +84,7 @@ pub fn init_stdlib(vm: &mut VirtualMachine) { let state = PyRc::get_mut(&mut vm.state).unwrap(); let settings = &mut state.settings; - #[allow(clippy::needless_collect)] // false positive - let path_list: Vec<_> = settings.path_list.drain(..).collect(); + let path_list = std::mem::take(&mut settings.path_list); // BUILDTIME_RUSTPYTHONPATH should be set when distributing if let Some(paths) = option_env!("BUILDTIME_RUSTPYTHONPATH") { diff --git a/src/lib.rs b/src/lib.rs index c9ff153547..f916dd6828 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -165,8 +165,17 @@ fn run_rustpython(vm: &VirtualMachine, run_mode: RunMode, quiet: bool) -> PyResu let scope = setup_main_module(vm)?; - let site_result = vm.import("site", None, 0); + if !vm.state.settings.safe_path { + // TODO: The prepending path depends on running mode + // See https://docs.python.org/3/using/cmdline.html#cmdoption-P + vm.run_code_string( + vm.new_scope_with_builtins(), + "import sys; sys.path.insert(0, '')", + "".to_owned(), + )?; + } + let site_result = vm.import("site", None, 0); if site_result.is_err() { warn!( "Failed to import site, consider adding the Lib directory to your RUSTPYTHONPATH \ diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index fa2dfb9073..e943ab7fab 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -223,18 +223,30 @@ impl VirtualMachine { #[cfg(feature = "encodings")] fn import_encodings(&mut self) -> PyResult<()> { self.import("encodings", None, 0).map_err(|import_err| { - let msg = if !self.state.settings.path_list.iter().any(|s| s == "PYTHONPATH" || s == "RUSTPYTHONPATH"){ - "Could not import encodings. Is your RUSTPYTHONPATH or PYTHONPATH set? If you don't have \ - access to a consistent external environment (e.g. if you're embedding \ - rustpython in another application), try enabling the freeze-stdlib feature" - .to_owned() + let rustpythonpath_env = std::env::var("RUSTPYTHONPATH").ok(); + let pythonpath_env = std::env::var("PYTHONPATH").ok(); + let env_set = rustpythonpath_env.as_ref().is_some() || pythonpath_env.as_ref().is_some(); + let path_contains_env = self.state.settings.path_list.iter().any(|s| { + Some(s.as_str()) == rustpythonpath_env.as_deref() || Some(s.as_str()) == pythonpath_env.as_deref() + }); + + let guide_message = if !env_set { + "Neither RUSTPYTHONPATH nor PYTHONPATH is set. Try setting one of them to the stdlib directory." + } else if path_contains_env { + "RUSTPYTHONPATH or PYTHONPATH is set, but it doesn't contain the encodings library. If you are customizing the RustPython vm/interpreter, try adding the stdlib directory to the path. If you are developing the RustPython interpreter, it might be a bug during development." } else { - "Could not import encodings. Try adding your path to Setting struct's path_list field. If you don't have \ - access to a consistent external environment (e.g. if you're embedding \ - rustpython in another application), try enabling the freeze-stdlib feature" - .to_owned() + "RUSTPYTHONPATH or PYTHONPATH is set, but it wasn't loaded to `Settings::path_list`. If you are going to customize the RustPython vm/interpreter, those environment variables are not loaded in the Settings struct by default. Please try creating a customized instance of the Settings struct. If you are developing the RustPython interpreter, it might be a bug during development." }; + let msg = format!( + "RustPython could not import the encodings module. It usually means something went wrong. Please carefully read the following messages and follow the steps.\n\ + \n\ + {guide_message}\n\ + If you don't have access to a consistent external environment (e.g. targeting wasm, embedding \ + rustpython in another application), try enabling the `freeze-stdlib` feature.\n\ + If this is intended and you want to exclude the encodings module from your interpreter, please remove the `encodings` feature from `rustpython-vm` crate." + ); + let err = self.new_runtime_error(msg); err.set_cause(Some(import_err)); err @@ -267,9 +279,6 @@ impl VirtualMachine { panic!("Double Initialize Error"); } - // add the current directory to sys.path - self.state_mut().settings.path_list.insert(0, "".to_owned()); - stdlib::builtins::init_module(self, &self.builtins); stdlib::sys::init_module(self, &self.sys_module, &self.builtins); @@ -327,9 +336,21 @@ impl VirtualMachine { } #[cfg(feature = "encodings")] - if let Err(e) = self.import_encodings() { - eprintln!("encodings initialization failed. Only utf-8 encoding will be supported."); - self.print_exception(e); + if cfg!(feature = "freeze-stdlib") || !self.state.settings.path_list.is_empty() { + if let Err(e) = self.import_encodings() { + eprintln!( + "encodings initialization failed. Only utf-8 encoding will be supported." + ); + self.print_exception(e); + } + } else { + // Here may not be the best place to give general `path_list` advice, + // but bare rustpython_vm::VirtualMachine users skipped proper settings must hit here while properly setup vm never enters here. + eprintln!( + "feature `encodings` is enabled but `settings.path_list` is empty. \ + Please add the library path to `settings.path_list`. If you intended to disable the entire standard library (including the `encodings` feature), please also make sure to disable the `encodings` feature.\n\ + Tip: You may also want to add `\"\"` to `settings.path_list` in order to enable importing from the current working directory." + ); } self.initialized = true; diff --git a/vm/src/vm/setting.rs b/vm/src/vm/setting.rs index c4d1f85b2a..a30c75560c 100644 --- a/vm/src/vm/setting.rs +++ b/vm/src/vm/setting.rs @@ -89,10 +89,9 @@ pub struct Settings { } impl Settings { - pub fn with_path(path: String) -> Self { - let mut settings = Self::default(); - settings.path_list.push(path); - settings + pub fn with_path(mut self, path: String) -> Self { + self.path_list.push(path); + self } } From 3900a086b8b8945405217c2579eb432fc78a2216 Mon Sep 17 00:00:00 2001 From: Reid00 <38450639+Reid00@users.noreply.github.com> Date: Thu, 31 Aug 2023 16:44:31 +0800 Subject: [PATCH 096/893] update imp.py from CPython 3.11 (#5054) --- Lib/imp.py | 4 +-- Lib/test/test_imp.py | 64 +++++++++++++++++++++++++++++++++----------- vm/src/stdlib/imp.rs | 5 ++++ 3 files changed, 56 insertions(+), 17 deletions(-) diff --git a/Lib/imp.py b/Lib/imp.py index e02aaef344..fc42c15765 100644 --- a/Lib/imp.py +++ b/Lib/imp.py @@ -9,7 +9,7 @@ from _imp import (lock_held, acquire_lock, release_lock, get_frozen_object, is_frozen_package, init_frozen, is_builtin, is_frozen, - _fix_co_filename) + _fix_co_filename, _frozen_module_names) try: from _imp import create_dynamic except ImportError: @@ -226,7 +226,7 @@ def load_module(name, file, filename, details): """ suffix, mode, type_ = details - if mode and (not mode.startswith(('r', 'U')) or '+' in mode): + if mode and (not mode.startswith('r') or '+' in mode): raise ValueError('invalid file open mode {!r}'.format(mode)) elif file is None and type_ in {PY_SOURCE, PY_COMPILED}: msg = 'file object required for import (type code {})'.format(type_) diff --git a/Lib/test/test_imp.py b/Lib/test/test_imp.py index 9b602582e2..1ccd072b3f 100644 --- a/Lib/test/test_imp.py +++ b/Lib/test/test_imp.py @@ -1,3 +1,4 @@ +import gc import importlib import importlib.util import os @@ -8,19 +9,21 @@ from test.support import import_helper from test.support import os_helper from test.support import script_helper +from test.support import warnings_helper import unittest import warnings -with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - import imp +imp = warnings_helper.import_deprecated('imp') import _imp +OS_PATH_NAME = os.path.__name__ + + def requires_load_dynamic(meth): """Decorator to skip a test if not running under CPython or lacking imp.load_dynamic().""" meth = support.cpython_only(meth) - return unittest.skipIf(not hasattr(imp, 'load_dynamic'), + return unittest.skipIf(getattr(imp, 'load_dynamic', None) is None, 'imp.load_dynamic() required')(meth) @@ -216,15 +219,17 @@ def test_load_from_source(self): # state after reversion. Reinitialising the module contents # and just reverting os.environ to its previous state is an OK # workaround - orig_path = os.path - orig_getenv = os.getenv - with os_helper.EnvironmentVarGuard(): - x = imp.find_module("os") - self.addCleanup(x[0].close) - new_os = imp.load_module("os", *x) - self.assertIs(os, new_os) - self.assertIs(orig_path, new_os.path) - self.assertIsNot(orig_getenv, new_os.getenv) + with import_helper.CleanImport('os', 'os.path', OS_PATH_NAME): + import os + orig_path = os.path + orig_getenv = os.getenv + with os_helper.EnvironmentVarGuard(): + x = imp.find_module("os") + self.addCleanup(x[0].close) + new_os = imp.load_module("os", *x) + self.assertIs(os, new_os) + self.assertIs(orig_path, new_os.path) + self.assertIsNot(orig_getenv, new_os.getenv) @requires_load_dynamic def test_issue15828_load_extensions(self): @@ -351,8 +356,8 @@ def test_issue_35321(self): # TODO: RUSTPYTHON @unittest.expectedFailure def test_source_hash(self): - self.assertEqual(_imp.source_hash(42, b'hi'), b'\xc6\xe7Z\r\x03:}\xab') - self.assertEqual(_imp.source_hash(43, b'hi'), b'\x85\x9765\xf8\x9a\x8b9') + self.assertEqual(_imp.source_hash(42, b'hi'), b'\xfb\xd9G\x05\xaf$\x9b~') + self.assertEqual(_imp.source_hash(43, b'hi'), b'\xd0/\x87C\xccC\xff\xe2') def test_pyc_invalidation_mode_from_cmdline(self): cases = [ @@ -384,6 +389,35 @@ def test_find_and_load_checked_pyc(self): self.assertEqual(mod.x, 42) + @support.cpython_only + def test_create_builtin_subinterp(self): + # gh-99578: create_builtin() behavior changes after the creation of the + # first sub-interpreter. Test both code paths, before and after the + # creation of a sub-interpreter. Previously, create_builtin() had + # a reference leak after the creation of the first sub-interpreter. + + import builtins + create_builtin = support.get_attribute(_imp, "create_builtin") + class Spec: + name = "builtins" + spec = Spec() + + def check_get_builtins(): + refcnt = sys.getrefcount(builtins) + mod = _imp.create_builtin(spec) + self.assertIs(mod, builtins) + self.assertEqual(sys.getrefcount(builtins), refcnt + 1) + # Check that a GC collection doesn't crash + gc.collect() + + check_get_builtins() + + ret = support.run_in_subinterp("import builtins") + self.assertEqual(ret, 0) + + check_get_builtins() + + class ReloadTests(unittest.TestCase): """Very basic tests to make sure that imp.reload() operates just like diff --git a/vm/src/stdlib/imp.rs b/vm/src/stdlib/imp.rs index 773afed6db..e2e735acd1 100644 --- a/vm/src/stdlib/imp.rs +++ b/vm/src/stdlib/imp.rs @@ -154,6 +154,11 @@ mod _imp { // TODO: } + #[pyfunction] + fn _frozen_module_names(_code: PyObjectRef) { + // TODO: + } + #[allow(clippy::type_complexity)] #[pyfunction] fn find_frozen( From c25aa1add4737d1f855550121b0eab60bf1b400d Mon Sep 17 00:00:00 2001 From: Dominic Elm Date: Thu, 31 Aug 2023 10:46:00 +0200 Subject: [PATCH 097/893] Print version and info text when executing the shell (#5057) --- src/shell.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/shell.rs b/src/shell.rs index 91becb0bbf..75d980b44c 100644 --- a/src/shell.rs +++ b/src/shell.rs @@ -6,7 +6,7 @@ use rustpython_vm::{ compiler::{self, CompileError, CompileErrorType}, readline::{Readline, ReadlineResult}, scope::Scope, - AsObject, PyResult, VirtualMachine, + version, AsObject, PyResult, VirtualMachine, }; enum ShellExecResult { @@ -93,6 +93,15 @@ pub fn run_shell(vm: &VirtualMachine, scope: Scope) -> PyResult<()> { let mut continuing = false; + println!( + "RustPython {}.{}.{}", + version::MAJOR, + version::MINOR, + version::MICRO, + ); + + println!("Type \"help\", \"copyright\", \"credits\" or \"license\" for more information."); + loop { let prompt_name = if continuing { "ps2" } else { "ps1" }; let prompt = vm From 9cf18a8bdca86b53e24247ef8820750e281808d9 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Fri, 1 Sep 2023 01:44:05 +0900 Subject: [PATCH 098/893] feature `encodings` is dependency of stdlib (#5061) --- Cargo.toml | 4 ++-- vm/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8fd494a67f..cd10c10f92 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,10 +80,10 @@ unicode_names2 = { version = "0.6.0", git = "https://github.com/youknowone/unico widestring = "0.5.1" [features] -default = ["threading", "stdlib", "zlib", "importlib", "encodings"] +default = ["threading", "stdlib", "zlib", "importlib"] importlib = ["rustpython-vm/importlib"] encodings = ["rustpython-vm/encodings"] -stdlib = ["rustpython-stdlib", "rustpython-pylib"] +stdlib = ["rustpython-stdlib", "rustpython-pylib", "encodings"] flame-it = ["rustpython-vm/flame-it", "flame", "flamescope"] freeze-stdlib = ["rustpython-vm/freeze-stdlib", "rustpython-pylib?/freeze-stdlib"] jit = ["rustpython-vm/jit"] diff --git a/vm/Cargo.toml b/vm/Cargo.toml index d25096bd4e..fb3ab1e554 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -14,7 +14,7 @@ importlib = [] encodings = ["importlib"] vm-tracing-logging = [] flame-it = ["flame", "flamer"] -freeze-stdlib = [] +freeze-stdlib = ["encodings"] jit = ["rustpython-jit"] threading = ["rustpython-common/threading"] compiler = ["parser", "codegen", "rustpython-compiler"] From 77939d2ca5f9a87a2c752c37b08163ac6293d8b9 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Fri, 1 Sep 2023 13:55:28 +0900 Subject: [PATCH 099/893] Update platform from CPython 3.11.5 (#5060) * Update platform and test from CPython 3.11.5 * sys.dllhandle (=0) * Unmark fixed test of test_sysconfig --------- Co-authored-by: CPython Developers <> --- Lib/platform.py | 35 ++++++++++++++++++++++++----------- Lib/test/test_platform.py | 13 ++++++++++++- Lib/test/test_sysconfig.py | 1 - vm/src/stdlib/sys.rs | 5 +++++ 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/Lib/platform.py b/Lib/platform.py index fe88fa9d52..58b66078e1 100755 --- a/Lib/platform.py +++ b/Lib/platform.py @@ -5,7 +5,7 @@ If called from the command line, it prints the platform information concatenated as single string to stdout. The output - format is useable as part of a filename. + format is usable as part of a filename. """ # This module is maintained by Marc-Andre Lemburg . @@ -116,7 +116,6 @@ import os import re import sys -import subprocess import functools import itertools @@ -169,7 +168,7 @@ def libc_ver(executable=None, lib='', version='', chunksize=16384): Note that the function has intimate knowledge of how different libc versions add symbols to the executable and thus is probably - only useable for executables compiled using gcc. + only usable for executables compiled using gcc. The file is read and scanned in chunks of chunksize bytes. @@ -187,12 +186,15 @@ def libc_ver(executable=None, lib='', version='', chunksize=16384): executable = sys.executable + if not executable: + # sys.executable is not set. + return lib, version + V = _comparable_version - if hasattr(os.path, 'realpath'): - # Python 2.2 introduced os.path.realpath(); it is used - # here to work around problems with Cygwin not being - # able to open symlinks for reading - executable = os.path.realpath(executable) + # We use os.path.realpath() + # here to work around problems with Cygwin not being + # able to open symlinks for reading + executable = os.path.realpath(executable) with open(executable, 'rb') as f: binary = f.read(chunksize) pos = 0 @@ -283,6 +285,7 @@ def _syscmd_ver(system='', release='', version='', stdin=subprocess.DEVNULL, stderr=subprocess.DEVNULL, text=True, + encoding="locale", shell=True) except (OSError, subprocess.CalledProcessError) as why: #print('Command %s failed: %s' % (cmd, why)) @@ -609,7 +612,10 @@ def _syscmd_file(target, default=''): # XXX Others too ? return default - import subprocess + try: + import subprocess + except ImportError: + return default target = _follow_symlinks(target) # "file" output is locale dependent: force the usage of the C locale # to get deterministic behavior. @@ -748,11 +754,16 @@ def from_subprocess(): """ Fall back to `uname -p` """ + try: + import subprocess + except ImportError: + return None try: return subprocess.check_output( ['uname', '-p'], stderr=subprocess.DEVNULL, text=True, + encoding="utf8", ).strip() except (OSError, subprocess.CalledProcessError): pass @@ -776,6 +787,8 @@ class uname_result( except when needed. """ + _fields = ('system', 'node', 'release', 'version', 'machine', 'processor') + @functools.cached_property def processor(self): return _unknown_as_blank(_Processor.get()) @@ -789,7 +802,7 @@ def __iter__(self): @classmethod def _make(cls, iterable): # override factory to affect length check - num_fields = len(cls._fields) + num_fields = len(cls._fields) - 1 result = cls.__new__(cls, *iterable) if len(result) != num_fields + 1: msg = f'Expected {num_fields} arguments, got {len(result)}' @@ -803,7 +816,7 @@ def __len__(self): return len(tuple(iter(self))) def __reduce__(self): - return uname_result, tuple(self)[:len(self._fields)] + return uname_result, tuple(self)[:len(self._fields) - 1] _uname_cache = None diff --git a/Lib/test/test_platform.py b/Lib/test/test_platform.py index c6a1cc3417..c9f27575b5 100644 --- a/Lib/test/test_platform.py +++ b/Lib/test/test_platform.py @@ -78,8 +78,8 @@ def clear_caches(self): def test_architecture(self): res = platform.architecture() - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") @os_helper.skip_unless_symlink + @support.requires_subprocess() def test_architecture_via_symlink(self): # issue3762 with support.PythonSymlink() as py: cmd = "-c", "import platform; print(platform.architecture())" @@ -269,7 +269,16 @@ def test_uname_slices(self): self.assertEqual(res[:], expected) self.assertEqual(res[:5], expected[:5]) + def test_uname_fields(self): + self.assertIn('processor', platform.uname()._fields) + + def test_uname_asdict(self): + res = platform.uname()._asdict() + self.assertEqual(len(res), 6) + self.assertIn('processor', res) + @unittest.skipIf(sys.platform in ['win32', 'OpenVMS'], "uname -p not used") + @support.requires_subprocess() def test_uname_processor(self): """ On some systems, the processor must match the output @@ -346,6 +355,7 @@ def test_mac_ver(self): else: self.assertEqual(res[2], 'PowerPC') + @unittest.skipUnless(sys.platform == 'darwin', "OSX only test") def test_mac_ver_with_fork(self): # Issue7895: platform.mac_ver() crashes when using fork without exec @@ -362,6 +372,7 @@ def test_mac_ver_with_fork(self): # parent support.wait_process(pid, exitcode=0) + @unittest.skipIf(support.is_emscripten, "Does not apply to Emscripten") def test_libc_ver(self): # check that libc_ver(executable) doesn't raise an exception if os.path.isdir(sys.executable) and \ diff --git a/Lib/test/test_sysconfig.py b/Lib/test/test_sysconfig.py index 2d662f94ab..6db1442980 100644 --- a/Lib/test/test_sysconfig.py +++ b/Lib/test/test_sysconfig.py @@ -346,7 +346,6 @@ def test_get_scheme_names(self): wanted.extend(['nt_user', 'osx_framework_user', 'posix_user']) self.assertEqual(get_scheme_names(), tuple(sorted(wanted))) - @unittest.expectedFailureIfWindows("TODO: RUSTPYTHON") @skip_unless_symlink @requires_subprocess() def test_symlink(self): # Issue 7880 diff --git a/vm/src/stdlib/sys.rs b/vm/src/stdlib/sys.rs index 47f9376f17..81498a7b06 100644 --- a/vm/src/stdlib/sys.rs +++ b/vm/src/stdlib/sys.rs @@ -71,10 +71,15 @@ mod sys { const PS1: &str = ">>>>> "; #[pyattr(name = "ps2")] const PS2: &str = "..... "; + #[cfg(windows)] #[pyattr(name = "_vpath")] const VPATH: Option<&'static str> = None; // TODO: actual VPATH value + #[cfg(windows)] + #[pyattr(name = "dllhandle")] + const DLLHANDLE: usize = 0; + #[pyattr] fn default_prefix(_vm: &VirtualMachine) -> &'static str { // TODO: the windows one doesn't really make sense From b864e5da1f18897fc884180b7093df5aa170024f Mon Sep 17 00:00:00 2001 From: Reid00 <38450639+Reid00@users.noreply.github.com> Date: Sun, 3 Sep 2023 20:33:42 +0800 Subject: [PATCH 100/893] feat: Implement _imp._frozen_module_names (#5062) --- vm/src/stdlib/imp.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vm/src/stdlib/imp.rs b/vm/src/stdlib/imp.rs index e2e735acd1..f59f8e4e63 100644 --- a/vm/src/stdlib/imp.rs +++ b/vm/src/stdlib/imp.rs @@ -155,8 +155,14 @@ mod _imp { } #[pyfunction] - fn _frozen_module_names(_code: PyObjectRef) { - // TODO: + fn _frozen_module_names(vm: &VirtualMachine) -> PyResult> { + let names = vm + .state + .frozen + .keys() + .map(|&name| vm.ctx.new_str(name).into()) + .collect(); + Ok(names) } #[allow(clippy::type_complexity)] From 8c11992e598dce48e7e82f10a755fc0d8e673353 Mon Sep 17 00:00:00 2001 From: Reid00 <38450639+Reid00@users.noreply.github.com> Date: Tue, 5 Sep 2023 16:18:00 +0800 Subject: [PATCH 101/893] test: update test_difflib from CPython3.11.2 (#5063) --- Lib/difflib.py | 53 +++++++-------------------------- Lib/test/test_difflib.py | 64 +++++++++++++++++++++++++++++++++++----- 2 files changed, 66 insertions(+), 51 deletions(-) diff --git a/Lib/difflib.py b/Lib/difflib.py index 0b14d3c779..ba0b256969 100644 --- a/Lib/difflib.py +++ b/Lib/difflib.py @@ -62,7 +62,7 @@ class SequenceMatcher: notion, pairing up elements that appear uniquely in each sequence. That, and the method here, appear to yield more intuitive difference reports than does diff. This method appears to be the least vulnerable - to synching up on blocks of "junk lines", though (like blank lines in + to syncing up on blocks of "junk lines", though (like blank lines in ordinary text files, or maybe "

" lines in HTML files). That may be because this is the only method of the 3 that has a *concept* of "junk" . @@ -115,38 +115,6 @@ class SequenceMatcher: case. SequenceMatcher is quadratic time for the worst case and has expected-case behavior dependent in a complicated way on how many elements the sequences have in common; best case time is linear. - - Methods: - - __init__(isjunk=None, a='', b='') - Construct a SequenceMatcher. - - set_seqs(a, b) - Set the two sequences to be compared. - - set_seq1(a) - Set the first sequence to be compared. - - set_seq2(b) - Set the second sequence to be compared. - - find_longest_match(alo, ahi, blo, bhi) - Find longest matching block in a[alo:ahi] and b[blo:bhi]. - - get_matching_blocks() - Return list of triples describing matching subsequences. - - get_opcodes() - Return list of 5-tuples describing how to turn a into b. - - ratio() - Return a measure of the sequences' similarity (float in [0,1]). - - quick_ratio() - Return an upper bound on .ratio() relatively quickly. - - real_quick_ratio() - Return an upper bound on ratio() very quickly. """ def __init__(self, isjunk=None, a='', b='', autojunk=True): @@ -334,9 +302,11 @@ def __chain_b(self): for elt in popular: # ditto; as fast for 1% deletion del b2j[elt] - def find_longest_match(self, alo, ahi, blo, bhi): + def find_longest_match(self, alo=0, ahi=None, blo=0, bhi=None): """Find longest matching block in a[alo:ahi] and b[blo:bhi]. + By default it will find the longest match in the entirety of a and b. + If isjunk is not defined: Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where @@ -391,6 +361,10 @@ def find_longest_match(self, alo, ahi, blo, bhi): # the unique 'b's and then matching the first two 'a's. a, b, b2j, isbjunk = self.a, self.b, self.b2j, self.bjunk.__contains__ + if ahi is None: + ahi = len(a) + if bhi is None: + bhi = len(b) besti, bestj, bestsize = alo, blo, 0 # find longest junk-free match # during an iteration of the loop, j2len[j] = length of longest @@ -688,6 +662,7 @@ def real_quick_ratio(self): __class_getitem__ = classmethod(GenericAlias) + def get_close_matches(word, possibilities, n=3, cutoff=0.6): """Use SequenceMatcher to return list of the best "good enough" matches. @@ -830,14 +805,6 @@ class Differ: + 4. Complicated is better than complex. ? ++++ ^ ^ + 5. Flat is better than nested. - - Methods: - - __init__(linejunk=None, charjunk=None) - Construct a text differencer, with optional filters. - - compare(a, b) - Compare two sequences of lines; generate the resulting delta. """ def __init__(self, linejunk=None, charjunk=None): @@ -870,7 +837,7 @@ def compare(self, a, b): Each sequence must contain individual single-line strings ending with newlines. Such sequences can be obtained from the `readlines()` method of file-like objects. The delta generated also consists of newline- - terminated strings, ready to be printed as-is via the writeline() + terminated strings, ready to be printed as-is via the writelines() method of a file-like object. Example: diff --git a/Lib/test/test_difflib.py b/Lib/test/test_difflib.py index 208208fd68..ed41074f7e 100644 --- a/Lib/test/test_difflib.py +++ b/Lib/test/test_difflib.py @@ -1,5 +1,5 @@ import difflib -from test.support import run_unittest, findfile +from test.support import findfile import unittest import doctest import sys @@ -241,7 +241,7 @@ def test_html_diff(self): #with open('test_difflib_expect.html','w') as fp: # fp.write(actual) - with open(findfile('test_difflib_expect.html')) as fp: + with open(findfile('test_difflib_expect.html'), encoding="utf-8") as fp: self.assertEqual(actual, fp.read()) def test_recursion_limit(self): @@ -503,12 +503,60 @@ def test_is_character_junk_false(self): for char in ['a', '#', '\n', '\f', '\r', '\v']: self.assertFalse(difflib.IS_CHARACTER_JUNK(char), repr(char)) -def test_main(): +class TestFindLongest(unittest.TestCase): + def longer_match_exists(self, a, b, n): + return any(b_part in a for b_part in + [b[i:i + n + 1] for i in range(0, len(b) - n - 1)]) + + def test_default_args(self): + a = 'foo bar' + b = 'foo baz bar' + sm = difflib.SequenceMatcher(a=a, b=b) + match = sm.find_longest_match() + self.assertEqual(match.a, 0) + self.assertEqual(match.b, 0) + self.assertEqual(match.size, 6) + self.assertEqual(a[match.a: match.a + match.size], + b[match.b: match.b + match.size]) + self.assertFalse(self.longer_match_exists(a, b, match.size)) + + match = sm.find_longest_match(alo=2, blo=4) + self.assertEqual(match.a, 3) + self.assertEqual(match.b, 7) + self.assertEqual(match.size, 4) + self.assertEqual(a[match.a: match.a + match.size], + b[match.b: match.b + match.size]) + self.assertFalse(self.longer_match_exists(a[2:], b[4:], match.size)) + + match = sm.find_longest_match(bhi=5, blo=1) + self.assertEqual(match.a, 1) + self.assertEqual(match.b, 1) + self.assertEqual(match.size, 4) + self.assertEqual(a[match.a: match.a + match.size], + b[match.b: match.b + match.size]) + self.assertFalse(self.longer_match_exists(a, b[1:5], match.size)) + + def test_longest_match_with_popular_chars(self): + a = 'dabcd' + b = 'd'*100 + 'abc' + 'd'*100 # length over 200 so popular used + sm = difflib.SequenceMatcher(a=a, b=b) + match = sm.find_longest_match(0, len(a), 0, len(b)) + self.assertEqual(match.a, 0) + self.assertEqual(match.b, 99) + self.assertEqual(match.size, 5) + self.assertEqual(a[match.a: match.a + match.size], + b[match.b: match.b + match.size]) + self.assertFalse(self.longer_match_exists(a, b, match.size)) + + +def setUpModule(): difflib.HtmlDiff._default_prefix = 0 - Doctests = doctest.DocTestSuite(difflib) - run_unittest( - TestWithAscii, TestAutojunk, TestSFpatches, TestSFbugs, - TestOutputFormat, TestBytes, TestJunkAPIs, Doctests) + + +def load_tests(loader, tests, pattern): + tests.addTest(doctest.DocTestSuite(difflib)) + return tests + if __name__ == '__main__': - test_main() + unittest.main() From e5cea3ad37f5a0ae5ac0153360b411aed335990e Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Tue, 29 Aug 2023 20:20:06 +0900 Subject: [PATCH 102/893] Update parser to 0.3.0 --- Cargo.lock | 47 ++++++++++++++++++++--------------------------- Cargo.toml | 10 +++++----- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f79b0748a6..2d2e02ce85 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1950,21 +1950,6 @@ dependencies = [ "syn-ext", ] -[[package]] -name = "ruff_source_location" -version = "0.0.0" -source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" -dependencies = [ - "memchr", - "once_cell", - "ruff_text_size", -] - -[[package]] -name = "ruff_text_size" -version = "0.0.0" -source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" - [[package]] name = "rustc-hash" version = "1.1.0" @@ -2020,8 +2005,8 @@ dependencies = [ [[package]] name = "rustpython-ast" -version = "0.2.0" -source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" +version = "0.3.0" +source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" dependencies = [ "is-macro", "malachite-bigint", @@ -2132,8 +2117,8 @@ dependencies = [ [[package]] name = "rustpython-format" -version = "0.2.0" -source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" +version = "0.3.0" +source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" dependencies = [ "bitflags 2.3.1", "itertools 0.10.5", @@ -2159,8 +2144,8 @@ dependencies = [ [[package]] name = "rustpython-literal" -version = "0.2.0" -source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" +version = "0.3.0" +source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" dependencies = [ "hexf-parse", "is-macro", @@ -2171,8 +2156,8 @@ dependencies = [ [[package]] name = "rustpython-parser" -version = "0.2.0" -source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" +version = "0.3.0" +source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" dependencies = [ "anyhow", "is-macro", @@ -2194,13 +2179,21 @@ dependencies = [ [[package]] name = "rustpython-parser-core" -version = "0.2.0" -source = "git+https://github.com/RustPython/Parser.git?rev=704eb40108239a8faf9bd1d4217e8dad0ac7edb3#704eb40108239a8faf9bd1d4217e8dad0ac7edb3" +version = "0.3.0" +source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" dependencies = [ "is-macro", "memchr", - "ruff_source_location", - "ruff_text_size", + "rustpython-parser-vendored", +] + +[[package]] +name = "rustpython-parser-vendored" +version = "0.3.0" +source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" +dependencies = [ + "memchr", + "once_cell", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index cd10c10f92..10898d2cae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,11 +29,11 @@ rustpython-pylib = { path = "pylib" } rustpython-stdlib = { path = "stdlib" } rustpython-doc = { git = "https://github.com/RustPython/__doc__", branch = "main" } -rustpython-literal = { git = "https://github.com/RustPython/Parser.git", rev = "704eb40108239a8faf9bd1d4217e8dad0ac7edb3" } -rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", rev = "704eb40108239a8faf9bd1d4217e8dad0ac7edb3" } -rustpython-parser = { git = "https://github.com/RustPython/Parser.git", rev = "704eb40108239a8faf9bd1d4217e8dad0ac7edb3" } -rustpython-ast = { git = "https://github.com/RustPython/Parser.git", rev = "704eb40108239a8faf9bd1d4217e8dad0ac7edb3" } -rustpython-format = { git = "https://github.com/RustPython/Parser.git", rev = "704eb40108239a8faf9bd1d4217e8dad0ac7edb3" } +rustpython-literal = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0" } +rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0" } +rustpython-parser = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0" } +rustpython-ast = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0" } +rustpython-format = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0" } # rustpython-literal = { path = "../RustPython-parser/literal" } # rustpython-parser-core = { path = "../RustPython-parser/core" } # rustpython-parser = { path = "../RustPython-parser/parser" } From 1208416b92aa20268d2eabbf5092e7b411c5555b Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Wed, 30 Aug 2023 22:21:46 +0900 Subject: [PATCH 103/893] 0.3.0 release --- Cargo.lock | 32 ++++++++++++------------ Cargo.toml | 49 ++++++++++++++++++------------------- benches/execution.rs | 5 ++-- common/Cargo.toml | 2 +- compiler/Cargo.toml | 4 ++- compiler/codegen/Cargo.toml | 2 +- compiler/core/Cargo.toml | 5 ++-- derive-impl/Cargo.toml | 6 ++++- derive/Cargo.toml | 6 ++--- jit/Cargo.toml | 4 +-- pylib/Cargo.toml | 4 +-- stdlib/Cargo.toml | 12 ++++++--- vm/Cargo.toml | 14 +++++------ wapm.toml | 2 +- wasm/lib/Cargo.toml | 10 ++++---- 15 files changed, 83 insertions(+), 74 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2d2e02ce85..36335bb395 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -548,9 +548,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.14" +version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" dependencies = [ "cfg-if", ] @@ -1981,7 +1981,7 @@ dependencies = [ [[package]] name = "rustpython" -version = "0.2.0" +version = "0.3.0" dependencies = [ "atty", "cfg-if", @@ -2017,7 +2017,7 @@ dependencies = [ [[package]] name = "rustpython-codegen" -version = "0.2.0" +version = "0.3.0" dependencies = [ "ahash", "bitflags 2.3.1", @@ -2035,7 +2035,7 @@ dependencies = [ [[package]] name = "rustpython-common" -version = "0.2.0" +version = "0.3.0" dependencies = [ "ascii", "bitflags 2.3.1", @@ -2061,7 +2061,7 @@ dependencies = [ [[package]] name = "rustpython-compiler" -version = "0.2.0" +version = "0.3.0" dependencies = [ "rustpython-codegen", "rustpython-compiler-core", @@ -2070,7 +2070,7 @@ dependencies = [ [[package]] name = "rustpython-compiler-core" -version = "0.2.0" +version = "0.3.0" dependencies = [ "bitflags 2.3.1", "itertools 0.10.5", @@ -2083,7 +2083,7 @@ dependencies = [ [[package]] name = "rustpython-derive" -version = "0.2.0" +version = "0.3.0" dependencies = [ "rustpython-compiler", "rustpython-derive-impl", @@ -2092,7 +2092,7 @@ dependencies = [ [[package]] name = "rustpython-derive-impl" -version = "0.2.0" +version = "0.3.0" dependencies = [ "itertools 0.10.5", "maplit", @@ -2109,8 +2109,8 @@ dependencies = [ [[package]] name = "rustpython-doc" -version = "0.1.0" -source = "git+https://github.com/RustPython/__doc__?branch=main#d927debd491e4c45b88e953e6e50e4718e0f2965" +version = "0.3.0" +source = "git+https://github.com/RustPython/__doc__?tag=0.3.0#8b62ce5d796d68a091969c9fa5406276cb483f79" dependencies = [ "once_cell", ] @@ -2129,7 +2129,7 @@ dependencies = [ [[package]] name = "rustpython-jit" -version = "0.2.0" +version = "0.3.0" dependencies = [ "approx", "cranelift", @@ -2198,7 +2198,7 @@ dependencies = [ [[package]] name = "rustpython-pylib" -version = "0.2.0" +version = "0.3.0" dependencies = [ "glob", "rustpython-compiler-core", @@ -2207,7 +2207,7 @@ dependencies = [ [[package]] name = "rustpython-stdlib" -version = "0.2.0" +version = "0.3.0" dependencies = [ "adler32", "ahash", @@ -2278,7 +2278,7 @@ dependencies = [ [[package]] name = "rustpython-vm" -version = "0.2.0" +version = "0.3.0" dependencies = [ "ahash", "ascii", @@ -2354,7 +2354,7 @@ dependencies = [ [[package]] name = "rustpython_wasm" -version = "0.2.0" +version = "0.3.0" dependencies = [ "console_error_panic_hook", "js-sys", diff --git a/Cargo.toml b/Cargo.toml index 10898d2cae..d56f61ee59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustpython" -version = "0.2.0" +version = "0.3.0" authors = ["RustPython Team"] edition = "2021" rust-version = "1.67.1" @@ -17,23 +17,23 @@ members = [ ] [workspace.dependencies] -rustpython-compiler-core = { path = "compiler/core" } -rustpython-compiler = { path = "compiler" } -rustpython-codegen = { path = "compiler/codegen" } -rustpython-common = { path = "common" } -rustpython-derive = { path = "derive" } -rustpython-derive-impl = { path = "derive-impl" } -rustpython-jit = { path = "jit" } -rustpython-vm = { path = "vm" } -rustpython-pylib = { path = "pylib" } -rustpython-stdlib = { path = "stdlib" } -rustpython-doc = { git = "https://github.com/RustPython/__doc__", branch = "main" } - -rustpython-literal = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0" } -rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0" } -rustpython-parser = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0" } -rustpython-ast = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0" } -rustpython-format = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0" } +rustpython-compiler-core = { path = "compiler/core", version = "0.3.0" } +rustpython-compiler = { path = "compiler", version = "0.3.0" } +rustpython-codegen = { path = "compiler/codegen", version = "0.3.0" } +rustpython-common = { path = "common", version = "0.3.0" } +rustpython-derive = { path = "derive", version = "0.3.0" } +rustpython-derive-impl = { path = "derive-impl", version = "0.3.0" } +rustpython-jit = { path = "jit", version = "0.3.0" } +rustpython-vm = { path = "vm", version = "0.3.0" } +rustpython-pylib = { path = "pylib", version = "0.3.0" } +rustpython-stdlib = { path = "stdlib", version = "0.3.0" } +rustpython-doc = { git = "https://github.com/RustPython/__doc__", tag = "0.3.0", version = "0.3.0" } + +rustpython-literal = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0", version = "0.3.0" } +rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0", version = "0.3.0" } +rustpython-parser = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0", version = "0.3.0" } +rustpython-ast = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0", version = "0.3.0" } +rustpython-format = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0", version = "0.3.0" } # rustpython-literal = { path = "../RustPython-parser/literal" } # rustpython-parser-core = { path = "../RustPython-parser/core" } # rustpython-parser = { path = "../RustPython-parser/parser" } @@ -48,7 +48,7 @@ bitflags = "2.2.1" bstr = "0.2.17" cfg-if = "1.0" chrono = "0.4.19" -crossbeam-utils = "0.8.9" +crossbeam-utils = "0.8.16" flame = "0.2.2" glob = "0.3" hex = "0.4.3" @@ -94,11 +94,10 @@ ssl = ["rustpython-stdlib/ssl"] ssl-vendor = ["rustpython-stdlib/ssl-vendor"] [dependencies] -rustpython-compiler = { path = "compiler", version = "0.2.0" } -rustpython-pylib = { path = "pylib", optional = true, default-features = false } -rustpython-stdlib = { path = "stdlib", optional = true, default-features = false } -rustpython-vm = { path = "vm", version = "0.2.0", default-features = false, features = ["compiler"] } - +rustpython-compiler = { workspace = true } +rustpython-pylib = { workspace = true, optional = true } +rustpython-stdlib = { workspace = true, optional = true } +rustpython-vm = { workspace = true, default-features = false, features = ["compiler"] } rustpython-parser = { workspace = true } atty = { workspace = true } @@ -120,7 +119,7 @@ rustyline = { workspace = true } [dev-dependencies] cpython = "0.7.0" criterion = "0.3.5" -python3-sys = "0.7.0" +python3-sys = "0.7.1" [[bench]] name = "execution" diff --git a/benches/execution.rs b/benches/execution.rs index 6e0c1503e5..14fadfc2a5 100644 --- a/benches/execution.rs +++ b/benches/execution.rs @@ -3,7 +3,8 @@ use criterion::{ criterion_group, criterion_main, Bencher, BenchmarkGroup, BenchmarkId, Criterion, Throughput, }; use rustpython_compiler::Mode; -use rustpython_parser::parse_program; +use rustpython_parser::ast; +use rustpython_parser::Parse; use rustpython_vm::{Interpreter, PyResult}; use std::collections::HashMap; use std::path::Path; @@ -51,7 +52,7 @@ pub fn benchmark_file_execution( pub fn benchmark_file_parsing(group: &mut BenchmarkGroup, name: &str, contents: &str) { group.throughput(Throughput::Bytes(contents.len() as u64)); group.bench_function(BenchmarkId::new("rustpython", name), |b| { - b.iter(|| parse_program(contents, name).unwrap()) + b.iter(|| ast::Suite::parse(contents, name).unwrap()) }); group.bench_function(BenchmarkId::new("cpython", name), |b| { let gil = cpython::Python::acquire_gil(); diff --git a/common/Cargo.toml b/common/Cargo.toml index 42b46cb514..397897f491 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustpython-common" -version = "0.2.0" +version = "0.3.0" description = "General python functions and algorithms for use in RustPython" authors = ["RustPython Team"] edition = "2021" diff --git a/compiler/Cargo.toml b/compiler/Cargo.toml index 8d5d81e0bd..8dec426b2b 100644 --- a/compiler/Cargo.toml +++ b/compiler/Cargo.toml @@ -1,8 +1,10 @@ [package] name = "rustpython-compiler" -version = "0.2.0" +version = "0.3.0" description = "A usability wrapper around rustpython-parser and rustpython-compiler-core" authors = ["RustPython Team"] +repository = "https://github.com/RustPython/RustPython" +license = "MIT" edition = "2021" [dependencies] diff --git a/compiler/codegen/Cargo.toml b/compiler/codegen/Cargo.toml index 6d33f48831..8cfd485b82 100644 --- a/compiler/codegen/Cargo.toml +++ b/compiler/codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustpython-codegen" -version = "0.2.0" +version = "0.3.0" description = "Compiler for python code into bytecode for the rustpython VM." authors = ["RustPython Team"] repository = "https://github.com/RustPython/RustPython" diff --git a/compiler/core/Cargo.toml b/compiler/core/Cargo.toml index 4fc05510be..d515d20707 100644 --- a/compiler/core/Cargo.toml +++ b/compiler/core/Cargo.toml @@ -1,14 +1,14 @@ [package] name = "rustpython-compiler-core" description = "RustPython specific bytecode." -version = "0.2.0" +version = "0.3.0" authors = ["RustPython Team"] edition = "2021" repository = "https://github.com/RustPython/RustPython" license = "MIT" [dependencies] -rustpython-parser-core = { workspace = true } +rustpython-parser-core = { workspace = true, features=["location"] } bitflags = { workspace = true } itertools = { workspace = true } @@ -17,4 +17,3 @@ num-complex = { workspace = true } serde = { version = "1.0.133", optional = true, default-features = false, features = ["derive"] } lz4_flex = "0.9.2" - diff --git a/derive-impl/Cargo.toml b/derive-impl/Cargo.toml index 1c6b214cc6..2b2b4131b3 100644 --- a/derive-impl/Cargo.toml +++ b/derive-impl/Cargo.toml @@ -1,6 +1,10 @@ [package] name = "rustpython-derive-impl" -version = "0.2.0" +version = "0.3.0" +description = "Rust language extensions and macros specific to rustpython." +authors = ["RustPython Team"] +repository = "https://github.com/RustPython/RustPython" +license = "MIT" edition = "2021" [dependencies] diff --git a/derive/Cargo.toml b/derive/Cargo.toml index f14d6db49e..5010877a6c 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustpython-derive" -version = "0.2.0" +version = "0.3.0" description = "Rust language extensions and macros specific to rustpython." authors = ["RustPython Team"] repository = "https://github.com/RustPython/RustPython" @@ -11,6 +11,6 @@ edition = "2021" proc-macro = true [dependencies] -rustpython-compiler = { path = "../compiler", version = "0.2.0" } -rustpython-derive-impl = { path = "../derive-impl" } +rustpython-compiler = { workspace = true } +rustpython-derive-impl = { workspace = true } syn = { workspace = true } diff --git a/jit/Cargo.toml b/jit/Cargo.toml index 0a7ff67fb2..71f8822cc9 100644 --- a/jit/Cargo.toml +++ b/jit/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustpython-jit" -version = "0.2.0" +version = "0.3.0" description = "Experimental JIT(just in time) compiler for python code." authors = ["RustPython Team"] repository = "https://github.com/RustPython/RustPython" @@ -21,7 +21,7 @@ cranelift-module = "0.88.0" libffi = "3.1.0" [dev-dependencies] -rustpython-derive = { path = "../derive", version = "0.2.0" } +rustpython-derive = { path = "../derive", version = "0.3.0" } approx = "0.5.1" diff --git a/pylib/Cargo.toml b/pylib/Cargo.toml index b2e6ab6f09..256dbe9d61 100644 --- a/pylib/Cargo.toml +++ b/pylib/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustpython-pylib" -version = "0.2.0" +version = "0.3.0" authors = ["RustPython Team"] description = "A subset of the Python standard library for use with RustPython" repository = "https://github.com/RustPython/RustPython" @@ -13,7 +13,7 @@ freeze-stdlib = [] [dependencies] rustpython-compiler-core = { workspace = true } -rustpython-derive = { version = "0.2.0", path = "../derive" } +rustpython-derive = { version = "0.3.0", path = "../derive" } [build-dependencies] glob = { workspace = true } diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index aa63df0eb6..cdb3902e52 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -1,6 +1,10 @@ [package] name = "rustpython-stdlib" -version = "0.2.0" +version = "0.3.0" +description = "RustPython standard libraries in Rust." +authors = ["RustPython Team"] +repository = "https://github.com/RustPython/RustPython" +license = "MIT" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -14,9 +18,9 @@ ssl-vendor = ["ssl", "openssl/vendored", "openssl-probe"] [dependencies] # rustpython crates -rustpython-derive = { path = "../derive" } -rustpython-vm = { path = "../vm" } -rustpython-common = { path = "../common" } +rustpython-derive = { workspace = true } +rustpython-vm = { workspace = true } +rustpython-common = { workspace = true } ahash = { workspace = true } ascii = { workspace = true } diff --git a/vm/Cargo.toml b/vm/Cargo.toml index fb3ab1e554..167bb956dc 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "rustpython-vm" -version = "0.2.0" -description = "Rust Python virtual machine." +version = "0.3.0" +description = "RustPython virtual machine." authors = ["RustPython Team"] repository = "https://github.com/RustPython/RustPython" license = "MIT" @@ -24,11 +24,11 @@ parser = ["rustpython-parser", "ast"] serde = ["dep:serde"] [dependencies] -rustpython-compiler = { path = "../compiler", optional = true, version = "0.2.0" } -rustpython-codegen = { path = "../compiler/codegen", optional = true, version = "0.2.0" } -rustpython-common = { path = "../common" } -rustpython-derive = { path = "../derive", version = "0.2.0" } -rustpython-jit = { path = "../jit", optional = true, version = "0.2.0" } +rustpython-compiler = { workspace = true, optional = true } +rustpython-codegen = { workspace = true, optional = true } +rustpython-common = { workspace = true } +rustpython-derive = { workspace = true } +rustpython-jit = { workspace = true, optional = true } rustpython-ast = { workspace = true, optional = true } rustpython-parser = { workspace = true, optional = true } diff --git a/wapm.toml b/wapm.toml index c8aabc9cce..98bf96bed1 100644 --- a/wapm.toml +++ b/wapm.toml @@ -1,6 +1,6 @@ [package] name = "rustpython" -version = "0.2.0" +version = "0.3.0" description = "A Python-3 (CPython >= 3.5.0) Interpreter written in Rust 🐍 😱 🤘" license-file = "LICENSE" readme = "README.md" diff --git a/wasm/lib/Cargo.toml b/wasm/lib/Cargo.toml index 15276ae74b..98be000fee 100644 --- a/wasm/lib/Cargo.toml +++ b/wasm/lib/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustpython_wasm" -version = "0.2.0" +version = "0.3.0" authors = ["RustPython Team"] license = "MIT" description = "A Python-3 (CPython >= 3.5.0) Interpreter written in Rust, compiled to WASM" @@ -16,11 +16,11 @@ freeze-stdlib = ["rustpython-vm/freeze-stdlib", "rustpython-pylib/freeze-stdlib" no-start-func = [] [dependencies] -rustpython-common = { path = "../../common" } -rustpython-pylib = { path = "../../pylib", default-features = false, optional = true } -rustpython-stdlib = { path = "../../stdlib", default-features = false, optional = true } +rustpython-common = { workspace = true } +rustpython-pylib = { workspace = true, optional = true } +rustpython-stdlib = { workspace = true, default-features = false, optional = true } # make sure no threading! otherwise wasm build will fail -rustpython-vm = { path = "../../vm", default-features = false, features = ["compiler", "encodings", "serde"] } +rustpython-vm = { workspace = true, default-features = false, features = ["compiler", "encodings", "serde"] } rustpython-parser = { workspace = true } From 21cff29c31c127b69eb5b8899452c09f4ede4a56 Mon Sep 17 00:00:00 2001 From: Caleb Cartwright Date: Thu, 14 Sep 2023 03:11:44 -0500 Subject: [PATCH 104/893] ci: simplify rustfmt invocation (#5064) --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 559756971b..8e89357884 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -308,7 +308,7 @@ jobs: with: components: rustfmt, clippy - name: run rustfmt - run: cargo fmt --all -- --check + run: cargo fmt --check - name: run clippy on wasm run: cargo clippy --manifest-path=wasm/lib/Cargo.toml -- -Dwarnings - uses: actions/setup-python@v4 From 39169de63a749eba24efb0d36b5b8d8e7ecb5c66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jelmer=20Vernoo=C4=B3?= Date: Fri, 22 Sep 2023 18:06:47 +0200 Subject: [PATCH 105/893] bump is-macro to 0.3 (#5066) --- Cargo.lock | 91 +++++++++++++++++++++++++++++++++++---------------- vm/Cargo.toml | 2 +- 2 files changed, 64 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 36335bb395..0826a3b526 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -616,7 +616,7 @@ dependencies = [ "proc-macro2", "quote", "scratch", - "syn", + "syn 1.0.107", ] [[package]] @@ -633,7 +633,7 @@ checksum = "086c685979a698443656e5cf7856c95c642295a38599f12fb1ff76fb28d19892" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -646,7 +646,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version", - "syn", + "syn 1.0.107", ] [[package]] @@ -723,7 +723,7 @@ dependencies = [ "base64", "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -818,7 +818,7 @@ checksum = "36b732da54fd4ea34452f2431cf464ac7be94ca4b339c9cd3d3d12eb06fe7aab" dependencies = [ "flame", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -1043,10 +1043,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a7d079e129b77477a49c5c4f1cfe9ce6c2c909ef52520693e8e811a714c7b20" dependencies = [ "Inflector", - "pmutil", + "pmutil 0.5.3", "proc-macro2", "quote", - "syn", + "syn 1.0.107", +] + +[[package]] +name = "is-macro" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4467ed1321b310c2625c5aa6c1b1ffc5de4d9e42668cf697a08fb033ee8265e" +dependencies = [ + "Inflector", + "pmutil 0.6.1", + "proc-macro2", + "quote", + "syn 2.0.32", ] [[package]] @@ -1487,7 +1500,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -1531,7 +1544,7 @@ checksum = "b501e44f11665960c7e7fcf062c7d96a14ade4aa98116c004b2e37b5be7d736c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -1693,7 +1706,18 @@ checksum = "3894e5d549cccbe44afecf72922f277f603cd4bb0219c8342631ef18fffbe004" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", +] + +[[package]] +name = "pmutil" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52a40bc70c2c58040d2d8b167ba9a5ff59fc9dab7ad44771cfde3dcfde7a09c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.32", ] [[package]] @@ -1739,9 +1763,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.23" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" dependencies = [ "proc-macro2", ] @@ -1943,10 +1967,10 @@ version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a29c8a4ac7839f1dcb8b899263b501e0d6932f210300c8a0d271323727b35c1" dependencies = [ - "pmutil", + "pmutil 0.5.3", "proc-macro2", "quote", - "syn", + "syn 1.0.107", "syn-ext", ] @@ -2008,7 +2032,7 @@ name = "rustpython-ast" version = "0.3.0" source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" dependencies = [ - "is-macro", + "is-macro 0.2.2", "malachite-bigint", "rustpython-literal", "rustpython-parser-core", @@ -2087,7 +2111,7 @@ version = "0.3.0" dependencies = [ "rustpython-compiler", "rustpython-derive-impl", - "syn", + "syn 1.0.107", ] [[package]] @@ -2102,7 +2126,7 @@ dependencies = [ "rustpython-compiler-core", "rustpython-doc", "rustpython-parser-core", - "syn", + "syn 1.0.107", "syn-ext", "textwrap 0.15.2", ] @@ -2148,7 +2172,7 @@ version = "0.3.0" source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" dependencies = [ "hexf-parse", - "is-macro", + "is-macro 0.2.2", "lexical-parse-float", "num-traits", "unic-ucd-category", @@ -2160,7 +2184,7 @@ version = "0.3.0" source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" dependencies = [ "anyhow", - "is-macro", + "is-macro 0.2.2", "itertools 0.10.5", "lalrpop-util", "log", @@ -2182,7 +2206,7 @@ name = "rustpython-parser-core" version = "0.3.0" source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" dependencies = [ - "is-macro", + "is-macro 0.2.2", "memchr", "rustpython-parser-vendored", ] @@ -2297,7 +2321,7 @@ dependencies = [ "half", "hex", "indexmap", - "is-macro", + "is-macro 0.3.0", "itertools 0.10.5", "libc", "log", @@ -2480,7 +2504,7 @@ checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -2617,7 +2641,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn", + "syn 1.0.107", ] [[package]] @@ -2637,13 +2661,24 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn" +version = "2.0.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn-ext" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b86cb2b68c5b3c078cac02588bc23f3c04bb828c5d3aedd17980876ec6a7be6" dependencies = [ - "syn", + "syn 1.0.107", ] [[package]] @@ -2723,7 +2758,7 @@ checksum = "1fb327af4685e4d03fa8cbcf1716380da910eeb2bb8be417e7f9fd3fb164f36f" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -3029,7 +3064,7 @@ checksum = "c1b300a878652a387d2a0de915bdae8f1a548f0c6d45e072fe2688794b656cc9" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", ] [[package]] @@ -3106,7 +3141,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn", + "syn 1.0.107", "wasm-bindgen-shared", ] @@ -3140,7 +3175,7 @@ checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.107", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 167bb956dc..966858c087 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -70,7 +70,7 @@ caseless = "0.2.1" getrandom = { version = "0.2.6", features = ["js"] } flamer = { version = "0.4", optional = true } half = "1.8.2" -is-macro = "0.2.2" +is-macro = "0.3" memchr = "2.4.1" memoffset = "0.6.5" optional = "0.5.0" From 37ce45fffb28da54c6f91756cf28af50af51e4f0 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 27 Sep 2023 17:42:54 +0900 Subject: [PATCH 106/893] Fix windows CI error (#5068) --- .github/workflows/ci.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8e89357884..5c1b1b6637 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -127,7 +127,7 @@ jobs: shell: bash run: | choco install llvm openssl --no-progress - echo "OPENSSL_DIR=C:\Program Files\OpenSSL" >>$GITHUB_ENV + echo "OPENSSL_DIR=C:\Program Files\OpenSSL-Win64" >> $GITHUB_ENV if: runner.os == 'Windows' - name: Set up the Mac environment run: brew install autoconf automake libtool @@ -252,7 +252,7 @@ jobs: shell: bash run: | choco install llvm openssl --no-progress - echo "OPENSSL_DIR=C:\Program Files\OpenSSL" >>$GITHUB_ENV + echo "OPENSSL_DIR=C:\Program Files\OpenSSL-Win64" >>$GITHUB_ENV if: runner.os == 'Windows' - name: Set up the Mac environment run: brew install autoconf automake libtool openssl@3 From 9031a0ac9fdfc194f239229dedb9bcbc5478b64f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jelmer=20Vernoo=C4=B3?= Date: Wed, 27 Sep 2023 09:43:17 +0100 Subject: [PATCH 107/893] bump lz4_flex dependency to 0.11 (#5067) --- Cargo.lock | 4 ++-- compiler/core/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0826a3b526..ee6aab2125 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1242,9 +1242,9 @@ dependencies = [ [[package]] name = "lz4_flex" -version = "0.9.5" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a8cbbb2831780bc3b9c15a41f5b49222ef756b6730a95f3decfdd15903eb5a3" +checksum = "3ea9b256699eda7b0387ffbc776dd625e28bde3918446381781245b7a50349d8" dependencies = [ "twox-hash", ] diff --git a/compiler/core/Cargo.toml b/compiler/core/Cargo.toml index d515d20707..12413b7b9e 100644 --- a/compiler/core/Cargo.toml +++ b/compiler/core/Cargo.toml @@ -16,4 +16,4 @@ malachite-bigint = { workspace = true } num-complex = { workspace = true } serde = { version = "1.0.133", optional = true, default-features = false, features = ["derive"] } -lz4_flex = "0.9.2" +lz4_flex = "0.11" From 10e4f715a5b72bee3a33b479ccf29c198075fdd5 Mon Sep 17 00:00:00 2001 From: dvermd <315743+dvermd@users.noreply.github.com> Date: Sun, 24 Sep 2023 07:10:50 +0200 Subject: [PATCH 108/893] Update compileall to CPython 3.11.5 --- Lib/compileall.py | 286 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 227 insertions(+), 59 deletions(-) diff --git a/Lib/compileall.py b/Lib/compileall.py index 1c9ceb6930..a388931fb5 100644 --- a/Lib/compileall.py +++ b/Lib/compileall.py @@ -4,7 +4,7 @@ given as arguments recursively; the -l option prevents it from recursing into directories. -Without arguments, if compiles all modules on sys.path, without +Without arguments, it compiles all modules on sys.path, without recursing into subdirectories. (Even though it should do so for packages -- for now, you'll have to deal with packages separately.) @@ -15,16 +15,14 @@ import importlib.util import py_compile import struct +import filecmp -try: - from concurrent.futures import ProcessPoolExecutor -except ImportError: - ProcessPoolExecutor = None from functools import partial +from pathlib import Path __all__ = ["compile_dir","compile_file","compile_path"] -def _walk_dir(dir, ddir=None, maxlevels=10, quiet=0): +def _walk_dir(dir, maxlevels, quiet=0): if quiet < 2 and isinstance(dir, os.PathLike): dir = os.fspath(dir) if not quiet: @@ -40,59 +38,94 @@ def _walk_dir(dir, ddir=None, maxlevels=10, quiet=0): if name == '__pycache__': continue fullname = os.path.join(dir, name) - if ddir is not None: - dfile = os.path.join(ddir, name) - else: - dfile = None if not os.path.isdir(fullname): yield fullname elif (maxlevels > 0 and name != os.curdir and name != os.pardir and os.path.isdir(fullname) and not os.path.islink(fullname)): - yield from _walk_dir(fullname, ddir=dfile, - maxlevels=maxlevels - 1, quiet=quiet) + yield from _walk_dir(fullname, maxlevels=maxlevels - 1, + quiet=quiet) -def compile_dir(dir, maxlevels=10, ddir=None, force=False, rx=None, - quiet=0, legacy=False, optimize=-1, workers=1): +def compile_dir(dir, maxlevels=None, ddir=None, force=False, + rx=None, quiet=0, legacy=False, optimize=-1, workers=1, + invalidation_mode=None, *, stripdir=None, + prependdir=None, limit_sl_dest=None, hardlink_dupes=False): """Byte-compile all modules in the given directory tree. Arguments (only dir is required): dir: the directory to byte-compile - maxlevels: maximum recursion level (default 10) + maxlevels: maximum recursion level (default `sys.getrecursionlimit()`) ddir: the directory that will be prepended to the path to the file as it is compiled into each byte-code file. force: if True, force compilation, even if timestamps are up-to-date quiet: full output with False or 0, errors only with 1, no output with 2 legacy: if True, produce legacy pyc paths instead of PEP 3147 paths - optimize: optimization level or -1 for level of the interpreter + optimize: int or list of optimization levels or -1 for level of + the interpreter. Multiple levels leads to multiple compiled + files each with one optimization level. workers: maximum number of parallel workers + invalidation_mode: how the up-to-dateness of the pyc will be checked + stripdir: part of path to left-strip from source file path + prependdir: path to prepend to beginning of original file path, applied + after stripdir + limit_sl_dest: ignore symlinks if they are pointing outside of + the defined path + hardlink_dupes: hardlink duplicated pyc files """ - if workers is not None and workers < 0: + ProcessPoolExecutor = None + if ddir is not None and (stripdir is not None or prependdir is not None): + raise ValueError(("Destination dir (ddir) cannot be used " + "in combination with stripdir or prependdir")) + if ddir is not None: + stripdir = dir + prependdir = ddir + ddir = None + if workers < 0: raise ValueError('workers must be greater or equal to 0') - - files = _walk_dir(dir, quiet=quiet, maxlevels=maxlevels, - ddir=ddir) + if workers != 1: + # Check if this is a system where ProcessPoolExecutor can function. + from concurrent.futures.process import _check_system_limits + try: + _check_system_limits() + except NotImplementedError: + workers = 1 + else: + from concurrent.futures import ProcessPoolExecutor + if maxlevels is None: + maxlevels = sys.getrecursionlimit() + files = _walk_dir(dir, quiet=quiet, maxlevels=maxlevels) success = True - if workers is not None and workers != 1 and ProcessPoolExecutor is not None: + if workers != 1 and ProcessPoolExecutor is not None: + # If workers == 0, let ProcessPoolExecutor choose workers = workers or None with ProcessPoolExecutor(max_workers=workers) as executor: results = executor.map(partial(compile_file, ddir=ddir, force=force, rx=rx, quiet=quiet, legacy=legacy, - optimize=optimize), + optimize=optimize, + invalidation_mode=invalidation_mode, + stripdir=stripdir, + prependdir=prependdir, + limit_sl_dest=limit_sl_dest, + hardlink_dupes=hardlink_dupes), files) success = min(results, default=True) else: for file in files: if not compile_file(file, ddir, force, rx, quiet, - legacy, optimize): + legacy, optimize, invalidation_mode, + stripdir=stripdir, prependdir=prependdir, + limit_sl_dest=limit_sl_dest, + hardlink_dupes=hardlink_dupes): success = False return success def compile_file(fullname, ddir=None, force=False, rx=None, quiet=0, - legacy=False, optimize=-1): + legacy=False, optimize=-1, + invalidation_mode=None, *, stripdir=None, prependdir=None, + limit_sl_dest=None, hardlink_dupes=False): """Byte-compile one file. Arguments (only fullname is required): @@ -104,49 +137,114 @@ def compile_file(fullname, ddir=None, force=False, rx=None, quiet=0, quiet: full output with False or 0, errors only with 1, no output with 2 legacy: if True, produce legacy pyc paths instead of PEP 3147 paths - optimize: optimization level or -1 for level of the interpreter + optimize: int or list of optimization levels or -1 for level of + the interpreter. Multiple levels leads to multiple compiled + files each with one optimization level. + invalidation_mode: how the up-to-dateness of the pyc will be checked + stripdir: part of path to left-strip from source file path + prependdir: path to prepend to beginning of original file path, applied + after stripdir + limit_sl_dest: ignore symlinks if they are pointing outside of + the defined path. + hardlink_dupes: hardlink duplicated pyc files """ + + if ddir is not None and (stripdir is not None or prependdir is not None): + raise ValueError(("Destination dir (ddir) cannot be used " + "in combination with stripdir or prependdir")) + success = True - if quiet < 2 and isinstance(fullname, os.PathLike): - fullname = os.fspath(fullname) + fullname = os.fspath(fullname) + stripdir = os.fspath(stripdir) if stripdir is not None else None name = os.path.basename(fullname) + + dfile = None + if ddir is not None: dfile = os.path.join(ddir, name) - else: - dfile = None + + if stripdir is not None: + fullname_parts = fullname.split(os.path.sep) + stripdir_parts = stripdir.split(os.path.sep) + ddir_parts = list(fullname_parts) + + for spart, opart in zip(stripdir_parts, fullname_parts): + if spart == opart: + ddir_parts.remove(spart) + + dfile = os.path.join(*ddir_parts) + + if prependdir is not None: + if dfile is None: + dfile = os.path.join(prependdir, fullname) + else: + dfile = os.path.join(prependdir, dfile) + + if isinstance(optimize, int): + optimize = [optimize] + + # Use set() to remove duplicates. + # Use sorted() to create pyc files in a deterministic order. + optimize = sorted(set(optimize)) + + if hardlink_dupes and len(optimize) < 2: + raise ValueError("Hardlinking of duplicated bytecode makes sense " + "only for more than one optimization level") + if rx is not None: mo = rx.search(fullname) if mo: return success + + if limit_sl_dest is not None and os.path.islink(fullname): + if Path(limit_sl_dest).resolve() not in Path(fullname).resolve().parents: + return success + + opt_cfiles = {} + if os.path.isfile(fullname): - if legacy: - cfile = fullname + 'c' - else: - if optimize >= 0: - opt = optimize if optimize >= 1 else '' - cfile = importlib.util.cache_from_source( - fullname, optimization=opt) + for opt_level in optimize: + if legacy: + opt_cfiles[opt_level] = fullname + 'c' else: - cfile = importlib.util.cache_from_source(fullname) - cache_dir = os.path.dirname(cfile) + if opt_level >= 0: + opt = opt_level if opt_level >= 1 else '' + cfile = (importlib.util.cache_from_source( + fullname, optimization=opt)) + opt_cfiles[opt_level] = cfile + else: + cfile = importlib.util.cache_from_source(fullname) + opt_cfiles[opt_level] = cfile + head, tail = name[:-3], name[-3:] if tail == '.py': if not force: try: mtime = int(os.stat(fullname).st_mtime) - expect = struct.pack('<4sl', importlib.util.MAGIC_NUMBER, - mtime) - with open(cfile, 'rb') as chandle: - actual = chandle.read(8) - if expect == actual: + expect = struct.pack('<4sLL', importlib.util.MAGIC_NUMBER, + 0, mtime & 0xFFFF_FFFF) + for cfile in opt_cfiles.values(): + with open(cfile, 'rb') as chandle: + actual = chandle.read(12) + if expect != actual: + break + else: return success except OSError: pass if not quiet: print('Compiling {!r}...'.format(fullname)) try: - ok = py_compile.compile(fullname, cfile, dfile, True, - optimize=optimize) + for index, opt_level in enumerate(optimize): + cfile = opt_cfiles[opt_level] + ok = py_compile.compile(fullname, cfile, dfile, True, + optimize=opt_level, + invalidation_mode=invalidation_mode) + if index > 0 and hardlink_dupes: + previous_cfile = opt_cfiles[optimize[index - 1]] + if filecmp.cmp(cfile, previous_cfile, shallow=False): + os.unlink(cfile) + os.link(previous_cfile, cfile) except py_compile.PyCompileError as err: success = False if quiet >= 2: @@ -156,9 +254,8 @@ def compile_file(fullname, ddir=None, force=False, rx=None, quiet=0, else: print('*** ', end='') # escape non-printable characters in msg - msg = err.msg.encode(sys.stdout.encoding, - errors='backslashreplace') - msg = msg.decode(sys.stdout.encoding) + encoding = sys.stdout.encoding or sys.getdefaultencoding() + msg = err.msg.encode(encoding, errors='backslashreplace').decode(encoding) print(msg) except (SyntaxError, UnicodeError, OSError) as e: success = False @@ -175,7 +272,8 @@ def compile_file(fullname, ddir=None, force=False, rx=None, quiet=0, return success def compile_path(skip_curdir=1, maxlevels=0, force=False, quiet=0, - legacy=False, optimize=-1): + legacy=False, optimize=-1, + invalidation_mode=None): """Byte-compile all module on sys.path. Arguments (all optional): @@ -186,6 +284,7 @@ def compile_path(skip_curdir=1, maxlevels=0, force=False, quiet=0, quiet: as for compile_dir() (default 0) legacy: as for compile_dir() (default False) optimize: as for compile_dir() (default -1) + invalidation_mode: as for compiler_dir() """ success = True for dir in sys.path: @@ -193,9 +292,16 @@ def compile_path(skip_curdir=1, maxlevels=0, force=False, quiet=0, if quiet < 2: print('Skipping current directory') else: - success = success and compile_dir(dir, maxlevels, None, - force, quiet=quiet, - legacy=legacy, optimize=optimize) + success = success and compile_dir( + dir, + maxlevels, + None, + force, + quiet=quiet, + legacy=legacy, + optimize=optimize, + invalidation_mode=invalidation_mode, + ) return success @@ -206,7 +312,7 @@ def main(): parser = argparse.ArgumentParser( description='Utilities to support installing Python libraries.') parser.add_argument('-l', action='store_const', const=0, - default=10, dest='maxlevels', + default=None, dest='maxlevels', help="don't recurse into subdirectories") parser.add_argument('-r', type=int, dest='recursion', help=('control the maximum recursion level. ' @@ -224,6 +330,20 @@ def main(): 'compile-time tracebacks and in runtime ' 'tracebacks in cases where the source file is ' 'unavailable')) + parser.add_argument('-s', metavar='STRIPDIR', dest='stripdir', + default=None, + help=('part of path to left-strip from path ' + 'to source file - for example buildroot. ' + '`-d` and `-s` options cannot be ' + 'specified together.')) + parser.add_argument('-p', metavar='PREPENDDIR', dest='prependdir', + default=None, + help=('path to add as prefix to path ' + 'to source file - for example / to make ' + 'it absolute when some part is removed ' + 'by `-s` option. ' + '`-d` and `-p` options cannot be ' + 'specified together.')) parser.add_argument('-x', metavar='REGEXP', dest='rx', default=None, help=('skip files matching the regular expression; ' 'the regexp is searched for in the full path ' @@ -238,6 +358,23 @@ def main(): 'to the equivalent of -l sys.path')) parser.add_argument('-j', '--workers', default=1, type=int, help='Run compileall concurrently') + invalidation_modes = [mode.name.lower().replace('_', '-') + for mode in py_compile.PycInvalidationMode] + parser.add_argument('--invalidation-mode', + choices=sorted(invalidation_modes), + help=('set .pyc invalidation mode; defaults to ' + '"checked-hash" if the SOURCE_DATE_EPOCH ' + 'environment variable is set, and ' + '"timestamp" otherwise.')) + parser.add_argument('-o', action='append', type=int, dest='opt_levels', + help=('Optimization levels to run compilation with. ' + 'Default is -1 which uses the optimization level ' + 'of the Python interpreter itself (see -O).')) + parser.add_argument('-e', metavar='DIR', dest='limit_sl_dest', + help='Ignore symlinks pointing outsite of the DIR') + parser.add_argument('--hardlink-dupes', action='store_true', + dest='hardlink_dupes', + help='Hardlink duplicated pyc files') args = parser.parse_args() compile_dests = args.compile_dest @@ -246,16 +383,31 @@ def main(): import re args.rx = re.compile(args.rx) + if args.limit_sl_dest == "": + args.limit_sl_dest = None if args.recursion is not None: maxlevels = args.recursion else: maxlevels = args.maxlevels + if args.opt_levels is None: + args.opt_levels = [-1] + + if len(args.opt_levels) == 1 and args.hardlink_dupes: + parser.error(("Hardlinking of duplicated bytecode makes sense " + "only for more than one optimization level.")) + + if args.ddir is not None and ( + args.stripdir is not None or args.prependdir is not None + ): + parser.error("-d cannot be used in combination with -s or -p") + # if flist is provided then load it if args.flist: try: - with (sys.stdin if args.flist=='-' else open(args.flist)) as f: + with (sys.stdin if args.flist=='-' else + open(args.flist, encoding="utf-8")) as f: for line in f: compile_dests.append(line.strip()) except OSError: @@ -263,8 +415,11 @@ def main(): print("Error reading file list {}".format(args.flist)) return False - if args.workers is not None: - args.workers = args.workers or None + if args.invalidation_mode: + ivl_mode = args.invalidation_mode.replace('-', '_').upper() + invalidation_mode = py_compile.PycInvalidationMode[ivl_mode] + else: + invalidation_mode = None success = True try: @@ -272,17 +427,30 @@ def main(): for dest in compile_dests: if os.path.isfile(dest): if not compile_file(dest, args.ddir, args.force, args.rx, - args.quiet, args.legacy): + args.quiet, args.legacy, + invalidation_mode=invalidation_mode, + stripdir=args.stripdir, + prependdir=args.prependdir, + optimize=args.opt_levels, + limit_sl_dest=args.limit_sl_dest, + hardlink_dupes=args.hardlink_dupes): success = False else: if not compile_dir(dest, maxlevels, args.ddir, args.force, args.rx, args.quiet, - args.legacy, workers=args.workers): + args.legacy, workers=args.workers, + invalidation_mode=invalidation_mode, + stripdir=args.stripdir, + prependdir=args.prependdir, + optimize=args.opt_levels, + limit_sl_dest=args.limit_sl_dest, + hardlink_dupes=args.hardlink_dupes): success = False return success else: return compile_path(legacy=args.legacy, force=args.force, - quiet=args.quiet) + quiet=args.quiet, + invalidation_mode=invalidation_mode) except KeyboardInterrupt: if args.quiet < 2: print("\n[interrupted]") From 48d4c22362e958259e5490a5e9544d0cb89dd0f0 Mon Sep 17 00:00:00 2001 From: dvermd <315743+dvermd@users.noreply.github.com> Date: Tue, 26 Sep 2023 19:19:57 +0200 Subject: [PATCH 109/893] Update fileinput to CPython 3.11.5 --- Lib/fileinput.py | 26 ++-------------- Lib/test/test_fileinput.py | 64 +++----------------------------------- 2 files changed, 8 insertions(+), 82 deletions(-) diff --git a/Lib/fileinput.py b/Lib/fileinput.py index 2ce2f91143..e234dc9ea6 100644 --- a/Lib/fileinput.py +++ b/Lib/fileinput.py @@ -217,15 +217,10 @@ def __init__(self, files=None, inplace=False, backup="", *, EncodingWarning, 2) # restrict mode argument to reading modes - if mode not in ('r', 'rU', 'U', 'rb'): - raise ValueError("FileInput opening mode must be one of " - "'r', 'rU', 'U' and 'rb'") - if 'U' in mode: - import warnings - warnings.warn("'U' mode is deprecated", - DeprecationWarning, 2) + if mode not in ('r', 'rb'): + raise ValueError("FileInput opening mode must be 'r' or 'rb'") self._mode = mode - self._write_mode = mode.replace('r', 'w') if 'U' not in mode else 'w' + self._write_mode = mode.replace('r', 'w') if openhook: if inplace: raise ValueError("FileInput cannot use an opening hook in inplace mode") @@ -262,21 +257,6 @@ def __next__(self): self.nextfile() # repeat with next file - def __getitem__(self, i): - import warnings - warnings.warn( - "Support for indexing FileInput objects is deprecated. " - "Use iterator protocol instead.", - DeprecationWarning, - stacklevel=2 - ) - if i != self.lineno(): - raise RuntimeError("accessing lines out of order") - try: - return self.__next__() - except StopIteration: - raise IndexError("end of input reached") - def nextfile(self): savestdout = self._savestdout self._savestdout = None diff --git a/Lib/test/test_fileinput.py b/Lib/test/test_fileinput.py index 270e109eb8..df894c5b2a 100644 --- a/Lib/test/test_fileinput.py +++ b/Lib/test/test_fileinput.py @@ -29,7 +29,6 @@ from test.support.os_helper import TESTFN from test.support.os_helper import unlink as safe_unlink from test.support import os_helper -from test.support import warnings_helper from test import support from unittest import mock @@ -230,22 +229,11 @@ def test_fileno(self): line = list(fi) self.assertEqual(fi.fileno(), -1) - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_opening_mode(self): - try: - # invalid mode, should raise ValueError - fi = FileInput(mode="w", encoding="utf-8") - self.fail("FileInput should reject invalid mode argument") - except ValueError: - pass - # try opening in universal newline mode - t1 = self.writeTmp(b"A\nB\r\nC\rD", mode="wb") - with warnings_helper.check_warnings(('', DeprecationWarning)): - fi = FileInput(files=t1, mode="U", encoding="utf-8") - with warnings_helper.check_warnings(('', DeprecationWarning)): - lines = list(fi) - self.assertEqual(lines, ["A\n", "B\n", "C\n", "D"]) + def test_invalid_opening_mode(self): + for mode in ('w', 'rU', 'U'): + with self.subTest(mode=mode): + with self.assertRaises(ValueError): + FileInput(mode=mode) def test_stdin_binary_mode(self): with mock.patch('sys.stdin') as m_stdin: @@ -380,44 +368,6 @@ def test_empty_files_list_specified_to_constructor(self): with FileInput(files=[], encoding="utf-8") as fi: self.assertEqual(fi._files, ('-',)) - @warnings_helper.ignore_warnings(category=DeprecationWarning) - def test__getitem__(self): - """Tests invoking FileInput.__getitem__() with the current - line number""" - t = self.writeTmp("line1\nline2\n") - with FileInput(files=[t], encoding="utf-8") as fi: - retval1 = fi[0] - self.assertEqual(retval1, "line1\n") - retval2 = fi[1] - self.assertEqual(retval2, "line2\n") - - def test__getitem___deprecation(self): - t = self.writeTmp("line1\nline2\n") - with self.assertWarnsRegex(DeprecationWarning, - r'Use iterator protocol instead'): - with FileInput(files=[t]) as fi: - self.assertEqual(fi[0], "line1\n") - - @warnings_helper.ignore_warnings(category=DeprecationWarning) - def test__getitem__invalid_key(self): - """Tests invoking FileInput.__getitem__() with an index unequal to - the line number""" - t = self.writeTmp("line1\nline2\n") - with FileInput(files=[t], encoding="utf-8") as fi: - with self.assertRaises(RuntimeError) as cm: - fi[1] - self.assertEqual(cm.exception.args, ("accessing lines out of order",)) - - @warnings_helper.ignore_warnings(category=DeprecationWarning) - def test__getitem__eof(self): - """Tests invoking FileInput.__getitem__() with the line number but at - end-of-input""" - t = self.writeTmp('') - with FileInput(files=[t], encoding="utf-8") as fi: - with self.assertRaises(IndexError) as cm: - fi[0] - self.assertEqual(cm.exception.args, ("end of input reached",)) - def test_nextfile_oserror_deleting_backup(self): """Tests invoking FileInput.nextfile() when the attempt to delete the backup file would raise OSError. This error is expected to be @@ -1031,10 +981,6 @@ def check(mode, expected_lines): self.assertEqual(lines, expected_lines) check('r', ['A\n', 'B\n', 'C\n', 'D\u20ac']) - with self.assertWarns(DeprecationWarning): - check('rU', ['A\n', 'B\n', 'C\n', 'D\u20ac']) - with self.assertWarns(DeprecationWarning): - check('U', ['A\n', 'B\n', 'C\n', 'D\u20ac']) with self.assertRaises(ValueError): check('rb', ['A\n', 'B\r\n', 'C\r', 'D\u20ac']) From fd98ab208480c1b1e6f8a64ae336cbb0595fd73e Mon Sep 17 00:00:00 2001 From: dvermd <315743+dvermd@users.noreply.github.com> Date: Tue, 26 Sep 2023 19:40:29 +0200 Subject: [PATCH 110/893] Update fractions to CPython 3.11.5 --- Lib/fractions.py | 232 ++++++++++++++++++++++++++----------- Lib/test/test_fractions.py | 156 ++++++++++++++++++++----- 2 files changed, 296 insertions(+), 92 deletions(-) diff --git a/Lib/fractions.py b/Lib/fractions.py index e4fcc8901b..f9ac882ec0 100644 --- a/Lib/fractions.py +++ b/Lib/fractions.py @@ -1,7 +1,7 @@ # Originally contributed by Sjoerd Mullender. # Significantly modified by Jeffrey Yasskin . -"""Fraction, infinite-precision, real numbers.""" +"""Fraction, infinite-precision, rational numbers.""" from decimal import Decimal import math @@ -10,31 +10,9 @@ import re import sys -__all__ = ['Fraction', 'gcd'] +__all__ = ['Fraction'] - -def gcd(a, b): - """Calculate the Greatest Common Divisor of a and b. - - Unless b==0, the result will have the same sign as b (so that when - b is divided by it, the result comes out positive). - """ - import warnings - warnings.warn('fractions.gcd() is deprecated. Use math.gcd() instead.', - DeprecationWarning, 2) - if type(a) is int is type(b): - if (b or a) < 0: - return -math.gcd(a, b) - return math.gcd(a, b) - return _gcd(a, b) - -def _gcd(a, b): - # Supports non-integers for backward compatibility. - while b: - a, b = b, a%b - return a - # Constants related to the hash implementation; hash(x) is based # on the reduction of x modulo the prime _PyHASH_MODULUS. _PyHASH_MODULUS = sys.hash_info.modulus @@ -43,17 +21,17 @@ def _gcd(a, b): _PyHASH_INF = sys.hash_info.inf _RATIONAL_FORMAT = re.compile(r""" - \A\s* # optional whitespace at the start, then - (?P[-+]?) # an optional sign, then - (?=\d|\.\d) # lookahead for digit or .digit - (?P\d*) # numerator (possibly empty) - (?: # followed by - (?:/(?P\d+))? # an optional denominator - | # or - (?:\.(?P\d*))? # an optional fractional part - (?:E(?P[-+]?\d+))? # and optional exponent + \A\s* # optional whitespace at the start, + (?P[-+]?) # an optional sign, then + (?=\d|\.\d) # lookahead for digit or .digit + (?P\d*|\d+(_\d+)*) # numerator (possibly empty) + (?: # followed by + (?:/(?P\d+(_\d+)*))? # an optional denominator + | # or + (?:\.(?Pd*|\d+(_\d+)*))? # an optional fractional part + (?:E(?P[-+]?\d+(_\d+)*))? # and optional exponent ) - \s*\Z # and optional whitespace to finish + \s*\Z # and optional whitespace to finish """, re.VERBOSE | re.IGNORECASE) @@ -144,6 +122,7 @@ def __new__(cls, numerator=0, denominator=None, *, _normalize=True): denominator = 1 decimal = m.group('decimal') if decimal: + decimal = decimal.replace('_', '') scale = 10**len(decimal) numerator = numerator * scale + int(decimal) denominator *= scale @@ -177,13 +156,9 @@ def __new__(cls, numerator=0, denominator=None, *, _normalize=True): if denominator == 0: raise ZeroDivisionError('Fraction(%s, 0)' % numerator) if _normalize: - if type(numerator) is int is type(denominator): - # *very* normal case - g = math.gcd(numerator, denominator) - if denominator < 0: - g = -g - else: - g = _gcd(numerator, denominator) + g = math.gcd(numerator, denominator) + if denominator < 0: + g = -g numerator //= g denominator //= g self._numerator = numerator @@ -406,32 +381,139 @@ def reverse(b, a): return forward, reverse + # Rational arithmetic algorithms: Knuth, TAOCP, Volume 2, 4.5.1. + # + # Assume input fractions a and b are normalized. + # + # 1) Consider addition/subtraction. + # + # Let g = gcd(da, db). Then + # + # na nb na*db ± nb*da + # a ± b == -- ± -- == ------------- == + # da db da*db + # + # na*(db//g) ± nb*(da//g) t + # == ----------------------- == - + # (da*db)//g d + # + # Now, if g > 1, we're working with smaller integers. + # + # Note, that t, (da//g) and (db//g) are pairwise coprime. + # + # Indeed, (da//g) and (db//g) share no common factors (they were + # removed) and da is coprime with na (since input fractions are + # normalized), hence (da//g) and na are coprime. By symmetry, + # (db//g) and nb are coprime too. Then, + # + # gcd(t, da//g) == gcd(na*(db//g), da//g) == 1 + # gcd(t, db//g) == gcd(nb*(da//g), db//g) == 1 + # + # Above allows us optimize reduction of the result to lowest + # terms. Indeed, + # + # g2 = gcd(t, d) == gcd(t, (da//g)*(db//g)*g) == gcd(t, g) + # + # t//g2 t//g2 + # a ± b == ----------------------- == ---------------- + # (da//g)*(db//g)*(g//g2) (da//g)*(db//g2) + # + # is a normalized fraction. This is useful because the unnormalized + # denominator d could be much larger than g. + # + # We should special-case g == 1 (and g2 == 1), since 60.8% of + # randomly-chosen integers are coprime: + # https://en.wikipedia.org/wiki/Coprime_integers#Probability_of_coprimality + # Note, that g2 == 1 always for fractions, obtained from floats: here + # g is a power of 2 and the unnormalized numerator t is an odd integer. + # + # 2) Consider multiplication + # + # Let g1 = gcd(na, db) and g2 = gcd(nb, da), then + # + # na*nb na*nb (na//g1)*(nb//g2) + # a*b == ----- == ----- == ----------------- + # da*db db*da (db//g1)*(da//g2) + # + # Note, that after divisions we're multiplying smaller integers. + # + # Also, the resulting fraction is normalized, because each of + # two factors in the numerator is coprime to each of the two factors + # in the denominator. + # + # Indeed, pick (na//g1). It's coprime with (da//g2), because input + # fractions are normalized. It's also coprime with (db//g1), because + # common factors are removed by g1 == gcd(na, db). + # + # As for addition/subtraction, we should special-case g1 == 1 + # and g2 == 1 for same reason. That happens also for multiplying + # rationals, obtained from floats. + def _add(a, b): """a + b""" - da, db = a.denominator, b.denominator - return Fraction(a.numerator * db + b.numerator * da, - da * db) + na, da = a.numerator, a.denominator + nb, db = b.numerator, b.denominator + g = math.gcd(da, db) + if g == 1: + return Fraction(na * db + da * nb, da * db, _normalize=False) + s = da // g + t = na * (db // g) + nb * s + g2 = math.gcd(t, g) + if g2 == 1: + return Fraction(t, s * db, _normalize=False) + return Fraction(t // g2, s * (db // g2), _normalize=False) __add__, __radd__ = _operator_fallbacks(_add, operator.add) def _sub(a, b): """a - b""" - da, db = a.denominator, b.denominator - return Fraction(a.numerator * db - b.numerator * da, - da * db) + na, da = a.numerator, a.denominator + nb, db = b.numerator, b.denominator + g = math.gcd(da, db) + if g == 1: + return Fraction(na * db - da * nb, da * db, _normalize=False) + s = da // g + t = na * (db // g) - nb * s + g2 = math.gcd(t, g) + if g2 == 1: + return Fraction(t, s * db, _normalize=False) + return Fraction(t // g2, s * (db // g2), _normalize=False) __sub__, __rsub__ = _operator_fallbacks(_sub, operator.sub) def _mul(a, b): """a * b""" - return Fraction(a.numerator * b.numerator, a.denominator * b.denominator) + na, da = a.numerator, a.denominator + nb, db = b.numerator, b.denominator + g1 = math.gcd(na, db) + if g1 > 1: + na //= g1 + db //= g1 + g2 = math.gcd(nb, da) + if g2 > 1: + nb //= g2 + da //= g2 + return Fraction(na * nb, db * da, _normalize=False) __mul__, __rmul__ = _operator_fallbacks(_mul, operator.mul) def _div(a, b): """a / b""" - return Fraction(a.numerator * b.denominator, - a.denominator * b.numerator) + # Same as _mul(), with inversed b. + na, da = a.numerator, a.denominator + nb, db = b.numerator, b.denominator + g1 = math.gcd(na, nb) + if g1 > 1: + na //= g1 + nb //= g1 + g2 = math.gcd(db, da) + if g2 > 1: + da //= g2 + db //= g2 + n, d = na * db, nb * da + if d < 0: + n, d = -n, -d + return Fraction(n, d, _normalize=False) __truediv__, __rtruediv__ = _operator_fallbacks(_div, operator.truediv) @@ -512,8 +594,15 @@ def __abs__(a): """abs(a)""" return Fraction(abs(a._numerator), a._denominator, _normalize=False) + def __int__(a, _index=operator.index): + """int(a)""" + if a._numerator < 0: + return _index(-(-a._numerator // a._denominator)) + else: + return _index(a._numerator // a._denominator) + def __trunc__(a): - """trunc(a)""" + """math.trunc(a)""" if a._numerator < 0: return -(-a._numerator // a._denominator) else: @@ -556,23 +645,34 @@ def __round__(self, ndigits=None): def __hash__(self): """hash(self)""" - # XXX since this method is expensive, consider caching the result - - # In order to make sure that the hash of a Fraction agrees - # with the hash of a numerically equal integer, float or - # Decimal instance, we follow the rules for numeric hashes - # outlined in the documentation. (See library docs, 'Built-in - # Types'). + # To make sure that the hash of a Fraction agrees with the hash + # of a numerically equal integer, float or Decimal instance, we + # follow the rules for numeric hashes outlined in the + # documentation. (See library docs, 'Built-in Types'). - # dinv is the inverse of self._denominator modulo the prime - # _PyHASH_MODULUS, or 0 if self._denominator is divisible by - # _PyHASH_MODULUS. - dinv = pow(self._denominator, _PyHASH_MODULUS - 2, _PyHASH_MODULUS) - if not dinv: + try: + dinv = pow(self._denominator, -1, _PyHASH_MODULUS) + except ValueError: + # ValueError means there is no modular inverse. hash_ = _PyHASH_INF else: - hash_ = abs(self._numerator) * dinv % _PyHASH_MODULUS - result = hash_ if self >= 0 else -hash_ + # The general algorithm now specifies that the absolute value of + # the hash is + # (|N| * dinv) % P + # where N is self._numerator and P is _PyHASH_MODULUS. That's + # optimized here in two ways: first, for a non-negative int i, + # hash(i) == i % P, but the int hash implementation doesn't need + # to divide, and is faster than doing % P explicitly. So we do + # hash(|N| * dinv) + # instead. Second, N is unbounded, so its product with dinv may + # be arbitrarily expensive to compute. The final answer is the + # same if we use the bounded |N| % P instead, which can again + # be done with an int hash() call. If 0 <= i < P, hash(i) == i, + # so this nested hash() call wastes a bit of time making a + # redundant copy when |N| < P, but can save an arbitrarily large + # amount of computation for large |N|. + hash_ = hash(hash(abs(self._numerator)) * dinv) + result = hash_ if self._numerator >= 0 else -hash_ return -2 if result == -1 else result def __eq__(a, b): @@ -643,7 +743,7 @@ def __bool__(a): # support for pickling, copy, and deepcopy def __reduce__(self): - return (self.__class__, (str(self),)) + return (self.__class__, (self._numerator, self._denominator)) def __copy__(self): if type(self) == Fraction: diff --git a/Lib/test/test_fractions.py b/Lib/test/test_fractions.py index 02a5022853..a79932cfa8 100644 --- a/Lib/test/test_fractions.py +++ b/Lib/test/test_fractions.py @@ -8,12 +8,12 @@ import fractions import functools import sys +import typing import unittest -import warnings from copy import copy, deepcopy from pickle import dumps, loads F = fractions.Fraction -gcd = fractions.gcd + class DummyFloat(object): """Dummy float class for testing comparisons with Fractions""" @@ -82,30 +82,6 @@ def __float__(self): class DummyFraction(fractions.Fraction): """Dummy Fraction subclass for copy and deepcopy testing.""" -class GcdTest(unittest.TestCase): - - def testMisc(self): - # fractions.gcd() is deprecated - with self.assertWarnsRegex(DeprecationWarning, r'fractions\.gcd'): - gcd(1, 1) - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', r'fractions\.gcd', - DeprecationWarning) - self.assertEqual(0, gcd(0, 0)) - self.assertEqual(1, gcd(1, 0)) - self.assertEqual(-1, gcd(-1, 0)) - self.assertEqual(1, gcd(0, 1)) - self.assertEqual(-1, gcd(0, -1)) - self.assertEqual(1, gcd(7, 1)) - self.assertEqual(-1, gcd(7, -1)) - self.assertEqual(1, gcd(-23, 15)) - self.assertEqual(12, gcd(120, 84)) - self.assertEqual(-12, gcd(84, -120)) - self.assertEqual(gcd(120.0, 84), 12.0) - self.assertEqual(gcd(120, 84.0), 12.0) - self.assertEqual(gcd(F(120), F(84)), F(12)) - self.assertEqual(gcd(F(120, 77), F(84, 55)), F(12, 385)) - def _components(r): return (r.numerator, r.denominator) @@ -197,6 +173,12 @@ def testFromString(self): self.assertEqual((-12300, 1), _components(F("-1.23e4"))) self.assertEqual((0, 1), _components(F(" .0e+0\t"))) self.assertEqual((0, 1), _components(F("-0.000e0"))) + self.assertEqual((123, 1), _components(F("1_2_3"))) + self.assertEqual((41, 107), _components(F("1_2_3/3_2_1"))) + self.assertEqual((6283, 2000), _components(F("3.14_15"))) + self.assertEqual((6283, 2*10**13), _components(F("3.14_15e-1_0"))) + self.assertEqual((101, 100), _components(F("1.01"))) + self.assertEqual((101, 100), _components(F("1.0_1"))) self.assertRaisesMessage( ZeroDivisionError, "Fraction(3, 0)", @@ -234,6 +216,62 @@ def testFromString(self): # Allow 3. and .3, but not . ValueError, "Invalid literal for Fraction: '.'", F, ".") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '_'", + F, "_") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '_1'", + F, "_1") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1__2'", + F, "1__2") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '/_'", + F, "/_") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1_/'", + F, "1_/") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '_1/'", + F, "_1/") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1__2/'", + F, "1__2/") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1/_'", + F, "1/_") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1/_1'", + F, "1/_1") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1/1__2'", + F, "1/1__2") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1._111'", + F, "1._111") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1.1__1'", + F, "1.1__1") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1.1e+_1'", + F, "1.1e+_1") + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1.1e+1__1'", + F, "1.1e+1__1") + # Test catastrophic backtracking. + val = "9"*50 + "_" + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '" + val + "'", + F, val) + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1/" + val + "'", + F, "1/" + val) + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1." + val + "'", + F, "1." + val) + self.assertRaisesMessage( + ValueError, "Invalid literal for Fraction: '1.1+e" + val + "'", + F, "1.1+e" + val) def testImmutable(self): r = F(7, 3) @@ -347,6 +385,47 @@ def testConversions(self): self.assertTypedEquals(0.1+0j, complex(F(1,10))) + def testSupportsInt(self): + # See bpo-44547. + f = F(3, 2) + self.assertIsInstance(f, typing.SupportsInt) + self.assertEqual(int(f), 1) + self.assertEqual(type(int(f)), int) + + def testIntGuaranteesIntReturn(self): + # Check that int(some_fraction) gives a result of exact type `int` + # even if the fraction is using some other Integral type for its + # numerator and denominator. + + class CustomInt(int): + """ + Subclass of int with just enough machinery to convince the Fraction + constructor to produce something with CustomInt numerator and + denominator. + """ + + @property + def numerator(self): + return self + + @property + def denominator(self): + return CustomInt(1) + + def __mul__(self, other): + return CustomInt(int(self) * int(other)) + + def __floordiv__(self, other): + return CustomInt(int(self) // int(other)) + + f = F(CustomInt(13), CustomInt(5)) + + self.assertIsInstance(f.numerator, CustomInt) + self.assertIsInstance(f.denominator, CustomInt) + self.assertIsInstance(f, typing.SupportsInt) + self.assertEqual(int(f), 2) + self.assertEqual(type(int(f)), int) + def testBoolGuarateesBoolReturn(self): # Ensure that __bool__ is used on numerator which guarantees a bool # return. See also bpo-39274. @@ -394,7 +473,9 @@ def testArithmetic(self): self.assertEqual(F(1, 2), F(1, 10) + F(2, 5)) self.assertEqual(F(-3, 10), F(1, 10) - F(2, 5)) self.assertEqual(F(1, 25), F(1, 10) * F(2, 5)) + self.assertEqual(F(5, 6), F(2, 3) * F(5, 4)) self.assertEqual(F(1, 4), F(1, 10) / F(2, 5)) + self.assertEqual(F(-15, 8), F(3, 4) / F(-2, 5)) self.assertTypedEquals(2, F(9, 10) // F(2, 5)) self.assertTypedEquals(10**23, F(10**23, 1) // F(1)) self.assertEqual(F(5, 6), F(7, 3) % F(3, 2)) @@ -729,5 +810,28 @@ def test_slots(self): r = F(13, 7) self.assertRaises(AttributeError, setattr, r, 'a', 10) + def test_int_subclass(self): + class myint(int): + def __mul__(self, other): + return type(self)(int(self) * int(other)) + def __floordiv__(self, other): + return type(self)(int(self) // int(other)) + def __mod__(self, other): + x = type(self)(int(self) % int(other)) + return x + @property + def numerator(self): + return type(self)(int(self)) + @property + def denominator(self): + return type(self)(1) + + f = fractions.Fraction(myint(1 * 3), myint(2 * 3)) + self.assertEqual(f.numerator, 1) + self.assertEqual(f.denominator, 2) + self.assertEqual(type(f.numerator), myint) + self.assertEqual(type(f.denominator), myint) + + if __name__ == '__main__': unittest.main() From 0a76a9b115fbf1a6e985d56e8ab6e3a833c34790 Mon Sep 17 00:00:00 2001 From: dvermd <315743+dvermd@users.noreply.github.com> Date: Wed, 27 Sep 2023 05:58:09 +0200 Subject: [PATCH 111/893] Update ftplib to CPython 3.11.5 part_of: #4564 --- Lib/ftplib.py | 55 ++++++++------ Lib/test/test_ftplib.py | 164 ++++++++++++++++++++++++++++------------ 2 files changed, 149 insertions(+), 70 deletions(-) diff --git a/Lib/ftplib.py b/Lib/ftplib.py index 58a46bca4a..7c5a50715f 100644 --- a/Lib/ftplib.py +++ b/Lib/ftplib.py @@ -72,17 +72,17 @@ class error_proto(Error): pass # response does not begin with [1-5] # The class itself class FTP: - '''An FTP client class. To create a connection, call the class using these arguments: - host, user, passwd, acct, timeout + host, user, passwd, acct, timeout, source_address, encoding The first four arguments are all strings, and have default value ''. - timeout must be numeric and defaults to None if not passed, - meaning that no timeout will be set on any ftp socket(s) + The parameter ´timeout´ must be numeric and defaults to None if not + passed, meaning that no timeout will be set on any ftp socket(s). If a timeout is passed, then this is now the default timeout for all ftp socket operations for this instance. + The last parameter is the encoding of filenames, which defaults to utf-8. Then use self.connect() with optional host and port argument. @@ -102,15 +102,19 @@ class FTP: sock = None file = None welcome = None - passiveserver = 1 - encoding = "latin-1" + passiveserver = True + # Disables https://bugs.python.org/issue43285 security if set to True. + trust_server_pasv_ipv4_address = False - # Initialization method (called by class instantiation). - # Initialize host to localhost, port to standard ftp port - # Optional arguments are host (for connect()), - # and user, passwd, acct (for login()) def __init__(self, host='', user='', passwd='', acct='', - timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None): + timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None, *, + encoding='utf-8'): + """Initialization method (called by class instantiation). + Initialize host to localhost, port to standard ftp port. + Optional arguments are host (for connect()), + and user, passwd, acct (for login()). + """ + self.encoding = encoding self.source_address = source_address self.timeout = timeout if host: @@ -146,6 +150,8 @@ def connect(self, host='', port=0, timeout=-999, source_address=None): self.port = port if timeout != -999: self.timeout = timeout + if self.timeout is not None and not self.timeout: + raise ValueError('Non-blocking socket (timeout=0) is not supported') if source_address is not None: self.source_address = source_address sys.audit("ftplib.connect", self, self.host, self.port) @@ -316,8 +322,13 @@ def makeport(self): return sock def makepasv(self): + """Internal: Does the PASV or EPSV handshake -> (address, port)""" if self.af == socket.AF_INET: - host, port = parse227(self.sendcmd('PASV')) + untrusted_host, port = parse227(self.sendcmd('PASV')) + if self.trust_server_pasv_ipv4_address: + host = untrusted_host + else: + host = self.sock.getpeername()[0] else: host, port = parse229(self.sendcmd('EPSV'), self.sock.getpeername()) return host, port @@ -704,9 +715,10 @@ class FTP_TLS(FTP): ''' ssl_version = ssl.PROTOCOL_TLS_CLIENT - def __init__(self, host='', user='', passwd='', acct='', keyfile=None, - certfile=None, context=None, - timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None): + def __init__(self, host='', user='', passwd='', acct='', + keyfile=None, certfile=None, context=None, + timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None, *, + encoding='utf-8'): if context is not None and keyfile is not None: raise ValueError("context and keyfile arguments are mutually " "exclusive") @@ -725,12 +737,13 @@ def __init__(self, host='', user='', passwd='', acct='', keyfile=None, keyfile=keyfile) self.context = context self._prot_p = False - FTP.__init__(self, host, user, passwd, acct, timeout, source_address) + super().__init__(host, user, passwd, acct, + timeout, source_address, encoding=encoding) def login(self, user='', passwd='', acct='', secure=True): if secure and not isinstance(self.sock, ssl.SSLSocket): self.auth() - return FTP.login(self, user, passwd, acct) + return super().login(user, passwd, acct) def auth(self): '''Set up secure control connection by using TLS/SSL.''' @@ -740,8 +753,7 @@ def auth(self): resp = self.voidcmd('AUTH TLS') else: resp = self.voidcmd('AUTH SSL') - self.sock = self.context.wrap_socket(self.sock, - server_hostname=self.host) + self.sock = self.context.wrap_socket(self.sock, server_hostname=self.host) self.file = self.sock.makefile(mode='r', encoding=self.encoding) return resp @@ -778,7 +790,7 @@ def prot_c(self): # --- Overridden FTP methods def ntransfercmd(self, cmd, rest=None): - conn, size = FTP.ntransfercmd(self, cmd, rest) + conn, size = super().ntransfercmd(cmd, rest) if self._prot_p: conn = self.context.wrap_socket(conn, server_hostname=self.host) @@ -823,7 +835,6 @@ def parse227(resp): '''Parse the '227' response for a PASV request. Raises error_proto if it does not contain '(h1,h2,h3,h4,p1,p2)' Return ('host.addr.as.numbers', port#) tuple.''' - if resp[:3] != '227': raise error_reply(resp) global _227_re @@ -843,7 +854,6 @@ def parse229(resp, peer): '''Parse the '229' response for an EPSV request. Raises error_proto if it does not contain '(|||port|)' Return ('host.addr.as.numbers', port#) tuple.''' - if resp[:3] != '229': raise error_reply(resp) left = resp.find('(') @@ -865,7 +875,6 @@ def parse257(resp): '''Parse the '257' response for a MKD or PWD request. This is a response to a MKD or PWD request: a directory name. Returns the directoryname in the 257 reply.''' - if resp[:3] != '257': raise error_reply(resp) if resp[3:5] != ' "': diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index 1c61697e93..e8c126ddc4 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -3,16 +3,14 @@ # Modified by Giampaolo Rodola' to test FTP class, IPv6 and TLS # environment -import unittest import ftplib -import asyncore -import asynchat import socket import io import errno import os import threading import time +import unittest try: import ssl except ImportError: @@ -31,13 +29,19 @@ if getattr(sys, '_rustpython_debugbuild', False): raise unittest.SkipTest("something's weird on debug builds") +asynchat = warnings_helper.import_deprecated('asynchat') +asyncore = warnings_helper.import_deprecated('asyncore') + + +support.requires_working_socket(module=True) TIMEOUT = support.LOOPBACK_TIMEOUT +DEFAULT_ENCODING = 'utf-8' # the dummy data returned by server over the data channel when # RETR, LIST, NLST, MLSD commands are issued -RETR_DATA = 'abcde12345\r\n' * 1000 -LIST_DATA = 'foo\r\nbar\r\n' -NLST_DATA = 'foo\r\nbar\r\n' +RETR_DATA = 'abcde12345\r\n' * 1000 + 'non-ascii char \xAE\r\n' +LIST_DATA = 'foo\r\nbar\r\n non-ascii char \xAE\r\n' +NLST_DATA = 'foo\r\nbar\r\n non-ascii char \xAE\r\n' MLSD_DATA = ("type=cdir;perm=el;unique==keVO1+ZF4; test\r\n" "type=pdir;perm=e;unique==keVO1+d?3; ..\r\n" "type=OS.unix=slink:/foobar;perm=;unique==keVO1+4G4; foobar\r\n" @@ -52,7 +56,16 @@ "type=dir;perm=cpmel;unique==keVO1+7G4; incoming\r\n" "type=file;perm=r;unique==keVO1+1G4; file2\r\n" "type=file;perm=r;unique==keVO1+1G4; file3\r\n" - "type=file;perm=r;unique==keVO1+1G4; file4\r\n") + "type=file;perm=r;unique==keVO1+1G4; file4\r\n" + "type=dir;perm=cpmel;unique==SGP1; dir \xAE non-ascii char\r\n" + "type=file;perm=r;unique==SGP2; file \xAE non-ascii char\r\n") + + +def default_error_handler(): + # bpo-44359: Silently ignore socket errors. Such errors occur when a client + # socket is closed, in TestFTPClass.tearDown() and makepasv() tests, and + # the server gets an error on its side. + pass class DummyDTPHandler(asynchat.async_chat): @@ -62,9 +75,11 @@ def __init__(self, conn, baseclass): asynchat.async_chat.__init__(self, conn) self.baseclass = baseclass self.baseclass.last_received_data = '' + self.encoding = baseclass.encoding def handle_read(self): - self.baseclass.last_received_data += self.recv(1024).decode('ascii') + new_data = self.recv(1024).decode(self.encoding, 'replace') + self.baseclass.last_received_data += new_data def handle_close(self): # XXX: this method can be called many times in a row for a single @@ -81,17 +96,17 @@ def push(self, what): self.baseclass.next_data = None if not what: return self.close_when_done() - super(DummyDTPHandler, self).push(what.encode('ascii')) + super(DummyDTPHandler, self).push(what.encode(self.encoding)) def handle_error(self): - raise Exception + default_error_handler() class DummyFTPHandler(asynchat.async_chat): dtp_handler = DummyDTPHandler - def __init__(self, conn): + def __init__(self, conn, encoding=DEFAULT_ENCODING): asynchat.async_chat.__init__(self, conn) # tells the socket to handle urgent data inline (ABOR command) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_OOBINLINE, 1) @@ -105,12 +120,17 @@ def __init__(self, conn): self.rest = None self.next_retr_data = RETR_DATA self.push('220 welcome') + self.encoding = encoding + # We use this as the string IPv4 address to direct the client + # to in response to a PASV command. To test security behavior. + # https://bugs.python.org/issue43285/. + self.fake_pasv_server_ip = '252.253.254.255' def collect_incoming_data(self, data): self.in_buffer.append(data) def found_terminator(self): - line = b''.join(self.in_buffer).decode('ascii') + line = b''.join(self.in_buffer).decode(self.encoding) self.in_buffer = [] if self.next_response: self.push(self.next_response) @@ -129,10 +149,10 @@ def found_terminator(self): self.push('550 command "%s" not understood.' %cmd) def handle_error(self): - raise Exception + default_error_handler() def push(self, data): - asynchat.async_chat.push(self, data.encode('ascii') + b'\r\n') + asynchat.async_chat.push(self, data.encode(self.encoding) + b'\r\n') def cmd_port(self, arg): addr = list(map(int, arg.split(','))) @@ -145,7 +165,8 @@ def cmd_port(self, arg): def cmd_pasv(self, arg): with socket.create_server((self.socket.getsockname()[0], 0)) as sock: sock.settimeout(TIMEOUT) - ip, port = sock.getsockname()[:2] + port = sock.getsockname()[1] + ip = self.fake_pasv_server_ip ip = ip.replace('.', ','); p1 = port / 256; p2 = port % 256 self.push('227 entering passive mode (%s,%d,%d)' %(ip, p1, p2)) conn, addr = sock.accept() @@ -262,7 +283,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): handler = DummyFTPHandler - def __init__(self, address, af=socket.AF_INET): + def __init__(self, address, af=socket.AF_INET, encoding=DEFAULT_ENCODING): threading.Thread.__init__(self) asyncore.dispatcher.__init__(self) self.daemon = True @@ -273,6 +294,7 @@ def __init__(self, address, af=socket.AF_INET): self.active_lock = threading.Lock() self.host, self.port = self.socket.getsockname()[:2] self.handler_instance = None + self.encoding = encoding def start(self): assert not self.active @@ -295,7 +317,7 @@ def stop(self): self.join() def handle_accepted(self, conn, addr): - self.handler_instance = self.handler(conn) + self.handler_instance = self.handler(conn, encoding=self.encoding) def handle_connect(self): self.close() @@ -305,7 +327,7 @@ def writable(self): return 0 def handle_error(self): - raise Exception + default_error_handler() if ssl is not None: @@ -320,7 +342,7 @@ class SSLConnection(asyncore.dispatcher): _ssl_closing = False def secure_connection(self): - context = ssl.SSLContext() + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) context.load_cert_chain(CERTFILE) socket = context.wrap_socket(self.socket, suppress_ragged_eofs=False, @@ -357,7 +379,7 @@ def _do_ssl_shutdown(self): if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): return - except OSError as err: + except OSError: # Any "socket error" corresponds to a SSL_ERROR_SYSCALL return # from OpenSSL's SSL_shutdown(), corresponding to a # closed socket condition. See also: @@ -408,7 +430,7 @@ def recv(self, buffer_size): raise def handle_error(self): - raise Exception + default_error_handler() def close(self): if (isinstance(self.socket, ssl.SSLSocket) and @@ -432,8 +454,8 @@ class DummyTLS_FTPHandler(SSLConnection, DummyFTPHandler): dtp_handler = DummyTLS_DTPHandler - def __init__(self, conn): - DummyFTPHandler.__init__(self, conn) + def __init__(self, conn, encoding=DEFAULT_ENCODING): + DummyFTPHandler.__init__(self, conn, encoding=encoding) self.secure_data_channel = False self._ccc = False @@ -473,10 +495,10 @@ class DummyTLS_FTPServer(DummyFTPServer): class TestFTPClass(TestCase): - def setUp(self): - self.server = DummyFTPServer((HOST, 0)) + def setUp(self, encoding=DEFAULT_ENCODING): + self.server = DummyFTPServer((HOST, 0), encoding=encoding) self.server.start() - self.client = ftplib.FTP(timeout=TIMEOUT) + self.client = ftplib.FTP(timeout=TIMEOUT, encoding=encoding) self.client.connect(self.server.host, self.server.port) def tearDown(self): @@ -576,14 +598,14 @@ def test_abort(self): def test_retrbinary(self): def callback(data): - received.append(data.decode('ascii')) + received.append(data.decode(self.client.encoding)) received = [] self.client.retrbinary('retr', callback) self.check_data(''.join(received), RETR_DATA) def test_retrbinary_rest(self): def callback(data): - received.append(data.decode('ascii')) + received.append(data.decode(self.client.encoding)) for rest in (0, 10, 20): received = [] self.client.retrbinary('retr', callback, rest=rest) @@ -596,7 +618,7 @@ def test_retrlines(self): @unittest.skip("TODO: RUSTPYTHON; weird limiting to 8192, something w/ buffering?") def test_storbinary(self): - f = io.BytesIO(RETR_DATA.encode('ascii')) + f = io.BytesIO(RETR_DATA.encode(self.client.encoding)) self.client.storbinary('stor', f) self.check_data(self.server.handler_instance.last_received_data, RETR_DATA) # test new callback arg @@ -606,14 +628,16 @@ def test_storbinary(self): self.assertTrue(flag) def test_storbinary_rest(self): - f = io.BytesIO(RETR_DATA.replace('\r\n', '\n').encode('ascii')) + data = RETR_DATA.replace('\r\n', '\n').encode(self.client.encoding) + f = io.BytesIO(data) for r in (30, '30'): f.seek(0) self.client.storbinary('stor', f, rest=r) self.assertEqual(self.server.handler_instance.rest, str(r)) def test_storlines(self): - f = io.BytesIO(RETR_DATA.replace('\r\n', '\n').encode('ascii')) + data = RETR_DATA.replace('\r\n', '\n').encode(self.client.encoding) + f = io.BytesIO(data) self.client.storlines('stor', f) self.check_data(self.server.handler_instance.last_received_data, RETR_DATA) # test new callback arg @@ -623,7 +647,7 @@ def test_storlines(self): self.assertTrue(flag) f = io.StringIO(RETR_DATA.replace('\r\n', '\n')) - # stowarnings_helper.check_warningsary file, not a text file + # storlines() expects a binary file, not a text file with warnings_helper.check_warnings(('', BytesWarning), quiet=True): self.assertRaises(TypeError, self.client.storlines, 'stor foo', f) @@ -707,6 +731,26 @@ def test_makepasv(self): # IPv4 is in use, just make sure send_epsv has not been used self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv') + def test_makepasv_issue43285_security_disabled(self): + """Test the opt-in to the old vulnerable behavior.""" + self.client.trust_server_pasv_ipv4_address = True + bad_host, port = self.client.makepasv() + self.assertEqual( + bad_host, self.server.handler_instance.fake_pasv_server_ip) + # Opening and closing a connection keeps the dummy server happy + # instead of timing out on accept. + socket.create_connection((self.client.sock.getpeername()[0], port), + timeout=TIMEOUT).close() + + def test_makepasv_issue43285_security_enabled_default(self): + self.assertFalse(self.client.trust_server_pasv_ipv4_address) + trusted_host, port = self.client.makepasv() + self.assertNotEqual( + trusted_host, self.server.handler_instance.fake_pasv_server_ip) + # Opening and closing a connection keeps the dummy server happy + # instead of timing out on accept. + socket.create_connection((trusted_host, port), timeout=TIMEOUT).close() + def test_with_statement(self): self.client.quit() @@ -802,14 +846,32 @@ def test_storlines_too_long(self): f = io.BytesIO(b'x' * self.client.maxline * 2) self.assertRaises(ftplib.Error, self.client.storlines, 'stor', f) + def test_encoding_param(self): + encodings = ['latin-1', 'utf-8'] + for encoding in encodings: + with self.subTest(encoding=encoding): + self.tearDown() + self.setUp(encoding=encoding) + self.assertEqual(encoding, self.client.encoding) + self.test_retrbinary() + self.test_storbinary() + self.test_retrlines() + new_dir = self.client.mkd('/non-ascii dir \xAE') + self.check_data(new_dir, '/non-ascii dir \xAE') + # Check default encoding + client = ftplib.FTP(timeout=TIMEOUT) + self.assertEqual(DEFAULT_ENCODING, client.encoding) + @skipUnless(socket_helper.IPV6_ENABLED, "IPv6 not enabled") class TestIPv6Environment(TestCase): def setUp(self): - self.server = DummyFTPServer((HOSTv6, 0), af=socket.AF_INET6) + self.server = DummyFTPServer((HOSTv6, 0), + af=socket.AF_INET6, + encoding=DEFAULT_ENCODING) self.server.start() - self.client = ftplib.FTP(timeout=TIMEOUT) + self.client = ftplib.FTP(timeout=TIMEOUT, encoding=DEFAULT_ENCODING) self.client.connect(self.server.host, self.server.port) def tearDown(self): @@ -836,7 +898,7 @@ def test_makepasv(self): def test_transfer(self): def retr(): def callback(data): - received.append(data.decode('ascii')) + received.append(data.decode(self.client.encoding)) received = [] self.client.retrbinary('retr', callback) self.assertEqual(len(''.join(received)), len(RETR_DATA)) @@ -854,10 +916,10 @@ class TestTLS_FTPClassMixin(TestFTPClass): and data connections first. """ - def setUp(self): - self.server = DummyTLS_FTPServer((HOST, 0)) + def setUp(self, encoding=DEFAULT_ENCODING): + self.server = DummyTLS_FTPServer((HOST, 0), encoding=encoding) self.server.start() - self.client = ftplib.FTP_TLS(timeout=TIMEOUT) + self.client = ftplib.FTP_TLS(timeout=TIMEOUT, encoding=encoding) self.client.connect(self.server.host, self.server.port) # enable TLS self.client.auth() @@ -869,8 +931,8 @@ def setUp(self): class TestTLS_FTPClass(TestCase): """Specific TLS_FTP class tests.""" - def setUp(self): - self.server = DummyTLS_FTPServer((HOST, 0)) + def setUp(self, encoding=DEFAULT_ENCODING): + self.server = DummyTLS_FTPServer((HOST, 0), encoding=encoding) self.server.start() self.client = ftplib.FTP_TLS(timeout=TIMEOUT) self.client.connect(self.server.host, self.server.port) @@ -891,7 +953,8 @@ def test_data_connection(self): # clear text with self.client.transfercmd('list') as sock: self.assertNotIsInstance(sock, ssl.SSLSocket) - self.assertEqual(sock.recv(1024), LIST_DATA.encode('ascii')) + self.assertEqual(sock.recv(1024), + LIST_DATA.encode(self.client.encoding)) self.assertEqual(self.client.voidresp(), "226 transfer complete") # secured, after PROT P @@ -900,14 +963,16 @@ def test_data_connection(self): self.assertIsInstance(sock, ssl.SSLSocket) # consume from SSL socket to finalize handshake and avoid # "SSLError [SSL] shutdown while in init" - self.assertEqual(sock.recv(1024), LIST_DATA.encode('ascii')) + self.assertEqual(sock.recv(1024), + LIST_DATA.encode(self.client.encoding)) self.assertEqual(self.client.voidresp(), "226 transfer complete") # PROT C is issued, the connection must be in cleartext again self.client.prot_c() with self.client.transfercmd('list') as sock: self.assertNotIsInstance(sock, ssl.SSLSocket) - self.assertEqual(sock.recv(1024), LIST_DATA.encode('ascii')) + self.assertEqual(sock.recv(1024), + LIST_DATA.encode(self.client.encoding)) self.assertEqual(self.client.voidresp(), "226 transfer complete") def test_login(self): @@ -1017,7 +1082,7 @@ def server(self): self.evt.set() try: conn, addr = self.sock.accept() - except socket.timeout: + except TimeoutError: pass else: conn.sendall(b"1 Hola mundo\n") @@ -1059,6 +1124,10 @@ def testTimeoutValue(self): self.evt.wait() ftp.close() + # bpo-39259 + with self.assertRaises(ValueError): + ftplib.FTP(HOST, timeout=0) + def testTimeoutConnect(self): ftp = ftplib.FTP() ftp.connect(HOST, timeout=30) @@ -1084,9 +1153,10 @@ def testTimeoutDirectAccess(self): class MiscTestCase(TestCase): def test__all__(self): - not_exported = {'MSG_OOB', 'FTP_PORT', 'MAXLINE', 'CRLF', 'B_CRLF', - 'Error', 'parse150', 'parse227', 'parse229', 'parse257', - 'print_line', 'ftpcp', 'test'} + not_exported = { + 'MSG_OOB', 'FTP_PORT', 'MAXLINE', 'CRLF', 'B_CRLF', 'Error', + 'parse150', 'parse227', 'parse229', 'parse257', 'print_line', + 'ftpcp', 'test'} support.check__all__(self, ftplib, not_exported=not_exported) From 0d0139b3226a501577ea47f8a1cd94bfa7e76b90 Mon Sep 17 00:00:00 2001 From: dvermd <315743+dvermd@users.noreply.github.com> Date: Tue, 3 Oct 2023 15:32:38 +0200 Subject: [PATCH 112/893] Update enum to CPython 3.11.5 (#5074) part of: #4564 --- .github/workflows/ci.yaml | 2 +- Lib/enum.py | 1924 +++++++++++++----- Lib/test/test_enum.py | 3893 ++++++++++++++++++++++++++----------- Lib/test/test_re.py | 2 + Lib/test/test_socket.py | 4 + Lib/test/test_unicode.py | 3 + 6 files changed, 4286 insertions(+), 1542 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5c1b1b6637..12308c1dbc 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -315,7 +315,7 @@ jobs: with: python-version: ${{ env.PYTHON_VERSION }} - name: install ruff - run: python -m pip install ruff + run: python -m pip install ruff==0.0.291 # astral-sh/ruff#7778 - name: run python lint run: ruff extra_tests wasm examples --exclude='./.*',./Lib,./vm/Lib,./benches/ --select=E9,F63,F7,F82 --show-source - name: install prettier diff --git a/Lib/enum.py b/Lib/enum.py index 31afdd3a24..625e9ea56a 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -1,14 +1,40 @@ import sys +import builtins as bltns from types import MappingProxyType, DynamicClassAttribute +from operator import or_ as _or_ +from functools import reduce __all__ = [ - 'EnumMeta', - 'Enum', 'IntEnum', 'Flag', 'IntFlag', - 'auto', 'unique', + 'EnumType', 'EnumMeta', + 'Enum', 'IntEnum', 'StrEnum', 'Flag', 'IntFlag', 'ReprEnum', + 'auto', 'unique', 'property', 'verify', 'member', 'nonmember', + 'FlagBoundary', 'STRICT', 'CONFORM', 'EJECT', 'KEEP', + 'global_flag_repr', 'global_enum_repr', 'global_str', 'global_enum', + 'EnumCheck', 'CONTINUOUS', 'NAMED_FLAGS', 'UNIQUE', + 'pickle_by_global_name', 'pickle_by_enum_name', ] +# Dummy value for Enum and Flag as there are explicit checks for them +# before they have been created. +# This is also why there are checks in EnumType like `if Enum is not None` +Enum = Flag = EJECT = _stdlib_enums = ReprEnum = None + +class nonmember(object): + """ + Protects item from becoming an Enum member during class creation. + """ + def __init__(self, value): + self.value = value + +class member(object): + """ + Forces item to become an Enum member during class creation. + """ + def __init__(self, value): + self.value = value + def _is_descriptor(obj): """ Returns True if obj is a descriptor, False otherwise. @@ -41,33 +67,297 @@ def _is_sunder(name): name[-2:-1] != '_' ) -def _make_class_unpicklable(cls): +def _is_internal_class(cls_name, obj): + # do not use `re` as `re` imports `enum` + if not isinstance(obj, type): + return False + qualname = getattr(obj, '__qualname__', '') + s_pattern = cls_name + '.' + getattr(obj, '__name__', '') + e_pattern = '.' + s_pattern + return qualname == s_pattern or qualname.endswith(e_pattern) + +def _is_private(cls_name, name): + # do not use `re` as `re` imports `enum` + pattern = '_%s__' % (cls_name, ) + pat_len = len(pattern) + if ( + len(name) > pat_len + and name.startswith(pattern) + and name[pat_len:pat_len+1] != ['_'] + and (name[-1] != '_' or name[-2] != '_') + ): + return True + else: + return False + +def _is_single_bit(num): """ - Make the given class un-picklable. + True if only one bit set in num (should be an int) + """ + if num == 0: + return False + num &= num - 1 + return num == 0 + +def _make_class_unpicklable(obj): + """ + Make the given obj un-picklable. + + obj should be either a dictionary, or an Enum """ def _break_on_call_reduce(self, proto): raise TypeError('%r cannot be pickled' % self) - cls.__reduce_ex__ = _break_on_call_reduce - cls.__module__ = '' + if isinstance(obj, dict): + obj['__reduce_ex__'] = _break_on_call_reduce + obj['__module__'] = '' + else: + setattr(obj, '__reduce_ex__', _break_on_call_reduce) + setattr(obj, '__module__', '') + +def _iter_bits_lsb(num): + # num must be a positive integer + original = num + if isinstance(num, Enum): + num = num.value + if num < 0: + raise ValueError('%r is not a positive integer' % original) + while num: + b = num & (~num + 1) + yield b + num ^= b + +def show_flag_values(value): + return list(_iter_bits_lsb(value)) + +def bin(num, max_bits=None): + """ + Like built-in bin(), except negative values are represented in + twos-compliment, and the leading bit always indicates sign + (0=positive, 1=negative). + + >>> bin(10) + '0b0 1010' + >>> bin(~10) # ~10 is -11 + '0b1 0101' + """ + + ceiling = 2 ** (num).bit_length() + if num >= 0: + s = bltns.bin(num + ceiling).replace('1', '0', 1) + else: + s = bltns.bin(~num ^ (ceiling - 1) + ceiling) + sign = s[:3] + digits = s[3:] + if max_bits is not None: + if len(digits) < max_bits: + digits = (sign[-1] * max_bits + digits)[-max_bits:] + return "%s %s" % (sign, digits) + +def _dedent(text): + """ + Like textwrap.dedent. Rewritten because we cannot import textwrap. + """ + lines = text.split('\n') + blanks = 0 + for i, ch in enumerate(lines[0]): + if ch != ' ': + break + for j, l in enumerate(lines): + lines[j] = l[i:] + return '\n'.join(lines) + +class _auto_null: + def __repr__(self): + return '_auto_null' +_auto_null = _auto_null() -_auto_null = object() class auto: """ Instances are replaced with an appropriate value in Enum class suites. """ - value = _auto_null + def __init__(self, value=_auto_null): + self.value = value + + def __repr__(self): + return "auto(%r)" % self.value + +class property(DynamicClassAttribute): + """ + This is a descriptor, used to define attributes that act differently + when accessed through an enum member and through an enum class. + Instance access is the same as property(), but access to an attribute + through the enum class will instead look in the class' _member_map_ for + a corresponding enum member. + """ + + def __get__(self, instance, ownerclass=None): + if instance is None: + try: + return ownerclass._member_map_[self.name] + except KeyError: + raise AttributeError( + '%r has no attribute %r' % (ownerclass, self.name) + ) + else: + if self.fget is None: + # look for a member by this name. + try: + return ownerclass._member_map_[self.name] + except KeyError: + raise AttributeError( + '%r has no attribute %r' % (ownerclass, self.name) + ) from None + else: + return self.fget(instance) + + def __set__(self, instance, value): + if self.fset is None: + raise AttributeError( + " cannot set attribute %r" % (self.clsname, self.name) + ) + else: + return self.fset(instance, value) + + def __delete__(self, instance): + if self.fdel is None: + raise AttributeError( + " cannot delete attribute %r" % (self.clsname, self.name) + ) + else: + return self.fdel(instance) + + def __set_name__(self, ownerclass, name): + self.name = name + self.clsname = ownerclass.__name__ + + +class _proto_member: + """ + intermediate step for enum members between class execution and final creation + """ + + def __init__(self, value): + self.value = value + + def __set_name__(self, enum_class, member_name): + """ + convert each quasi-member into an instance of the new enum class + """ + # first step: remove ourself from enum_class + delattr(enum_class, member_name) + # second step: create member based on enum_class + value = self.value + if not isinstance(value, tuple): + args = (value, ) + else: + args = value + if enum_class._member_type_ is tuple: # special case for tuple enums + args = (args, ) # wrap it one more time + if not enum_class._use_args_: + enum_member = enum_class._new_member_(enum_class) + else: + enum_member = enum_class._new_member_(enum_class, *args) + if not hasattr(enum_member, '_value_'): + if enum_class._member_type_ is object: + enum_member._value_ = value + else: + try: + enum_member._value_ = enum_class._member_type_(*args) + except Exception as exc: + new_exc = TypeError( + '_value_ not set in __new__, unable to create it' + ) + new_exc.__cause__ = exc + raise new_exc + value = enum_member._value_ + enum_member._name_ = member_name + enum_member.__objclass__ = enum_class + enum_member.__init__(*args) + enum_member._sort_order_ = len(enum_class._member_names_) + + if Flag is not None and issubclass(enum_class, Flag): + enum_class._flag_mask_ |= value + if _is_single_bit(value): + enum_class._singles_mask_ |= value + enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1 + + # If another member with the same value was already defined, the + # new member becomes an alias to the existing one. + try: + try: + # try to do a fast lookup to avoid the quadratic loop + enum_member = enum_class._value2member_map_[value] + except TypeError: + for name, canonical_member in enum_class._member_map_.items(): + if canonical_member._value_ == value: + enum_member = canonical_member + break + else: + raise KeyError + except KeyError: + # this could still be an alias if the value is multi-bit and the + # class is a flag class + if ( + Flag is None + or not issubclass(enum_class, Flag) + ): + # no other instances found, record this member in _member_names_ + enum_class._member_names_.append(member_name) + elif ( + Flag is not None + and issubclass(enum_class, Flag) + and _is_single_bit(value) + ): + # no other instances found, record this member in _member_names_ + enum_class._member_names_.append(member_name) + # if necessary, get redirect in place and then add it to _member_map_ + found_descriptor = None + for base in enum_class.__mro__[1:]: + descriptor = base.__dict__.get(member_name) + if descriptor is not None: + if isinstance(descriptor, (property, DynamicClassAttribute)): + found_descriptor = descriptor + break + elif ( + hasattr(descriptor, 'fget') and + hasattr(descriptor, 'fset') and + hasattr(descriptor, 'fdel') + ): + found_descriptor = descriptor + continue + if found_descriptor: + redirect = property() + redirect.member = enum_member + redirect.__set_name__(enum_class, member_name) + # earlier descriptor found; copy fget, fset, fdel to this one. + redirect.fget = found_descriptor.fget + redirect.fset = found_descriptor.fset + redirect.fdel = found_descriptor.fdel + setattr(enum_class, member_name, redirect) + else: + setattr(enum_class, member_name, enum_member) + # now add to _member_map_ (even aliases) + enum_class._member_map_[member_name] = enum_member + try: + # This may fail if value is not hashable. We can't add the value + # to the map, and by-value lookups for this value will be + # linear. + enum_class._value2member_map_.setdefault(value, enum_member) + except TypeError: + # keep track of the value in a list so containment checks are quick + enum_class._unhashable_values_.append(value) class _EnumDict(dict): """ Track enum member order and ensure member names are not reused. - EnumMeta will use the names found in self._member_names as the + EnumType will use the names found in self._member_names as the enumeration member names. """ def __init__(self): super().__init__() - self._member_names = [] + self._member_names = {} # use a dict to keep insertion order self._last_values = [] self._ignore = [] self._auto_called = False @@ -81,17 +371,33 @@ def __setitem__(self, key, value): Single underscore (sunder) names are reserved. """ - if _is_sunder(key): + if _is_internal_class(self._cls_name, value): + import warnings + warnings.warn( + "In 3.13 classes created inside an enum will not become a member. " + "Use the `member` decorator to keep the current behavior.", + DeprecationWarning, + stacklevel=2, + ) + if _is_private(self._cls_name, key): + # also do nothing, name will be a normal attribute + pass + elif _is_sunder(key): if key not in ( - '_order_', '_create_pseudo_member_', - '_generate_next_value_', '_missing_', '_ignore_', + '_order_', + '_generate_next_value_', '_numeric_repr_', '_missing_', '_ignore_', + '_iter_member_', '_iter_member_by_value_', '_iter_member_by_def_', ): - raise ValueError('_names_ are reserved for future Enum use') + raise ValueError( + '_sunder_ names, such as %r, are reserved for future Enum use' + % (key, ) + ) if key == '_generate_next_value_': # check if members already defined as auto() if self._auto_called: raise TypeError("_generate_next_value_ must be defined before members") - setattr(self, '_generate_next_value', value) + _gnv = value.__func__ if isinstance(value, staticmethod) else value + setattr(self, '_generate_next_value', _gnv) elif key == '_ignore_': if isinstance(value, str): value = value.replace(',',' ').split() @@ -109,43 +415,77 @@ def __setitem__(self, key, value): key = '_order_' elif key in self._member_names: # descriptor overwriting an enum? - raise TypeError('Attempted to reuse key: %r' % key) + raise TypeError('%r already defined as %r' % (key, self[key])) elif key in self._ignore: pass - elif not _is_descriptor(value): + elif isinstance(value, nonmember): + # unwrap value here; it won't be processed by the below `else` + value = value.value + elif _is_descriptor(value): + pass + # TODO: uncomment next three lines in 3.13 + # elif _is_internal_class(self._cls_name, value): + # # do nothing, name will be a normal attribute + # pass + else: if key in self: # enum overwriting a descriptor? - raise TypeError('%r already defined as: %r' % (key, self[key])) - if isinstance(value, auto): - if value.value == _auto_null: - value.value = self._generate_next_value( - key, - 1, - len(self._member_names), - self._last_values[:], - ) - self._auto_called = True + raise TypeError('%r already defined as %r' % (key, self[key])) + elif isinstance(value, member): + # unwrap value here -- it will become a member value = value.value - self._member_names.append(key) - self._last_values.append(value) + non_auto_store = True + single = False + if isinstance(value, auto): + single = True + value = (value, ) + if type(value) is tuple and any(isinstance(v, auto) for v in value): + # insist on an actual tuple, no subclasses, in keeping with only supporting + # top-level auto() usage (not contained in any other data structure) + auto_valued = [] + for v in value: + if isinstance(v, auto): + non_auto_store = False + if v.value == _auto_null: + v.value = self._generate_next_value( + key, 1, len(self._member_names), self._last_values[:], + ) + self._auto_called = True + v = v.value + self._last_values.append(v) + auto_valued.append(v) + if single: + value = auto_valued[0] + else: + value = tuple(auto_valued) + self._member_names[key] = None + if non_auto_store: + self._last_values.append(value) super().__setitem__(key, value) + def update(self, members, **more_members): + try: + for name in members.keys(): + self[name] = members[name] + except AttributeError: + for name, value in members: + self[name] = value + for name, value in more_members.items(): + self[name] = value -# Dummy value for Enum as EnumMeta explicitly checks for it, but of course -# until EnumMeta finishes running the first time the Enum class doesn't exist. -# This is also why there are checks in EnumMeta like `if Enum is not None` -Enum = None -class EnumMeta(type): +class EnumType(type): """ Metaclass for Enum """ + @classmethod - def __prepare__(metacls, cls, bases): + def __prepare__(metacls, cls, bases, **kwds): # check that previous enum members do not exist - metacls._check_for_existing_members(cls, bases) + metacls._check_for_existing_members_(cls, bases) # create the namespace dict enum_dict = _EnumDict() + enum_dict._cls_name = cls # inherit previous flags and _generate_next_value_ function member_type, first_enum = metacls._get_mixins_(cls, bases) if first_enum is not None: @@ -154,138 +494,125 @@ def __prepare__(metacls, cls, bases): ) return enum_dict - def __new__(metacls, cls, bases, classdict): + def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **kwds): # an Enum class is final once enumeration items have been defined; it # cannot be mixed with other types (int, float, etc.) if it has an # inherited __new__ unless a new __new__ is defined (or the resulting # class will fail). # + if _simple: + return super().__new__(metacls, cls, bases, classdict, **kwds) + # # remove any keys listed in _ignore_ classdict.setdefault('_ignore_', []).append('_ignore_') ignore = classdict['_ignore_'] for key in ignore: classdict.pop(key, None) + # + # grab member names + member_names = classdict._member_names + # + # check for illegal enum names (any others?) + invalid_names = set(member_names) & {'mro', ''} + if invalid_names: + raise ValueError('invalid enum member name(s) %s' % ( + ','.join(repr(n) for n in invalid_names) + )) + # + # adjust the sunders + _order_ = classdict.pop('_order_', None) + # convert to normal dict + classdict = dict(classdict.items()) + # + # data type of member and the controlling Enum class member_type, first_enum = metacls._get_mixins_(cls, bases) __new__, save_new, use_args = metacls._find_new_( classdict, member_type, first_enum, ) - - # save enum items into separate mapping so they don't get baked into - # the new class - enum_members = {k: classdict[k] for k in classdict._member_names} - for name in classdict._member_names: - del classdict[name] - - # adjust the sunders - _order_ = classdict.pop('_order_', None) - - # check for illegal enum names (any others?) - invalid_names = set(enum_members) & {'mro', ''} - if invalid_names: - raise ValueError('Invalid enum member name: {0}'.format( - ','.join(invalid_names))) - - # create a default docstring if one has not been provided - if '__doc__' not in classdict: - classdict['__doc__'] = 'An enumeration.' - - # create our new Enum type - enum_class = super().__new__(metacls, cls, bases, classdict) - enum_class._member_names_ = [] # names in definition order - enum_class._member_map_ = {} # name->value map - enum_class._member_type_ = member_type - - # save DynamicClassAttribute attributes from super classes so we know - # if we can take the shortcut of storing members in the class dict - dynamic_attributes = { - k for c in enum_class.mro() - for k, v in c.__dict__.items() - if isinstance(v, DynamicClassAttribute) - } - - # Reverse value->name map for hashable values. - enum_class._value2member_map_ = {} - - # If a custom type is mixed into the Enum, and it does not know how - # to pickle itself, pickle.dumps will succeed but pickle.loads will - # fail. Rather than have the error show up later and possibly far - # from the source, sabotage the pickle protocol for this class so - # that pickle.dumps also fails. + classdict['_new_member_'] = __new__ + classdict['_use_args_'] = use_args + # + # convert future enum members into temporary _proto_members + for name in member_names: + value = classdict[name] + classdict[name] = _proto_member(value) + # + # house-keeping structures + classdict['_member_names_'] = [] + classdict['_member_map_'] = {} + classdict['_value2member_map_'] = {} + classdict['_unhashable_values_'] = [] + classdict['_member_type_'] = member_type + # now set the __repr__ for the value + classdict['_value_repr_'] = metacls._find_data_repr_(cls, bases) + # + # Flag structures (will be removed if final class is not a Flag + classdict['_boundary_'] = ( + boundary + or getattr(first_enum, '_boundary_', None) + ) + classdict['_flag_mask_'] = 0 + classdict['_singles_mask_'] = 0 + classdict['_all_bits_'] = 0 + classdict['_inverted_'] = None + try: + exc = None + enum_class = super().__new__(metacls, cls, bases, classdict, **kwds) + except RuntimeError as e: + # any exceptions raised by member.__new__ will get converted to a + # RuntimeError, so get that original exception back and raise it instead + exc = e.__cause__ or e + if exc is not None: + raise exc + # + # update classdict with any changes made by __init_subclass__ + classdict.update(enum_class.__dict__) # - # However, if the new class implements its own __reduce_ex__, do not - # sabotage -- it's on them to make sure it works correctly. We use - # __reduce_ex__ instead of any of the others as it is preferred by - # pickle over __reduce__, and it handles all pickle protocols. - if '__reduce_ex__' not in classdict: - if member_type is not object: - methods = ('__getnewargs_ex__', '__getnewargs__', - '__reduce_ex__', '__reduce__') - if not any(m in member_type.__dict__ for m in methods): - _make_class_unpicklable(enum_class) - - # instantiate them, checking for duplicates as we go - # we instantiate first instead of checking for duplicates first in case - # a custom __new__ is doing something funky with the values -- such as - # auto-numbering ;) - for member_name in classdict._member_names: - value = enum_members[member_name] - if not isinstance(value, tuple): - args = (value, ) - else: - args = value - if member_type is tuple: # special case for tuple enums - args = (args, ) # wrap it one more time - if not use_args: - enum_member = __new__(enum_class) - if not hasattr(enum_member, '_value_'): - enum_member._value_ = value - else: - enum_member = __new__(enum_class, *args) - if not hasattr(enum_member, '_value_'): - if member_type is object: - enum_member._value_ = value - else: - enum_member._value_ = member_type(*args) - value = enum_member._value_ - enum_member._name_ = member_name - enum_member.__objclass__ = enum_class - enum_member.__init__(*args) - # If another member with the same value was already defined, the - # new member becomes an alias to the existing one. - for name, canonical_member in enum_class._member_map_.items(): - if canonical_member._value_ == enum_member._value_: - enum_member = canonical_member - break - else: - # Aliases don't appear in member names (only in __members__). - enum_class._member_names_.append(member_name) - # performance boost for any member that would not shadow - # a DynamicClassAttribute - if member_name not in dynamic_attributes: - setattr(enum_class, member_name, enum_member) - # now add to _member_map_ - enum_class._member_map_[member_name] = enum_member - try: - # This may fail if value is not hashable. We can't add the value - # to the map, and by-value lookups for this value will be - # linear. - enum_class._value2member_map_[value] = enum_member - except TypeError: - pass - # double check that repr and friends are not the mixin's or various # things break (such as pickle) # however, if the method is defined in the Enum itself, don't replace # it + # + # Also, special handling for ReprEnum + if ReprEnum is not None and ReprEnum in bases: + if member_type is object: + raise TypeError( + 'ReprEnum subclasses must be mixed with a data type (i.e.' + ' int, str, float, etc.)' + ) + if '__format__' not in classdict: + enum_class.__format__ = member_type.__format__ + classdict['__format__'] = enum_class.__format__ + if '__str__' not in classdict: + method = member_type.__str__ + if method is object.__str__: + # if member_type does not define __str__, object.__str__ will use + # its __repr__ instead, so we'll also use its __repr__ + method = member_type.__repr__ + enum_class.__str__ = method + classdict['__str__'] = enum_class.__str__ for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): - if name in classdict: - continue - class_method = getattr(enum_class, name) - obj_method = getattr(member_type, name, None) - enum_method = getattr(first_enum, name, None) - if obj_method is not None and obj_method is class_method: - setattr(enum_class, name, enum_method) - + if name not in classdict: + # check for mixin overrides before replacing + enum_method = getattr(first_enum, name) + found_method = getattr(enum_class, name) + object_method = getattr(object, name) + data_type_method = getattr(member_type, name) + if found_method in (data_type_method, object_method): + setattr(enum_class, name, enum_method) + # + # for Flag, add __or__, __and__, __xor__, and __invert__ + if Flag is not None and issubclass(enum_class, Flag): + for name in ( + '__or__', '__and__', '__xor__', + '__ror__', '__rand__', '__rxor__', + '__invert__' + ): + if name not in classdict: + enum_method = getattr(Flag, name) + setattr(enum_class, name, enum_method) + classdict[name] = enum_method + # # replace any other __new__ with our own (as long as Enum is not None, # anyway) -- again, this is to support pickle if Enum is not None: @@ -294,23 +621,69 @@ def __new__(metacls, cls, bases, classdict): if save_new: enum_class.__new_member__ = __new__ enum_class.__new__ = Enum.__new__ - + # # py3 support for definition order (helps keep py2/py3 code in sync) + # + # _order_ checking is spread out into three/four steps + # - if enum_class is a Flag: + # - remove any non-single-bit flags from _order_ + # - remove any aliases from _order_ + # - check that _order_ and _member_names_ match + # + # step 1: ensure we have a list if _order_ is not None: if isinstance(_order_, str): _order_ = _order_.replace(',', ' ').split() + # + # remove Flag structures if final class is not a Flag + if ( + Flag is None and cls != 'Flag' + or Flag is not None and not issubclass(enum_class, Flag) + ): + delattr(enum_class, '_boundary_') + delattr(enum_class, '_flag_mask_') + delattr(enum_class, '_singles_mask_') + delattr(enum_class, '_all_bits_') + delattr(enum_class, '_inverted_') + elif Flag is not None and issubclass(enum_class, Flag): + # set correct __iter__ + member_list = [m._value_ for m in enum_class] + if member_list != sorted(member_list): + enum_class._iter_member_ = enum_class._iter_member_by_def_ + if _order_: + # _order_ step 2: remove any items from _order_ that are not single-bit + _order_ = [ + o + for o in _order_ + if o not in enum_class._member_map_ or _is_single_bit(enum_class[o]._value_) + ] + # + if _order_: + # _order_ step 3: remove aliases from _order_ + _order_ = [ + o + for o in _order_ + if ( + o not in enum_class._member_map_ + or + (o in enum_class._member_map_ and o in enum_class._member_names_) + )] + # _order_ step 4: verify that _order_ and _member_names_ match if _order_ != enum_class._member_names_: - raise TypeError('member order does not match _order_') + raise TypeError( + 'member order does not match _order_:\n %r\n %r' + % (enum_class._member_names_, _order_) + ) return enum_class - def __bool__(self): + def __bool__(cls): """ classes/types should always be True. """ return True - def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, start=1): + def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None): """ Either returns an existing member, or creates a new enum class. @@ -345,10 +718,25 @@ def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, s qualname=qualname, type=type, start=start, + boundary=boundary, ) def __contains__(cls, member): + """ + Return True if member is a member of this enum + raises TypeError if member is not an enum member + + note: in 3.12 TypeError will no longer be raised, and True will also be + returned if member is the value of a member in this enum + """ if not isinstance(member, Enum): + import warnings + warnings.warn( + "in 3.12 __contains__ will no longer raise TypeError, but will return True or\n" + "False depending on whether the value is a member or the value of a member", + DeprecationWarning, + stacklevel=2, + ) raise TypeError( "unsupported operand type(s) for 'in': '%s' and '%s'" % ( type(member).__qualname__, cls.__class__.__qualname__)) @@ -358,14 +746,26 @@ def __delattr__(cls, attr): # nicer error message when someone tries to delete an attribute # (see issue19025). if attr in cls._member_map_: - raise AttributeError("%s: cannot delete Enum member." % cls.__name__) + raise AttributeError("%r cannot delete member %r." % (cls.__name__, attr)) super().__delattr__(attr) - def __dir__(self): - return ( - ['__class__', '__doc__', '__members__', '__module__'] - + self._member_names_ + def __dir__(cls): + interesting = set([ + '__class__', '__contains__', '__doc__', '__getitem__', + '__iter__', '__len__', '__members__', '__module__', + '__name__', '__qualname__', + ] + + cls._member_names_ ) + if cls._new_member_ is not object.__new__: + interesting.add('__new__') + if cls.__init_subclass__ is not object.__init_subclass__: + interesting.add('__init_subclass__') + if cls._member_type_ is object: + return sorted(interesting) + else: + # return whatever mixed-in data type has + return sorted(set(dir(cls._member_type_)) | interesting) def __getattr__(cls, name): """ @@ -384,18 +784,24 @@ def __getattr__(cls, name): raise AttributeError(name) from None def __getitem__(cls, name): + """ + Return the member matching `name`. + """ return cls._member_map_[name] def __iter__(cls): """ - Returns members in definition order. + Return members in definition order. """ return (cls._member_map_[name] for name in cls._member_names_) def __len__(cls): + """ + Return the number of members (no aliases) + """ return len(cls._member_names_) - @property + @bltns.property def __members__(cls): """ Returns a mapping of member name->value. @@ -406,11 +812,14 @@ def __members__(cls): return MappingProxyType(cls._member_map_) def __repr__(cls): - return "" % cls.__name__ + if Flag is not None and issubclass(cls, Flag): + return "" % cls.__name__ + else: + return "" % cls.__name__ def __reversed__(cls): """ - Returns members in reverse definition order. + Return members in reverse definition order. """ return (cls._member_map_[name] for name in reversed(cls._member_names_)) @@ -424,10 +833,10 @@ def __setattr__(cls, name, value): """ member_map = cls.__dict__.get('_member_map_', {}) if name in member_map: - raise AttributeError('Cannot reassign members.') + raise AttributeError('cannot reassign member %r' % (name, )) super().__setattr__(name, value) - def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, start=1): + def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, start=1, boundary=None): """ Convenience method to create a new Enum class. @@ -441,7 +850,7 @@ def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, s """ metacls = cls.__class__ bases = (cls, ) if type is None else (type, cls) - _, first_enum = cls._get_mixins_(cls, bases) + _, first_enum = cls._get_mixins_(class_name, bases) classdict = metacls.__prepare__(class_name, bases) # special processing needed for names? @@ -462,25 +871,24 @@ def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, s else: member_name, member_value = item classdict[member_name] = member_value - enum_class = metacls.__new__(metacls, class_name, bases, classdict) # TODO: replace the frame hack if a blessed way to know the calling # module is ever developed if module is None: try: module = sys._getframe(2).f_globals['__name__'] - except (AttributeError, ValueError, KeyError) as exc: + except (AttributeError, ValueError, KeyError): pass if module is None: - _make_class_unpicklable(enum_class) + _make_class_unpicklable(classdict) else: - enum_class.__module__ = module + classdict['__module__'] = module if qualname is not None: - enum_class.__qualname__ = qualname + classdict['__qualname__'] = qualname - return enum_class + return metacls.__new__(metacls, class_name, bases, classdict, boundary=boundary) - def _convert_(cls, name, module, filter, source=None): + def _convert_(cls, name, module, filter, source=None, *, boundary=None, as_global=False): """ Create a new Enum subclass that replaces a collection of global constants """ @@ -489,9 +897,9 @@ def _convert_(cls, name, module, filter, source=None): # module; # also, replace the __reduce_ex__ method so unpickling works in # previous Python versions - module_globals = vars(sys.modules[module]) + module_globals = sys.modules[module].__dict__ if source: - source = vars(source) + source = source.__dict__ else: source = module_globals # _value2member_map_ is populated in the same order every time @@ -507,30 +915,29 @@ def _convert_(cls, name, module, filter, source=None): except TypeError: # unless some values aren't comparable, in which case sort by name members.sort(key=lambda t: t[0]) - cls = cls(name, members, module=module) - cls.__reduce_ex__ = _reduce_ex_by_name - module_globals.update(cls.__members__) + body = {t[0]: t[1] for t in members} + body['__module__'] = module + tmp_cls = type(name, (object, ), body) + cls = _simple_enum(etype=cls, boundary=boundary or KEEP)(tmp_cls) + if as_global: + global_enum(cls) + else: + sys.modules[cls.__module__].__dict__.update(cls.__members__) module_globals[name] = cls return cls - def _convert(cls, *args, **kwargs): - import warnings - warnings.warn("_convert is deprecated and will be removed in 3.9, use " - "_convert_ instead.", DeprecationWarning, stacklevel=2) - return cls._convert_(*args, **kwargs) - - @staticmethod - def _check_for_existing_members(class_name, bases): + @classmethod + def _check_for_existing_members_(mcls, class_name, bases): for chain in bases: for base in chain.__mro__: - if issubclass(base, Enum) and base._member_names_: + if isinstance(base, EnumType) and base._member_names_: raise TypeError( - "%s: cannot extend enumeration %r" - % (class_name, base.__name__) + " cannot extend %r" + % (class_name, base) ) - @staticmethod - def _get_mixins_(class_name, bases): + @classmethod + def _get_mixins_(mcls, class_name, bases): """ Returns the type for creating enum members, and the first inherited enum class. @@ -540,44 +947,62 @@ def _get_mixins_(class_name, bases): if not bases: return object, Enum - def _find_data_type(bases): - data_types = [] - for chain in bases: - candidate = None - for base in chain.__mro__: - if base is object: - continue - elif issubclass(base, Enum): - if base._member_type_ is not object: - data_types.append(base._member_type_) - break - elif '__new__' in base.__dict__: - if issubclass(base, Enum): - continue - data_types.append(candidate or base) - break - else: - candidate = base - if len(data_types) > 1: - raise TypeError('%r: too many data types: %r' % (class_name, data_types)) - elif data_types: - return data_types[0] - else: - return None + mcls._check_for_existing_members_(class_name, bases) # ensure final parent class is an Enum derivative, find any concrete # data type, and check that Enum has no members first_enum = bases[-1] - if not issubclass(first_enum, Enum): + if not isinstance(first_enum, EnumType): raise TypeError("new enumerations should be created as " "`EnumName([mixin_type, ...] [data_type,] enum_type)`") - member_type = _find_data_type(bases) or object - if first_enum._member_names_: - raise TypeError("Cannot extend enumerations") + member_type = mcls._find_data_type_(class_name, bases) or object return member_type, first_enum - @staticmethod - def _find_new_(classdict, member_type, first_enum): + @classmethod + def _find_data_repr_(mcls, class_name, bases): + for chain in bases: + for base in chain.__mro__: + if base is object: + continue + elif isinstance(base, EnumType): + # if we hit an Enum, use it's _value_repr_ + return base._value_repr_ + elif '__repr__' in base.__dict__: + # this is our data repr + return base.__dict__['__repr__'] + return None + + @classmethod + def _find_data_type_(mcls, class_name, bases): + # a datatype has a __new__ method + data_types = set() + base_chain = set() + for chain in bases: + candidate = None + for base in chain.__mro__: + base_chain.add(base) + if base is object: + continue + elif isinstance(base, EnumType): + if base._member_type_ is not object: + data_types.add(base._member_type_) + break + elif '__new__' in base.__dict__ or '__dataclass_fields__' in base.__dict__: + if isinstance(base, EnumType): + continue + data_types.add(candidate or base) + break + else: + candidate = candidate or base + if len(data_types) > 1: + raise TypeError('too many data types for %r: %r' % (class_name, data_types)) + elif data_types: + return data_types.pop() + else: + return None + + @classmethod + def _find_new_(mcls, classdict, member_type, first_enum): """ Returns the __new__ to be used for creating the enum members. @@ -591,7 +1016,7 @@ def _find_new_(classdict, member_type, first_enum): __new__ = classdict.get('__new__', None) # should __new__ be saved as __new_member__ later? - save_new = __new__ is not None + save_new = first_enum is not None and __new__ is not None if __new__ is None: # check all possibles for __new_member__ before falling back to @@ -615,19 +1040,54 @@ def _find_new_(classdict, member_type, first_enum): # if a non-object.__new__ is used then whatever value/tuple was # assigned to the enum member name will be passed to __new__ and to the # new enum member's __init__ - if __new__ is object.__new__: + if first_enum is None or __new__ in (Enum.__new__, object.__new__): use_args = False else: use_args = True return __new__, save_new, use_args +EnumMeta = EnumType -class Enum(metaclass=EnumMeta): +class Enum(metaclass=EnumType): """ - Generic enumeration. + Create a collection of name/value pairs. + + Example enumeration: + + >>> class Color(Enum): + ... RED = 1 + ... BLUE = 2 + ... GREEN = 3 + + Access them by: + + - attribute access:: + + >>> Color.RED + + + - value lookup: + + >>> Color(1) + + + - name lookup: + + >>> Color['RED'] + + + Enumerations can be iterated over, and know how many members they have: + + >>> len(Color) + 3 - Derive from this class to define new enumerations. + >>> list(Color) + [, , ] + + Methods can be added to enumerations, and members can have their own + attributes -- see the documentation for details. """ + def __new__(cls, value): # all enum instances are actually created during class construction # without calling this method; this method is called by the metaclass' @@ -654,19 +1114,33 @@ def __new__(cls, value): except Exception as e: exc = e result = None - if isinstance(result, cls): - return result - else: - ve_exc = ValueError("%r is not a valid %s" % (value, cls.__name__)) - if result is None and exc is None: - raise ve_exc - elif exc is None: - exc = TypeError( - 'error in %s._missing_: returned %r instead of None or a valid member' - % (cls.__name__, result) - ) - exc.__context__ = ve_exc - raise exc + try: + if isinstance(result, cls): + return result + elif ( + Flag is not None and issubclass(cls, Flag) + and cls._boundary_ is EJECT and isinstance(result, int) + ): + return result + else: + ve_exc = ValueError("%r is not a valid %s" % (value, cls.__qualname__)) + if result is None and exc is None: + raise ve_exc + elif exc is None: + exc = TypeError( + 'error in %s._missing_: returned %r instead of None or a valid member' + % (cls.__name__, result) + ) + if not isinstance(exc, ValueError): + exc.__context__ = ve_exc + raise exc + finally: + # ensure all variables that could hold an exception are destroyed + exc = None + ve_exc = None + + def __init__(self, *args, **kwds): + pass def _generate_next_value_(name, start, count, last_values): """ @@ -675,14 +1149,32 @@ def _generate_next_value_(name, start, count, last_values): name: the name of the member start: the initial start value or None count: the number of existing members - last_value: the last value assigned or None + last_values: the list of values assigned """ - for last_value in reversed(last_values): - try: - return last_value + 1 - except TypeError: - pass - else: + if not last_values: + return start + try: + last = last_values[-1] + last_values.sort() + if last == last_values[-1]: + # no difference between old and new methods + return last + 1 + else: + # trigger old method (with warning) + raise TypeError + except TypeError: + import warnings + warnings.warn( + "In 3.13 the default `auto()`/`_generate_next_value_` will require all values to be sortable and support adding +1\n" + "and the value returned will be the largest value in the enum incremented by 1", + DeprecationWarning, + stacklevel=3, + ) + for v in reversed(last_values): + try: + return v + 1 + except TypeError: + pass return start @classmethod @@ -690,42 +1182,44 @@ def _missing_(cls, value): return None def __repr__(self): - return "<%s.%s: %r>" % ( - self.__class__.__name__, self._name_, self._value_) + v_repr = self.__class__._value_repr_ or repr + return "<%s.%s: %s>" % (self.__class__.__name__, self._name_, v_repr(self._value_)) def __str__(self): - return "%s.%s" % (self.__class__.__name__, self._name_) + return "%s.%s" % (self.__class__.__name__, self._name_, ) def __dir__(self): """ Returns all members and all public methods """ - added_behavior = [ - m - for cls in self.__class__.mro() - for m in cls.__dict__ - if m[0] != '_' and m not in self._member_map_ - ] + [m for m in self.__dict__ if m[0] != '_'] - return (['__class__', '__doc__', '__module__'] + added_behavior) + if self.__class__._member_type_ is object: + interesting = set(['__class__', '__doc__', '__eq__', '__hash__', '__module__', 'name', 'value']) + else: + interesting = set(object.__dir__(self)) + for name in getattr(self, '__dict__', []): + if name[0] != '_': + interesting.add(name) + for cls in self.__class__.mro(): + for name, obj in cls.__dict__.items(): + if name[0] == '_': + continue + if isinstance(obj, property): + # that's an enum.property + if obj.fget is not None or name not in self._member_map_: + interesting.add(name) + else: + # in case it was added by `dir(self)` + interesting.discard(name) + else: + interesting.add(name) + names = sorted( + set(['__class__', '__doc__', '__eq__', '__hash__', '__module__']) + | interesting + ) + return names def __format__(self, format_spec): - """ - Returns format using actual value type unless __str__ has been overridden. - """ - # mixed-in Enums should use the mixed-in type's __format__, otherwise - # we can get strange results with the Enum name showing up instead of - # the value - - # pure Enum branch, or branch with __str__ explicitly overridden - str_overridden = type(self).__str__ not in (Enum.__str__, Flag.__str__) - if self._member_type_ is object or str_overridden: - cls = str - val = str(self) - # mix-in branch - else: - cls = self._member_type_ - val = self._value_ - return cls.__format__(val, format_spec) + return str.__format__(str(self), format_spec) def __hash__(self): return hash(self._name_) @@ -733,36 +1227,107 @@ def __hash__(self): def __reduce_ex__(self, proto): return self.__class__, (self._value_, ) - # DynamicClassAttribute is used to provide access to the `name` and - # `value` properties of enum members while keeping some measure of + def __deepcopy__(self,memo): + return self + + def __copy__(self): + return self + + # enum.property is used to provide access to the `name` and + # `value` attributes of enum members while keeping some measure of # protection from modification, while still allowing for an enumeration # to have members named `name` and `value`. This works because enumeration - # members are not set directly on the enum class -- __getattr__ is - # used to look them up. + # members are not set directly on the enum class; they are kept in a + # separate structure, _member_map_, which is where enum.property looks for + # them - @DynamicClassAttribute + @property def name(self): """The name of the Enum member.""" return self._name_ - @DynamicClassAttribute + @property def value(self): """The value of the Enum member.""" return self._value_ -class IntEnum(int, Enum): - """Enum where members are also (and must be) ints""" +class ReprEnum(Enum): + """ + Only changes the repr(), leaving str() and format() to the mixed-in type. + """ -def _reduce_ex_by_name(self, proto): +class IntEnum(int, ReprEnum): + """ + Enum where members are also (and must be) ints + """ + + +class StrEnum(str, ReprEnum): + """ + Enum where members are also (and must be) strings + """ + + def __new__(cls, *values): + "values must already be of type `str`" + if len(values) > 3: + raise TypeError('too many arguments for str(): %r' % (values, )) + if len(values) == 1: + # it must be a string + if not isinstance(values[0], str): + raise TypeError('%r is not a string' % (values[0], )) + if len(values) >= 2: + # check that encoding argument is a string + if not isinstance(values[1], str): + raise TypeError('encoding must be a string, not %r' % (values[1], )) + if len(values) == 3: + # check that errors argument is a string + if not isinstance(values[2], str): + raise TypeError('errors must be a string, not %r' % (values[2])) + value = str(*values) + member = str.__new__(cls, value) + member._value_ = value + return member + + def _generate_next_value_(name, start, count, last_values): + """ + Return the lower-cased version of the member name. + """ + return name.lower() + + +def pickle_by_global_name(self, proto): + # should not be used with Flag-type enums return self.name +_reduce_ex_by_global_name = pickle_by_global_name + +def pickle_by_enum_name(self, proto): + # should not be used with Flag-type enums + return getattr, (self.__class__, self._name_) -class Flag(Enum): +class FlagBoundary(StrEnum): + """ + control how out of range values are handled + "strict" -> error is raised [default for Flag] + "conform" -> extra bits are discarded + "eject" -> lose flag status + "keep" -> keep flag status and all bits [default for IntFlag] + """ + STRICT = auto() + CONFORM = auto() + EJECT = auto() + KEEP = auto() +STRICT, CONFORM, EJECT, KEEP = FlagBoundary + + +class Flag(Enum, boundary=STRICT): """ Support for flags """ + _numeric_repr_ = repr + def _generate_next_value_(name, start, count, last_values): """ Generate the next value when not given. @@ -770,49 +1335,128 @@ def _generate_next_value_(name, start, count, last_values): name: the name of the member start: the initial start value or None count: the number of existing members - last_value: the last value assigned or None + last_values: the last value assigned or None """ if not count: return start if start is not None else 1 - for last_value in reversed(last_values): - try: - high_bit = _high_bit(last_value) - break - except Exception: - raise TypeError('Invalid Flag value: %r' % last_value) from None + last_value = max(last_values) + try: + high_bit = _high_bit(last_value) + except Exception: + raise TypeError('invalid flag value %r' % last_value) from None return 2 ** (high_bit+1) @classmethod - def _missing_(cls, value): + def _iter_member_by_value_(cls, value): """ - Returns member (possibly creating it) if one can be found for value. + Extract all members from the value in definition (i.e. increasing value) order. """ - original_value = value - if value < 0: - value = ~value - possible_member = cls._create_pseudo_member_(value) - if original_value < 0: - possible_member = ~possible_member - return possible_member + for val in _iter_bits_lsb(value & cls._flag_mask_): + yield cls._value2member_map_.get(val) + + _iter_member_ = _iter_member_by_value_ @classmethod - def _create_pseudo_member_(cls, value): + def _iter_member_by_def_(cls, value): """ - Create a composite member iff value contains only members. + Extract all members from the value in definition order. """ - pseudo_member = cls._value2member_map_.get(value, None) - if pseudo_member is None: - # verify all bits are accounted for - _, extra_flags = _decompose(cls, value) - if extra_flags: - raise ValueError("%r is not a valid %s" % (value, cls.__name__)) + yield from sorted( + cls._iter_member_by_value_(value), + key=lambda m: m._sort_order_, + ) + + @classmethod + def _missing_(cls, value): + """ + Create a composite member containing all canonical members present in `value`. + + If non-member values are present, result depends on `_boundary_` setting. + """ + if not isinstance(value, int): + raise ValueError( + "%r is not a valid %s" % (value, cls.__qualname__) + ) + # check boundaries + # - value must be in range (e.g. -16 <-> +15, i.e. ~15 <-> 15) + # - value must not include any skipped flags (e.g. if bit 2 is not + # defined, then 0d10 is invalid) + flag_mask = cls._flag_mask_ + singles_mask = cls._singles_mask_ + all_bits = cls._all_bits_ + neg_value = None + if ( + not ~all_bits <= value <= all_bits + or value & (all_bits ^ flag_mask) + ): + if cls._boundary_ is STRICT: + max_bits = max(value.bit_length(), flag_mask.bit_length()) + raise ValueError( + "%r invalid value %r\n given %s\n allowed %s" % ( + cls, value, bin(value, max_bits), bin(flag_mask, max_bits), + )) + elif cls._boundary_ is CONFORM: + value = value & flag_mask + elif cls._boundary_ is EJECT: + return value + elif cls._boundary_ is KEEP: + if value < 0: + value = ( + max(all_bits+1, 2**(value.bit_length())) + + value + ) + else: + raise ValueError( + '%r unknown flag boundary %r' % (cls, cls._boundary_, ) + ) + if value < 0: + neg_value = value + value = all_bits + 1 + value + # get members and unknown + unknown = value & ~flag_mask + aliases = value & ~singles_mask + member_value = value & singles_mask + if unknown and cls._boundary_ is not KEEP: + raise ValueError( + '%s(%r) --> unknown values %r [%s]' + % (cls.__name__, value, unknown, bin(unknown)) + ) + # normal Flag? + if cls._member_type_ is object: # construct a singleton enum pseudo-member pseudo_member = object.__new__(cls) - pseudo_member._name_ = None + else: + pseudo_member = cls._member_type_.__new__(cls, value) + if not hasattr(pseudo_member, '_value_'): pseudo_member._value_ = value - # use setdefault in case another thread already created a composite - # with this value - pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) + if member_value or aliases: + members = [] + combined_value = 0 + for m in cls._iter_member_(member_value): + members.append(m) + combined_value |= m._value_ + if aliases: + value = member_value | aliases + for n, pm in cls._member_map_.items(): + if pm not in members and pm._value_ and pm._value_ & value == pm._value_: + members.append(pm) + combined_value |= pm._value_ + unknown = value ^ combined_value + pseudo_member._name_ = '|'.join([m._name_ for m in members]) + if not combined_value: + pseudo_member._name_ = None + elif unknown and cls._boundary_ is STRICT: + raise ValueError('%r: no members with value %r' % (cls, unknown)) + elif unknown: + pseudo_member._name_ += '|%s' % cls._numeric_repr_(unknown) + else: + pseudo_member._name_ = None + # use setdefault in case another thread already created a composite + # with this value + # note: zero is a special case -- always add it + pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) + if neg_value is not None: + cls._value2member_map_[neg_value] = pseudo_member return pseudo_member def __contains__(self, other): @@ -821,133 +1465,85 @@ def __contains__(self, other): """ if not isinstance(other, self.__class__): raise TypeError( - "unsupported operand type(s) for 'in': '%s' and '%s'" % ( + "unsupported operand type(s) for 'in': %r and %r" % ( type(other).__qualname__, self.__class__.__qualname__)) return other._value_ & self._value_ == other._value_ + def __iter__(self): + """ + Returns flags in definition order. + """ + yield from self._iter_member_(self._value_) + + def __len__(self): + return self._value_.bit_count() + def __repr__(self): - cls = self.__class__ - if self._name_ is not None: - return '<%s.%s: %r>' % (cls.__name__, self._name_, self._value_) - members, uncovered = _decompose(cls, self._value_) - return '<%s.%s: %r>' % ( - cls.__name__, - '|'.join([str(m._name_ or m._value_) for m in members]), - self._value_, - ) + cls_name = self.__class__.__name__ + v_repr = self.__class__._value_repr_ or repr + if self._name_ is None: + return "<%s: %s>" % (cls_name, v_repr(self._value_)) + else: + return "<%s.%s: %s>" % (cls_name, self._name_, v_repr(self._value_)) def __str__(self): - cls = self.__class__ - if self._name_ is not None: - return '%s.%s' % (cls.__name__, self._name_) - members, uncovered = _decompose(cls, self._value_) - if len(members) == 1 and members[0]._name_ is None: - return '%s.%r' % (cls.__name__, members[0]._value_) + cls_name = self.__class__.__name__ + if self._name_ is None: + return '%s(%r)' % (cls_name, self._value_) else: - return '%s.%s' % ( - cls.__name__, - '|'.join([str(m._name_ or m._value_) for m in members]), - ) + return "%s.%s" % (cls_name, self._name_) def __bool__(self): return bool(self._value_) def __or__(self, other): - if not isinstance(other, self.__class__): + if isinstance(other, self.__class__): + other = other._value_ + elif self._member_type_ is not object and isinstance(other, self._member_type_): + other = other + else: return NotImplemented - return self.__class__(self._value_ | other._value_) + value = self._value_ + return self.__class__(value | other) def __and__(self, other): - if not isinstance(other, self.__class__): + if isinstance(other, self.__class__): + other = other._value_ + elif self._member_type_ is not object and isinstance(other, self._member_type_): + other = other + else: return NotImplemented - return self.__class__(self._value_ & other._value_) + value = self._value_ + return self.__class__(value & other) def __xor__(self, other): - if not isinstance(other, self.__class__): + if isinstance(other, self.__class__): + other = other._value_ + elif self._member_type_ is not object and isinstance(other, self._member_type_): + other = other + else: return NotImplemented - return self.__class__(self._value_ ^ other._value_) + value = self._value_ + return self.__class__(value ^ other) def __invert__(self): - members, uncovered = _decompose(self.__class__, self._value_) - inverted = self.__class__(0) - for m in self.__class__: - if m not in members and not (m._value_ & self._value_): - inverted = inverted | m - return self.__class__(inverted) + if self._inverted_ is None: + if self._boundary_ in (EJECT, KEEP): + self._inverted_ = self.__class__(~self._value_) + else: + self._inverted_ = self.__class__(self._singles_mask_ & ~self._value_) + return self._inverted_ + + __rand__ = __and__ + __ror__ = __or__ + __rxor__ = __xor__ -class IntFlag(int, Flag): +class IntFlag(int, ReprEnum, Flag, boundary=KEEP): """ Support for integer-based Flags """ - @classmethod - def _missing_(cls, value): - """ - Returns member (possibly creating it) if one can be found for value. - """ - if not isinstance(value, int): - raise ValueError("%r is not a valid %s" % (value, cls.__name__)) - new_member = cls._create_pseudo_member_(value) - return new_member - - @classmethod - def _create_pseudo_member_(cls, value): - """ - Create a composite member iff value contains only members. - """ - pseudo_member = cls._value2member_map_.get(value, None) - if pseudo_member is None: - need_to_create = [value] - # get unaccounted for bits - _, extra_flags = _decompose(cls, value) - # timer = 10 - while extra_flags: - # timer -= 1 - bit = _high_bit(extra_flags) - flag_value = 2 ** bit - if (flag_value not in cls._value2member_map_ and - flag_value not in need_to_create - ): - need_to_create.append(flag_value) - if extra_flags == -flag_value: - extra_flags = 0 - else: - extra_flags ^= flag_value - for value in reversed(need_to_create): - # construct singleton pseudo-members - pseudo_member = int.__new__(cls, value) - pseudo_member._name_ = None - pseudo_member._value_ = value - # use setdefault in case another thread already created a composite - # with this value - pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) - return pseudo_member - - def __or__(self, other): - if not isinstance(other, (self.__class__, int)): - return NotImplemented - result = self.__class__(self._value_ | self.__class__(other)._value_) - return result - - def __and__(self, other): - if not isinstance(other, (self.__class__, int)): - return NotImplemented - return self.__class__(self._value_ & self.__class__(other)._value_) - - def __xor__(self, other): - if not isinstance(other, (self.__class__, int)): - return NotImplemented - return self.__class__(self._value_ ^ self.__class__(other)._value_) - - __ror__ = __or__ - __rand__ = __and__ - __rxor__ = __xor__ - - def __invert__(self): - result = self.__class__(~self._value_) - return result - def _high_bit(value): """ @@ -970,44 +1566,478 @@ def unique(enumeration): (enumeration, alias_details)) return enumeration -def _decompose(flag, value): - """ - Extract all members from the value. - """ - # _decompose is only called if the value is not named - not_covered = value - negative = value < 0 - # issue29167: wrap accesses to _value2member_map_ in a list to avoid race - # conditions between iterating over it and having more pseudo- - # members added to it - if negative: - # only check for named flags - flags_to_check = [ - (m, v) - for v, m in list(flag._value2member_map_.items()) - if m.name is not None - ] - else: - # check for named flags and powers-of-two flags - flags_to_check = [ - (m, v) - for v, m in list(flag._value2member_map_.items()) - if m.name is not None or _power_of_two(v) - ] - members = [] - for member, member_value in flags_to_check: - if member_value and member_value & value == member_value: - members.append(member) - not_covered &= ~member_value - if not members and value in flag._value2member_map_: - members.append(flag._value2member_map_[value]) - members.sort(key=lambda m: m._value_, reverse=True) - if len(members) > 1 and members[0].value == value: - # we have the breakdown, don't need the value member itself - members.pop(0) - return members, not_covered - def _power_of_two(value): if value < 1: return False return value == 2 ** _high_bit(value) + +def global_enum_repr(self): + """ + use module.enum_name instead of class.enum_name + + the module is the last module in case of a multi-module name + """ + module = self.__class__.__module__.split('.')[-1] + return '%s.%s' % (module, self._name_) + +def global_flag_repr(self): + """ + use module.flag_name instead of class.flag_name + + the module is the last module in case of a multi-module name + """ + module = self.__class__.__module__.split('.')[-1] + cls_name = self.__class__.__name__ + if self._name_ is None: + return "%s.%s(%r)" % (module, cls_name, self._value_) + if _is_single_bit(self): + return '%s.%s' % (module, self._name_) + if self._boundary_ is not FlagBoundary.KEEP: + return '|'.join(['%s.%s' % (module, name) for name in self.name.split('|')]) + else: + name = [] + for n in self._name_.split('|'): + if n[0].isdigit(): + name.append(n) + else: + name.append('%s.%s' % (module, n)) + return '|'.join(name) + +def global_str(self): + """ + use enum_name instead of class.enum_name + """ + if self._name_ is None: + cls_name = self.__class__.__name__ + return "%s(%r)" % (cls_name, self._value_) + else: + return self._name_ + +def global_enum(cls, update_str=False): + """ + decorator that makes the repr() of an enum member reference its module + instead of its class; also exports all members to the enum's module's + global namespace + """ + if issubclass(cls, Flag): + cls.__repr__ = global_flag_repr + else: + cls.__repr__ = global_enum_repr + if not issubclass(cls, ReprEnum) or update_str: + cls.__str__ = global_str + sys.modules[cls.__module__].__dict__.update(cls.__members__) + return cls + +def _simple_enum(etype=Enum, *, boundary=None, use_args=None): + """ + Class decorator that converts a normal class into an :class:`Enum`. No + safety checks are done, and some advanced behavior (such as + :func:`__init_subclass__`) is not available. Enum creation can be faster + using :func:`simple_enum`. + + >>> from enum import Enum, _simple_enum + >>> @_simple_enum(Enum) + ... class Color: + ... RED = auto() + ... GREEN = auto() + ... BLUE = auto() + >>> Color + + """ + def convert_class(cls): + nonlocal use_args + cls_name = cls.__name__ + if use_args is None: + use_args = etype._use_args_ + __new__ = cls.__dict__.get('__new__') + if __new__ is not None: + new_member = __new__.__func__ + else: + new_member = etype._member_type_.__new__ + attrs = {} + body = {} + if __new__ is not None: + body['__new_member__'] = new_member + body['_new_member_'] = new_member + body['_use_args_'] = use_args + body['_generate_next_value_'] = gnv = etype._generate_next_value_ + body['_member_names_'] = member_names = [] + body['_member_map_'] = member_map = {} + body['_value2member_map_'] = value2member_map = {} + body['_unhashable_values_'] = [] + body['_member_type_'] = member_type = etype._member_type_ + body['_value_repr_'] = etype._value_repr_ + if issubclass(etype, Flag): + body['_boundary_'] = boundary or etype._boundary_ + body['_flag_mask_'] = None + body['_all_bits_'] = None + body['_singles_mask_'] = None + body['_inverted_'] = None + body['__or__'] = Flag.__or__ + body['__xor__'] = Flag.__xor__ + body['__and__'] = Flag.__and__ + body['__ror__'] = Flag.__ror__ + body['__rxor__'] = Flag.__rxor__ + body['__rand__'] = Flag.__rand__ + body['__invert__'] = Flag.__invert__ + for name, obj in cls.__dict__.items(): + if name in ('__dict__', '__weakref__'): + continue + if _is_dunder(name) or _is_private(cls_name, name) or _is_sunder(name) or _is_descriptor(obj): + body[name] = obj + else: + attrs[name] = obj + if cls.__dict__.get('__doc__') is None: + body['__doc__'] = 'An enumeration.' + # + # double check that repr and friends are not the mixin's or various + # things break (such as pickle) + # however, if the method is defined in the Enum itself, don't replace + # it + enum_class = type(cls_name, (etype, ), body, boundary=boundary, _simple=True) + for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): + if name not in body: + # check for mixin overrides before replacing + enum_method = getattr(etype, name) + found_method = getattr(enum_class, name) + object_method = getattr(object, name) + data_type_method = getattr(member_type, name) + if found_method in (data_type_method, object_method): + setattr(enum_class, name, enum_method) + gnv_last_values = [] + if issubclass(enum_class, Flag): + # Flag / IntFlag + single_bits = multi_bits = 0 + for name, value in attrs.items(): + if isinstance(value, auto) and auto.value is _auto_null: + value = gnv(name, 1, len(member_names), gnv_last_values) + if value in value2member_map: + # an alias to an existing member + redirect = property() + redirect.__set_name__(enum_class, name) + setattr(enum_class, name, redirect) + member_map[name] = value2member_map[value] + else: + # create the member + if use_args: + if not isinstance(value, tuple): + value = (value, ) + member = new_member(enum_class, *value) + value = value[0] + else: + member = new_member(enum_class) + if __new__ is None: + member._value_ = value + member._name_ = name + member.__objclass__ = enum_class + member.__init__(value) + redirect = property() + redirect.__set_name__(enum_class, name) + setattr(enum_class, name, redirect) + member_map[name] = member + member._sort_order_ = len(member_names) + value2member_map[value] = member + if _is_single_bit(value): + # not a multi-bit alias, record in _member_names_ and _flag_mask_ + member_names.append(name) + single_bits |= value + else: + multi_bits |= value + gnv_last_values.append(value) + enum_class._flag_mask_ = single_bits | multi_bits + enum_class._singles_mask_ = single_bits + enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1 + # set correct __iter__ + member_list = [m._value_ for m in enum_class] + if member_list != sorted(member_list): + enum_class._iter_member_ = enum_class._iter_member_by_def_ + else: + # Enum / IntEnum / StrEnum + for name, value in attrs.items(): + if isinstance(value, auto): + if value.value is _auto_null: + value.value = gnv(name, 1, len(member_names), gnv_last_values) + value = value.value + if value in value2member_map: + # an alias to an existing member + redirect = property() + redirect.__set_name__(enum_class, name) + setattr(enum_class, name, redirect) + member_map[name] = value2member_map[value] + else: + # create the member + if use_args: + if not isinstance(value, tuple): + value = (value, ) + member = new_member(enum_class, *value) + value = value[0] + else: + member = new_member(enum_class) + if __new__ is None: + member._value_ = value + member._name_ = name + member.__objclass__ = enum_class + member.__init__(value) + member._sort_order_ = len(member_names) + redirect = property() + redirect.__set_name__(enum_class, name) + setattr(enum_class, name, redirect) + member_map[name] = member + value2member_map[value] = member + member_names.append(name) + gnv_last_values.append(value) + if '__new__' in body: + enum_class.__new_member__ = enum_class.__new__ + enum_class.__new__ = Enum.__new__ + return enum_class + return convert_class + +@_simple_enum(StrEnum) +class EnumCheck: + """ + various conditions to check an enumeration for + """ + CONTINUOUS = "no skipped integer values" + NAMED_FLAGS = "multi-flag aliases may not contain unnamed flags" + UNIQUE = "one name per value" +CONTINUOUS, NAMED_FLAGS, UNIQUE = EnumCheck + + +class verify: + """ + Check an enumeration for various constraints. (see EnumCheck) + """ + def __init__(self, *checks): + self.checks = checks + def __call__(self, enumeration): + checks = self.checks + cls_name = enumeration.__name__ + if Flag is not None and issubclass(enumeration, Flag): + enum_type = 'flag' + elif issubclass(enumeration, Enum): + enum_type = 'enum' + else: + raise TypeError("the 'verify' decorator only works with Enum and Flag") + for check in checks: + if check is UNIQUE: + # check for duplicate names + duplicates = [] + for name, member in enumeration.__members__.items(): + if name != member.name: + duplicates.append((name, member.name)) + if duplicates: + alias_details = ', '.join( + ["%s -> %s" % (alias, name) for (alias, name) in duplicates]) + raise ValueError('aliases found in %r: %s' % + (enumeration, alias_details)) + elif check is CONTINUOUS: + values = set(e.value for e in enumeration) + if len(values) < 2: + continue + low, high = min(values), max(values) + missing = [] + if enum_type == 'flag': + # check for powers of two + for i in range(_high_bit(low)+1, _high_bit(high)): + if 2**i not in values: + missing.append(2**i) + elif enum_type == 'enum': + # check for powers of one + for i in range(low+1, high): + if i not in values: + missing.append(i) + else: + raise Exception('verify: unknown type %r' % enum_type) + if missing: + raise ValueError(('invalid %s %r: missing values %s' % ( + enum_type, cls_name, ', '.join((str(m) for m in missing))) + )[:256]) + # limit max length to protect against DOS attacks + elif check is NAMED_FLAGS: + # examine each alias and check for unnamed flags + member_names = enumeration._member_names_ + member_values = [m.value for m in enumeration] + missing_names = [] + missing_value = 0 + for name, alias in enumeration._member_map_.items(): + if name in member_names: + # not an alias + continue + if alias.value < 0: + # negative numbers are not checked + continue + values = list(_iter_bits_lsb(alias.value)) + missed = [v for v in values if v not in member_values] + if missed: + missing_names.append(name) + missing_value |= reduce(_or_, missed) + if missing_names: + if len(missing_names) == 1: + alias = 'alias %s is missing' % missing_names[0] + else: + alias = 'aliases %s and %s are missing' % ( + ', '.join(missing_names[:-1]), missing_names[-1] + ) + if _is_single_bit(missing_value): + value = 'value 0x%x' % missing_value + else: + value = 'combined values of 0x%x' % missing_value + raise ValueError( + 'invalid Flag %r: %s %s [use enum.show_flag_values(value) for details]' + % (cls_name, alias, value) + ) + return enumeration + +def _test_simple_enum(checked_enum, simple_enum): + """ + A function that can be used to test an enum created with :func:`_simple_enum` + against the version created by subclassing :class:`Enum`:: + + >>> from enum import Enum, _simple_enum, _test_simple_enum + >>> @_simple_enum(Enum) + ... class Color: + ... RED = auto() + ... GREEN = auto() + ... BLUE = auto() + >>> class CheckedColor(Enum): + ... RED = auto() + ... GREEN = auto() + ... BLUE = auto() + >>> # TODO: RUSTPYTHON + >>> # _test_simple_enum(CheckedColor, Color) + + If differences are found, a :exc:`TypeError` is raised. + """ + failed = [] + if checked_enum.__dict__ != simple_enum.__dict__: + checked_dict = checked_enum.__dict__ + checked_keys = list(checked_dict.keys()) + simple_dict = simple_enum.__dict__ + simple_keys = list(simple_dict.keys()) + member_names = set( + list(checked_enum._member_map_.keys()) + + list(simple_enum._member_map_.keys()) + ) + for key in set(checked_keys + simple_keys): + if key in ('__module__', '_member_map_', '_value2member_map_', '__doc__'): + # keys known to be different, or very long + continue + elif key in member_names: + # members are checked below + continue + elif key not in simple_keys: + failed.append("missing key: %r" % (key, )) + elif key not in checked_keys: + failed.append("extra key: %r" % (key, )) + else: + checked_value = checked_dict[key] + simple_value = simple_dict[key] + if callable(checked_value) or isinstance(checked_value, bltns.property): + continue + if key == '__doc__': + # remove all spaces/tabs + compressed_checked_value = checked_value.replace(' ','').replace('\t','') + compressed_simple_value = simple_value.replace(' ','').replace('\t','') + if compressed_checked_value != compressed_simple_value: + failed.append("%r:\n %s\n %s" % ( + key, + "checked -> %r" % (checked_value, ), + "simple -> %r" % (simple_value, ), + )) + elif checked_value != simple_value: + failed.append("%r:\n %s\n %s" % ( + key, + "checked -> %r" % (checked_value, ), + "simple -> %r" % (simple_value, ), + )) + failed.sort() + for name in member_names: + failed_member = [] + if name not in simple_keys: + failed.append('missing member from simple enum: %r' % name) + elif name not in checked_keys: + failed.append('extra member in simple enum: %r' % name) + else: + checked_member_dict = checked_enum[name].__dict__ + checked_member_keys = list(checked_member_dict.keys()) + simple_member_dict = simple_enum[name].__dict__ + simple_member_keys = list(simple_member_dict.keys()) + for key in set(checked_member_keys + simple_member_keys): + if key in ('__module__', '__objclass__', '_inverted_'): + # keys known to be different or absent + continue + elif key not in simple_member_keys: + failed_member.append("missing key %r not in the simple enum member %r" % (key, name)) + elif key not in checked_member_keys: + failed_member.append("extra key %r in simple enum member %r" % (key, name)) + else: + checked_value = checked_member_dict[key] + simple_value = simple_member_dict[key] + if checked_value != simple_value: + failed_member.append("%r:\n %s\n %s" % ( + key, + "checked member -> %r" % (checked_value, ), + "simple member -> %r" % (simple_value, ), + )) + if failed_member: + failed.append('%r member mismatch:\n %s' % ( + name, '\n '.join(failed_member), + )) + for method in ( + '__str__', '__repr__', '__reduce_ex__', '__format__', + '__getnewargs_ex__', '__getnewargs__', '__reduce_ex__', '__reduce__' + ): + if method in simple_keys and method in checked_keys: + # cannot compare functions, and it exists in both, so we're good + continue + elif method not in simple_keys and method not in checked_keys: + # method is inherited -- check it out + checked_method = getattr(checked_enum, method, None) + simple_method = getattr(simple_enum, method, None) + if hasattr(checked_method, '__func__'): + checked_method = checked_method.__func__ + simple_method = simple_method.__func__ + if checked_method != simple_method: + failed.append("%r: %-30s %s" % ( + method, + "checked -> %r" % (checked_method, ), + "simple -> %r" % (simple_method, ), + )) + else: + # if the method existed in only one of the enums, it will have been caught + # in the first checks above + pass + if failed: + raise TypeError('enum mismatch:\n %s' % '\n '.join(failed)) + +def _old_convert_(etype, name, module, filter, source=None, *, boundary=None): + """ + Create a new Enum subclass that replaces a collection of global constants + """ + # convert all constants from source (or module) that pass filter() to + # a new Enum called name, and export the enum and its members back to + # module; + # also, replace the __reduce_ex__ method so unpickling works in + # previous Python versions + module_globals = sys.modules[module].__dict__ + if source: + source = source.__dict__ + else: + source = module_globals + # _value2member_map_ is populated in the same order every time + # for a consistent reverse mapping of number to name when there + # are multiple names for the same number. + members = [ + (name, value) + for name, value in source.items() + if filter(name)] + try: + # sort by value + members.sort(key=lambda t: (t[1], t[0])) + except TypeError: + # unless some values aren't comparable, in which case sort by name + members.sort(key=lambda t: t[0]) + cls = etype(name, members, module=module, boundary=boundary or KEEP) + return cls + +_stdlib_enums = IntEnum, StrEnum, IntFlag diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index 1cccd27dee..be242e93f7 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -1,16 +1,46 @@ +import copy import enum +import doctest import inspect +import os import pydoc import sys import unittest import threading +import typing +import builtins as bltns from collections import OrderedDict -from enum import Enum, IntEnum, EnumMeta, Flag, IntFlag, unique, auto +from datetime import date +from enum import Enum, EnumMeta, IntEnum, StrEnum, EnumType, Flag, IntFlag, unique, auto +from enum import STRICT, CONFORM, EJECT, KEEP, _simple_enum, _test_simple_enum +from enum import verify, UNIQUE, CONTINUOUS, NAMED_FLAGS, ReprEnum +from enum import member, nonmember, _iter_bits_lsb from io import StringIO from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL -from test.support import ALWAYS_EQ, check__all__, threading_helper +from test import support +from test.support import ALWAYS_EQ +from test.support import threading_helper +from textwrap import dedent from datetime import timedelta +python_version = sys.version_info[:2] + +def load_tests(loader, tests, ignore): + tests.addTests(doctest.DocTestSuite(enum)) + if os.path.exists('Doc/library/enum.rst'): + tests.addTests(doctest.DocFileSuite( + '../../Doc/library/enum.rst', + optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE, + )) + if os.path.exists('Doc/howto/enum.rst'): + tests.addTests(doctest.DocFileSuite( + '../../Doc/howto/enum.rst', + optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE, + )) + return tests + +MODULE = __name__ +SHORT_MODULE = MODULE.split('.')[-1] # for pickle tests try: @@ -41,19 +71,35 @@ class FloatStooges(float, Enum): class FlagStooges(Flag): LARRY = 1 CURLY = 2 - MOE = 3 + MOE = 4 + BIG = 389 except Exception as exc: FlagStooges = exc +class FlagStoogesWithZero(Flag): + NOFLAG = 0 + LARRY = 1 + CURLY = 2 + MOE = 4 + BIG = 389 + +class IntFlagStooges(IntFlag): + LARRY = 1 + CURLY = 2 + MOE = 4 + BIG = 389 + +class IntFlagStoogesWithZero(IntFlag): + NOFLAG = 0 + LARRY = 1 + CURLY = 2 + MOE = 4 + BIG = 389 + # for pickle test and subclass tests -try: - class StrEnum(str, Enum): - 'accepts only string values' - class Name(StrEnum): - BDFL = 'Guido van Rossum' - FLUFL = 'Barry Warsaw' -except Exception as exc: - Name = exc +class Name(StrEnum): + BDFL = 'Guido van Rossum' + FLUFL = 'Barry Warsaw' try: Question = Enum('Question', 'who what when where why', module=__name__) @@ -93,6 +139,12 @@ def test_pickle_exception(assertion, exception, obj): class TestHelpers(unittest.TestCase): # _is_descriptor, _is_sunder, _is_dunder + sunder_names = '_bad_', '_good_', '_what_ho_' + dunder_names = '__mal__', '__bien__', '__que_que__' + private_names = '_MyEnum__private', '_MyEnum__still_private' + private_and_sunder_names = '_MyEnum__private_', '_MyEnum__also_private_' + random_names = 'okay', '_semi_private', '_weird__', '_MyEnum__' + def test_is_descriptor(self): class foo: pass @@ -102,21 +154,40 @@ class foo: setattr(obj, attr, 1) self.assertTrue(enum._is_descriptor(obj)) - def test_is_sunder(self): + def test_sunder(self): + for name in self.sunder_names + self.private_and_sunder_names: + self.assertTrue(enum._is_sunder(name), '%r is a not sunder name?' % name) + for name in self.dunder_names + self.private_names + self.random_names: + self.assertFalse(enum._is_sunder(name), '%r is a sunder name?' % name) for s in ('_a_', '_aa_'): self.assertTrue(enum._is_sunder(s)) - for s in ('a', 'a_', '_a', '__a', 'a__', '__a__', '_a__', '__a_', '_', '__', '___', '____', '_____',): self.assertFalse(enum._is_sunder(s)) - def test_is_dunder(self): + def test_dunder(self): + for name in self.dunder_names: + self.assertTrue(enum._is_dunder(name), '%r is a not dunder name?' % name) + for name in self.sunder_names + self.private_names + self.private_and_sunder_names + self.random_names: + self.assertFalse(enum._is_dunder(name), '%r is a dunder name?' % name) for s in ('__a__', '__aa__'): self.assertTrue(enum._is_dunder(s)) for s in ('a', 'a_', '_a', '__a', 'a__', '_a_', '_a__', '__a_', '_', '__', '___', '____', '_____',): self.assertFalse(enum._is_dunder(s)) + + def test_is_private(self): + for name in self.private_names + self.private_and_sunder_names: + self.assertTrue(enum._is_private('MyEnum', name), '%r is a not private name?') + for name in self.sunder_names + self.dunder_names + self.random_names: + self.assertFalse(enum._is_private('MyEnum', name), '%r is a private name?') + + def test_iter_bits_lsb(self): + self.assertEqual(list(_iter_bits_lsb(7)), [1, 2, 4]) + self.assertRaisesRegex(ValueError, '-8 is not a positive integer', list, _iter_bits_lsb(-8)) + + # for subclassing tests class classproperty: @@ -132,199 +203,794 @@ def __init__(self, fget=None, fset=None, fdel=None, doc=None): def __get__(self, instance, ownerclass): return self.fget(ownerclass) +# for global repr tests + +@enum.global_enum +class HeadlightsK(IntFlag, boundary=enum.KEEP): + OFF_K = 0 + LOW_BEAM_K = auto() + HIGH_BEAM_K = auto() + FOG_K = auto() + + +@enum.global_enum +class HeadlightsC(IntFlag, boundary=enum.CONFORM): + OFF_C = 0 + LOW_BEAM_C = auto() + HIGH_BEAM_C = auto() + FOG_C = auto() + + +@enum.global_enum +class NoName(Flag): + ONE = 1 + TWO = 2 + # tests -class TestEnum(unittest.TestCase): +class _EnumTests: + """ + Test for behavior that is the same across the different types of enumerations. + """ - def setUp(self): - class Season(Enum): - SPRING = 1 - SUMMER = 2 - AUTUMN = 3 - WINTER = 4 - self.Season = Season + values = None - class Konstants(float, Enum): - E = 2.7182818 - PI = 3.1415926 - TAU = 2 * PI - self.Konstants = Konstants + def setUp(self): + class BaseEnum(self.enum_type): + @enum.property + def first(self): + return '%s is first!' % self.name + class MainEnum(BaseEnum): + first = auto() + second = auto() + third = auto() + if issubclass(self.enum_type, Flag): + dupe = 3 + else: + dupe = third + self.MainEnum = MainEnum + # + class NewStrEnum(self.enum_type): + def __str__(self): + return self.name.upper() + first = auto() + self.NewStrEnum = NewStrEnum + # + class NewFormatEnum(self.enum_type): + def __format__(self, spec): + return self.name.upper() + first = auto() + self.NewFormatEnum = NewFormatEnum + # + class NewStrFormatEnum(self.enum_type): + def __str__(self): + return self.name.title() + def __format__(self, spec): + return ''.join(reversed(self.name)) + first = auto() + self.NewStrFormatEnum = NewStrFormatEnum + # + class NewBaseEnum(self.enum_type): + def __str__(self): + return self.name.title() + def __format__(self, spec): + return ''.join(reversed(self.name)) + class NewSubEnum(NewBaseEnum): + first = auto() + self.NewSubEnum = NewSubEnum + # + self.is_flag = False + self.names = ['first', 'second', 'third'] + if issubclass(MainEnum, StrEnum): + self.values = self.names + elif MainEnum._member_type_ is str: + self.values = ['1', '2', '3'] + elif issubclass(self.enum_type, Flag): + self.values = [1, 2, 4] + self.is_flag = True + self.dupe2 = MainEnum(5) + else: + self.values = self.values or [1, 2, 3] + # + if not getattr(self, 'source_values', False): + self.source_values = self.values - class Grades(IntEnum): - A = 5 - B = 4 - C = 3 - D = 2 - F = 0 - self.Grades = Grades + def assertFormatIsValue(self, spec, member): + self.assertEqual(spec.format(member), spec.format(member.value)) - class Directional(str, Enum): - EAST = 'east' - WEST = 'west' - NORTH = 'north' - SOUTH = 'south' - self.Directional = Directional + def assertFormatIsStr(self, spec, member): + self.assertEqual(spec.format(member), spec.format(str(member))) - from datetime import date - class Holiday(date, Enum): - NEW_YEAR = 2013, 1, 1 - IDES_OF_MARCH = 2013, 3, 15 - self.Holiday = Holiday + def test_attribute_deletion(self): + class Season(self.enum_type): + SPRING = auto() + SUMMER = auto() + AUTUMN = auto() + # + def spam(cls): + pass + # + self.assertTrue(hasattr(Season, 'spam')) + del Season.spam + self.assertFalse(hasattr(Season, 'spam')) + # + with self.assertRaises(AttributeError): + del Season.SPRING + with self.assertRaises(AttributeError): + del Season.DRY + with self.assertRaises(AttributeError): + del Season.SPRING.name - def test_dir_on_class(self): - Season = self.Season + def test_basics(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(repr(TE), "") + self.assertEqual(str(TE), "") + self.assertEqual(format(TE), "") + self.assertTrue(TE(5) is self.dupe2) + else: + self.assertEqual(repr(TE), "") + self.assertEqual(str(TE), "") + self.assertEqual(format(TE), "") + self.assertEqual(list(TE), [TE.first, TE.second, TE.third]) + self.assertEqual( + [m.name for m in TE], + self.names, + ) self.assertEqual( - set(dir(Season)), - set(['__class__', '__doc__', '__members__', '__module__', - 'SPRING', 'SUMMER', 'AUTUMN', 'WINTER']), + [m.value for m in TE], + self.values, + ) + self.assertEqual( + [m.first for m in TE], + ['first is first!', 'second is first!', 'third is first!'] + ) + for member, name in zip(TE, self.names, strict=True): + self.assertIs(TE[name], member) + for member, value in zip(TE, self.values, strict=True): + self.assertIs(TE(value), member) + if issubclass(TE, StrEnum): + self.assertTrue(TE.dupe is TE('third') is TE['dupe']) + elif TE._member_type_ is str: + self.assertTrue(TE.dupe is TE('3') is TE['dupe']) + elif issubclass(TE, Flag): + self.assertTrue(TE.dupe is TE(3) is TE['dupe']) + else: + self.assertTrue(TE.dupe is TE(self.values[2]) is TE['dupe']) + + def test_bool_is_true(self): + class Empty(self.enum_type): + pass + self.assertTrue(Empty) + # + self.assertTrue(self.MainEnum) + for member in self.MainEnum: + self.assertTrue(member) + + def test_changing_member_fails(self): + MainEnum = self.MainEnum + with self.assertRaises(AttributeError): + self.MainEnum.second = 'really first' + + @unittest.skipIf( + python_version >= (3, 12), + '__contains__ now returns True/False for all inputs', ) + def test_contains_er(self): + MainEnum = self.MainEnum + self.assertIn(MainEnum.third, MainEnum) + with self.assertRaises(TypeError): + with self.assertWarns(DeprecationWarning): + self.source_values[1] in MainEnum + with self.assertRaises(TypeError): + with self.assertWarns(DeprecationWarning): + 'first' in MainEnum + val = MainEnum.dupe + self.assertIn(val, MainEnum) + # + class OtherEnum(Enum): + one = auto() + two = auto() + self.assertNotIn(OtherEnum.two, MainEnum) - def test_dir_on_item(self): - Season = self.Season - self.assertEqual( - set(dir(Season.WINTER)), - set(['__class__', '__doc__', '__module__', 'name', 'value']), + @unittest.skipIf( + python_version < (3, 12), + '__contains__ works only with enum memmbers before 3.12', ) + def test_contains_tf(self): + MainEnum = self.MainEnum + self.assertIn(MainEnum.first, MainEnum) + self.assertTrue(self.source_values[0] in MainEnum) + self.assertFalse('first' in MainEnum) + val = MainEnum.dupe + self.assertIn(val, MainEnum) + # + class OtherEnum(Enum): + one = auto() + two = auto() + self.assertNotIn(OtherEnum.two, MainEnum) + + def test_dir_on_class(self): + TE = self.MainEnum + self.assertEqual(set(dir(TE)), set(enum_dir(TE))) + + def test_dir_on_item(self): + TE = self.MainEnum + self.assertEqual(set(dir(TE.first)), set(member_dir(TE.first))) def test_dir_with_added_behavior(self): - class Test(Enum): - this = 'that' - these = 'those' + class Test(self.enum_type): + this = auto() + these = auto() def wowser(self): return ("Wowser! I'm %s!" % self.name) - self.assertEqual( - set(dir(Test)), - set(['__class__', '__doc__', '__members__', '__module__', 'this', 'these']), - ) - self.assertEqual( - set(dir(Test.this)), - set(['__class__', '__doc__', '__module__', 'name', 'value', 'wowser']), - ) + self.assertTrue('wowser' not in dir(Test)) + self.assertTrue('wowser' in dir(Test.this)) def test_dir_on_sub_with_behavior_on_super(self): # see issue22506 - class SuperEnum(Enum): + class SuperEnum(self.enum_type): def invisible(self): return "did you see me?" class SubEnum(SuperEnum): - sample = 5 - self.assertEqual( - set(dir(SubEnum.sample)), - set(['__class__', '__doc__', '__module__', 'name', 'value', 'invisible']), - ) + sample = auto() + self.assertTrue('invisible' not in dir(SubEnum)) + self.assertTrue('invisible' in dir(SubEnum.sample)) def test_dir_on_sub_with_behavior_including_instance_dict_on_super(self): # see issue40084 - class SuperEnum(IntEnum): - def __new__(cls, value, description=""): - obj = int.__new__(cls, value) - obj._value_ = value - obj.description = description + class SuperEnum(self.enum_type): + def __new__(cls, *value, **kwds): + new = self.enum_type._member_type_.__new__ + if self.enum_type._member_type_ is object: + obj = new(cls) + else: + if isinstance(value[0], tuple): + create_value ,= value[0] + else: + create_value = value + obj = new(cls, *create_value) + obj._value_ = value[0] if len(value) == 1 else value + obj.description = 'test description' return obj class SubEnum(SuperEnum): - sample = 5 - self.assertTrue({'description'} <= set(dir(SubEnum.sample))) + sample = self.source_values[1] + self.assertTrue('description' not in dir(SubEnum)) + self.assertTrue('description' in dir(SubEnum.sample), dir(SubEnum.sample)) def test_enum_in_enum_out(self): - Season = self.Season - self.assertIs(Season(Season.WINTER), Season.WINTER) - - def test_enum_value(self): - Season = self.Season - self.assertEqual(Season.SPRING.value, 1) - - def test_intenum_value(self): - self.assertEqual(IntStooges.CURLY.value, 2) - - def test_enum(self): - Season = self.Season - lst = list(Season) - self.assertEqual(len(lst), len(Season)) - self.assertEqual(len(Season), 4, Season) - self.assertEqual( - [Season.SPRING, Season.SUMMER, Season.AUTUMN, Season.WINTER], lst) - - for i, season in enumerate('SPRING SUMMER AUTUMN WINTER'.split(), 1): - e = Season(i) - self.assertEqual(e, getattr(Season, season)) - self.assertEqual(e.value, i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, season) - self.assertIn(e, Season) - self.assertIs(type(e), Season) - self.assertIsInstance(e, Season) - self.assertEqual(str(e), 'Season.' + season) - self.assertEqual( - repr(e), - ''.format(season, i), - ) - - def test_value_name(self): - Season = self.Season - self.assertEqual(Season.SPRING.name, 'SPRING') - self.assertEqual(Season.SPRING.value, 1) - with self.assertRaises(AttributeError): - Season.SPRING.name = 'invierno' - with self.assertRaises(AttributeError): - Season.SPRING.value = 2 - - def test_changing_member(self): - Season = self.Season - with self.assertRaises(AttributeError): - Season.WINTER = 'really cold' - - def test_attribute_deletion(self): - class Season(Enum): - SPRING = 1 - SUMMER = 2 - AUTUMN = 3 - WINTER = 4 + Main = self.MainEnum + self.assertIs(Main(Main.first), Main.first) - def spam(cls): - pass - - self.assertTrue(hasattr(Season, 'spam')) - del Season.spam - self.assertFalse(hasattr(Season, 'spam')) - - with self.assertRaises(AttributeError): - del Season.SPRING - with self.assertRaises(AttributeError): - del Season.DRY - with self.assertRaises(AttributeError): - del Season.SPRING.name - - def test_bool_of_class(self): - class Empty(Enum): - pass - self.assertTrue(bool(Empty)) - - def test_bool_of_member(self): - class Count(Enum): - zero = 0 - one = 1 - two = 2 - for member in Count: - self.assertTrue(bool(member)) + def test_hash(self): + MainEnum = self.MainEnum + mapping = {} + mapping[MainEnum.first] = '1225' + mapping[MainEnum.second] = '0315' + mapping[MainEnum.third] = '0704' + self.assertEqual(mapping[MainEnum.second], '0315') def test_invalid_names(self): with self.assertRaises(ValueError): - class Wrong(Enum): + class Wrong(self.enum_type): mro = 9 with self.assertRaises(ValueError): - class Wrong(Enum): + class Wrong(self.enum_type): _create_= 11 with self.assertRaises(ValueError): - class Wrong(Enum): + class Wrong(self.enum_type): _get_mixins_ = 9 with self.assertRaises(ValueError): - class Wrong(Enum): + class Wrong(self.enum_type): _find_new_ = 1 with self.assertRaises(ValueError): - class Wrong(Enum): + class Wrong(self.enum_type): _any_name_ = 9 + def test_object_str_override(self): + "check that setting __str__ to object's is not reset to Enum's" + class Generic(self.enum_type): + item = self.source_values[2] + def __repr__(self): + return "%s.test" % (self._name_, ) + __str__ = object.__str__ + self.assertEqual(str(Generic.item), 'item.test') + + def test_overridden_str(self): + # TODO: RUSTPYTHON, format(NS.first) does not use __str__ + if isinstance(self, TestIntFlag) or isinstance(self, TestIntEnum) or isinstance(self, TestMinimalFloat): + self.skipTest("format(NS.first) does not use __str__") + NS = self.NewStrEnum + self.assertEqual(str(NS.first), NS.first.name.upper()) + self.assertEqual(format(NS.first), NS.first.name.upper()) + + def test_overridden_str_format(self): + NSF = self.NewStrFormatEnum + self.assertEqual(str(NSF.first), NSF.first.name.title()) + self.assertEqual(format(NSF.first), ''.join(reversed(NSF.first.name))) + + def test_overridden_str_format_inherited(self): + NSE = self.NewSubEnum + self.assertEqual(str(NSE.first), NSE.first.name.title()) + self.assertEqual(format(NSE.first), ''.join(reversed(NSE.first.name))) + + def test_programmatic_function_string(self): + MinorEnum = self.enum_type('MinorEnum', 'june july august') + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + values = self.values + if self.enum_type is StrEnum: + values = ['june','july','august'] + for month, av in zip('june july august'.split(), values): + e = MinorEnum[month] + self.assertEqual(e.value, av, list(MinorEnum)) + self.assertEqual(e.name, month) + if MinorEnum._member_type_ is not object and issubclass(MinorEnum, MinorEnum._member_type_): + self.assertEqual(e, av) + else: + self.assertNotEqual(e, av) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) + self.assertIs(e, MinorEnum(av)) + + def test_programmatic_function_string_list(self): + MinorEnum = self.enum_type('MinorEnum', ['june', 'july', 'august']) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + values = self.values + if self.enum_type is StrEnum: + values = ['june','july','august'] + for month, av in zip('june july august'.split(), values): + e = MinorEnum[month] + self.assertEqual(e.value, av) + self.assertEqual(e.name, month) + if MinorEnum._member_type_ is not object and issubclass(MinorEnum, MinorEnum._member_type_): + self.assertEqual(e, av) + else: + self.assertNotEqual(e, av) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) + self.assertIs(e, MinorEnum(av)) + + def test_programmatic_function_iterable(self): + MinorEnum = self.enum_type( + 'MinorEnum', + (('june', self.source_values[0]), ('july', self.source_values[1]), ('august', self.source_values[2])) + ) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for month, av in zip('june july august'.split(), self.values): + e = MinorEnum[month] + self.assertEqual(e.value, av) + self.assertEqual(e.name, month) + if MinorEnum._member_type_ is not object and issubclass(MinorEnum, MinorEnum._member_type_): + self.assertEqual(e, av) + else: + self.assertNotEqual(e, av) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) + self.assertIs(e, MinorEnum(av)) + + def test_programmatic_function_from_dict(self): + MinorEnum = self.enum_type( + 'MinorEnum', + OrderedDict((('june', self.source_values[0]), ('july', self.source_values[1]), ('august', self.source_values[2]))) + ) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for month, av in zip('june july august'.split(), self.values): + e = MinorEnum[month] + if MinorEnum._member_type_ is not object and issubclass(MinorEnum, MinorEnum._member_type_): + self.assertEqual(e, av) + else: + self.assertNotEqual(e, av) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) + self.assertIs(e, MinorEnum(av)) + + def test_repr(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(repr(TE(0)), "") + self.assertEqual(repr(TE.dupe), "") + self.assertEqual(repr(self.dupe2), "") + elif issubclass(TE, StrEnum): + self.assertEqual(repr(TE.dupe), "") + else: + self.assertEqual(repr(TE.dupe), "" % (self.values[2], ), TE._value_repr_) + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(repr(member), "" % (member.name, member.value)) + + def test_repr_override(self): + class Generic(self.enum_type): + first = auto() + second = auto() + third = auto() + def __repr__(self): + return "don't you just love shades of %s?" % self.name + self.assertEqual( + repr(Generic.third), + "don't you just love shades of third?", + ) + + def test_inherited_repr(self): + class MyEnum(self.enum_type): + def __repr__(self): + return "My name is %s." % self.name + class MySubEnum(MyEnum): + this = auto() + that = auto() + theother = auto() + self.assertEqual(repr(MySubEnum.that), "My name is that.") + + def test_multiple_superclasses_repr(self): + class _EnumSuperClass(metaclass=EnumMeta): + pass + class E(_EnumSuperClass, Enum): + A = 1 + self.assertEqual(repr(E.A), "") + + def test_reversed_iteration_order(self): + self.assertEqual( + list(reversed(self.MainEnum)), + [self.MainEnum.third, self.MainEnum.second, self.MainEnum.first], + ) + +class _PlainOutputTests: + + def test_str(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(str(TE(0)), "MainEnum(0)") + self.assertEqual(str(TE.dupe), "MainEnum.dupe") + self.assertEqual(str(self.dupe2), "MainEnum.first|third") + else: + self.assertEqual(str(TE.dupe), "MainEnum.third") + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(str(member), "MainEnum.%s" % (member.name, )) + + def test_format(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(format(TE.dupe), "MainEnum.dupe") + self.assertEqual(format(self.dupe2), "MainEnum.first|third") + else: + self.assertEqual(format(TE.dupe), "MainEnum.third") + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(format(member), "MainEnum.%s" % (member.name, )) + + def test_overridden_format(self): + NF = self.NewFormatEnum + self.assertEqual(str(NF.first), "NewFormatEnum.first", '%s %r' % (NF.__str__, NF.first)) + self.assertEqual(format(NF.first), "FIRST") + + def test_format_specs(self): + TE = self.MainEnum + self.assertFormatIsStr('{}', TE.second) + self.assertFormatIsStr('{:}', TE.second) + self.assertFormatIsStr('{:20}', TE.second) + self.assertFormatIsStr('{:^20}', TE.second) + self.assertFormatIsStr('{:>20}', TE.second) + self.assertFormatIsStr('{:<20}', TE.second) + self.assertFormatIsStr('{:5.2}', TE.second) + + +class _MixedOutputTests: + + def test_str(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(str(TE.dupe), "MainEnum.dupe") + self.assertEqual(str(self.dupe2), "MainEnum.first|third") + else: + self.assertEqual(str(TE.dupe), "MainEnum.third") + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(str(member), "MainEnum.%s" % (member.name, )) + + def test_format(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(format(TE.dupe), "MainEnum.dupe") + self.assertEqual(format(self.dupe2), "MainEnum.first|third") + else: + self.assertEqual(format(TE.dupe), "MainEnum.third") + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(format(member), "MainEnum.%s" % (member.name, )) + + def test_overridden_format(self): + NF = self.NewFormatEnum + self.assertEqual(str(NF.first), "NewFormatEnum.first") + self.assertEqual(format(NF.first), "FIRST") + + def test_format_specs(self): + TE = self.MainEnum + self.assertFormatIsStr('{}', TE.first) + self.assertFormatIsStr('{:}', TE.first) + self.assertFormatIsStr('{:20}', TE.first) + self.assertFormatIsStr('{:^20}', TE.first) + self.assertFormatIsStr('{:>20}', TE.first) + self.assertFormatIsStr('{:<20}', TE.first) + self.assertFormatIsStr('{:5.2}', TE.first) + + +class _MinimalOutputTests: + + def test_str(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(str(TE.dupe), "3") + self.assertEqual(str(self.dupe2), "5") + else: + self.assertEqual(str(TE.dupe), str(self.values[2])) + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(str(member), str(value)) + + def test_format(self): + TE = self.MainEnum + if self.is_flag: + self.assertEqual(format(TE.dupe), "3") + self.assertEqual(format(self.dupe2), "5") + else: + self.assertEqual(format(TE.dupe), format(self.values[2])) + for name, value, member in zip(self.names, self.values, TE, strict=True): + self.assertEqual(format(member), format(value)) + + def test_overridden_format(self): + NF = self.NewFormatEnum + self.assertEqual(str(NF.first), str(self.values[0])) + self.assertEqual(format(NF.first), "FIRST") + + def test_format_specs(self): + TE = self.MainEnum + self.assertFormatIsValue('{}', TE.third) + self.assertFormatIsValue('{:}', TE.third) + self.assertFormatIsValue('{:20}', TE.third) + self.assertFormatIsValue('{:^20}', TE.third) + self.assertFormatIsValue('{:>20}', TE.third) + self.assertFormatIsValue('{:<20}', TE.third) + if TE._member_type_ is float: + self.assertFormatIsValue('{:n}', TE.third) + self.assertFormatIsValue('{:5.2}', TE.third) + self.assertFormatIsValue('{:f}', TE.third) + + def test_copy(self): + TE = self.MainEnum + copied = copy.copy(TE) + self.assertEqual(copied, TE) + self.assertIs(copied, TE) + deep = copy.deepcopy(TE) + self.assertEqual(deep, TE) + self.assertIs(deep, TE) + + def test_copy_member(self): + TE = self.MainEnum + copied = copy.copy(TE.first) + self.assertIs(copied, TE.first) + deep = copy.deepcopy(TE.first) + self.assertIs(deep, TE.first) + +class _FlagTests: + + def test_default_missing_with_wrong_type_value(self): + with self.assertRaisesRegex( + ValueError, + "'RED' is not a valid ", + ) as ctx: + self.MainEnum('RED') + self.assertIs(ctx.exception.__context__, None) + + def test_closed_invert_expectations(self): + class ClosedAB(self.enum_type): + A = 1 + B = 2 + MASK = 3 + A, B = ClosedAB + AB_MASK = ClosedAB.MASK + # + self.assertIs(~A, B) + self.assertIs(~B, A) + self.assertIs(~(A|B), ClosedAB(0)) + self.assertIs(~AB_MASK, ClosedAB(0)) + self.assertIs(~ClosedAB(0), (A|B)) + # + class ClosedXYZ(self.enum_type): + X = 4 + Y = 2 + Z = 1 + MASK = 7 + X, Y, Z = ClosedXYZ + XYZ_MASK = ClosedXYZ.MASK + # + self.assertIs(~X, Y|Z) + self.assertIs(~Y, X|Z) + self.assertIs(~Z, X|Y) + self.assertIs(~(X|Y), Z) + self.assertIs(~(X|Z), Y) + self.assertIs(~(Y|Z), X) + self.assertIs(~(X|Y|Z), ClosedXYZ(0)) + self.assertIs(~XYZ_MASK, ClosedXYZ(0)) + self.assertIs(~ClosedXYZ(0), (X|Y|Z)) + + def test_open_invert_expectations(self): + class OpenAB(self.enum_type): + A = 1 + B = 2 + MASK = 255 + A, B = OpenAB + AB_MASK = OpenAB.MASK + # + if OpenAB._boundary_ in (EJECT, KEEP): + self.assertIs(~A, OpenAB(254)) + self.assertIs(~B, OpenAB(253)) + self.assertIs(~(A|B), OpenAB(252)) + self.assertIs(~AB_MASK, OpenAB(0)) + self.assertIs(~OpenAB(0), AB_MASK) + else: + self.assertIs(~A, B) + self.assertIs(~B, A) + self.assertIs(~(A|B), OpenAB(0)) + self.assertIs(~AB_MASK, OpenAB(0)) + self.assertIs(~OpenAB(0), (A|B)) + # + class OpenXYZ(self.enum_type): + X = 4 + Y = 2 + Z = 1 + MASK = 31 + X, Y, Z = OpenXYZ + XYZ_MASK = OpenXYZ.MASK + # + if OpenXYZ._boundary_ in (EJECT, KEEP): + self.assertIs(~X, OpenXYZ(27)) + self.assertIs(~Y, OpenXYZ(29)) + self.assertIs(~Z, OpenXYZ(30)) + self.assertIs(~(X|Y), OpenXYZ(25)) + self.assertIs(~(X|Z), OpenXYZ(26)) + self.assertIs(~(Y|Z), OpenXYZ(28)) + self.assertIs(~(X|Y|Z), OpenXYZ(24)) + self.assertIs(~XYZ_MASK, OpenXYZ(0)) + self.assertTrue(~OpenXYZ(0), XYZ_MASK) + else: + self.assertIs(~X, Y|Z) + self.assertIs(~Y, X|Z) + self.assertIs(~Z, X|Y) + self.assertIs(~(X|Y), Z) + self.assertIs(~(X|Z), Y) + self.assertIs(~(Y|Z), X) + self.assertIs(~(X|Y|Z), OpenXYZ(0)) + self.assertIs(~XYZ_MASK, OpenXYZ(0)) + self.assertTrue(~OpenXYZ(0), (X|Y|Z)) + + +class TestPlainEnum(_EnumTests, _PlainOutputTests, unittest.TestCase): + enum_type = Enum + + +class TestPlainFlag(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase): + enum_type = Flag + + +class TestIntEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase): + enum_type = IntEnum + + +class TestStrEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase): + enum_type = StrEnum + + +class TestIntFlag(_EnumTests, _MinimalOutputTests, _FlagTests, unittest.TestCase): + enum_type = IntFlag + + +class TestMixedInt(_EnumTests, _MixedOutputTests, unittest.TestCase): + class enum_type(int, Enum): pass + + +class TestMixedStr(_EnumTests, _MixedOutputTests, unittest.TestCase): + class enum_type(str, Enum): pass + + +class TestMixedIntFlag(_EnumTests, _MixedOutputTests, _FlagTests, unittest.TestCase): + class enum_type(int, Flag): pass + + +class TestMixedDate(_EnumTests, _MixedOutputTests, unittest.TestCase): + + values = [date(2021, 12, 25), date(2020, 3, 15), date(2019, 11, 27)] + source_values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)] + + class enum_type(date, Enum): + def _generate_next_value_(name, start, count, last_values): + values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)] + return values[count] + + +class TestMinimalDate(_EnumTests, _MinimalOutputTests, unittest.TestCase): + + values = [date(2023, 12, 1), date(2016, 2, 29), date(2009, 1, 1)] + source_values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)] + + class enum_type(date, ReprEnum): + def _generate_next_value_(name, start, count, last_values): + values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)] + return values[count] + + +class TestMixedFloat(_EnumTests, _MixedOutputTests, unittest.TestCase): + + values = [1.1, 2.2, 3.3] + + class enum_type(float, Enum): + def _generate_next_value_(name, start, count, last_values): + values = [1.1, 2.2, 3.3] + return values[count] + + +class TestMinimalFloat(_EnumTests, _MinimalOutputTests, unittest.TestCase): + + values = [4.4, 5.5, 6.6] + + class enum_type(float, ReprEnum): + def _generate_next_value_(name, start, count, last_values): + values = [4.4, 5.5, 6.6] + return values[count] + + +class TestSpecial(unittest.TestCase): + """ + various operations that are not attributable to every possible enum + """ + + def setUp(self): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = 3 + WINTER = 4 + self.Season = Season + # + class Grades(IntEnum): + A = 5 + B = 4 + C = 3 + D = 2 + F = 0 + self.Grades = Grades + # + class Directional(str, Enum): + EAST = 'east' + WEST = 'west' + NORTH = 'north' + SOUTH = 'south' + self.Directional = Directional + # + from datetime import date + class Holiday(date, Enum): + NEW_YEAR = 2013, 1, 1 + IDES_OF_MARCH = 2013, 3, 15 + self.Holiday = Holiday + def test_bool(self): # plain Enum members are always True class Logic(Enum): @@ -347,71 +1013,57 @@ class IntLogic(int, Enum): self.assertTrue(IntLogic.true) self.assertFalse(IntLogic.false) - def test_contains(self): - Season = self.Season - self.assertIn(Season.AUTUMN, Season) - with self.assertRaises(TypeError): - 3 in Season - with self.assertRaises(TypeError): - 'AUTUMN' in Season - - val = Season(3) - self.assertIn(val, Season) - - class OtherEnum(Enum): - one = 1; two = 2 - self.assertNotIn(OtherEnum.two, Season) - def test_comparisons(self): Season = self.Season with self.assertRaises(TypeError): Season.SPRING < Season.WINTER with self.assertRaises(TypeError): Season.SPRING > 4 - + # self.assertNotEqual(Season.SPRING, 1) - + # class Part(Enum): SPRING = 1 CLIP = 2 BARREL = 3 - + # self.assertNotEqual(Season.SPRING, Part.SPRING) with self.assertRaises(TypeError): Season.SPRING < Part.CLIP - def test_enum_duplicates(self): - class Season(Enum): - SPRING = 1 - SUMMER = 2 - AUTUMN = FALL = 3 - WINTER = 4 - ANOTHER_SPRING = 1 - lst = list(Season) - self.assertEqual( - lst, - [Season.SPRING, Season.SUMMER, - Season.AUTUMN, Season.WINTER, - ]) - self.assertIs(Season.FALL, Season.AUTUMN) - self.assertEqual(Season.FALL.value, 3) - self.assertEqual(Season.AUTUMN.value, 3) - self.assertIs(Season(3), Season.AUTUMN) - self.assertIs(Season(1), Season.SPRING) - self.assertEqual(Season.FALL.name, 'AUTUMN') - self.assertEqual( - [k for k,v in Season.__members__.items() if v.name != k], - ['FALL', 'ANOTHER_SPRING'], - ) + @unittest.skip('to-do list') + def test_dir_with_custom_dunders(self): + class PlainEnum(Enum): + pass + cls_dir = dir(PlainEnum) + self.assertNotIn('__repr__', cls_dir) + self.assertNotIn('__str__', cls_dir) + self.assertNotIn('__format__', cls_dir) + self.assertNotIn('__init__', cls_dir) + # + class MyEnum(Enum): + def __repr__(self): + return object.__repr__(self) + def __str__(self): + return object.__repr__(self) + def __format__(self): + return object.__repr__(self) + def __init__(self): + pass + cls_dir = dir(MyEnum) + self.assertIn('__repr__', cls_dir) + self.assertIn('__str__', cls_dir) + self.assertIn('__format__', cls_dir) + self.assertIn('__init__', cls_dir) - def test_duplicate_name(self): + def test_duplicate_name_error(self): with self.assertRaises(TypeError): class Color(Enum): red = 1 green = 2 blue = 3 red = 4 - + # with self.assertRaises(TypeError): class Color(Enum): red = 1 @@ -419,186 +1071,345 @@ class Color(Enum): blue = 3 def red(self): return 'red' - + # with self.assertRaises(TypeError): class Color(Enum): - @property + @enum.property def red(self): return 'redder' red = 1 green = 2 blue = 3 + def test_enum_function_with_qualname(self): + if isinstance(Theory, Exception): + raise Theory + self.assertEqual(Theory.__qualname__, 'spanish_inquisition') + + def test_enum_of_types(self): + """Support using Enum to refer to types deliberately.""" + class MyTypes(Enum): + i = int + f = float + s = str + self.assertEqual(MyTypes.i.value, int) + self.assertEqual(MyTypes.f.value, float) + self.assertEqual(MyTypes.s.value, str) + class Foo: + pass + class Bar: + pass + class MyTypes2(Enum): + a = Foo + b = Bar + self.assertEqual(MyTypes2.a.value, Foo) + self.assertEqual(MyTypes2.b.value, Bar) + class SpamEnumNotInner: + pass + class SpamEnum(Enum): + spam = SpamEnumNotInner + self.assertEqual(SpamEnum.spam.value, SpamEnumNotInner) + + def test_enum_of_generic_aliases(self): + class E(Enum): + a = typing.List[int] + b = list[int] + self.assertEqual(E.a.value, typing.List[int]) + self.assertEqual(E.b.value, list[int]) + self.assertEqual(repr(E.a), '') + self.assertEqual(repr(E.b), '') + + @unittest.skipIf( + python_version >= (3, 13), + 'inner classes are not members', + ) + def test_nested_classes_in_enum_are_members(self): + """ + Check for warnings pre-3.13 + """ + with self.assertWarnsRegex(DeprecationWarning, 'will not become a member'): + class Outer(Enum): + a = 1 + b = 2 + class Inner(Enum): + foo = 10 + bar = 11 + self.assertTrue(isinstance(Outer.Inner, Outer)) + self.assertEqual(Outer.a.value, 1) + self.assertEqual(Outer.Inner.value.foo.value, 10) + self.assertEqual( + list(Outer.Inner.value), + [Outer.Inner.value.foo, Outer.Inner.value.bar], + ) + self.assertEqual( + list(Outer), + [Outer.a, Outer.b, Outer.Inner], + ) + + @unittest.skipIf( + python_version < (3, 13), + 'inner classes are still members', + ) + def test_nested_classes_in_enum_are_not_members(self): + """Support locally-defined nested classes.""" + class Outer(Enum): + a = 1 + b = 2 + class Inner(Enum): + foo = 10 + bar = 11 + self.assertTrue(isinstance(Outer.Inner, type)) + self.assertEqual(Outer.a.value, 1) + self.assertEqual(Outer.Inner.foo.value, 10) + self.assertEqual( + list(Outer.Inner), + [Outer.Inner.foo, Outer.Inner.bar], + ) + self.assertEqual( + list(Outer), + [Outer.a, Outer.b], + ) + + def test_nested_classes_in_enum_with_nonmember(self): + class Outer(Enum): + a = 1 + b = 2 + @nonmember + class Inner(Enum): + foo = 10 + bar = 11 + self.assertTrue(isinstance(Outer.Inner, type)) + self.assertEqual(Outer.a.value, 1) + self.assertEqual(Outer.Inner.foo.value, 10) + self.assertEqual( + list(Outer.Inner), + [Outer.Inner.foo, Outer.Inner.bar], + ) + self.assertEqual( + list(Outer), + [Outer.a, Outer.b], + ) + + def test_enum_of_types_with_nonmember(self): + """Support using Enum to refer to types deliberately.""" + class MyTypes(Enum): + i = int + f = nonmember(float) + s = str + self.assertEqual(MyTypes.i.value, int) + self.assertTrue(MyTypes.f is float) + self.assertEqual(MyTypes.s.value, str) + class Foo: + pass + class Bar: + pass + class MyTypes2(Enum): + a = Foo + b = nonmember(Bar) + self.assertEqual(MyTypes2.a.value, Foo) + self.assertTrue(MyTypes2.b is Bar) + class SpamEnumIsInner: + pass + class SpamEnum(Enum): + spam = nonmember(SpamEnumIsInner) + self.assertTrue(SpamEnum.spam is SpamEnumIsInner) + + def test_nested_classes_in_enum_with_member(self): + """Support locally-defined nested classes.""" + class Outer(Enum): + a = 1 + b = 2 + @member + class Inner(Enum): + foo = 10 + bar = 11 + self.assertTrue(isinstance(Outer.Inner, Outer)) + self.assertEqual(Outer.a.value, 1) + self.assertEqual(Outer.Inner.value.foo.value, 10) + self.assertEqual( + list(Outer.Inner.value), + [Outer.Inner.value.foo, Outer.Inner.value.bar], + ) + self.assertEqual( + list(Outer), + [Outer.a, Outer.b, Outer.Inner], + ) + def test_enum_with_value_name(self): class Huh(Enum): name = 1 value = 2 - self.assertEqual( - list(Huh), - [Huh.name, Huh.value], - ) + self.assertEqual(list(Huh), [Huh.name, Huh.value]) self.assertIs(type(Huh.name), Huh) self.assertEqual(Huh.name.name, 'name') self.assertEqual(Huh.name.value, 1) - def test_format_enum(self): - Season = self.Season - self.assertEqual('{}'.format(Season.SPRING), - '{}'.format(str(Season.SPRING))) - self.assertEqual( '{:}'.format(Season.SPRING), - '{:}'.format(str(Season.SPRING))) - self.assertEqual('{:20}'.format(Season.SPRING), - '{:20}'.format(str(Season.SPRING))) - self.assertEqual('{:^20}'.format(Season.SPRING), - '{:^20}'.format(str(Season.SPRING))) - self.assertEqual('{:>20}'.format(Season.SPRING), - '{:>20}'.format(str(Season.SPRING))) - self.assertEqual('{:<20}'.format(Season.SPRING), - '{:<20}'.format(str(Season.SPRING))) - - def test_str_override_enum(self): - class EnumWithStrOverrides(Enum): - one = auto() - two = auto() + def test_inherited_data_type(self): + class HexInt(int): + __qualname__ = 'HexInt' + def __repr__(self): + return hex(self) + class MyEnum(HexInt, enum.Enum): + __qualname__ = 'MyEnum' + A = 1 + B = 2 + C = 3 + self.assertEqual(repr(MyEnum.A), '') + globals()['HexInt'] = HexInt + globals()['MyEnum'] = MyEnum + test_pickle_dump_load(self.assertIs, MyEnum.A) + test_pickle_dump_load(self.assertIs, MyEnum) + # + class SillyInt(HexInt): + __qualname__ = 'SillyInt' + pass + class MyOtherEnum(SillyInt, enum.Enum): + __qualname__ = 'MyOtherEnum' + D = 4 + E = 5 + F = 6 + self.assertIs(MyOtherEnum._member_type_, SillyInt) + globals()['SillyInt'] = SillyInt + globals()['MyOtherEnum'] = MyOtherEnum + test_pickle_dump_load(self.assertIs, MyOtherEnum.E) + test_pickle_dump_load(self.assertIs, MyOtherEnum) + # + # This did not work in 3.10, but does now with pickling by name + class UnBrokenInt(int): + __qualname__ = 'UnBrokenInt' + def __new__(cls, value): + return int.__new__(cls, value) + class MyUnBrokenEnum(UnBrokenInt, Enum): + __qualname__ = 'MyUnBrokenEnum' + G = 7 + H = 8 + I = 9 + self.assertIs(MyUnBrokenEnum._member_type_, UnBrokenInt) + self.assertIs(MyUnBrokenEnum(7), MyUnBrokenEnum.G) + globals()['UnBrokenInt'] = UnBrokenInt + globals()['MyUnBrokenEnum'] = MyUnBrokenEnum + test_pickle_dump_load(self.assertIs, MyUnBrokenEnum.I) + test_pickle_dump_load(self.assertIs, MyUnBrokenEnum) - def __str__(self): - return 'Str!' - self.assertEqual(str(EnumWithStrOverrides.one), 'Str!') - self.assertEqual('{}'.format(EnumWithStrOverrides.one), 'Str!') - - def test_format_override_enum(self): - class EnumWithFormatOverride(Enum): - one = 1.0 - two = 2.0 - def __format__(self, spec): - return 'Format!!' - self.assertEqual(str(EnumWithFormatOverride.one), 'EnumWithFormatOverride.one') - self.assertEqual('{}'.format(EnumWithFormatOverride.one), 'Format!!') + def test_floatenum_fromhex(self): + h = float.hex(FloatStooges.MOE.value) + self.assertIs(FloatStooges.fromhex(h), FloatStooges.MOE) + h = float.hex(FloatStooges.MOE.value + 0.01) + with self.assertRaises(ValueError): + FloatStooges.fromhex(h) - def test_str_and_format_override_enum(self): - class EnumWithStrFormatOverrides(Enum): - one = auto() - two = auto() - def __str__(self): - return 'Str!' - def __format__(self, spec): - return 'Format!' - self.assertEqual(str(EnumWithStrFormatOverrides.one), 'Str!') - self.assertEqual('{}'.format(EnumWithStrFormatOverrides.one), 'Format!') - - def test_str_override_mixin(self): - class MixinEnumWithStrOverride(float, Enum): - one = 1.0 - two = 2.0 - def __str__(self): - return 'Overridden!' - self.assertEqual(str(MixinEnumWithStrOverride.one), 'Overridden!') - self.assertEqual('{}'.format(MixinEnumWithStrOverride.one), 'Overridden!') - - def test_str_and_format_override_mixin(self): - class MixinWithStrFormatOverrides(float, Enum): - one = 1.0 - two = 2.0 - def __str__(self): - return 'Str!' - def __format__(self, spec): - return 'Format!' - self.assertEqual(str(MixinWithStrFormatOverrides.one), 'Str!') - self.assertEqual('{}'.format(MixinWithStrFormatOverrides.one), 'Format!') - - def test_format_override_mixin(self): - class TestFloat(float, Enum): - one = 1.0 - two = 2.0 - def __format__(self, spec): - return 'TestFloat success!' - self.assertEqual(str(TestFloat.one), 'TestFloat.one') - self.assertEqual('{}'.format(TestFloat.one), 'TestFloat success!') + def test_programmatic_function_type(self): + MinorEnum = Enum('MinorEnum', 'june july august', type=int) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = MinorEnum(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) - def assertFormatIsValue(self, spec, member): - self.assertEqual(spec.format(member), spec.format(member.value)) + def test_programmatic_function_string_with_start(self): + MinorEnum = Enum('MinorEnum', 'june july august', start=10) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 10): + e = MinorEnum(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) + + def test_programmatic_function_type_with_start(self): + MinorEnum = Enum('MinorEnum', 'june july august', type=int, start=30) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 30): + e = MinorEnum(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) - def test_format_enum_date(self): - Holiday = self.Holiday - self.assertFormatIsValue('{}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:20}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:^20}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:>20}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:<20}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:%Y %m}', Holiday.IDES_OF_MARCH) - self.assertFormatIsValue('{:%Y %m %M:00}', Holiday.IDES_OF_MARCH) - - def test_format_enum_float(self): - Konstants = self.Konstants - self.assertFormatIsValue('{}', Konstants.TAU) - self.assertFormatIsValue('{:}', Konstants.TAU) - self.assertFormatIsValue('{:20}', Konstants.TAU) - self.assertFormatIsValue('{:^20}', Konstants.TAU) - self.assertFormatIsValue('{:>20}', Konstants.TAU) - self.assertFormatIsValue('{:<20}', Konstants.TAU) - self.assertFormatIsValue('{:n}', Konstants.TAU) - self.assertFormatIsValue('{:5.2}', Konstants.TAU) - self.assertFormatIsValue('{:f}', Konstants.TAU) - - def test_format_enum_int(self): - Grades = self.Grades - self.assertFormatIsValue('{}', Grades.C) - self.assertFormatIsValue('{:}', Grades.C) - self.assertFormatIsValue('{:20}', Grades.C) - self.assertFormatIsValue('{:^20}', Grades.C) - self.assertFormatIsValue('{:>20}', Grades.C) - self.assertFormatIsValue('{:<20}', Grades.C) - self.assertFormatIsValue('{:+}', Grades.C) - self.assertFormatIsValue('{:08X}', Grades.C) - self.assertFormatIsValue('{:b}', Grades.C) - - def test_format_enum_str(self): - Directional = self.Directional - self.assertFormatIsValue('{}', Directional.WEST) - self.assertFormatIsValue('{:}', Directional.WEST) - self.assertFormatIsValue('{:20}', Directional.WEST) - self.assertFormatIsValue('{:^20}', Directional.WEST) - self.assertFormatIsValue('{:>20}', Directional.WEST) - self.assertFormatIsValue('{:<20}', Directional.WEST) + def test_programmatic_function_string_list_with_start(self): + MinorEnum = Enum('MinorEnum', ['june', 'july', 'august'], start=20) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 20): + e = MinorEnum(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) + + def test_programmatic_function_type_from_subclass(self): + MinorEnum = IntEnum('MinorEnum', 'june july august') + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 1): + e = MinorEnum(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) - def test_object_str_override(self): - class Colors(Enum): - RED, GREEN, BLUE = 1, 2, 3 - def __repr__(self): - return "test.%s" % (self._name_, ) - __str__ = object.__str__ - self.assertEqual(str(Colors.RED), 'test.RED') + def test_programmatic_function_type_from_subclass_with_start(self): + MinorEnum = IntEnum('MinorEnum', 'june july august', start=40) + lst = list(MinorEnum) + self.assertEqual(len(lst), len(MinorEnum)) + self.assertEqual(len(MinorEnum), 3, MinorEnum) + self.assertEqual( + [MinorEnum.june, MinorEnum.july, MinorEnum.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 40): + e = MinorEnum(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, MinorEnum) + self.assertIs(type(e), MinorEnum) - def test_enum_str_override(self): - class MyStrEnum(Enum): - def __str__(self): - return 'MyStr' - class MyMethodEnum(Enum): - def hello(self): - return 'Hello! My name is %s' % self.name - class Test1Enum(MyMethodEnum, int, MyStrEnum): - One = 1 - Two = 2 - self.assertTrue(Test1Enum._member_type_ is int) - self.assertEqual(str(Test1Enum.One), 'MyStr') - self.assertEqual(format(Test1Enum.One, ''), 'MyStr') - # - class Test2Enum(MyStrEnum, MyMethodEnum): - One = 1 - Two = 2 - self.assertEqual(str(Test2Enum.One), 'MyStr') - self.assertEqual(format(Test1Enum.One, ''), 'MyStr') + # TODO: RUSTPYTHON, AssertionError: is not + @unittest.expectedFailure + def test_intenum_from_bytes(self): + self.assertIs(IntStooges.from_bytes(b'\x00\x03', 'big'), IntStooges.MOE) + with self.assertRaises(ValueError): + IntStooges.from_bytes(b'\x00\x05', 'big') - def test_inherited_data_type(self): - class HexInt(int): - def __repr__(self): - return hex(self) - class MyEnum(HexInt, enum.Enum): - A = 1 - B = 2 - C = 3 - self.assertEqual(repr(MyEnum.A), '') + def test_reserved_sunder_error(self): + with self.assertRaisesRegex( + ValueError, + '_sunder_ names, such as ._bad_., are reserved', + ): + class Bad(Enum): + _bad_ = 1 def test_too_many_data_types(self): with self.assertRaisesRegex(TypeError, 'too many data types'): @@ -615,116 +1426,6 @@ def repr(self): class Huh(MyStr, MyInt, Enum): One = 1 - def test_hash(self): - Season = self.Season - dates = {} - dates[Season.WINTER] = '1225' - dates[Season.SPRING] = '0315' - dates[Season.SUMMER] = '0704' - dates[Season.AUTUMN] = '1031' - self.assertEqual(dates[Season.AUTUMN], '1031') - - def test_intenum_from_scratch(self): - class phy(int, Enum): - pi = 3 - tau = 2 * pi - self.assertTrue(phy.pi < phy.tau) - - def test_intenum_inherited(self): - class IntEnum(int, Enum): - pass - class phy(IntEnum): - pi = 3 - tau = 2 * pi - self.assertTrue(phy.pi < phy.tau) - - def test_floatenum_from_scratch(self): - class phy(float, Enum): - pi = 3.1415926 - tau = 2 * pi - self.assertTrue(phy.pi < phy.tau) - - def test_floatenum_inherited(self): - class FloatEnum(float, Enum): - pass - class phy(FloatEnum): - pi = 3.1415926 - tau = 2 * pi - self.assertTrue(phy.pi < phy.tau) - - def test_strenum_from_scratch(self): - class phy(str, Enum): - pi = 'Pi' - tau = 'Tau' - self.assertTrue(phy.pi < phy.tau) - - def test_strenum_inherited(self): - class StrEnum(str, Enum): - pass - class phy(StrEnum): - pi = 'Pi' - tau = 'Tau' - self.assertTrue(phy.pi < phy.tau) - - - def test_intenum(self): - class WeekDay(IntEnum): - SUNDAY = 1 - MONDAY = 2 - TUESDAY = 3 - WEDNESDAY = 4 - THURSDAY = 5 - FRIDAY = 6 - SATURDAY = 7 - - self.assertEqual(['a', 'b', 'c'][WeekDay.MONDAY], 'c') - self.assertEqual([i for i in range(WeekDay.TUESDAY)], [0, 1, 2]) - - lst = list(WeekDay) - self.assertEqual(len(lst), len(WeekDay)) - self.assertEqual(len(WeekDay), 7) - target = 'SUNDAY MONDAY TUESDAY WEDNESDAY THURSDAY FRIDAY SATURDAY' - target = target.split() - for i, weekday in enumerate(target, 1): - e = WeekDay(i) - self.assertEqual(e, i) - self.assertEqual(int(e), i) - self.assertEqual(e.name, weekday) - self.assertIn(e, WeekDay) - self.assertEqual(lst.index(e)+1, i) - self.assertTrue(0 < e < 8) - self.assertIs(type(e), WeekDay) - self.assertIsInstance(e, int) - self.assertIsInstance(e, Enum) - - def test_intenum_duplicates(self): - class WeekDay(IntEnum): - SUNDAY = 1 - MONDAY = 2 - TUESDAY = TEUSDAY = 3 - WEDNESDAY = 4 - THURSDAY = 5 - FRIDAY = 6 - SATURDAY = 7 - self.assertIs(WeekDay.TEUSDAY, WeekDay.TUESDAY) - self.assertEqual(WeekDay(3).name, 'TUESDAY') - self.assertEqual([k for k,v in WeekDay.__members__.items() - if v.name != k], ['TEUSDAY', ]) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_intenum_from_bytes(self): - self.assertIs(IntStooges.from_bytes(b'\x00\x03', 'big'), IntStooges.MOE) - with self.assertRaises(ValueError): - IntStooges.from_bytes(b'\x00\x05', 'big') - - def test_floatenum_fromhex(self): - h = float.hex(FloatStooges.MOE.value) - self.assertIs(FloatStooges.fromhex(h), FloatStooges.MOE) - h = float.hex(FloatStooges.MOE.value + 0.01) - with self.assertRaises(ValueError): - FloatStooges.fromhex(h) - def test_pickle_enum(self): if isinstance(Stooges, Exception): raise Stooges @@ -755,12 +1456,7 @@ def test_pickle_enum_function_with_module(self): test_pickle_dump_load(self.assertIs, Question.who) test_pickle_dump_load(self.assertIs, Question) - def test_enum_function_with_qualname(self): - if isinstance(Theory, Exception): - raise Theory - self.assertEqual(Theory.__qualname__, 'spanish_inquisition') - - def test_class_nested_enum_and_pickle_protocol_four(self): + def test_pickle_nested_class(self): # would normally just have this directly in the class namespace class NestedEnum(Enum): twigs = 'common' @@ -774,11 +1470,11 @@ def test_pickle_by_name(self): class ReplaceGlobalInt(IntEnum): ONE = 1 TWO = 2 - ReplaceGlobalInt.__reduce_ex__ = enum._reduce_ex_by_name + ReplaceGlobalInt.__reduce_ex__ = enum._reduce_ex_by_global_name for proto in range(HIGHEST_PROTOCOL): self.assertEqual(ReplaceGlobalInt.TWO.__reduce_ex__(proto), 'TWO') - def test_exploding_pickle(self): + def test_pickle_explodes(self): BadPickle = Enum( 'BadPickle', 'dill sweet bread-n-butter', module=__name__) globals()['BadPickle'] = BadPickle @@ -819,185 +1515,6 @@ class Season(Enum): [Season.SUMMER, Season.WINTER, Season.AUTUMN, Season.SPRING], ) - def test_reversed_iteration_order(self): - self.assertEqual( - list(reversed(self.Season)), - [self.Season.WINTER, self.Season.AUTUMN, self.Season.SUMMER, - self.Season.SPRING] - ) - - def test_programmatic_function_string(self): - SummerMonth = Enum('SummerMonth', 'june july august') - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 1): - e = SummerMonth(i) - self.assertEqual(int(e.value), i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_string_with_start(self): - SummerMonth = Enum('SummerMonth', 'june july august', start=10) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 10): - e = SummerMonth(i) - self.assertEqual(int(e.value), i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_string_list(self): - SummerMonth = Enum('SummerMonth', ['june', 'july', 'august']) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 1): - e = SummerMonth(i) - self.assertEqual(int(e.value), i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_string_list_with_start(self): - SummerMonth = Enum('SummerMonth', ['june', 'july', 'august'], start=20) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 20): - e = SummerMonth(i) - self.assertEqual(int(e.value), i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_iterable(self): - SummerMonth = Enum( - 'SummerMonth', - (('june', 1), ('july', 2), ('august', 3)) - ) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 1): - e = SummerMonth(i) - self.assertEqual(int(e.value), i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_from_dict(self): - SummerMonth = Enum( - 'SummerMonth', - OrderedDict((('june', 1), ('july', 2), ('august', 3))) - ) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 1): - e = SummerMonth(i) - self.assertEqual(int(e.value), i) - self.assertNotEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_type(self): - SummerMonth = Enum('SummerMonth', 'june july august', type=int) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 1): - e = SummerMonth(i) - self.assertEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_type_with_start(self): - SummerMonth = Enum('SummerMonth', 'june july august', type=int, start=30) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 30): - e = SummerMonth(i) - self.assertEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_type_from_subclass(self): - SummerMonth = IntEnum('SummerMonth', 'june july august') - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 1): - e = SummerMonth(i) - self.assertEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - - def test_programmatic_function_type_from_subclass_with_start(self): - SummerMonth = IntEnum('SummerMonth', 'june july august', start=40) - lst = list(SummerMonth) - self.assertEqual(len(lst), len(SummerMonth)) - self.assertEqual(len(SummerMonth), 3, SummerMonth) - self.assertEqual( - [SummerMonth.june, SummerMonth.july, SummerMonth.august], - lst, - ) - for i, month in enumerate('june july august'.split(), 40): - e = SummerMonth(i) - self.assertEqual(e, i) - self.assertEqual(e.name, month) - self.assertIn(e, SummerMonth) - self.assertIs(type(e), SummerMonth) - def test_subclassing(self): if isinstance(Name, Exception): raise Name @@ -1011,14 +1528,19 @@ class Color(Enum): red = 1 green = 2 blue = 3 + # with self.assertRaises(TypeError): class MoreColor(Color): cyan = 4 magenta = 5 yellow = 6 - with self.assertRaisesRegex(TypeError, "EvenMoreColor: cannot extend enumeration 'Color'"): + # + with self.assertRaisesRegex(TypeError, " cannot extend "): class EvenMoreColor(Color, IntEnum): chartruese = 7 + # + with self.assertRaisesRegex(TypeError, " cannot extend "): + Color('Foo', ('pink', 'black')) def test_exclude_methods(self): class whatever(Enum): @@ -1121,32 +1643,13 @@ class Color(Enum): with self.assertRaises(KeyError): Color['chartreuse'] - def test_new_repr(self): - class Color(Enum): - red = 1 - green = 2 - blue = 3 - def __repr__(self): - return "don't you just love shades of %s?" % self.name - self.assertEqual( - repr(Color.blue), - "don't you just love shades of blue?", - ) - - def test_inherited_repr(self): - class MyEnum(Enum): - def __repr__(self): - return "My name is %s." % self.name - class MyIntEnum(int, MyEnum): - this = 1 - that = 2 - theother = 3 - self.assertEqual(repr(MyIntEnum.that), "My name is that.") + # tests that need to be evalualted for moving def test_multiple_mixin_mro(self): class auto_enum(type(Enum)): def __new__(metacls, cls, bases, classdict): temp = type(classdict)() + temp._cls_name = cls names = set(classdict._member_names) i = 0 for k in classdict._member_names: @@ -1193,14 +1696,16 @@ def __new__(cls, *args): return self def __getnewargs__(self): return self._args - @property + @bltns.property def __name__(self): return self._intname def __repr__(self): # repr() is updated to include the name and type info - return "{}({!r}, {})".format(type(self).__name__, - self.__name__, - int.__repr__(self)) + return "{}({!r}, {})".format( + type(self).__name__, + self.__name__, + int.__repr__(self), + ) def __str__(self): # str() is unchanged, even if it relies on the repr() fallback base = int @@ -1215,7 +1720,8 @@ def __add__(self, other): if isinstance(self, NamedInt) and isinstance(other, NamedInt): return NamedInt( '({0} + {1})'.format(self.__name__, other.__name__), - temp ) + temp, + ) else: return temp @@ -1236,7 +1742,7 @@ class NEI(NamedInt, Enum): test_pickle_dump_load(self.assertIs, NEI.y) test_pickle_dump_load(self.assertIs, NEI) - # TODO: RUSTPYTHON + # TODO: RUSTPYTHON, fails on pickle @unittest.expectedFailure def test_subclasses_with_getnewargs_ex(self): class NamedInt(int): @@ -1252,14 +1758,16 @@ def __new__(cls, *args): return self def __getnewargs_ex__(self): return self._args, {} - @property + @bltns.property def __name__(self): return self._intname def __repr__(self): # repr() is updated to include the name and type info - return "{}({!r}, {})".format(type(self).__name__, - self.__name__, - int.__repr__(self)) + return "{}({!r}, {})".format( + type(self).__name__, + self.__name__, + int.__repr__(self), + ) def __str__(self): # str() is unchanged, even if it relies on the repr() fallback base = int @@ -1274,7 +1782,8 @@ def __add__(self, other): if isinstance(self, NamedInt) and isinstance(other, NamedInt): return NamedInt( '({0} + {1})'.format(self.__name__, other.__name__), - temp ) + temp, + ) else: return temp @@ -1309,14 +1818,16 @@ def __new__(cls, *args): return self def __reduce__(self): return self.__class__, self._args - @property + @bltns.property def __name__(self): return self._intname def __repr__(self): # repr() is updated to include the name and type info - return "{}({!r}, {})".format(type(self).__name__, - self.__name__, - int.__repr__(self)) + return "{}({!r}, {})".format( + type(self).__name__, + self.__name__, + int.__repr__(self), + ) def __str__(self): # str() is unchanged, even if it relies on the repr() fallback base = int @@ -1331,7 +1842,8 @@ def __add__(self, other): if isinstance(self, NamedInt) and isinstance(other, NamedInt): return NamedInt( '({0} + {1})'.format(self.__name__, other.__name__), - temp ) + temp, + ) else: return temp @@ -1366,14 +1878,16 @@ def __new__(cls, *args): return self def __reduce_ex__(self, proto): return self.__class__, self._args - @property + @bltns.property def __name__(self): return self._intname def __repr__(self): # repr() is updated to include the name and type info - return "{}({!r}, {})".format(type(self).__name__, - self.__name__, - int.__repr__(self)) + return "{}({!r}, {})".format( + type(self).__name__, + self.__name__, + int.__repr__(self), + ) def __str__(self): # str() is unchanged, even if it relies on the repr() fallback base = int @@ -1388,7 +1902,8 @@ def __add__(self, other): if isinstance(self, NamedInt) and isinstance(other, NamedInt): return NamedInt( '({0} + {1})'.format(self.__name__, other.__name__), - temp ) + temp, + ) else: return temp @@ -1397,7 +1912,6 @@ class NEI(NamedInt, Enum): x = ('the-x', 1) y = ('the-y', 2) - self.assertIs(NEI.__new__, Enum.__new__) self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") globals()['NamedInt'] = NamedInt @@ -1421,14 +1935,16 @@ def __new__(cls, *args): self._intname = name self._args = _args return self - @property + @bltns.property def __name__(self): return self._intname def __repr__(self): # repr() is updated to include the name and type info - return "{}({!r}, {})".format(type(self).__name__, - self.__name__, - int.__repr__(self)) + return "{}({!r}, {})".format( + type(self).__name__, + self.__name__, + int.__repr__(self), + ) def __str__(self): # str() is unchanged, even if it relies on the repr() fallback base = int @@ -1451,7 +1967,6 @@ class NEI(NamedInt, Enum): __qualname__ = 'NEI' x = ('the-x', 1) y = ('the-y', 2) - self.assertIs(NEI.__new__, Enum.__new__) self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") globals()['NamedInt'] = NamedInt @@ -1459,10 +1974,14 @@ class NEI(NamedInt, Enum): NI5 = NamedInt('test', 5) self.assertEqual(NI5, 5) self.assertEqual(NEI.y.value, 2) - test_pickle_exception(self.assertRaises, TypeError, NEI.x) - test_pickle_exception(self.assertRaises, PicklingError, NEI) + with self.assertRaisesRegex(TypeError, "name and value must be specified"): + test_pickle_dump_load(self.assertIs, NEI.y) + # fix pickle support and try again + NEI.__reduce_ex__ = enum.pickle_by_enum_name + test_pickle_dump_load(self.assertIs, NEI.y) + test_pickle_dump_load(self.assertIs, NEI) - def test_subclasses_without_direct_pickle_support_using_name(self): + def test_subclasses_with_direct_pickle_support(self): class NamedInt(int): __qualname__ = 'NamedInt' def __new__(cls, *args): @@ -1474,14 +1993,16 @@ def __new__(cls, *args): self._intname = name self._args = _args return self - @property + @bltns.property def __name__(self): return self._intname def __repr__(self): # repr() is updated to include the name and type info - return "{}({!r}, {})".format(type(self).__name__, - self.__name__, - int.__repr__(self)) + return "{}({!r}, {})".format( + type(self).__name__, + self.__name__, + int.__repr__(self), + ) def __str__(self): # str() is unchanged, even if it relies on the repr() fallback base = int @@ -1496,7 +2017,8 @@ def __add__(self, other): if isinstance(self, NamedInt) and isinstance(other, NamedInt): return NamedInt( '({0} + {1})'.format(self.__name__, other.__name__), - temp ) + temp, + ) else: return temp @@ -1580,13 +2102,10 @@ class Color(AutoNumber): self.assertEqual(list(map(int, Color)), [1, 2, 3]) def test_equality(self): - class AlwaysEqual: - def __eq__(self, other): - return True class OrdinaryEnum(Enum): a = 1 - self.assertEqual(AlwaysEqual(), OrdinaryEnum.a) - self.assertEqual(OrdinaryEnum.a, AlwaysEqual()) + self.assertEqual(ALWAYS_EQ, OrdinaryEnum.a) + self.assertEqual(OrdinaryEnum.a, ALWAYS_EQ) def test_ordered_mixin(self): class OrderedEnum(Enum): @@ -1663,6 +2182,15 @@ def test(self): class Test(Base): test = 1 self.assertEqual(Test.test.test, 'dynamic') + self.assertEqual(Test.test.value, 1) + class Base2(Enum): + @enum.property + def flash(self): + return 'flashy dynamic' + class Test(Base2): + flash = 1 + self.assertEqual(Test.flash.flash, 'flashy dynamic') + self.assertEqual(Test.flash.value, 1) def test_no_duplicates(self): class UniqueEnum(Enum): @@ -1699,7 +2227,7 @@ class Planet(Enum): def __init__(self, mass, radius): self.mass = mass # in kilograms self.radius = radius # in meters - @property + @enum.property def surface_gravity(self): # universal gravitational constant (m3 kg-1 s-2) G = 6.67300E-11 @@ -1735,124 +2263,41 @@ def __new__(cls, value, period): self.assertTrue(Period.month_1 is Period.day_30) self.assertTrue(Period.week_4 is Period.day_28) - def test_nonhash_value(self): - class AutoNumberInAList(Enum): - def __new__(cls): - value = [len(cls.__members__) + 1] - obj = object.__new__(cls) - obj._value_ = value - return obj - class ColorInAList(AutoNumberInAList): - red = () - green = () - blue = () - self.assertEqual(list(ColorInAList), [ColorInAList.red, ColorInAList.green, ColorInAList.blue]) - for enum, value in zip(ColorInAList, range(3)): - value += 1 - self.assertEqual(enum.value, [value]) - self.assertIs(ColorInAList([value]), enum) - - def test_conflicting_types_resolved_in_new(self): - class LabelledIntEnum(int, Enum): - def __new__(cls, *args): - value, label = args - obj = int.__new__(cls, value) - obj.label = label - obj._value_ = value - return obj - - class LabelledList(LabelledIntEnum): - unprocessed = (1, "Unprocessed") - payment_complete = (2, "Payment Complete") - - self.assertEqual(list(LabelledList), [LabelledList.unprocessed, LabelledList.payment_complete]) - self.assertEqual(LabelledList.unprocessed, 1) - self.assertEqual(LabelledList(1), LabelledList.unprocessed) - - def test_auto_number(self): - class Color(Enum): - red = auto() - blue = auto() - green = auto() - - self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) - self.assertEqual(Color.red.value, 1) - self.assertEqual(Color.blue.value, 2) - self.assertEqual(Color.green.value, 3) - - def test_auto_name(self): - class Color(Enum): - def _generate_next_value_(name, start, count, last): - return name - red = auto() - blue = auto() - green = auto() - - self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) - self.assertEqual(Color.red.value, 'red') - self.assertEqual(Color.blue.value, 'blue') - self.assertEqual(Color.green.value, 'green') - - def test_auto_name_inherit(self): - class AutoNameEnum(Enum): - def _generate_next_value_(name, start, count, last): - return name - class Color(AutoNameEnum): - red = auto() - blue = auto() - green = auto() - - self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) - self.assertEqual(Color.red.value, 'red') - self.assertEqual(Color.blue.value, 'blue') - self.assertEqual(Color.green.value, 'green') - - def test_auto_garbage(self): - class Color(Enum): - red = 'red' - blue = auto() - self.assertEqual(Color.blue.value, 1) - - def test_auto_garbage_corrected(self): - class Color(Enum): - red = 'red' - blue = 2 - green = auto() - - self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) - self.assertEqual(Color.red.value, 'red') - self.assertEqual(Color.blue.value, 2) - self.assertEqual(Color.green.value, 3) - - def test_auto_order(self): - with self.assertRaises(TypeError): - class Color(Enum): - red = auto() - green = auto() - blue = auto() - def _generate_next_value_(name, start, count, last): - return name + def test_nonhash_value(self): + class AutoNumberInAList(Enum): + def __new__(cls): + value = [len(cls.__members__) + 1] + obj = object.__new__(cls) + obj._value_ = value + return obj + class ColorInAList(AutoNumberInAList): + red = () + green = () + blue = () + self.assertEqual(list(ColorInAList), [ColorInAList.red, ColorInAList.green, ColorInAList.blue]) + for enum, value in zip(ColorInAList, range(3)): + value += 1 + self.assertEqual(enum.value, [value]) + self.assertIs(ColorInAList([value]), enum) - def test_auto_order_wierd(self): - weird_auto = auto() - weird_auto.value = 'pathological case' - class Color(Enum): - red = weird_auto - def _generate_next_value_(name, start, count, last): - return name - blue = auto() - self.assertEqual(list(Color), [Color.red, Color.blue]) - self.assertEqual(Color.red.value, 'pathological case') - self.assertEqual(Color.blue.value, 'blue') + def test_conflicting_types_resolved_in_new(self): + class LabelledIntEnum(int, Enum): + def __new__(cls, *args): + value, label = args + obj = int.__new__(cls, value) + obj.label = label + obj._value_ = value + return obj - def test_duplicate_auto(self): - class Dupes(Enum): - first = primero = auto() - second = auto() - third = auto() - self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes)) + class LabelledList(LabelledIntEnum): + unprocessed = (1, "Unprocessed") + payment_complete = (2, "Payment Complete") + + self.assertEqual(list(LabelledList), [LabelledList.unprocessed, LabelledList.payment_complete]) + self.assertEqual(LabelledList.unprocessed, 1) + self.assertEqual(LabelledList(1), LabelledList.unprocessed) - def test_default_missing(self): + def test_default_missing_no_chained_exception(self): class Color(Enum): RED = 1 GREEN = 2 @@ -1864,7 +2309,7 @@ class Color(Enum): else: raise Exception('Exception not raised.') - def test_missing(self): + def test_missing_override(self): class Color(Enum): red = 1 green = 2 @@ -1901,6 +2346,40 @@ def _missing_(cls, item): else: raise Exception('Exception not raised.') + def test_missing_exceptions_reset(self): + import gc + import weakref + # + class TestEnum(enum.Enum): + VAL1 = 'val1' + VAL2 = 'val2' + # + class Class1: + def __init__(self): + # Gracefully handle an exception of our own making + try: + raise ValueError() + except ValueError: + pass + # + class Class2: + def __init__(self): + # Gracefully handle an exception of Enum's making + try: + TestEnum('invalid_value') + except ValueError: + pass + # No strong refs here so these are free to die. + class_1_ref = weakref.ref(Class1()) + class_2_ref = weakref.ref(Class2()) + # + # The exception raised by Enum used to create a reference loop and thus + # Class2 instances would stick around until the next garbage collection + # cycle, unlike Class1. Verify Class2 no longer does this. + gc.collect() # For PyPy or other GCs. + self.assertIs(class_1_ref(), None) + self.assertIs(class_2_ref(), None) + def test_multiple_mixin(self): class MaxMixin: @classproperty @@ -1932,6 +2411,7 @@ class Color(MaxMixin, StrMixin, Enum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(Color.RED.value, 1) self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 3) @@ -1941,6 +2421,7 @@ class Color(StrMixin, MaxMixin, Enum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(Color.RED.value, 1) self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 3) @@ -1950,6 +2431,7 @@ class CoolColor(StrMixin, SomeEnum, Enum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(CoolColor.RED.value, 1) self.assertEqual(CoolColor.GREEN.value, 2) self.assertEqual(CoolColor.BLUE.value, 3) @@ -1959,6 +2441,7 @@ class CoolerColor(StrMixin, AnotherEnum, Enum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(CoolerColor.RED.value, 1) self.assertEqual(CoolerColor.GREEN.value, 2) self.assertEqual(CoolerColor.BLUE.value, 3) @@ -1969,6 +2452,7 @@ class CoolestColor(StrMixin, SomeEnum, AnotherEnum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(CoolestColor.RED.value, 1) self.assertEqual(CoolestColor.GREEN.value, 2) self.assertEqual(CoolestColor.BLUE.value, 3) @@ -1979,6 +2463,7 @@ class ConfusedColor(StrMixin, AnotherEnum, SomeEnum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(ConfusedColor.RED.value, 1) self.assertEqual(ConfusedColor.GREEN.value, 2) self.assertEqual(ConfusedColor.BLUE.value, 3) @@ -1989,6 +2474,7 @@ class ReformedColor(StrMixin, IntEnum, SomeEnum, AnotherEnum): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ # needed as of 3.11 self.assertEqual(ReformedColor.RED.value, 1) self.assertEqual(ReformedColor.GREEN.value, 2) self.assertEqual(ReformedColor.BLUE.value, 3) @@ -1998,13 +2484,6 @@ class ReformedColor(StrMixin, IntEnum, SomeEnum, AnotherEnum): self.assertTrue(issubclass(ReformedColor, int)) def test_multiple_inherited_mixin(self): - class StrEnum(str, Enum): - def __new__(cls, *args, **kwargs): - for a in args: - if not isinstance(a, str): - raise TypeError("Enumeration '%s' (%s) is not" - " a string" % (a, type(a).__name__)) - return str.__new__(cls, *args, **kwargs) @unique class Decision1(StrEnum): REVERT = "REVERT" @@ -2028,11 +2507,12 @@ def __repr__(self): return hex(self) class MyIntEnum(HexMixin, MyInt, enum.Enum): - pass + __repr__ = HexMixin.__repr__ class Foo(MyIntEnum): TEST = 1 self.assertTrue(isinstance(Foo.TEST, MyInt)) + self.assertEqual(Foo._member_type_, MyInt) self.assertEqual(repr(Foo.TEST), "0x1") class Fee(MyIntEnum): @@ -2044,6 +2524,50 @@ def __new__(cls, value): return member self.assertEqual(Fee.TEST, 2) + def test_multiple_mixin_with_common_data_type(self): + class CaseInsensitiveStrEnum(str, Enum): + @classmethod + def _missing_(cls, value): + for member in cls._member_map_.values(): + if member._value_.lower() == value.lower(): + return member + return super()._missing_(value) + # + class LenientStrEnum(str, Enum): + def __init__(self, *args): + self._valid = True + @classmethod + def _missing_(cls, value): + unknown = cls._member_type_.__new__(cls, value) + unknown._valid = False + unknown._name_ = value.upper() + unknown._value_ = value + cls._member_map_[value] = unknown + return unknown + @enum.property + def valid(self): + return self._valid + # + class JobStatus(CaseInsensitiveStrEnum, LenientStrEnum): + ACTIVE = "active" + PENDING = "pending" + TERMINATED = "terminated" + # + JS = JobStatus + self.assertEqual(list(JobStatus), [JS.ACTIVE, JS.PENDING, JS.TERMINATED]) + self.assertEqual(JS.ACTIVE, 'active') + self.assertEqual(JS.ACTIVE.value, 'active') + self.assertIs(JS('Active'), JS.ACTIVE) + self.assertTrue(JS.ACTIVE.valid) + missing = JS('missing') + self.assertEqual(list(JobStatus), [JS.ACTIVE, JS.PENDING, JS.TERMINATED]) + self.assertEqual(JS.ACTIVE, 'active') + self.assertEqual(JS.ACTIVE.value, 'active') + self.assertIs(JS('Active'), JS.ACTIVE) + self.assertTrue(JS.ACTIVE.valid) + self.assertTrue(isinstance(missing, JS)) + self.assertFalse(missing.valid) + def test_empty_globals(self): # bpo-35717: sys._getframe(2).f_globals['__name__'] fails with KeyError # when using compile and exec because f_globals is empty @@ -2053,8 +2577,358 @@ def test_empty_globals(self): local_ls = {} exec(code, global_ns, local_ls) + def test_strenum(self): + class GoodStrEnum(StrEnum): + one = '1' + two = '2' + three = b'3', 'ascii' + four = b'4', 'latin1', 'strict' + self.assertEqual(GoodStrEnum.one, '1') + self.assertEqual(str(GoodStrEnum.one), '1') + self.assertEqual('{}'.format(GoodStrEnum.one), '1') + self.assertEqual(GoodStrEnum.one, str(GoodStrEnum.one)) + self.assertEqual(GoodStrEnum.one, '{}'.format(GoodStrEnum.one)) + self.assertEqual(repr(GoodStrEnum.one), "") + # + class DumbMixin: + def __str__(self): + return "don't do this" + class DumbStrEnum(DumbMixin, StrEnum): + five = '5' + six = '6' + seven = '7' + __str__ = DumbMixin.__str__ # needed as of 3.11 + self.assertEqual(DumbStrEnum.seven, '7') + self.assertEqual(str(DumbStrEnum.seven), "don't do this") + # + class EnumMixin(Enum): + def hello(self): + print('hello from %s' % (self, )) + class HelloEnum(EnumMixin, StrEnum): + eight = '8' + self.assertEqual(HelloEnum.eight, '8') + self.assertEqual(HelloEnum.eight, str(HelloEnum.eight)) + # + class GoodbyeMixin: + def goodbye(self): + print('%s wishes you a fond farewell') + class GoodbyeEnum(GoodbyeMixin, EnumMixin, StrEnum): + nine = '9' + self.assertEqual(GoodbyeEnum.nine, '9') + self.assertEqual(GoodbyeEnum.nine, str(GoodbyeEnum.nine)) + # + with self.assertRaisesRegex(TypeError, '1 is not a string'): + class FirstFailedStrEnum(StrEnum): + one = 1 + two = '2' + with self.assertRaisesRegex(TypeError, "2 is not a string"): + class SecondFailedStrEnum(StrEnum): + one = '1' + two = 2, + three = '3' + with self.assertRaisesRegex(TypeError, '2 is not a string'): + class ThirdFailedStrEnum(StrEnum): + one = '1' + two = 2 + with self.assertRaisesRegex(TypeError, 'encoding must be a string, not %r' % (sys.getdefaultencoding, )): + class ThirdFailedStrEnum(StrEnum): + one = '1' + two = b'2', sys.getdefaultencoding + with self.assertRaisesRegex(TypeError, 'errors must be a string, not 9'): + class ThirdFailedStrEnum(StrEnum): + one = '1' + two = b'2', 'ascii', 9 + + # TODO: RUSTPYTHON, fails on encoding testing : TypeError: Expected type 'str' but 'builtin_function_or_method' found + @unittest.expectedFailure + def test_custom_strenum(self): + class CustomStrEnum(str, Enum): + pass + class OkayEnum(CustomStrEnum): + one = '1' + two = '2' + three = b'3', 'ascii' + four = b'4', 'latin1', 'strict' + self.assertEqual(OkayEnum.one, '1') + self.assertEqual(str(OkayEnum.one), 'OkayEnum.one') + self.assertEqual('{}'.format(OkayEnum.one), 'OkayEnum.one') + self.assertEqual(repr(OkayEnum.one), "") + # + class DumbMixin: + def __str__(self): + return "don't do this" + class DumbStrEnum(DumbMixin, CustomStrEnum): + five = '5' + six = '6' + seven = '7' + __str__ = DumbMixin.__str__ # needed as of 3.11 + self.assertEqual(DumbStrEnum.seven, '7') + self.assertEqual(str(DumbStrEnum.seven), "don't do this") + # + class EnumMixin(Enum): + def hello(self): + print('hello from %s' % (self, )) + class HelloEnum(EnumMixin, CustomStrEnum): + eight = '8' + self.assertEqual(HelloEnum.eight, '8') + self.assertEqual(str(HelloEnum.eight), 'HelloEnum.eight') + # + class GoodbyeMixin: + def goodbye(self): + print('%s wishes you a fond farewell') + class GoodbyeEnum(GoodbyeMixin, EnumMixin, CustomStrEnum): + nine = '9' + self.assertEqual(GoodbyeEnum.nine, '9') + self.assertEqual(str(GoodbyeEnum.nine), 'GoodbyeEnum.nine') + # + class FirstFailedStrEnum(CustomStrEnum): + one = 1 # this will become '1' + two = '2' + class SecondFailedStrEnum(CustomStrEnum): + one = '1' + two = 2, # this will become '2' + three = '3' + class ThirdFailedStrEnum(CustomStrEnum): + one = '1' + two = 2 # this will become '2' + with self.assertRaisesRegex(TypeError, '.encoding. must be str, not '): + class ThirdFailedStrEnum(CustomStrEnum): + one = '1' + two = b'2', sys.getdefaultencoding + with self.assertRaisesRegex(TypeError, '.errors. must be str, not '): + class ThirdFailedStrEnum(CustomStrEnum): + one = '1' + two = b'2', 'ascii', 9 + + def test_missing_value_error(self): + with self.assertRaisesRegex(TypeError, "_value_ not set in __new__"): + class Combined(str, Enum): + # + def __new__(cls, value, sequence): + enum = str.__new__(cls, value) + if '(' in value: + fis_name, segment = value.split('(', 1) + segment = segment.strip(' )') + else: + fis_name = value + segment = None + enum.fis_name = fis_name + enum.segment = segment + enum.sequence = sequence + return enum + # + def __repr__(self): + return "<%s.%s>" % (self.__class__.__name__, self._name_) + # + key_type = 'An$(1,2)', 0 + company_id = 'An$(3,2)', 1 + code = 'An$(5,1)', 2 + description = 'Bn$', 3 + + + def test_private_variable_is_normal_attribute(self): + class Private(Enum): + __corporal = 'Radar' + __major_ = 'Hoolihan' + self.assertEqual(Private._Private__corporal, 'Radar') + self.assertEqual(Private._Private__major_, 'Hoolihan') + + @unittest.skip("Accessing all values retained for performance reasons, see GH-93910") + def test_exception_for_member_from_member_access(self): + with self.assertRaisesRegex(AttributeError, " member has no attribute .NO."): + class Di(Enum): + YES = 1 + NO = 0 + nope = Di.YES.NO + + + def test_dynamic_members_with_static_methods(self): + # + foo_defines = {'FOO_CAT': 'aloof', 'BAR_DOG': 'friendly', 'FOO_HORSE': 'big'} + class Foo(Enum): + vars().update({ + k: v + for k, v in foo_defines.items() + if k.startswith('FOO_') + }) + def upper(self): + return self.value.upper() + self.assertEqual(list(Foo), [Foo.FOO_CAT, Foo.FOO_HORSE]) + self.assertEqual(Foo.FOO_CAT.value, 'aloof') + self.assertEqual(Foo.FOO_HORSE.upper(), 'BIG') + # + with self.assertRaisesRegex(TypeError, "'FOO_CAT' already defined as 'aloof'"): + class FooBar(Enum): + vars().update({ + k: v + for k, v in foo_defines.items() + if k.startswith('FOO_') + }, + **{'FOO_CAT': 'small'}, + ) + def upper(self): + return self.value.upper() + + def test_repr_with_dataclass(self): + "ensure dataclass-mixin has correct repr()" + from dataclasses import dataclass + @dataclass + class Foo: + __qualname__ = 'Foo' + a: int + class Entries(Foo, Enum): + ENTRY1 = 1 + self.assertTrue(isinstance(Entries.ENTRY1, Foo)) + self.assertTrue(Entries._member_type_ is Foo, Entries._member_type_) + self.assertTrue(Entries.ENTRY1.value == Foo(1), Entries.ENTRY1.value) + self.assertEqual(repr(Entries.ENTRY1), '') + + def test_repr_with_init_data_type_mixin(self): + # non-data_type is a mixin that doesn't define __new__ + class Foo: + def __init__(self, a): + self.a = a + def __repr__(self): + return f'Foo(a={self.a!r})' + class Entries(Foo, Enum): + ENTRY1 = 1 + # + self.assertEqual(repr(Entries.ENTRY1), 'Foo(a=1)') + + def test_repr_and_str_with_non_data_type_mixin(self): + # non-data_type is a mixin that doesn't define __new__ + class Foo: + def __repr__(self): + return 'Foo' + def __str__(self): + return 'ooF' + class Entries(Foo, Enum): + ENTRY1 = 1 + # + self.assertEqual(repr(Entries.ENTRY1), 'Foo') + self.assertEqual(str(Entries.ENTRY1), 'ooF') + + def test_value_backup_assign(self): + # check that enum will add missing values when custom __new__ does not + class Some(Enum): + def __new__(cls, val): + return object.__new__(cls) + x = 1 + y = 2 + self.assertEqual(Some.x.value, 1) + self.assertEqual(Some.y.value, 2) + + def test_custom_flag_bitwise(self): + class MyIntFlag(int, Flag): + ONE = 1 + TWO = 2 + FOUR = 4 + self.assertTrue(isinstance(MyIntFlag.ONE | MyIntFlag.TWO, MyIntFlag), MyIntFlag.ONE | MyIntFlag.TWO) + self.assertTrue(isinstance(MyIntFlag.ONE | 2, MyIntFlag)) + + def test_int_flags_copy(self): + class MyIntFlag(IntFlag): + ONE = 1 + TWO = 2 + FOUR = 4 + + flags = MyIntFlag.ONE | MyIntFlag.TWO + copied = copy.copy(flags) + deep = copy.deepcopy(flags) + self.assertEqual(copied, flags) + self.assertEqual(deep, flags) + + flags = MyIntFlag.ONE | MyIntFlag.TWO | 8 + copied = copy.copy(flags) + deep = copy.deepcopy(flags) + self.assertEqual(copied, flags) + self.assertEqual(deep, flags) + self.assertEqual(copied.value, 1 | 2 | 8) + + def test_namedtuple_as_value(self): + from collections import namedtuple + TTuple = namedtuple('TTuple', 'id a blist') + class NTEnum(Enum): + NONE = TTuple(0, 0, []) + A = TTuple(1, 2, [4]) + B = TTuple(2, 4, [0, 1, 2]) + self.assertEqual(repr(NTEnum.NONE), "") + self.assertEqual(NTEnum.NONE.value, TTuple(id=0, a=0, blist=[])) + self.assertEqual( + [x.value for x in NTEnum], + [TTuple(id=0, a=0, blist=[]), TTuple(id=1, a=2, blist=[4]), TTuple(id=2, a=4, blist=[0, 1, 2])], + ) + + def test_flag_with_custom_new(self): + class FlagFromChar(IntFlag): + def __new__(cls, c): + value = 1 << c + self = int.__new__(cls, value) + self._value_ = value + return self + # + a = ord('a') + # + self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343) + self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672) + self.assertEqual(FlagFromChar.a, 158456325028528675187087900672) + self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673) + # + # + class FlagFromChar(Flag): + def __new__(cls, c): + value = 1 << c + self = object.__new__(cls) + self._value_ = value + return self + # + a = ord('a') + z = 1 + # + self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343) + self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900674) + self.assertEqual(FlagFromChar.a.value, 158456325028528675187087900672) + self.assertEqual((FlagFromChar.a|FlagFromChar.z).value, 158456325028528675187087900674) + # + # + class FlagFromChar(int, Flag, boundary=KEEP): + def __new__(cls, c): + value = 1 << c + self = int.__new__(cls, value) + self._value_ = value + return self + # + a = ord('a') + # + self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343) + self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672) + self.assertEqual(FlagFromChar.a, 158456325028528675187087900672) + self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673) + + def test_init_exception(self): + class Base: + def __new__(cls, *args): + return object.__new__(cls) + def __init__(self, x): + raise ValueError("I don't like", x) + with self.assertRaises(TypeError): + class MyEnum(Base, enum.Enum): + A = 'a' + def __init__(self, y): + self.y = y + with self.assertRaises(ValueError): + class MyEnum(Base, enum.Enum): + A = 'a' + def __init__(self, y): + self.y = y + def __new__(cls, value): + member = Base.__new__(cls) + member._value_ = Base(value) + return member + class TestOrder(unittest.TestCase): + "test usage of the `_order_` attribute" def test_same_members(self): class Color(Enum): @@ -2116,7 +2990,7 @@ class Color(Enum): verde = green -class TestFlag(unittest.TestCase): +class OldTestFlag(unittest.TestCase): """Tests of the Flags.""" class Perm(Flag): @@ -2132,68 +3006,12 @@ class Open(Flag): class Color(Flag): BLACK = 0 RED = 1 + ROJO = 1 GREEN = 2 BLUE = 4 PURPLE = RED|BLUE - - def test_str(self): - Perm = self.Perm - self.assertEqual(str(Perm.R), 'Perm.R') - self.assertEqual(str(Perm.W), 'Perm.W') - self.assertEqual(str(Perm.X), 'Perm.X') - self.assertEqual(str(Perm.R | Perm.W), 'Perm.R|W') - self.assertEqual(str(Perm.R | Perm.W | Perm.X), 'Perm.R|W|X') - self.assertEqual(str(Perm(0)), 'Perm.0') - self.assertEqual(str(~Perm.R), 'Perm.W|X') - self.assertEqual(str(~Perm.W), 'Perm.R|X') - self.assertEqual(str(~Perm.X), 'Perm.R|W') - self.assertEqual(str(~(Perm.R | Perm.W)), 'Perm.X') - self.assertEqual(str(~(Perm.R | Perm.W | Perm.X)), 'Perm.0') - self.assertEqual(str(Perm(~0)), 'Perm.R|W|X') - - Open = self.Open - self.assertEqual(str(Open.RO), 'Open.RO') - self.assertEqual(str(Open.WO), 'Open.WO') - self.assertEqual(str(Open.AC), 'Open.AC') - self.assertEqual(str(Open.RO | Open.CE), 'Open.CE') - self.assertEqual(str(Open.WO | Open.CE), 'Open.CE|WO') - self.assertEqual(str(~Open.RO), 'Open.CE|AC|RW|WO') - self.assertEqual(str(~Open.WO), 'Open.CE|RW') - self.assertEqual(str(~Open.AC), 'Open.CE') - self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC') - self.assertEqual(str(~(Open.WO | Open.CE)), 'Open.RW') - - def test_repr(self): - Perm = self.Perm - self.assertEqual(repr(Perm.R), '') - self.assertEqual(repr(Perm.W), '') - self.assertEqual(repr(Perm.X), '') - self.assertEqual(repr(Perm.R | Perm.W), '') - self.assertEqual(repr(Perm.R | Perm.W | Perm.X), '') - self.assertEqual(repr(Perm(0)), '') - self.assertEqual(repr(~Perm.R), '') - self.assertEqual(repr(~Perm.W), '') - self.assertEqual(repr(~Perm.X), '') - self.assertEqual(repr(~(Perm.R | Perm.W)), '') - self.assertEqual(repr(~(Perm.R | Perm.W | Perm.X)), '') - self.assertEqual(repr(Perm(~0)), '') - - Open = self.Open - self.assertEqual(repr(Open.RO), '') - self.assertEqual(repr(Open.WO), '') - self.assertEqual(repr(Open.AC), '') - self.assertEqual(repr(Open.RO | Open.CE), '') - self.assertEqual(repr(Open.WO | Open.CE), '') - self.assertEqual(repr(~Open.RO), '') - self.assertEqual(repr(~Open.WO), '') - self.assertEqual(repr(~Open.AC), '') - self.assertEqual(repr(~(Open.RO | Open.CE)), '') - self.assertEqual(repr(~(Open.WO | Open.CE)), '') - - def test_format(self): - Perm = self.Perm - self.assertEqual(format(Perm.R, ''), 'Perm.R') - self.assertEqual(format(Perm.R | Perm.X, ''), 'Perm.R|X') + WHITE = RED|GREEN|BLUE + BLANCO = RED|GREEN|BLUE def test_or(self): Perm = self.Perm @@ -2238,22 +3056,6 @@ def test_xor(self): self.assertIs(Open.RO ^ Open.CE, Open.CE) self.assertIs(Open.CE ^ Open.CE, Open.RO) - def test_invert(self): - Perm = self.Perm - RW = Perm.R | Perm.W - RX = Perm.R | Perm.X - WX = Perm.W | Perm.X - RWX = Perm.R | Perm.W | Perm.X - values = list(Perm) + [RW, RX, WX, RWX, Perm(0)] - for i in values: - self.assertIs(type(~i), Perm) - self.assertEqual(~~i, i) - for i in Perm: - self.assertIs(~~i, i) - Open = self.Open - self.assertIs(Open.WO & ~Open.WO, Open.RO) - self.assertIs((Open.WO|Open.CE) & ~Open.WO, Open.CE) - def test_bool(self): Perm = self.Perm for f in Perm: @@ -2262,6 +3064,74 @@ def test_bool(self): for f in Open: self.assertEqual(bool(f.value), bool(f)) + def test_boundary(self): + self.assertIs(enum.Flag._boundary_, STRICT) + class Iron(Flag, boundary=CONFORM): + ONE = 1 + TWO = 2 + EIGHT = 8 + self.assertIs(Iron._boundary_, CONFORM) + # + class Water(Flag, boundary=STRICT): + ONE = 1 + TWO = 2 + EIGHT = 8 + self.assertIs(Water._boundary_, STRICT) + # + class Space(Flag, boundary=EJECT): + ONE = 1 + TWO = 2 + EIGHT = 8 + self.assertIs(Space._boundary_, EJECT) + # + class Bizarre(Flag, boundary=KEEP): + b = 3 + c = 4 + d = 6 + # + self.assertRaisesRegex(ValueError, 'invalid value 7', Water, 7) + # + self.assertIs(Iron(7), Iron.ONE|Iron.TWO) + self.assertIs(Iron(~9), Iron.TWO) + # + self.assertEqual(Space(7), 7) + self.assertTrue(type(Space(7)) is int) + # + self.assertEqual(list(Bizarre), [Bizarre.c]) + self.assertIs(Bizarre(3), Bizarre.b) + self.assertIs(Bizarre(6), Bizarre.d) + # + class SkipFlag(enum.Flag): + A = 1 + B = 2 + C = 4 | B + # + self.assertTrue(SkipFlag.C in (SkipFlag.A|SkipFlag.C)) + self.assertRaisesRegex(ValueError, 'SkipFlag.. invalid value 42', SkipFlag, 42) + # + class SkipIntFlag(enum.IntFlag): + A = 1 + B = 2 + C = 4 | B + # + self.assertTrue(SkipIntFlag.C in (SkipIntFlag.A|SkipIntFlag.C)) + self.assertEqual(SkipIntFlag(42).value, 42) + # + class MethodHint(Flag): + HiddenText = 0x10 + DigitsOnly = 0x01 + LettersOnly = 0x02 + OnlyMask = 0x0f + # + self.assertEqual(str(MethodHint.HiddenText|MethodHint.OnlyMask), 'MethodHint.HiddenText|DigitsOnly|LettersOnly|OnlyMask') + + + def test_iter(self): + Color = self.Color + Open = self.Open + self.assertEqual(list(Color), [Color.RED, Color.GREEN, Color.BLUE]) + self.assertEqual(list(Open), [Open.WO, Open.RW, Open.CE]) + def test_programatic_function_string(self): Perm = Flag('Perm', 'R W X') lst = list(Perm) @@ -2340,22 +3210,81 @@ def test_programatic_function_from_dict(self): def test_pickle(self): if isinstance(FlagStooges, Exception): raise FlagStooges - test_pickle_dump_load(self.assertIs, FlagStooges.CURLY|FlagStooges.MOE) + test_pickle_dump_load(self.assertIs, FlagStooges.CURLY) + test_pickle_dump_load(self.assertEqual, + FlagStooges.CURLY|FlagStooges.MOE) + test_pickle_dump_load(self.assertEqual, + FlagStooges.CURLY&~FlagStooges.CURLY) test_pickle_dump_load(self.assertIs, FlagStooges) - - def test_contains(self): + test_pickle_dump_load(self.assertEqual, FlagStooges.BIG) + test_pickle_dump_load(self.assertEqual, + FlagStooges.CURLY|FlagStooges.BIG) + + test_pickle_dump_load(self.assertIs, FlagStoogesWithZero.CURLY) + test_pickle_dump_load(self.assertEqual, + FlagStoogesWithZero.CURLY|FlagStoogesWithZero.MOE) + test_pickle_dump_load(self.assertIs, FlagStoogesWithZero.NOFLAG) + test_pickle_dump_load(self.assertEqual, FlagStoogesWithZero.BIG) + test_pickle_dump_load(self.assertEqual, + FlagStoogesWithZero.CURLY|FlagStoogesWithZero.BIG) + + test_pickle_dump_load(self.assertIs, IntFlagStooges.CURLY) + test_pickle_dump_load(self.assertEqual, + IntFlagStooges.CURLY|IntFlagStooges.MOE) + test_pickle_dump_load(self.assertEqual, + IntFlagStooges.CURLY|IntFlagStooges.MOE|0x30) + test_pickle_dump_load(self.assertEqual, IntFlagStooges(0)) + test_pickle_dump_load(self.assertEqual, IntFlagStooges(0x30)) + test_pickle_dump_load(self.assertIs, IntFlagStooges) + test_pickle_dump_load(self.assertEqual, IntFlagStooges.BIG) + test_pickle_dump_load(self.assertEqual, IntFlagStooges.BIG|1) + test_pickle_dump_load(self.assertEqual, + IntFlagStooges.CURLY|IntFlagStooges.BIG) + + test_pickle_dump_load(self.assertIs, IntFlagStoogesWithZero.CURLY) + test_pickle_dump_load(self.assertEqual, + IntFlagStoogesWithZero.CURLY|IntFlagStoogesWithZero.MOE) + test_pickle_dump_load(self.assertIs, IntFlagStoogesWithZero.NOFLAG) + test_pickle_dump_load(self.assertEqual, IntFlagStoogesWithZero.BIG) + test_pickle_dump_load(self.assertEqual, IntFlagStoogesWithZero.BIG|1) + test_pickle_dump_load(self.assertEqual, + IntFlagStoogesWithZero.CURLY|IntFlagStoogesWithZero.BIG) + + @unittest.skipIf( + python_version >= (3, 12), + '__contains__ now returns True/False for all inputs', + ) + def test_contains_er(self): Open = self.Open Color = self.Color self.assertFalse(Color.BLACK in Open) self.assertFalse(Open.RO in Color) with self.assertRaises(TypeError): - 'BLACK' in Color + with self.assertWarns(DeprecationWarning): + 'BLACK' in Color with self.assertRaises(TypeError): - 'RO' in Open + with self.assertWarns(DeprecationWarning): + 'RO' in Open with self.assertRaises(TypeError): - 1 in Color + with self.assertWarns(DeprecationWarning): + 1 in Color with self.assertRaises(TypeError): - 1 in Open + with self.assertWarns(DeprecationWarning): + 1 in Open + + @unittest.skipIf( + python_version < (3, 12), + '__contains__ only works with enum memmbers before 3.12', + ) + def test_contains_tf(self): + Open = self.Open + Color = self.Color + self.assertFalse(Color.BLACK in Open) + self.assertFalse(Open.RO in Color) + self.assertFalse('BLACK' in Color) + self.assertFalse('RO' in Open) + self.assertTrue(1 in Color) + self.assertTrue(1 in Open) def test_member_contains(self): Perm = self.Perm @@ -2377,6 +3306,48 @@ def test_member_contains(self): self.assertFalse(W in RX) self.assertFalse(X in RW) + def test_member_iter(self): + Color = self.Color + self.assertEqual(list(Color.BLACK), []) + self.assertEqual(list(Color.PURPLE), [Color.RED, Color.BLUE]) + self.assertEqual(list(Color.BLUE), [Color.BLUE]) + self.assertEqual(list(Color.GREEN), [Color.GREEN]) + self.assertEqual(list(Color.WHITE), [Color.RED, Color.GREEN, Color.BLUE]) + self.assertEqual(list(Color.WHITE), [Color.RED, Color.GREEN, Color.BLUE]) + + def test_member_length(self): + self.assertEqual(self.Color.__len__(self.Color.BLACK), 0) + self.assertEqual(self.Color.__len__(self.Color.GREEN), 1) + self.assertEqual(self.Color.__len__(self.Color.PURPLE), 2) + self.assertEqual(self.Color.__len__(self.Color.BLANCO), 3) + + def test_number_reset_and_order_cleanup(self): + class Confused(Flag): + _order_ = 'ONE TWO FOUR DOS EIGHT SIXTEEN' + ONE = auto() + TWO = auto() + FOUR = auto() + DOS = 2 + EIGHT = auto() + SIXTEEN = auto() + self.assertEqual( + list(Confused), + [Confused.ONE, Confused.TWO, Confused.FOUR, Confused.EIGHT, Confused.SIXTEEN]) + self.assertIs(Confused.TWO, Confused.DOS) + self.assertEqual(Confused.DOS._value_, 2) + self.assertEqual(Confused.EIGHT._value_, 8) + self.assertEqual(Confused.SIXTEEN._value_, 16) + + def test_aliases(self): + Color = self.Color + self.assertEqual(Color(1).name, 'RED') + self.assertEqual(Color['ROJO'].name, 'RED') + self.assertEqual(Color(7).name, 'WHITE') + self.assertEqual(Color['BLANCO'].name, 'WHITE') + self.assertIs(Color.BLANCO, Color.WHITE) + Open = self.Open + self.assertIs(Open['AC'], Open.AC) + def test_auto_number(self): class Color(Flag): red = auto() @@ -2389,24 +3360,11 @@ class Color(Flag): self.assertEqual(Color.green.value, 4) def test_auto_number_garbage(self): - with self.assertRaisesRegex(TypeError, 'Invalid Flag value: .not an int.'): + with self.assertRaisesRegex(TypeError, 'invalid flag value .not an int.'): class Color(Flag): red = 'not an int' blue = auto() - def test_cascading_failure(self): - class Bizarre(Flag): - c = 3 - d = 4 - f = 6 - # Bizarre.c | Bizarre.d - self.assertRaisesRegex(ValueError, "5 is not a valid Bizarre", Bizarre, 5) - self.assertRaisesRegex(ValueError, "5 is not a valid Bizarre", Bizarre, 5) - self.assertRaisesRegex(ValueError, "2 is not a valid Bizarre", Bizarre, 2) - self.assertRaisesRegex(ValueError, "2 is not a valid Bizarre", Bizarre, 2) - self.assertRaisesRegex(ValueError, "1 is not a valid Bizarre", Bizarre, 1) - self.assertRaisesRegex(ValueError, "1 is not a valid Bizarre", Bizarre, 1) - def test_duplicate_auto(self): class Dupes(Enum): first = primero = auto() @@ -2414,13 +3372,6 @@ class Dupes(Enum): third = auto() self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes)) - def test_bizarre(self): - class Bizarre(Flag): - b = 3 - c = 4 - d = 6 - self.assertEqual(repr(Bizarre(7)), '') - def test_multiple_mixin(self): class AllMixin: @classproperty @@ -2449,6 +3400,7 @@ class Color(AllMixin, StrMixin, Flag): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ self.assertEqual(Color.RED.value, 1) self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 4) @@ -2458,14 +3410,15 @@ class Color(StrMixin, AllMixin, Flag): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ self.assertEqual(Color.RED.value, 1) self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 4) self.assertEqual(Color.ALL.value, 7) self.assertEqual(str(Color.BLUE), 'blue') - @unittest.skip("TODO: RUSTPYTHON, inconsistent test result on Windows due to threading") @threading_helper.reap_threads + @threading_helper.requires_working_threading() def test_unique_composite(self): # override __eq__ to be identity only class TestFlag(Flag): @@ -2495,22 +3448,59 @@ def cycle_enum(): threading.Thread(target=cycle_enum) for _ in range(8) ] - with threading_helper.start_threads(threads): - pass + with threading_helper.wait_threads_exit(): + with threading_helper.start_threads(threads): + pass # check that only 248 members were created self.assertFalse( failed, 'at least one thread failed while creating composite members') self.assertEqual(256, len(seen), 'too many composite members created') + def test_init_subclass(self): + class MyEnum(Flag): + def __init_subclass__(cls, **kwds): + super().__init_subclass__(**kwds) + self.assertFalse(cls.__dict__.get('_test', False)) + cls._test1 = 'MyEnum' + # + class TheirEnum(MyEnum): + def __init_subclass__(cls, **kwds): + super(TheirEnum, cls).__init_subclass__(**kwds) + cls._test2 = 'TheirEnum' + class WhoseEnum(TheirEnum): + def __init_subclass__(cls, **kwds): + pass + class NoEnum(WhoseEnum): + ONE = 1 + self.assertEqual(TheirEnum.__dict__['_test1'], 'MyEnum') + self.assertEqual(WhoseEnum.__dict__['_test1'], 'MyEnum') + self.assertEqual(WhoseEnum.__dict__['_test2'], 'TheirEnum') + self.assertFalse(NoEnum.__dict__.get('_test1', False)) + self.assertFalse(NoEnum.__dict__.get('_test2', False)) + # + class OurEnum(MyEnum): + def __init_subclass__(cls, **kwds): + cls._test2 = 'OurEnum' + class WhereEnum(OurEnum): + def __init_subclass__(cls, **kwds): + pass + class NeverEnum(WhereEnum): + ONE = 1 + self.assertEqual(OurEnum.__dict__['_test1'], 'MyEnum') + self.assertFalse(WhereEnum.__dict__.get('_test1', False)) + self.assertEqual(WhereEnum.__dict__['_test2'], 'OurEnum') + self.assertFalse(NeverEnum.__dict__.get('_test1', False)) + self.assertFalse(NeverEnum.__dict__.get('_test2', False)) + -class TestIntFlag(unittest.TestCase): +class OldTestIntFlag(unittest.TestCase): """Tests of the IntFlags.""" class Perm(IntFlag): - X = 1 << 0 - W = 1 << 1 R = 1 << 2 + W = 1 << 1 + X = 1 << 0 class Open(IntFlag): RO = 0 @@ -2522,9 +3512,17 @@ class Open(IntFlag): class Color(IntFlag): BLACK = 0 RED = 1 + ROJO = 1 GREEN = 2 BLUE = 4 PURPLE = RED|BLUE + WHITE = RED|GREEN|BLUE + BLANCO = RED|GREEN|BLUE + + class Skip(IntFlag): + FIRST = 1 + SECOND = 2 + EIGHTH = 8 def test_type(self): Perm = self.Perm @@ -2541,77 +3539,54 @@ def test_type(self): self.assertTrue(isinstance(Open.WO | Open.RW, Open)) self.assertEqual(Open.WO | Open.RW, 3) + def test_global_repr_keep(self): + self.assertEqual( + repr(HeadlightsK(0)), + '%s.OFF_K' % SHORT_MODULE, + ) + self.assertEqual( + repr(HeadlightsK(2**0 + 2**2 + 2**3)), + '%(m)s.LOW_BEAM_K|%(m)s.FOG_K|8' % {'m': SHORT_MODULE}, + ) + self.assertEqual( + repr(HeadlightsK(2**3)), + '%(m)s.HeadlightsK(8)' % {'m': SHORT_MODULE}, + ) - def test_str(self): - Perm = self.Perm - self.assertEqual(str(Perm.R), 'Perm.R') - self.assertEqual(str(Perm.W), 'Perm.W') - self.assertEqual(str(Perm.X), 'Perm.X') - self.assertEqual(str(Perm.R | Perm.W), 'Perm.R|W') - self.assertEqual(str(Perm.R | Perm.W | Perm.X), 'Perm.R|W|X') - self.assertEqual(str(Perm.R | 8), 'Perm.8|R') - self.assertEqual(str(Perm(0)), 'Perm.0') - self.assertEqual(str(Perm(8)), 'Perm.8') - self.assertEqual(str(~Perm.R), 'Perm.W|X') - self.assertEqual(str(~Perm.W), 'Perm.R|X') - self.assertEqual(str(~Perm.X), 'Perm.R|W') - self.assertEqual(str(~(Perm.R | Perm.W)), 'Perm.X') - self.assertEqual(str(~(Perm.R | Perm.W | Perm.X)), 'Perm.-8') - self.assertEqual(str(~(Perm.R | 8)), 'Perm.W|X') - self.assertEqual(str(Perm(~0)), 'Perm.R|W|X') - self.assertEqual(str(Perm(~8)), 'Perm.R|W|X') - - Open = self.Open - self.assertEqual(str(Open.RO), 'Open.RO') - self.assertEqual(str(Open.WO), 'Open.WO') - self.assertEqual(str(Open.AC), 'Open.AC') - self.assertEqual(str(Open.RO | Open.CE), 'Open.CE') - self.assertEqual(str(Open.WO | Open.CE), 'Open.CE|WO') - self.assertEqual(str(Open(4)), 'Open.4') - self.assertEqual(str(~Open.RO), 'Open.CE|AC|RW|WO') - self.assertEqual(str(~Open.WO), 'Open.CE|RW') - self.assertEqual(str(~Open.AC), 'Open.CE') - self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC|RW|WO') - self.assertEqual(str(~(Open.WO | Open.CE)), 'Open.RW') - self.assertEqual(str(Open(~4)), 'Open.CE|AC|RW|WO') + def test_global_repr_conform1(self): + self.assertEqual( + repr(HeadlightsC(0)), + '%s.OFF_C' % SHORT_MODULE, + ) + self.assertEqual( + repr(HeadlightsC(2**0 + 2**2 + 2**3)), + '%(m)s.LOW_BEAM_C|%(m)s.FOG_C' % {'m': SHORT_MODULE}, + ) + self.assertEqual( + repr(HeadlightsC(2**3)), + '%(m)s.OFF_C' % {'m': SHORT_MODULE}, + ) - def test_repr(self): - Perm = self.Perm - self.assertEqual(repr(Perm.R), '') - self.assertEqual(repr(Perm.W), '') - self.assertEqual(repr(Perm.X), '') - self.assertEqual(repr(Perm.R | Perm.W), '') - self.assertEqual(repr(Perm.R | Perm.W | Perm.X), '') - self.assertEqual(repr(Perm.R | 8), '') - self.assertEqual(repr(Perm(0)), '') - self.assertEqual(repr(Perm(8)), '') - self.assertEqual(repr(~Perm.R), '') - self.assertEqual(repr(~Perm.W), '') - self.assertEqual(repr(~Perm.X), '') - self.assertEqual(repr(~(Perm.R | Perm.W)), '') - self.assertEqual(repr(~(Perm.R | Perm.W | Perm.X)), '') - self.assertEqual(repr(~(Perm.R | 8)), '') - self.assertEqual(repr(Perm(~0)), '') - self.assertEqual(repr(Perm(~8)), '') + def test_global_enum_str(self): + self.assertEqual(str(NoName.ONE & NoName.TWO), 'NoName(0)') + self.assertEqual(str(NoName(0)), 'NoName(0)') - Open = self.Open - self.assertEqual(repr(Open.RO), '') - self.assertEqual(repr(Open.WO), '') - self.assertEqual(repr(Open.AC), '') - self.assertEqual(repr(Open.RO | Open.CE), '') - self.assertEqual(repr(Open.WO | Open.CE), '') - self.assertEqual(repr(Open(4)), '') - self.assertEqual(repr(~Open.RO), '') - self.assertEqual(repr(~Open.WO), '') - self.assertEqual(repr(~Open.AC), '') - self.assertEqual(repr(~(Open.RO | Open.CE)), '') - self.assertEqual(repr(~(Open.WO | Open.CE)), '') - self.assertEqual(repr(Open(~4)), '') + # TODO: RUSTPYTHON, format(NewPerm.R) does not use __str__ + @unittest.expectedFailure def test_format(self): Perm = self.Perm self.assertEqual(format(Perm.R, ''), '4') self.assertEqual(format(Perm.R | Perm.X, ''), '5') + # + class NewPerm(IntFlag): + R = 1 << 2 + W = 1 << 1 + X = 1 << 0 + def __str__(self): + return self._name_ + self.assertEqual(format(NewPerm.R, ''), 'R') + self.assertEqual(format(NewPerm.R | Perm.X, ''), 'R|X') def test_or(self): Perm = self.Perm @@ -2689,8 +3664,7 @@ def test_invert(self): RWX = Perm.R | Perm.W | Perm.X values = list(Perm) + [RW, RX, WX, RWX, Perm(0)] for i in values: - self.assertEqual(~i, ~i.value) - self.assertEqual((~i).value, ~i.value) + self.assertEqual(~i, (~i).value) self.assertIs(type(~i), Perm) self.assertEqual(~~i, i) for i in Perm: @@ -2699,6 +3673,58 @@ def test_invert(self): self.assertIs(Open.WO & ~Open.WO, Open.RO) self.assertIs((Open.WO|Open.CE) & ~Open.WO, Open.CE) + def test_boundary(self): + self.assertIs(enum.IntFlag._boundary_, KEEP) + class Simple(IntFlag, boundary=KEEP): + SINGLE = 1 + # + class Iron(IntFlag, boundary=STRICT): + ONE = 1 + TWO = 2 + EIGHT = 8 + self.assertIs(Iron._boundary_, STRICT) + # + class Water(IntFlag, boundary=CONFORM): + ONE = 1 + TWO = 2 + EIGHT = 8 + self.assertIs(Water._boundary_, CONFORM) + # + class Space(IntFlag, boundary=EJECT): + ONE = 1 + TWO = 2 + EIGHT = 8 + self.assertIs(Space._boundary_, EJECT) + # + class Bizarre(IntFlag, boundary=KEEP): + b = 3 + c = 4 + d = 6 + # + self.assertRaisesRegex(ValueError, 'invalid value 5', Iron, 5) + # + self.assertIs(Water(7), Water.ONE|Water.TWO) + self.assertIs(Water(~9), Water.TWO) + # + self.assertEqual(Space(7), 7) + self.assertTrue(type(Space(7)) is int) + # + self.assertEqual(list(Bizarre), [Bizarre.c]) + self.assertIs(Bizarre(3), Bizarre.b) + self.assertIs(Bizarre(6), Bizarre.d) + # + simple = Simple.SINGLE | Iron.TWO + self.assertEqual(simple, 3) + self.assertIsInstance(simple, Simple) + self.assertEqual(repr(simple), ': 3>') + self.assertEqual(str(simple), '3') + + def test_iter(self): + Color = self.Color + Open = self.Open + self.assertEqual(list(Color), [Color.RED, Color.GREEN, Color.BLUE]) + self.assertEqual(list(Open), [Open.WO, Open.RW, Open.CE]) + def test_programatic_function_string(self): Perm = IntFlag('Perm', 'R W X') lst = list(Perm) @@ -2800,7 +3826,11 @@ def test_programatic_function_from_empty_tuple(self): self.assertEqual(len(lst), len(Thing)) self.assertEqual(len(Thing), 0, Thing) - def test_contains(self): + @unittest.skipIf( + python_version >= (3, 12), + '__contains__ now returns True/False for all inputs', + ) + def test_contains_er(self): Open = self.Open Color = self.Color self.assertTrue(Color.GREEN in Color) @@ -2808,13 +3838,33 @@ def test_contains(self): self.assertFalse(Color.GREEN in Open) self.assertFalse(Open.RW in Color) with self.assertRaises(TypeError): - 'GREEN' in Color + with self.assertWarns(DeprecationWarning): + 'GREEN' in Color with self.assertRaises(TypeError): - 'RW' in Open + with self.assertWarns(DeprecationWarning): + 'RW' in Open with self.assertRaises(TypeError): - 2 in Color + with self.assertWarns(DeprecationWarning): + 2 in Color with self.assertRaises(TypeError): - 2 in Open + with self.assertWarns(DeprecationWarning): + 2 in Open + + @unittest.skipIf( + python_version < (3, 12), + '__contains__ only works with enum memmbers before 3.12', + ) + def test_contains_tf(self): + Open = self.Open + Color = self.Color + self.assertTrue(Color.GREEN in Color) + self.assertTrue(Open.RW in Open) + self.assertTrue(Color.GREEN in Open) + self.assertTrue(Open.RW in Color) + self.assertFalse('GREEN' in Color) + self.assertFalse('RW' in Open) + self.assertTrue(2 in Color) + self.assertTrue(2 in Open) def test_member_contains(self): Perm = self.Perm @@ -2838,6 +3888,30 @@ def test_member_contains(self): with self.assertRaises(TypeError): self.assertFalse('test' in RW) + def test_member_iter(self): + Color = self.Color + self.assertEqual(list(Color.BLACK), []) + self.assertEqual(list(Color.PURPLE), [Color.RED, Color.BLUE]) + self.assertEqual(list(Color.BLUE), [Color.BLUE]) + self.assertEqual(list(Color.GREEN), [Color.GREEN]) + self.assertEqual(list(Color.WHITE), [Color.RED, Color.GREEN, Color.BLUE]) + + def test_member_length(self): + self.assertEqual(self.Color.__len__(self.Color.BLACK), 0) + self.assertEqual(self.Color.__len__(self.Color.GREEN), 1) + self.assertEqual(self.Color.__len__(self.Color.PURPLE), 2) + self.assertEqual(self.Color.__len__(self.Color.BLANCO), 3) + + def test_aliases(self): + Color = self.Color + self.assertEqual(Color(1).name, 'RED') + self.assertEqual(Color['ROJO'].name, 'RED') + self.assertEqual(Color(7).name, 'WHITE') + self.assertEqual(Color['BLANCO'].name, 'WHITE') + self.assertIs(Color.BLANCO, Color.WHITE) + Open = self.Open + self.assertIs(Open['AC'], Open.AC) + def test_bool(self): Perm = self.Perm for f in Perm: @@ -2846,6 +3920,7 @@ def test_bool(self): for f in Open: self.assertEqual(bool(f.value), bool(f)) + def test_multiple_mixin(self): class AllMixin: @classproperty @@ -2869,11 +3944,12 @@ class Color(AllMixin, IntFlag): self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 4) self.assertEqual(Color.ALL.value, 7) - self.assertEqual(str(Color.BLUE), 'Color.BLUE') + self.assertEqual(str(Color.BLUE), '4') class Color(AllMixin, StrMixin, IntFlag): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ self.assertEqual(Color.RED.value, 1) self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 4) @@ -2883,14 +3959,15 @@ class Color(StrMixin, AllMixin, IntFlag): RED = auto() GREEN = auto() BLUE = auto() + __str__ = StrMixin.__str__ self.assertEqual(Color.RED.value, 1) self.assertEqual(Color.GREEN.value, 2) self.assertEqual(Color.BLUE.value, 4) self.assertEqual(Color.ALL.value, 7) self.assertEqual(str(Color.BLUE), 'blue') - @unittest.skip("TODO: RUSTPYTHON, inconsistent test result due to threading") @threading_helper.reap_threads + @threading_helper.requires_working_threading() def test_unique_composite(self): # override __eq__ to be identity only class TestFlag(IntFlag): @@ -2920,8 +3997,9 @@ def cycle_enum(): threading.Thread(target=cycle_enum) for _ in range(8) ] - with threading_helper.start_threads(threads): - pass + with threading_helper.wait_threads_exit(): + with threading_helper.start_threads(threads): + pass # check that only 248 members were created self.assertFalse( failed, @@ -2954,6 +4032,7 @@ class Clean(Enum): one = 1 two = 'dos' tres = 4.0 + # @unique class Cleaner(IntEnum): single = 1 @@ -2979,27 +4058,400 @@ class Dirtier(IntEnum): turkey = 3 def test_unique_with_name(self): - @unique + @verify(UNIQUE) class Silly(Enum): one = 1 two = 'dos' name = 3 - @unique + # + @verify(UNIQUE) + class Sillier(IntEnum): + single = 1 + name = 2 + triple = 3 + value = 4 + +class TestVerify(unittest.TestCase): + + def test_continuous(self): + @verify(CONTINUOUS) + class Auto(Enum): + FIRST = auto() + SECOND = auto() + THIRD = auto() + FORTH = auto() + # + @verify(CONTINUOUS) + class Manual(Enum): + FIRST = 3 + SECOND = 4 + THIRD = 5 + FORTH = 6 + # + with self.assertRaisesRegex(ValueError, 'invalid enum .Missing.: missing values 5, 6, 7, 8, 9, 10, 12'): + @verify(CONTINUOUS) + class Missing(Enum): + FIRST = 3 + SECOND = 4 + THIRD = 11 + FORTH = 13 + # + with self.assertRaisesRegex(ValueError, 'invalid flag .Incomplete.: missing values 32'): + @verify(CONTINUOUS) + class Incomplete(Flag): + FIRST = 4 + SECOND = 8 + THIRD = 16 + FORTH = 64 + # + with self.assertRaisesRegex(ValueError, 'invalid flag .StillIncomplete.: missing values 16'): + @verify(CONTINUOUS) + class StillIncomplete(Flag): + FIRST = 4 + SECOND = 8 + THIRD = 11 + FORTH = 32 + + + def test_composite(self): + class Bizarre(Flag): + b = 3 + c = 4 + d = 6 + self.assertEqual(list(Bizarre), [Bizarre.c]) + self.assertEqual(Bizarre.b.value, 3) + self.assertEqual(Bizarre.c.value, 4) + self.assertEqual(Bizarre.d.value, 6) + with self.assertRaisesRegex( + ValueError, + "invalid Flag 'Bizarre': aliases b and d are missing combined values of 0x3 .use enum.show_flag_values.value. for details.", + ): + @verify(NAMED_FLAGS) + class Bizarre(Flag): + b = 3 + c = 4 + d = 6 + # + self.assertEqual(enum.show_flag_values(3), [1, 2]) + class Bizarre(IntFlag): + b = 3 + c = 4 + d = 6 + self.assertEqual(list(Bizarre), [Bizarre.c]) + self.assertEqual(Bizarre.b.value, 3) + self.assertEqual(Bizarre.c.value, 4) + self.assertEqual(Bizarre.d.value, 6) + with self.assertRaisesRegex( + ValueError, + "invalid Flag 'Bizarre': alias d is missing value 0x2 .use enum.show_flag_values.value. for details.", + ): + @verify(NAMED_FLAGS) + class Bizarre(IntFlag): + c = 4 + d = 6 + self.assertEqual(enum.show_flag_values(2), [2]) + + def test_unique_clean(self): + @verify(UNIQUE) + class Clean(Enum): + one = 1 + two = 'dos' + tres = 4.0 + # + @verify(UNIQUE) + class Cleaner(IntEnum): + single = 1 + double = 2 + triple = 3 + + def test_unique_dirty(self): + with self.assertRaisesRegex(ValueError, 'tres.*one'): + @verify(UNIQUE) + class Dirty(Enum): + one = 1 + two = 'dos' + tres = 1 + with self.assertRaisesRegex( + ValueError, + 'double.*single.*turkey.*triple', + ): + @verify(UNIQUE) + class Dirtier(IntEnum): + single = 1 + double = 1 + triple = 3 + turkey = 3 + + def test_unique_with_name(self): + @verify(UNIQUE) + class Silly(Enum): + one = 1 + two = 'dos' + name = 3 + # + @verify(UNIQUE) class Sillier(IntEnum): single = 1 name = 2 triple = 3 value = 4 + def test_negative_alias(self): + @verify(NAMED_FLAGS) + class Color(Flag): + RED = 1 + GREEN = 2 + BLUE = 4 + WHITE = -1 + # no error means success + + +class TestInternals(unittest.TestCase): + + sunder_names = '_bad_', '_good_', '_what_ho_' + dunder_names = '__mal__', '__bien__', '__que_que__' + private_names = '_MyEnum__private', '_MyEnum__still_private' + private_and_sunder_names = '_MyEnum__private_', '_MyEnum__also_private_' + random_names = 'okay', '_semi_private', '_weird__', '_MyEnum__' + + def test_sunder(self): + for name in self.sunder_names + self.private_and_sunder_names: + self.assertTrue(enum._is_sunder(name), '%r is a not sunder name?' % name) + for name in self.dunder_names + self.private_names + self.random_names: + self.assertFalse(enum._is_sunder(name), '%r is a sunder name?' % name) + + def test_dunder(self): + for name in self.dunder_names: + self.assertTrue(enum._is_dunder(name), '%r is a not dunder name?' % name) + for name in self.sunder_names + self.private_names + self.private_and_sunder_names + self.random_names: + self.assertFalse(enum._is_dunder(name), '%r is a dunder name?' % name) + + def test_is_private(self): + for name in self.private_names + self.private_and_sunder_names: + self.assertTrue(enum._is_private('MyEnum', name), '%r is a not private name?') + for name in self.sunder_names + self.dunder_names + self.random_names: + self.assertFalse(enum._is_private('MyEnum', name), '%r is a private name?') + + def test_auto_number(self): + class Color(Enum): + red = auto() + blue = auto() + green = auto() + + self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) + self.assertEqual(Color.red.value, 1) + self.assertEqual(Color.blue.value, 2) + self.assertEqual(Color.green.value, 3) + + def test_auto_name(self): + class Color(Enum): + def _generate_next_value_(name, start, count, last): + return name + red = auto() + blue = auto() + green = auto() + + self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) + self.assertEqual(Color.red.value, 'red') + self.assertEqual(Color.blue.value, 'blue') + self.assertEqual(Color.green.value, 'green') + + def test_auto_name_inherit(self): + class AutoNameEnum(Enum): + def _generate_next_value_(name, start, count, last): + return name + class Color(AutoNameEnum): + red = auto() + blue = auto() + green = auto() + + self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) + self.assertEqual(Color.red.value, 'red') + self.assertEqual(Color.blue.value, 'blue') + self.assertEqual(Color.green.value, 'green') + + @unittest.skipIf( + python_version >= (3, 13), + 'mixed types with auto() no longer supported', + ) + def test_auto_garbage_ok(self): + with self.assertWarnsRegex(DeprecationWarning, 'will require all values to be sortable'): + class Color(Enum): + red = 'red' + blue = auto() + self.assertEqual(Color.blue.value, 1) + + @unittest.skipIf( + python_version >= (3, 13), + 'mixed types with auto() no longer supported', + ) + def test_auto_garbage_corrected_ok(self): + with self.assertWarnsRegex(DeprecationWarning, 'will require all values to be sortable'): + class Color(Enum): + red = 'red' + blue = 2 + green = auto() + yellow = auto() + + self.assertEqual(list(Color), + [Color.red, Color.blue, Color.green, Color.yellow]) + self.assertEqual(Color.red.value, 'red') + self.assertEqual(Color.blue.value, 2) + self.assertEqual(Color.green.value, 3) + self.assertEqual(Color.yellow.value, 4) + + @unittest.skipIf( + python_version < (3, 13), + 'mixed types with auto() will raise in 3.13', + ) + def test_auto_garbage_fail(self): + with self.assertRaisesRegex(TypeError, 'will require all values to be sortable'): + class Color(Enum): + red = 'red' + blue = auto() + + @unittest.skipIf( + python_version < (3, 13), + 'mixed types with auto() will raise in 3.13', + ) + def test_auto_garbage_corrected_fail(self): + with self.assertRaisesRegex(TypeError, 'will require all values to be sortable'): + class Color(Enum): + red = 'red' + blue = 2 + green = auto() + + def test_auto_order(self): + with self.assertRaises(TypeError): + class Color(Enum): + red = auto() + green = auto() + blue = auto() + def _generate_next_value_(name, start, count, last): + return name + + def test_auto_order_wierd(self): + weird_auto = auto() + weird_auto.value = 'pathological case' + class Color(Enum): + red = weird_auto + def _generate_next_value_(name, start, count, last): + return name + blue = auto() + self.assertEqual(list(Color), [Color.red, Color.blue]) + self.assertEqual(Color.red.value, 'pathological case') + self.assertEqual(Color.blue.value, 'blue') + + @unittest.skipIf( + python_version < (3, 13), + 'auto() will return highest value + 1 in 3.13', + ) + def test_auto_with_aliases(self): + class Color(Enum): + red = auto() + blue = auto() + oxford = blue + crimson = red + green = auto() + self.assertIs(Color.crimson, Color.red) + self.assertIs(Color.oxford, Color.blue) + self.assertIsNot(Color.green, Color.red) + self.assertIsNot(Color.green, Color.blue) + + def test_duplicate_auto(self): + class Dupes(Enum): + first = primero = auto() + second = auto() + third = auto() + self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes)) + def test_multiple_auto_on_line(self): + class Huh(Enum): + ONE = auto() + TWO = auto(), auto() + THREE = auto(), auto(), auto() + self.assertEqual(Huh.ONE.value, 1) + self.assertEqual(Huh.TWO.value, (2, 3)) + self.assertEqual(Huh.THREE.value, (4, 5, 6)) + # + class Hah(Enum): + def __new__(cls, value, abbr=None): + member = object.__new__(cls) + member._value_ = value + member.abbr = abbr or value[:3].lower() + return member + def _generate_next_value_(name, start, count, last): + return name + # + MONDAY = auto() + TUESDAY = auto() + WEDNESDAY = auto(), 'WED' + THURSDAY = auto(), 'Thu' + FRIDAY = auto() + self.assertEqual(Hah.MONDAY.value, 'MONDAY') + self.assertEqual(Hah.MONDAY.abbr, 'mon') + self.assertEqual(Hah.TUESDAY.value, 'TUESDAY') + self.assertEqual(Hah.TUESDAY.abbr, 'tue') + self.assertEqual(Hah.WEDNESDAY.value, 'WEDNESDAY') + self.assertEqual(Hah.WEDNESDAY.abbr, 'WED') + self.assertEqual(Hah.THURSDAY.value, 'THURSDAY') + self.assertEqual(Hah.THURSDAY.abbr, 'Thu') + self.assertEqual(Hah.FRIDAY.value, 'FRIDAY') + self.assertEqual(Hah.FRIDAY.abbr, 'fri') + # + class Huh(Enum): + def _generate_next_value_(name, start, count, last): + return count+1 + ONE = auto() + TWO = auto(), auto() + THREE = auto(), auto(), auto() + self.assertEqual(Huh.ONE.value, 1) + self.assertEqual(Huh.TWO.value, (2, 2)) + self.assertEqual(Huh.THREE.value, (3, 3, 3)) + +class TestEnumTypeSubclassing(unittest.TestCase): + pass expected_help_output_with_docs = """\ Help on class Color in module %s: class Color(enum.Enum) - | Color(value, names=None, *, module=None, qualname=None, type=None, start=1) + | Create a collection of name/value pairs. + |\x20\x20 + | Example enumeration: + |\x20\x20 + | >>> class Color(Enum): + | ... RED = 1 + | ... BLUE = 2 + | ... GREEN = 3 + |\x20\x20 + | Access them by: |\x20\x20 - | An enumeration. + | - attribute access:: + |\x20\x20 + | >>> Color.RED + | + |\x20\x20 + | - value lookup: + |\x20\x20 + | >>> Color(1) + | + |\x20\x20 + | - name lookup: + |\x20\x20 + | >>> Color['RED'] + | + |\x20\x20 + | Enumerations can be iterated over, and know how many members they have: + |\x20\x20 + | >>> len(Color) + | 3 + |\x20\x20 + | >>> list(Color) + | [, , ] + |\x20\x20 + | Methods can be added to enumerations, and members can have their own + | attributes -- see the documentation for details. |\x20\x20 | Method resolution order: | Color @@ -3008,11 +4460,11 @@ class Color(enum.Enum) |\x20\x20 | Data and other attributes defined here: |\x20\x20 - | blue = + | CYAN = |\x20\x20 - | green = + | MAGENTA = |\x20\x20 - | red = + | YELLOW = |\x20\x20 | ---------------------------------------------------------------------- | Data descriptors inherited from enum.Enum: @@ -3024,13 +4476,28 @@ class Color(enum.Enum) | The value of the Enum member. |\x20\x20 | ---------------------------------------------------------------------- - | Readonly properties inherited from enum.EnumMeta: + | Methods inherited from enum.EnumType: |\x20\x20 - | __members__ - | Returns a mapping of member name->value. + | __contains__(member) from enum.EnumType + | Return True if member is a member of this enum + | raises TypeError if member is not an enum member |\x20\x20\x20\x20\x20\x20 - | This mapping lists all enum members, including aliases. Note that this - | is a read-only view of the internal mapping.""" + | note: in 3.12 TypeError will no longer be raised, and True will also be + | returned if member is the value of a member in this enum + |\x20\x20 + | __getitem__(name) from enum.EnumType + | Return the member matching `name`. + |\x20\x20 + | __iter__() from enum.EnumType + | Return members in definition order. + |\x20\x20 + | __len__() from enum.EnumType + | Return the number of members (no aliases) + |\x20\x20 + | ---------------------------------------------------------------------- + | Data descriptors inherited from enum.EnumType: + |\x20\x20 + | __members__""" expected_help_output_without_docs = """\ Help on class Color in module %s: @@ -3045,11 +4512,11 @@ class Color(enum.Enum) |\x20\x20 | Data and other attributes defined here: |\x20\x20 - | blue = + | YELLOW = |\x20\x20 - | green = + | MAGENTA = |\x20\x20 - | red = + | CYAN = |\x20\x20 | ---------------------------------------------------------------------- | Data descriptors inherited from enum.Enum: @@ -3059,7 +4526,7 @@ class Color(enum.Enum) | value |\x20\x20 | ---------------------------------------------------------------------- - | Data descriptors inherited from enum.EnumMeta: + | Data descriptors inherited from enum.EnumType: |\x20\x20 | __members__""" @@ -3068,12 +4535,10 @@ class TestStdLib(unittest.TestCase): maxDiff = None class Color(Enum): - red = 1 - green = 2 - blue = 3 + CYAN = 1 + MAGENTA = 2 + YELLOW = 3 - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_pydoc(self): # indirectly test __objclass__ if StrEnum.__doc__ is None: @@ -3084,26 +4549,34 @@ def test_pydoc(self): helper = pydoc.Helper(output=output) helper(self.Color) result = output.getvalue().strip() - self.assertEqual(result, expected_text) + self.assertEqual(result, expected_text, result) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_inspect_getmembers(self): values = dict(( - ('__class__', EnumMeta), - ('__doc__', 'An enumeration.'), + ('__class__', EnumType), + ('__doc__', '...'), ('__members__', self.Color.__members__), ('__module__', __name__), - ('blue', self.Color.blue), - ('green', self.Color.green), + ('YELLOW', self.Color.YELLOW), + ('MAGENTA', self.Color.MAGENTA), + ('CYAN', self.Color.CYAN), ('name', Enum.__dict__['name']), - ('red', self.Color.red), ('value', Enum.__dict__['value']), + ('__len__', self.Color.__len__), + ('__contains__', self.Color.__contains__), + ('__name__', 'Color'), + ('__getitem__', self.Color.__getitem__), + ('__qualname__', 'TestStdLib.Color'), + ('__init_subclass__', getattr(self.Color, '__init_subclass__')), + ('__iter__', self.Color.__iter__), )) result = dict(inspect.getmembers(self.Color)) - self.assertEqual(values.keys(), result.keys()) + self.assertEqual(set(values.keys()), set(result.keys())) failed = False for k in values.keys(): + if k == '__doc__': + # __doc__ is huge, not comparing + continue if result[k] != values[k]: print() print('\n%s\n key: %s\n result: %s\nexpected: %s\n%s\n' % @@ -3112,46 +4585,145 @@ def test_inspect_getmembers(self): if failed: self.fail("result does not equal expected, see print above") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_inspect_classify_class_attrs(self): # indirectly test __objclass__ from inspect import Attribute values = [ Attribute(name='__class__', kind='data', - defining_class=object, object=EnumMeta), + defining_class=object, object=EnumType), + Attribute(name='__contains__', kind='method', + defining_class=EnumType, object=self.Color.__contains__), Attribute(name='__doc__', kind='data', - defining_class=self.Color, object='An enumeration.'), + defining_class=self.Color, object='...'), + Attribute(name='__getitem__', kind='method', + defining_class=EnumType, object=self.Color.__getitem__), + Attribute(name='__iter__', kind='method', + defining_class=EnumType, object=self.Color.__iter__), + Attribute(name='__init_subclass__', kind='class method', + defining_class=object, object=getattr(self.Color, '__init_subclass__')), + Attribute(name='__len__', kind='method', + defining_class=EnumType, object=self.Color.__len__), Attribute(name='__members__', kind='property', - defining_class=EnumMeta, object=EnumMeta.__members__), + defining_class=EnumType, object=EnumType.__members__), Attribute(name='__module__', kind='data', defining_class=self.Color, object=__name__), - Attribute(name='blue', kind='data', - defining_class=self.Color, object=self.Color.blue), - Attribute(name='green', kind='data', - defining_class=self.Color, object=self.Color.green), - Attribute(name='red', kind='data', - defining_class=self.Color, object=self.Color.red), + Attribute(name='__name__', kind='data', + defining_class=self.Color, object='Color'), + Attribute(name='__qualname__', kind='data', + defining_class=self.Color, object='TestStdLib.Color'), + Attribute(name='YELLOW', kind='data', + defining_class=self.Color, object=self.Color.YELLOW), + Attribute(name='MAGENTA', kind='data', + defining_class=self.Color, object=self.Color.MAGENTA), + Attribute(name='CYAN', kind='data', + defining_class=self.Color, object=self.Color.CYAN), Attribute(name='name', kind='data', defining_class=Enum, object=Enum.__dict__['name']), Attribute(name='value', kind='data', defining_class=Enum, object=Enum.__dict__['value']), ] + for v in values: + try: + v.name + except AttributeError: + print(v) values.sort(key=lambda item: item.name) result = list(inspect.classify_class_attrs(self.Color)) result.sort(key=lambda item: item.name) + self.assertEqual( + len(values), len(result), + "%s != %s" % ([a.name for a in values], [a.name for a in result]) + ) failed = False for v, r in zip(values, result): - if r != v: + if r.name in ('__init_subclass__', '__doc__'): + # not sure how to make the __init_subclass_ Attributes match + # so as long as there is one, call it good + # __doc__ is too big to check exactly, so treat the same as __init_subclass__ + for name in ('name','kind','defining_class'): + if getattr(v, name) != getattr(r, name): + print('\n%s\n%s\n%s\n%s\n' % ('=' * 75, r, v, '=' * 75), sep='') + failed = True + elif r != v: print('\n%s\n%s\n%s\n%s\n' % ('=' * 75, r, v, '=' * 75), sep='') failed = True if failed: self.fail("result does not equal expected, see print above") + # TODO: RUSTPYTHON, len is often/always > 256 + @unittest.expectedFailure + def test_test_simple_enum(self): + @_simple_enum(Enum) + class SimpleColor: + CYAN = 1 + MAGENTA = 2 + YELLOW = 3 + @bltns.property + def zeroth(self): + return 'zeroed %s' % self.name + class CheckedColor(Enum): + CYAN = 1 + MAGENTA = 2 + YELLOW = 3 + @bltns.property + def zeroth(self): + return 'zeroed %s' % self.name + self.assertTrue(_test_simple_enum(CheckedColor, SimpleColor) is None) + SimpleColor.MAGENTA._value_ = 9 + self.assertRaisesRegex( + TypeError, "enum mismatch", + _test_simple_enum, CheckedColor, SimpleColor, + ) + class CheckedMissing(IntFlag, boundary=KEEP): + SIXTY_FOUR = 64 + ONE_TWENTY_EIGHT = 128 + TWENTY_FORTY_EIGHT = 2048 + ALL = 2048 + 128 + 64 + 12 + CM = CheckedMissing + self.assertEqual(list(CheckedMissing), [CM.SIXTY_FOUR, CM.ONE_TWENTY_EIGHT, CM.TWENTY_FORTY_EIGHT]) + # + @_simple_enum(IntFlag, boundary=KEEP) + class Missing: + SIXTY_FOUR = 64 + ONE_TWENTY_EIGHT = 128 + TWENTY_FORTY_EIGHT = 2048 + ALL = 2048 + 128 + 64 + 12 + M = Missing + self.assertEqual(list(CheckedMissing), [M.SIXTY_FOUR, M.ONE_TWENTY_EIGHT, M.TWENTY_FORTY_EIGHT]) + # + _test_simple_enum(CheckedMissing, Missing) + class MiscTestCase(unittest.TestCase): + def test__all__(self): - check__all__(self, enum) + support.check__all__(self, enum, not_exported={'bin', 'show_flag_values'}) + + def test_doc_1(self): + class Single(Enum): + ONE = 1 + self.assertEqual(Single.__doc__, None) + + def test_doc_2(self): + class Double(Enum): + ONE = 1 + TWO = 2 + self.assertEqual(Double.__doc__, None) + + def test_doc_3(self): + class Triple(Enum): + ONE = 1 + TWO = 2 + THREE = 3 + self.assertEqual(Triple.__doc__, None) + + def test_doc_4(self): + class Quadruple(Enum): + ONE = 1 + TWO = 2 + THREE = 3 + FOUR = 4 + self.assertEqual(Quadruple.__doc__, None) # These are unordered here on purpose to ensure that declaration order @@ -3163,21 +4735,61 @@ def test__all__(self): CONVERT_TEST_NAME_E = 5 CONVERT_TEST_NAME_F = 5 -class TestIntEnumConvert(unittest.TestCase): +CONVERT_STRING_TEST_NAME_D = 5 +CONVERT_STRING_TEST_NAME_C = 5 +CONVERT_STRING_TEST_NAME_B = 5 +CONVERT_STRING_TEST_NAME_A = 5 # This one should sort first. +CONVERT_STRING_TEST_NAME_E = 5 +CONVERT_STRING_TEST_NAME_F = 5 + +# global names for StrEnum._convert_ test +CONVERT_STR_TEST_2 = 'goodbye' +CONVERT_STR_TEST_1 = 'hello' + +# We also need values that cannot be compared: +UNCOMPARABLE_A = 5 +UNCOMPARABLE_C = (9, 1) # naming order is broken on purpose +UNCOMPARABLE_B = 'value' + +COMPLEX_C = 1j +COMPLEX_A = 2j +COMPLEX_B = 3j + +class _ModuleWrapper: + """We use this class as a namespace for swapping modules.""" + def __init__(self, module): + self.__dict__.update(module.__dict__) + +class TestConvert(unittest.TestCase): + def tearDown(self): + # Reset the module-level test variables to their original integer + # values, otherwise the already created enum values get converted + # instead. + g = globals() + for suffix in ['A', 'B', 'C', 'D', 'E', 'F']: + g['CONVERT_TEST_NAME_%s' % suffix] = 5 + g['CONVERT_STRING_TEST_NAME_%s' % suffix] = 5 + for suffix, value in (('A', 5), ('B', (9, 1)), ('C', 'value')): + g['UNCOMPARABLE_%s' % suffix] = value + for suffix, value in (('A', 2j), ('B', 3j), ('C', 1j)): + g['COMPLEX_%s' % suffix] = value + for suffix, value in (('1', 'hello'), ('2', 'goodbye')): + g['CONVERT_STR_TEST_%s' % suffix] = value + def test_convert_value_lookup_priority(self): test_type = enum.IntEnum._convert_( 'UnittestConvert', - ('test.test_enum', '__main__')[__name__=='__main__'], + MODULE, filter=lambda x: x.startswith('CONVERT_TEST_')) # We don't want the reverse lookup value to vary when there are # multiple possible names for a given value. It should always # report the first lexigraphical name in that case. self.assertEqual(test_type(5).name, 'CONVERT_TEST_NAME_A') - def test_convert(self): + def test_convert_int(self): test_type = enum.IntEnum._convert_( 'UnittestConvert', - ('test.test_enum', '__main__')[__name__=='__main__'], + MODULE, filter=lambda x: x.startswith('CONVERT_TEST_')) # Ensure that test_type has all of the desired names and values. self.assertEqual(test_type.CONVERT_TEST_NAME_F, @@ -3187,30 +4799,123 @@ def test_convert(self): self.assertEqual(test_type.CONVERT_TEST_NAME_D, 5) self.assertEqual(test_type.CONVERT_TEST_NAME_E, 5) # Ensure that test_type only picked up names matching the filter. - self.assertEqual([name for name in dir(test_type) - if name[0:2] not in ('CO', '__')], - [], msg='Names other than CONVERT_TEST_* found.') - - @unittest.skipUnless(sys.version_info[:2] == (3, 8), - '_convert was deprecated in 3.8') - def test_convert_warn(self): - with self.assertWarns(DeprecationWarning): - enum.IntEnum._convert( + int_dir = dir(int) + [ + 'CONVERT_TEST_NAME_A', 'CONVERT_TEST_NAME_B', 'CONVERT_TEST_NAME_C', + 'CONVERT_TEST_NAME_D', 'CONVERT_TEST_NAME_E', 'CONVERT_TEST_NAME_F', + 'CONVERT_TEST_SIGABRT', 'CONVERT_TEST_SIGIOT', + 'CONVERT_TEST_EIO', 'CONVERT_TEST_EBUS', + ] + extra = [name for name in dir(test_type) if name not in enum_dir(test_type)] + missing = [name for name in enum_dir(test_type) if name not in dir(test_type)] + self.assertEqual( + extra + missing, + [], + msg='extra names: %r; missing names: %r' % (extra, missing), + ) + + + def test_convert_uncomparable(self): + uncomp = enum.Enum._convert_( + 'Uncomparable', + MODULE, + filter=lambda x: x.startswith('UNCOMPARABLE_')) + # Should be ordered by `name` only: + self.assertEqual( + list(uncomp), + [uncomp.UNCOMPARABLE_A, uncomp.UNCOMPARABLE_B, uncomp.UNCOMPARABLE_C], + ) + + def test_convert_complex(self): + uncomp = enum.Enum._convert_( + 'Uncomparable', + MODULE, + filter=lambda x: x.startswith('COMPLEX_')) + # Should be ordered by `name` only: + self.assertEqual( + list(uncomp), + [uncomp.COMPLEX_A, uncomp.COMPLEX_B, uncomp.COMPLEX_C], + ) + + def test_convert_str(self): + test_type = enum.StrEnum._convert_( 'UnittestConvert', - ('test.test_enum', '__main__')[__name__=='__main__'], - filter=lambda x: x.startswith('CONVERT_TEST_')) + MODULE, + filter=lambda x: x.startswith('CONVERT_STR_'), + as_global=True) + # Ensure that test_type has all of the desired names and values. + self.assertEqual(test_type.CONVERT_STR_TEST_1, 'hello') + self.assertEqual(test_type.CONVERT_STR_TEST_2, 'goodbye') + # Ensure that test_type only picked up names matching the filter. + str_dir = dir(str) + ['CONVERT_STR_TEST_1', 'CONVERT_STR_TEST_2'] + extra = [name for name in dir(test_type) if name not in enum_dir(test_type)] + missing = [name for name in enum_dir(test_type) if name not in dir(test_type)] + self.assertEqual( + extra + missing, + [], + msg='extra names: %r; missing names: %r' % (extra, missing), + ) + self.assertEqual(repr(test_type.CONVERT_STR_TEST_1), '%s.CONVERT_STR_TEST_1' % SHORT_MODULE) + self.assertEqual(str(test_type.CONVERT_STR_TEST_2), 'goodbye') + self.assertEqual(format(test_type.CONVERT_STR_TEST_1), 'hello') - # TODO: RUSTPYTHON - @unittest.expectedFailure - @unittest.skipUnless(sys.version_info >= (3, 9), - '_convert was removed in 3.9') def test_convert_raise(self): with self.assertRaises(AttributeError): enum.IntEnum._convert( 'UnittestConvert', - ('test.test_enum', '__main__')[__name__=='__main__'], + MODULE, filter=lambda x: x.startswith('CONVERT_TEST_')) + def test_convert_repr_and_str(self): + test_type = enum.IntEnum._convert_( + 'UnittestConvert', + MODULE, + filter=lambda x: x.startswith('CONVERT_STRING_TEST_'), + as_global=True) + self.assertEqual(repr(test_type.CONVERT_STRING_TEST_NAME_A), '%s.CONVERT_STRING_TEST_NAME_A' % SHORT_MODULE) + self.assertEqual(str(test_type.CONVERT_STRING_TEST_NAME_A), '5') + self.assertEqual(format(test_type.CONVERT_STRING_TEST_NAME_A), '5') + + +# helpers + +def enum_dir(cls): + interesting = set([ + '__class__', '__contains__', '__doc__', '__getitem__', + '__iter__', '__len__', '__members__', '__module__', + '__name__', '__qualname__', + ] + + cls._member_names_ + ) + if cls._new_member_ is not object.__new__: + interesting.add('__new__') + if cls.__init_subclass__ is not object.__init_subclass__: + interesting.add('__init_subclass__') + if cls._member_type_ is object: + return sorted(interesting) + else: + # return whatever mixed-in data type has + return sorted(set(dir(cls._member_type_)) | interesting) + +def member_dir(member): + if member.__class__._member_type_ is object: + allowed = set(['__class__', '__doc__', '__eq__', '__hash__', '__module__', 'name', 'value']) + else: + allowed = set(dir(member)) + for cls in member.__class__.mro(): + for name, obj in cls.__dict__.items(): + if name[0] == '_': + continue + if isinstance(obj, enum.property): + if obj.fget is not None or name not in member._member_map_: + allowed.add(name) + else: + allowed.discard(name) + else: + allowed.add(name) + return sorted(allowed) + +missing = object() + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py index 03cb8172de..9b30b4137c 100644 --- a/Lib/test/test_re.py +++ b/Lib/test/test_re.py @@ -2304,6 +2304,8 @@ def test_long_pattern(self): self.assertEqual(r[:30], "re.compile('Very long long lon") self.assertEqual(r[-16:], ", re.IGNORECASE)") + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_flags_repr(self): self.assertEqual(repr(re.I), "re.IGNORECASE") self.assertEqual(repr(re.I|re.S|re.X), diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 72b0f19275..35f94a4e22 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1493,6 +1493,8 @@ def test_sio_loopback_fast_path(self): raise self.assertRaises(TypeError, s.ioctl, socket.SIO_LOOPBACK_FAST_PATH, None) + # TODO: RUSTPYTHON, AssertionError: '2' != 'AddressFamily.AF_INET' + @unittest.expectedFailure def testGetaddrinfo(self): try: socket.getaddrinfo('localhost', 80) @@ -1799,6 +1801,8 @@ def test_getnameinfo_ipv6_scopeid_numeric(self): nameinfo = socket.getnameinfo(sockaddr, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV) self.assertEqual(nameinfo, ('ff02::1de:c0:face:8d%' + str(ifindex), '1234')) + # TODO: RUSTPYTHON, AssertionError: '2' != 'AddressFamily.AF_INET' + @unittest.expectedFailure def test_str_for_enums(self): # Make sure that the AF_* and SOCK_* constants have enum-like string # reprs. diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index cf1b14bacd..071a2a06c1 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -1509,6 +1509,9 @@ def __int__(self): self.assertRaisesRegex(TypeError, '%x format: an integer is required, not PseudoFloat', operator.mod, '%x', pi), self.assertRaises(TypeError, operator.mod, '%c', pi), + + # TODO: RUSTPYTHON, AssertionError: '...15...' != '...Int.IDES...' + @unittest.expectedFailure def test_formatting_with_enum(self): # issue18780 import enum From c3a8d5a9b556b698785d5a34245ce09188ff44a8 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Tue, 3 Oct 2023 23:19:31 +0900 Subject: [PATCH 113/893] Retry to fix win_lib_path again (#5076) --- pylib/build.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pylib/build.rs b/pylib/build.rs index 1aca8b3318..a541bc56d5 100644 --- a/pylib/build.rs +++ b/pylib/build.rs @@ -12,7 +12,12 @@ fn main() { if cfg!(windows) { if let Ok(real_path) = std::fs::read_to_string("Lib") { - println!("cargo:rustc-env=win_lib_path={real_path:?}"); + let canonicalized_path = std::fs::canonicalize(real_path) + .expect("failed to resolve RUSTPYTHONPATH during build time"); + println!( + "cargo:rustc-env=win_lib_path={}", + canonicalized_path.to_str().unwrap() + ); } } } From 23bf5c42cad7d236f59f72a761eb3b5fbc45dd6c Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 4 Oct 2023 23:42:07 +0900 Subject: [PATCH 114/893] bump up and sync dependencies with rustpython-parser (#5075) --- Cargo.lock | 169 +++++++++++++++++++++++++------------------- Cargo.toml | 37 +++++----- vm/Cargo.toml | 4 +- wasm/lib/Cargo.toml | 2 +- 4 files changed, 119 insertions(+), 93 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ee6aab2125..98fcfc5fa5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -22,10 +22,11 @@ checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] name = "ahash" -version = "0.7.6" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" dependencies = [ + "cfg-if", "getrandom 0.2.8", "once_cell", "version_check", @@ -40,6 +41,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + [[package]] name = "android_system_properties" version = "0.1.5" @@ -125,9 +132,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.1" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6776fc96284a0bb647b615056fc496d1fe1644a7ab01829818a6d91cae888b84" +checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" [[package]] name = "blake2" @@ -237,17 +244,16 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.23" +version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b0a3d9ed01224b22057780a37bb8c5dbfe1be8ba48678e7bf57ec4b385411f" +checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" dependencies = [ + "android-tzdata", "iana-time-zone", "js-sys", - "num-integer", "num-traits", - "time", "wasm-bindgen", - "winapi", + "windows-targets 0.48.0", ] [[package]] @@ -894,6 +900,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "getopts" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5" +dependencies = [ + "unicode-width", +] + [[package]] name = "getrandom" version = "0.1.16" @@ -1014,9 +1029,9 @@ dependencies = [ [[package]] name = "insta" -version = "1.28.0" +version = "1.33.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea5b3894afe466b4bcf0388630fc15e11938a6074af0cd637c825ba2ec8a099" +checksum = "1aa511b2e298cd49b1856746f6bb73e17036bcd66b25f5e92cdcdbec9bd75686" dependencies = [ "console", "lazy_static 1.4.0", @@ -1036,19 +1051,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "is-macro" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a7d079e129b77477a49c5c4f1cfe9ce6c2c909ef52520693e8e811a714c7b20" -dependencies = [ - "Inflector", - "pmutil 0.5.3", - "proc-macro2", - "quote", - "syn 1.0.107", -] - [[package]] name = "is-macro" version = "0.3.0" @@ -1080,6 +1082,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.5" @@ -1505,9 +1516,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.17.1" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] name = "oorandom" @@ -2029,10 +2040,10 @@ dependencies = [ [[package]] name = "rustpython-ast" -version = "0.3.0" -source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" +version = "0.3.1" +source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" dependencies = [ - "is-macro 0.2.2", + "is-macro", "malachite-bigint", "rustpython-literal", "rustpython-parser-core", @@ -2044,10 +2055,10 @@ name = "rustpython-codegen" version = "0.3.0" dependencies = [ "ahash", - "bitflags 2.3.1", + "bitflags 2.4.0", "indexmap", "insta", - "itertools 0.10.5", + "itertools 0.11.0", "log", "num-complex", "num-traits", @@ -2062,10 +2073,10 @@ name = "rustpython-common" version = "0.3.0" dependencies = [ "ascii", - "bitflags 2.3.1", + "bitflags 2.4.0", "bstr", "cfg-if", - "itertools 0.10.5", + "itertools 0.11.0", "libc", "lock_api", "malachite-base", @@ -2096,8 +2107,8 @@ dependencies = [ name = "rustpython-compiler-core" version = "0.3.0" dependencies = [ - "bitflags 2.3.1", - "itertools 0.10.5", + "bitflags 2.4.0", + "itertools 0.11.0", "lz4_flex", "malachite-bigint", "num-complex", @@ -2118,7 +2129,7 @@ dependencies = [ name = "rustpython-derive-impl" version = "0.3.0" dependencies = [ - "itertools 0.10.5", + "itertools 0.11.0", "maplit", "once_cell", "proc-macro2", @@ -2141,11 +2152,11 @@ dependencies = [ [[package]] name = "rustpython-format" -version = "0.3.0" -source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" +version = "0.3.1" +source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" dependencies = [ - "bitflags 2.3.1", - "itertools 0.10.5", + "bitflags 2.4.0", + "itertools 0.11.0", "malachite-bigint", "num-traits", "rustpython-literal", @@ -2168,11 +2179,11 @@ dependencies = [ [[package]] name = "rustpython-literal" -version = "0.3.0" -source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" +version = "0.3.1" +source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" dependencies = [ "hexf-parse", - "is-macro 0.2.2", + "is-macro", "lexical-parse-float", "num-traits", "unic-ucd-category", @@ -2180,12 +2191,12 @@ dependencies = [ [[package]] name = "rustpython-parser" -version = "0.3.0" -source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" +version = "0.3.1" +source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" dependencies = [ "anyhow", - "is-macro 0.2.2", - "itertools 0.10.5", + "is-macro", + "itertools 0.11.0", "lalrpop-util", "log", "malachite-bigint", @@ -2203,18 +2214,18 @@ dependencies = [ [[package]] name = "rustpython-parser-core" -version = "0.3.0" -source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" +version = "0.3.1" +source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" dependencies = [ - "is-macro 0.2.2", + "is-macro", "memchr", "rustpython-parser-vendored", ] [[package]] name = "rustpython-parser-vendored" -version = "0.3.0" -source = "git+https://github.com/RustPython/Parser.git?tag=0.3.0#a1e4336f7043807eda8a5ecb15d4115172cc4a7e" +version = "0.3.1" +source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" dependencies = [ "memchr", "once_cell", @@ -2250,7 +2261,7 @@ dependencies = [ "foreign-types-shared", "gethostname", "hex", - "itertools 0.10.5", + "itertools 0.11.0", "libc", "libsqlite3-sys", "libz-sys", @@ -2307,7 +2318,7 @@ dependencies = [ "ahash", "ascii", "atty", - "bitflags 2.3.1", + "bitflags 2.4.0", "bstr", "caseless", "cfg-if", @@ -2321,8 +2332,8 @@ dependencies = [ "half", "hex", "indexmap", - "is-macro 0.3.0", - "itertools 0.10.5", + "is-macro", + "itertools 0.11.0", "libc", "log", "malachite-bigint", @@ -2440,11 +2451,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.21" +version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" +checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys 0.42.0", + "windows-sys 0.48.0", ] [[package]] @@ -2784,15 +2795,20 @@ dependencies = [ [[package]] name = "time" -version = "0.1.45" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" +checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890" dependencies = [ - "libc", - "wasi 0.10.0+wasi-snapshot-preview1", - "winapi", + "serde", + "time-core", ] +[[package]] +name = "time-core" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" + [[package]] name = "timsort" version = "0.1.2" @@ -3032,10 +3048,25 @@ checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" [[package]] name = "unicode_names2" -version = "0.6.0" -source = "git+https://github.com/youknowone/unicode_names2.git?rev=4ce16aa85cbcdd9cc830410f1a72ef9a235f2fde#4ce16aa85cbcdd9cc830410f1a72ef9a235f2fde" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b2c0942619ae1797f999a0ce7efc6c09592ad30e68e16cdbfdcd48a98c3579" dependencies = [ "phf", + "unicode_names2_generator", +] + +[[package]] +name = "unicode_names2_generator" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d0d66ab60be9799a70f8eb227ea43da7dcc47561dd9102cbadacfe0930113f7" +dependencies = [ + "getopts", + "log", + "phf_codegen", + "rand 0.8.5", + "time", ] [[package]] @@ -3108,12 +3139,6 @@ version = "0.9.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" -[[package]] -name = "wasi" -version = "0.10.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index d56f61ee59..a3b78fd70a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,37 +24,38 @@ rustpython-common = { path = "common", version = "0.3.0" } rustpython-derive = { path = "derive", version = "0.3.0" } rustpython-derive-impl = { path = "derive-impl", version = "0.3.0" } rustpython-jit = { path = "jit", version = "0.3.0" } -rustpython-vm = { path = "vm", version = "0.3.0" } +rustpython-vm = { path = "vm", default-features = false, version = "0.3.0" } rustpython-pylib = { path = "pylib", version = "0.3.0" } -rustpython-stdlib = { path = "stdlib", version = "0.3.0" } +rustpython-stdlib = { path = "stdlib", default-features = false, version = "0.3.0" } rustpython-doc = { git = "https://github.com/RustPython/__doc__", tag = "0.3.0", version = "0.3.0" } -rustpython-literal = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0", version = "0.3.0" } -rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0", version = "0.3.0" } -rustpython-parser = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0", version = "0.3.0" } -rustpython-ast = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0", version = "0.3.0" } -rustpython-format = { git = "https://github.com/RustPython/Parser.git", tag = "0.3.0", version = "0.3.0" } +rustpython-literal = { git = "https://github.com/RustPython/Parser.git", rev = "13cae0af64d0a23de95f08c0210e97ad74d155e9" } +rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", rev = "13cae0af64d0a23de95f08c0210e97ad74d155e9" } +rustpython-parser = { git = "https://github.com/RustPython/Parser.git", rev = "13cae0af64d0a23de95f08c0210e97ad74d155e9" } +rustpython-ast = { git = "https://github.com/RustPython/Parser.git", rev = "13cae0af64d0a23de95f08c0210e97ad74d155e9" } +rustpython-format = { git = "https://github.com/RustPython/Parser.git", rev = "13cae0af64d0a23de95f08c0210e97ad74d155e9" } # rustpython-literal = { path = "../RustPython-parser/literal" } # rustpython-parser-core = { path = "../RustPython-parser/core" } # rustpython-parser = { path = "../RustPython-parser/parser" } # rustpython-ast = { path = "../RustPython-parser/ast" } # rustpython-format = { path = "../RustPython-parser/format" } -ahash = "0.7.6" +ahash = "0.8.3" anyhow = "1.0.45" ascii = "1.0" atty = "0.2.14" -bitflags = "2.2.1" +bitflags = "2.4.0" bstr = "0.2.17" cfg-if = "1.0" -chrono = "0.4.19" +chrono = "0.4.31" crossbeam-utils = "0.8.16" flame = "0.2.2" glob = "0.3" hex = "0.4.3" indexmap = "1.8.1" -insta = "1.14.0" -itertools = "0.10.3" +insta = "1.33.0" +itertools = "0.11.0" +is-macro = "0.3.0" libc = "0.2.133" log = "0.4.16" nix = "0.26" @@ -65,18 +66,18 @@ num-complex = "0.4.0" num-integer = "0.1.44" num-traits = "0.2" num_enum = "0.5.7" -once_cell = "1.13" -parking_lot = "0.12" +once_cell = "1.18" +parking_lot = "0.12.1" paste = "1.0.7" rand = "0.8.5" rustyline = "11" -serde = "1.0" -schannel = "0.1.19" +serde = { version = "1.0.133", default-features = false } +schannel = "0.1.22" static_assertions = "1.1" syn = "1.0.91" thiserror = "1.0" thread_local = "1.1.4" -unicode_names2 = { version = "0.6.0", git = "https://github.com/youknowone/unicode_names2.git", rev = "4ce16aa85cbcdd9cc830410f1a72ef9a235f2fde" } +unicode_names2 = "1.1.0" widestring = "0.5.1" [features] @@ -97,7 +98,7 @@ ssl-vendor = ["rustpython-stdlib/ssl-vendor"] rustpython-compiler = { workspace = true } rustpython-pylib = { workspace = true, optional = true } rustpython-stdlib = { workspace = true, optional = true } -rustpython-vm = { workspace = true, default-features = false, features = ["compiler"] } +rustpython-vm = { workspace = true, features = ["compiler"] } rustpython-parser = { workspace = true } atty = { workspace = true } diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 966858c087..ee2eadc681 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -49,6 +49,7 @@ flame = { workspace = true, optional = true } hex = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } +is-macro = { workspace = true } libc = { workspace = true } log = { workspace = true } nix = { workspace = true } @@ -70,7 +71,6 @@ caseless = "0.2.1" getrandom = { version = "0.2.6", features = ["js"] } flamer = { version = "0.4", optional = true } half = "1.8.2" -is-macro = "0.3" memchr = "2.4.1" memoffset = "0.6.5" optional = "0.5.0" @@ -123,7 +123,7 @@ version = "0.3.9" features = [ "winsock2", "handleapi", "ws2def", "std", "winbase", "wincrypt", "fileapi", "processenv", "namedpipeapi", "winnt", "processthreadsapi", "errhandlingapi", "winuser", "synchapi", "wincon", - "impl-default", "vcruntime", "ifdef", "netioapi", "memoryapi", + "impl-default", "vcruntime", "ifdef", "netioapi", "memoryapi", "profileapi", ] [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/wasm/lib/Cargo.toml b/wasm/lib/Cargo.toml index 98be000fee..309607e70f 100644 --- a/wasm/lib/Cargo.toml +++ b/wasm/lib/Cargo.toml @@ -20,7 +20,7 @@ rustpython-common = { workspace = true } rustpython-pylib = { workspace = true, optional = true } rustpython-stdlib = { workspace = true, default-features = false, optional = true } # make sure no threading! otherwise wasm build will fail -rustpython-vm = { workspace = true, default-features = false, features = ["compiler", "encodings", "serde"] } +rustpython-vm = { workspace = true, features = ["compiler", "encodings", "serde"] } rustpython-parser = { workspace = true } From d975c51b9607db30265540ecc8a70ee8ab18a7d5 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 4 Oct 2023 23:42:37 +0900 Subject: [PATCH 115/893] implement more warnings (#5077) --- .cspell.json | 2 + vm/src/vm/context.rs | 2 + vm/src/warn.rs | 104 ++++++++++++++++++++++++++++++++++++------- 3 files changed, 91 insertions(+), 17 deletions(-) diff --git a/.cspell.json b/.cspell.json index 3f60bb076d..ad90153f56 100644 --- a/.cspell.json +++ b/.cspell.json @@ -171,6 +171,8 @@ "scproxy", "setattro", "setcomp", + "showwarnmsg", + "warnmsg", "stacklevel", "subclasscheck", "subclasshook", diff --git a/vm/src/vm/context.rs b/vm/src/vm/context.rs index 124fffdd4d..4181a65896 100644 --- a/vm/src/vm/context.rs +++ b/vm/src/vm/context.rs @@ -222,6 +222,7 @@ declare_const_name! { // common names _attributes, _fields, + _showwarnmsg, decode, encode, keys, @@ -232,6 +233,7 @@ declare_const_name! { copy, flush, close, + WarningMessage, } // Basic objects: diff --git a/vm/src/warn.rs b/vm/src/warn.rs index dde546dcfa..ba45714853 100644 --- a/vm/src/warn.rs +++ b/vm/src/warn.rs @@ -1,5 +1,8 @@ use crate::{ - builtins::{PyDict, PyDictRef, PyListRef, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef}, + builtins::{ + PyDict, PyDictRef, PyListRef, PyStr, PyStrInterned, PyStrRef, PyTuple, PyTupleRef, + PyTypeRef, + }, convert::{IntoObject, TryFromObject}, types::PyComparisonOp, AsObject, Context, Py, PyObjectRef, PyResult, VirtualMachine, @@ -48,19 +51,26 @@ fn check_matched(obj: &PyObjectRef, arg: &PyObjectRef, vm: &VirtualMachine) -> P Ok(result.is_ok()) } -pub fn py_warn( - category: &Py, - message: String, - stack_level: usize, +fn get_warnings_attr( vm: &VirtualMachine, -) -> PyResult<()> { - // TODO: use rust warnings module - if let Ok(module) = vm.import("warnings", None, 0) { - if let Ok(func) = module.get_attr("warn", vm) { - let _ = func.call((message, category.to_owned(), stack_level), vm); + attr_name: &'static PyStrInterned, + try_import: bool, +) -> PyResult> { + let module = if try_import + && !vm + .state + .finalizing + .load(std::sync::atomic::Ordering::SeqCst) + { + match vm.import("warnings", None, 0) { + Ok(module) => module, + Err(_) => return Ok(None), } - } - Ok(()) + } else { + // TODO: finalizing support + return Ok(None); + }; + Ok(Some(module.get_attr(attr_name, vm)?)) } pub fn warn( @@ -192,7 +202,7 @@ fn already_warned( Ok(true) } -fn normalize_module(filename: PyStrRef, vm: &VirtualMachine) -> Option { +fn normalize_module(filename: &Py, vm: &VirtualMachine) -> Option { let obj = match filename.char_len() { 0 => vm.new_pyobj(""), len if len >= 3 && filename.as_str().ends_with(".py") => { @@ -211,8 +221,8 @@ fn warn_explicit( lineno: usize, module: Option, registry: PyObjectRef, - _source_line: Option, - _source: Option, + source_line: Option, + source: Option, vm: &VirtualMachine, ) -> PyResult<()> { let registry: PyObjectRef = registry @@ -220,7 +230,7 @@ fn warn_explicit( .map_err(|_| vm.new_type_error("'registry' must be a dict or None".to_owned()))?; // Normalize module. - let module = match module.or_else(|| normalize_module(filename, vm)) { + let module = match module.or_else(|| normalize_module(&filename, vm)) { Some(module) => module, None => return Ok(()), }; @@ -280,8 +290,68 @@ fn warn_explicit( return Ok(()); } + call_show_warning( + // t_state, + category, + message, + filename, + lineno, // lineno_obj, + source_line, + source, + vm, + ) +} + +fn call_show_warning( + category: PyTypeRef, + message: PyStrRef, + filename: PyStrRef, + lineno: usize, + source_line: Option, + source: Option, + vm: &VirtualMachine, +) -> PyResult<()> { + let Some(show_fn) = + get_warnings_attr(vm, identifier!(&vm.ctx, _showwarnmsg), source.is_some())? + else { + return show_warning(filename, lineno, message, category, source_line, vm); + }; + if !show_fn.is_callable() { + return Err( + vm.new_type_error("warnings._showwarnmsg() must be set to a callable".to_owned()) + ); + } + let Some(warnmsg_cls) = get_warnings_attr(vm, identifier!(&vm.ctx, WarningMessage), false)? + else { + return Err(vm.new_type_error("unable to get warnings.WarningMessage".to_owned())); + }; + + let msg = warnmsg_cls.call( + vec![ + message.into(), + category.into(), + filename.into(), + vm.new_pyobj(lineno), + vm.ctx.none(), + vm.ctx.none(), + vm.unwrap_or_none(source), + ], + vm, + )?; + show_fn.call((msg,), vm)?; + Ok(()) +} + +fn show_warning( + _filename: PyStrRef, + _lineno: usize, + text: PyStrRef, + category: PyTypeRef, + _source_line: Option, + vm: &VirtualMachine, +) -> PyResult<()> { let stderr = crate::stdlib::sys::PyStderr(vm); - writeln!(stderr, "{}: {}", category.name(), text,); + writeln!(stderr, "{}: {}", category.name(), text.as_str(),); Ok(()) } From 4135da42ac6ec8490547845c94ba7a32783bd0a9 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Fri, 6 Oct 2023 03:17:03 +0900 Subject: [PATCH 116/893] Fix clippy (#5083) * Fix clippy * Fix nightly clippy --- compiler/codegen/src/symboltable.rs | 2 +- src/shell/helper.rs | 2 +- stdlib/src/zlib.rs | 6 +++--- vm/src/frame.rs | 28 ++++++++++++---------------- vm/src/import.rs | 5 +---- vm/src/macros.rs | 2 ++ vm/src/stdlib/ast.rs | 2 +- vm/src/stdlib/sys.rs | 4 ++-- wasm/lib/src/lib.rs | 2 +- 9 files changed, 24 insertions(+), 29 deletions(-) diff --git a/compiler/codegen/src/symboltable.rs b/compiler/codegen/src/symboltable.rs index e09f0c4aab..5283abad53 100644 --- a/compiler/codegen/src/symboltable.rs +++ b/compiler/codegen/src/symboltable.rs @@ -297,7 +297,7 @@ impl SymbolTableAnalyzer { &mut self, symbol: &mut Symbol, st_typ: SymbolTableType, - sub_tables: &mut [SymbolTable], + sub_tables: &[SymbolTable], ) -> SymbolTableResult { if symbol .flags diff --git a/src/shell/helper.rs b/src/shell/helper.rs index 83d72907bd..34691e7995 100644 --- a/src/shell/helper.rs +++ b/src/shell/helper.rs @@ -119,8 +119,8 @@ impl<'vm> ShellHelper<'vm> { // only the completions that don't start with a '_' let no_underscore = all_completions .iter() + .filter(|&s| !s.as_str().starts_with('_')) .cloned() - .filter(|s| !s.as_str().starts_with('_')) .collect::>(); // if there are only completions that start with a '_', give them all of the diff --git a/stdlib/src/zlib.rs b/stdlib/src/zlib.rs index 92cefc863f..37ee1c83f7 100644 --- a/stdlib/src/zlib.rs +++ b/stdlib/src/zlib.rs @@ -310,7 +310,7 @@ mod zlib { fn save_unused_input( &self, - d: &mut Decompress, + d: &Decompress, data: &[u8], stream_end: bool, orig_in: u64, @@ -349,7 +349,7 @@ mod zlib { Ok((buf, false)) => (Ok(buf), false), Err(err) => (Err(err), false), }; - self.save_unused_input(&mut d, data, stream_end, orig_in, vm); + self.save_unused_input(&d, data, stream_end, orig_in, vm); let leftover = if stream_end { b"" @@ -390,7 +390,7 @@ mod zlib { Ok((buf, stream_end)) => (Ok(buf), stream_end), Err(err) => (Err(err), false), }; - self.save_unused_input(&mut d, &data, stream_end, orig_in, vm); + self.save_unused_input(&d, &data, stream_end, orig_in, vm); *data = PyBytes::from(Vec::new()).into_ref(&vm.ctx); diff --git a/vm/src/frame.rs b/vm/src/frame.rs index e53b57a2ee..719aa6288a 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -1946,22 +1946,18 @@ impl ExecutingFrame<'_> { impl fmt::Debug for Frame { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let state = self.state.lock(); - let stack_str = state - .stack - .iter() - .map(|elem| { - if elem.payload_is::() { - "\n > {frame}".to_owned() - } else { - format!("\n > {elem:?}") - } - }) - .collect::(); - let block_str = state - .blocks - .iter() - .map(|elem| format!("\n > {elem:?}")) - .collect::(); + let stack_str = state.stack.iter().fold(String::new(), |mut s, elem| { + if elem.payload_is::() { + s.push_str("\n > {frame}"); + } else { + std::fmt::write(&mut s, format_args!("\n > {elem:?}")).unwrap(); + } + s + }); + let block_str = state.blocks.iter().fold(String::new(), |mut s, elem| { + std::fmt::write(&mut s, format_args!("\n > {elem:?}")).unwrap(); + s + }); // TODO: fix this up let locals = self.locals.clone(); write!( diff --git a/vm/src/import.rs b/vm/src/import.rs index 0edc2f77a2..918424296e 100644 --- a/vm/src/import.rs +++ b/vm/src/import.rs @@ -31,10 +31,7 @@ pub(crate) fn init_importlib_base(vm: &mut VirtualMachine) -> PyResult PyResult<()> { +pub(crate) fn init_importlib_package(vm: &VirtualMachine, importlib: PyObjectRef) -> PyResult<()> { thread::enter_vm(vm, || { flame_guard!("install_external"); diff --git a/vm/src/macros.rs b/vm/src/macros.rs index c058764f38..4554a65c26 100644 --- a/vm/src/macros.rs +++ b/vm/src/macros.rs @@ -116,10 +116,12 @@ macro_rules! match_class { // The default arm, binding the original object to the specified identifier. (match ($obj:expr) { $binding:ident => $default:expr $(,)? }) => {{ + #[allow(clippy::redundant_locals)] let $binding = $obj; $default }}; (match ($obj:expr) { ref $binding:ident => $default:expr $(,)? }) => {{ + #[allow(clippy::redundant_locals)] let $binding = &$obj; $default }}; diff --git a/vm/src/stdlib/ast.rs b/vm/src/stdlib/ast.rs index bf34528e2c..50b3153d0c 100644 --- a/vm/src/stdlib/ast.rs +++ b/vm/src/stdlib/ast.rs @@ -168,7 +168,7 @@ fn range_from_object( None }; let range = SourceRange { - start: location.unwrap_or(SourceLocation::default()), + start: location.unwrap_or_default(), end: end_location, }; Ok(range) diff --git a/vm/src/stdlib/sys.rs b/vm/src/stdlib/sys.rs index 81498a7b06..65d50eaa15 100644 --- a/vm/src/stdlib/sys.rs +++ b/vm/src/stdlib/sys.rs @@ -784,8 +784,8 @@ mod sys { impl PyThreadInfo { const INFO: Self = PyThreadInfo { name: crate::stdlib::thread::_thread::PYTHREAD_NAME, - /// As I know, there's only way to use lock as "Mutex" in Rust - /// with satisfying python document spec. + // As I know, there's only way to use lock as "Mutex" in Rust + // with satisfying python document spec. lock: Some("mutex+cond"), version: None, }; diff --git a/wasm/lib/src/lib.rs b/wasm/lib/src/lib.rs index 07bb2f5f7d..85546c78d3 100644 --- a/wasm/lib/src/lib.rs +++ b/wasm/lib/src/lib.rs @@ -52,7 +52,7 @@ pub mod eval { fn run_py(source: &str, options: Option, mode: Mode) -> Result { let vm = VMStore::init(PY_EVAL_VM_ID.into(), Some(true)); - let options = options.unwrap_or_else(Object::new); + let options = options.unwrap_or_default(); let js_vars = { let prop = Reflect::get(&options, &"vars".into())?; if prop.is_undefined() { From 7022512b83d2bc2dcc473f0edf2ae1e31b20cf3d Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Fri, 6 Oct 2023 14:34:35 +0900 Subject: [PATCH 117/893] retry windows ci openssl fix (#5082) --- .github/workflows/ci.yaml | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 12308c1dbc..2f53a5dd1c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -127,7 +127,11 @@ jobs: shell: bash run: | choco install llvm openssl --no-progress - echo "OPENSSL_DIR=C:\Program Files\OpenSSL-Win64" >> $GITHUB_ENV + if [[ -d "C:\Program Files\OpenSSL-Win64" ]]; then + echo "OPENSSL_DIR=C:\Program Files\OpenSSL-Win64" >> $GITHUB_ENV + else + echo "OPENSSL_DIR=C:\Program Files\OpenSSL" >> $GITHUB_ENV + fi if: runner.os == 'Windows' - name: Set up the Mac environment run: brew install autoconf automake libtool @@ -252,7 +256,11 @@ jobs: shell: bash run: | choco install llvm openssl --no-progress - echo "OPENSSL_DIR=C:\Program Files\OpenSSL-Win64" >>$GITHUB_ENV + if [[ -d "C:\Program Files\OpenSSL-Win64" ]]; then + echo "OPENSSL_DIR=C:\Program Files\OpenSSL-Win64;" >> $GITHUB_ENV + else + echo "OPENSSL_DIR=C:\Program Files\OpenSSL;" >> $GITHUB_ENV + fi if: runner.os == 'Windows' - name: Set up the Mac environment run: brew install autoconf automake libtool openssl@3 From 2fa88f94b6e1c7693aed8680b522d9b05f0b5659 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Fri, 6 Oct 2023 14:34:50 +0900 Subject: [PATCH 118/893] Skip flaky test_enum::test_unique_composite (#5084) --- Lib/test/test_enum.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index be242e93f7..1c307e75ee 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -3417,6 +3417,7 @@ class Color(StrMixin, AllMixin, Flag): self.assertEqual(Color.ALL.value, 7) self.assertEqual(str(Color.BLUE), 'blue') + @unittest.skip("TODO: RUSTPYTHON; flaky test") @threading_helper.reap_threads @threading_helper.requires_working_threading() def test_unique_composite(self): From 987d50c09257d1d94d5a50eaf4704989f590b729 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Sat, 7 Oct 2023 03:01:42 +0900 Subject: [PATCH 119/893] port to windows-rs (#5080) * Fix OpenSSL in windows CI * bump windows-rs * prepare windows-sys 0.48 * CloseHandle * DuplicateHandle * CreatePipe * GetFileType * GetExitCodeProcess * TerminateProcess * GetStdHandle * GetCurrentProcess * DeleteProcThreadAttributeList * WaitForSingleObject * CreateProcessW * InitializeProcThreadAttributeList * UpdateProcThreadAttribute * clean up helpers --- .github/workflows/ci.yaml | 4 +- Cargo.lock | 101 ++++++--------- vm/Cargo.toml | 20 ++- vm/src/stdlib/winapi.rs | 256 +++++++++++++++++++++++--------------- 4 files changed, 213 insertions(+), 168 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2f53a5dd1c..b4e1eb1932 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -257,9 +257,9 @@ jobs: run: | choco install llvm openssl --no-progress if [[ -d "C:\Program Files\OpenSSL-Win64" ]]; then - echo "OPENSSL_DIR=C:\Program Files\OpenSSL-Win64;" >> $GITHUB_ENV + echo "OPENSSL_DIR=C:\Program Files\OpenSSL-Win64" >> $GITHUB_ENV else - echo "OPENSSL_DIR=C:\Program Files\OpenSSL;" >> $GITHUB_ENV + echo "OPENSSL_DIR=C:\Program Files\OpenSSL" >> $GITHUB_ENV fi if: runner.os == 'Windows' - name: Set up the Mac environment diff --git a/Cargo.lock b/Cargo.lock index 98fcfc5fa5..8688266ad1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -253,7 +253,7 @@ dependencies = [ "js-sys", "num-traits", "wasm-bindgen", - "windows-targets 0.48.0", + "windows-targets 0.48.5", ] [[package]] @@ -2384,6 +2384,7 @@ dependencies = [ "widestring", "winapi", "windows", + "windows-sys 0.48.0", "winreg", ] @@ -3271,15 +3272,21 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows" -version = "0.39.0" +version = "0.51.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1c4bd0a50ac6020f65184721f758dba47bb9fbc2133df715ec74a237b26794a" +checksum = "ca229916c5ee38c2f2bc1e9d8f04df975b4bd93f9955dc69fabb5d91270045c9" dependencies = [ - "windows_aarch64_msvc 0.39.0", - "windows_i686_gnu 0.39.0", - "windows_i686_msvc 0.39.0", - "windows_x86_64_gnu 0.39.0", - "windows_x86_64_msvc 0.39.0", + "windows-core", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-core" +version = "0.51.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" +dependencies = [ + "windows-targets 0.48.5", ] [[package]] @@ -3325,7 +3332,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.48.0", + "windows-targets 0.48.5", ] [[package]] @@ -3345,17 +3352,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm 0.48.0", - "windows_aarch64_msvc 0.48.0", - "windows_i686_gnu 0.48.0", - "windows_i686_msvc 0.48.0", - "windows_x86_64_gnu 0.48.0", - "windows_x86_64_gnullvm 0.48.0", - "windows_x86_64_msvc 0.48.0", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", ] [[package]] @@ -3366,9 +3373,9 @@ checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_msvc" @@ -3376,12 +3383,6 @@ version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" -[[package]] -name = "windows_aarch64_msvc" -version = "0.39.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec7711666096bd4096ffa835238905bb33fb87267910e154b18b44eaabb340f2" - [[package]] name = "windows_aarch64_msvc" version = "0.42.1" @@ -3390,9 +3391,9 @@ checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" [[package]] name = "windows_aarch64_msvc" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_i686_gnu" @@ -3400,12 +3401,6 @@ version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" -[[package]] -name = "windows_i686_gnu" -version = "0.39.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "763fc57100a5f7042e3057e7e8d9bdd7860d330070251a73d003563a3bb49e1b" - [[package]] name = "windows_i686_gnu" version = "0.42.1" @@ -3414,9 +3409,9 @@ checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" [[package]] name = "windows_i686_gnu" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_msvc" @@ -3424,12 +3419,6 @@ version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" -[[package]] -name = "windows_i686_msvc" -version = "0.39.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bc7cbfe58828921e10a9f446fcaaf649204dcfe6c1ddd712c5eebae6bda1106" - [[package]] name = "windows_i686_msvc" version = "0.42.1" @@ -3438,9 +3427,9 @@ checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" [[package]] name = "windows_i686_msvc" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_x86_64_gnu" @@ -3448,12 +3437,6 @@ version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" -[[package]] -name = "windows_x86_64_gnu" -version = "0.39.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6868c165637d653ae1e8dc4d82c25d4f97dd6605eaa8d784b5c6e0ab2a252b65" - [[package]] name = "windows_x86_64_gnu" version = "0.42.1" @@ -3462,9 +3445,9 @@ checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" [[package]] name = "windows_x86_64_gnu" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnullvm" @@ -3474,9 +3457,9 @@ checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_msvc" @@ -3484,12 +3467,6 @@ version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" -[[package]] -name = "windows_x86_64_msvc" -version = "0.39.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e4d40883ae9cae962787ca76ba76390ffa29214667a111db9e0a1ad8377e809" - [[package]] name = "windows_x86_64_msvc" version = "0.42.1" @@ -3498,9 +3475,9 @@ checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" [[package]] name = "windows_x86_64_msvc" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "winreg" diff --git a/vm/Cargo.toml b/vm/Cargo.toml index ee2eadc681..ec1190d165 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -113,9 +113,25 @@ widestring = { workspace = true } winreg = "0.10.1" [target.'cfg(windows)'.dependencies.windows] -version = "0.39.0" +version = "0.51.1" features = [ - "Win32_UI_Shell", "Win32_System_LibraryLoader", "Win32_Foundation" + "Win32_Foundation", + "Win32_System_LibraryLoader", + "Win32_System_Threading", + "Win32_UI_Shell", +] + +[target.'cfg(windows)'.dependencies.windows-sys] +version = "0.48.0" +features = [ + "Win32_Foundation", + "Win32_Security", + "Win32_Storage_FileSystem", + "Win32_System_Console", + "Win32_System_LibraryLoader", + "Win32_System_Pipes", + "Win32_System_Threading", + "Win32_UI_Shell", ] [target.'cfg(windows)'.dependencies.winapi] diff --git a/vm/src/stdlib/winapi.rs b/vm/src/stdlib/winapi.rs index 68d6f51256..34ec941531 100644 --- a/vm/src/stdlib/winapi.rs +++ b/vm/src/stdlib/winapi.rs @@ -6,22 +6,18 @@ mod _winapi { use crate::{ builtins::PyStrRef, common::windows::ToWideString, - convert::ToPyException, + convert::{ToPyException, ToPyObject, ToPyResult}, function::{ArgMapping, ArgSequence, OptionalArg}, stdlib::os::errno_err, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use std::ptr::{null, null_mut}; - use winapi::shared::winerror; - use winapi::um::{ - fileapi, handleapi, namedpipeapi, processenv, processthreadsapi, synchapi, winbase, - winnt::HANDLE, - }; + use winapi::um::winbase; use windows::{ core::PCWSTR, - Win32::Foundation::{HINSTANCE, MAX_PATH}, - Win32::System::LibraryLoader::{GetModuleFileNameW, LoadLibraryW}, + Win32::Foundation::{HANDLE, HINSTANCE, MAX_PATH}, }; + use windows_sys::Win32::Foundation::{BOOL, HANDLE as RAW_HANDLE}; #[pyattr] use winapi::{ @@ -66,41 +62,77 @@ mod _winapi { unsafe { winapi::um::errhandlingapi::GetLastError() } } - fn husize(h: HANDLE) -> usize { - h as usize - } - - trait Convertible { + trait WindowsSysResultValue { + type Ok: ToPyObject; fn is_err(&self) -> bool; + fn into_ok(self) -> Self::Ok; } - impl Convertible for HANDLE { + impl WindowsSysResultValue for RAW_HANDLE { + type Ok = HANDLE; fn is_err(&self) -> bool { - *self == handleapi::INVALID_HANDLE_VALUE + *self == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE + } + fn into_ok(self) -> Self::Ok { + HANDLE(self) } } - impl Convertible for i32 { + + impl WindowsSysResultValue for BOOL { + type Ok = (); fn is_err(&self) -> bool { *self == 0 } + fn into_ok(self) -> Self::Ok {} } - fn cvt(vm: &VirtualMachine, res: T) -> PyResult { - if res.is_err() { - Err(errno_err(vm)) - } else { - Ok(res) + struct WindowsSysResult(T); + + impl WindowsSysResult { + fn is_err(&self) -> bool { + self.0.is_err() + } + fn into_pyresult(self, vm: &VirtualMachine) -> PyResult { + if self.is_err() { + Err(errno_err(vm)) + } else { + Ok(self.0.into_ok()) + } + } + } + + impl ToPyResult for WindowsSysResult { + fn to_pyresult(self, vm: &VirtualMachine) -> PyResult { + let ok = self.into_pyresult(vm)?; + Ok(ok.to_pyobject(vm)) + } + } + + type HandleInt = usize; // TODO: change to isize when fully ported to windows-rs + + impl TryFromObject for HANDLE { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + let handle = HandleInt::try_from_object(vm, obj)?; + Ok(HANDLE(handle as isize)) + } + } + + impl ToPyObject for HANDLE { + fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { + (self.0 as HandleInt).to_pyobject(vm) } } #[pyfunction] - fn CloseHandle(handle: usize, vm: &VirtualMachine) -> PyResult<()> { - cvt(vm, unsafe { handleapi::CloseHandle(handle as HANDLE) }).map(drop) + fn CloseHandle(handle: HANDLE) -> WindowsSysResult { + WindowsSysResult(unsafe { windows_sys::Win32::Foundation::CloseHandle(handle.0) }) } #[pyfunction] - fn GetStdHandle(std_handle: u32, vm: &VirtualMachine) -> PyResult { - cvt(vm, unsafe { processenv::GetStdHandle(std_handle) }).map(husize) + fn GetStdHandle( + std_handle: windows_sys::Win32::System::Console::STD_HANDLE, + ) -> WindowsSysResult { + WindowsSysResult(unsafe { windows_sys::Win32::System::Console::GetStdHandle(std_handle) }) } #[pyfunction] @@ -108,51 +140,63 @@ mod _winapi { _pipe_attrs: PyObjectRef, size: u32, vm: &VirtualMachine, - ) -> PyResult<(usize, usize)> { - let mut read = null_mut(); - let mut write = null_mut(); - cvt(vm, unsafe { - namedpipeapi::CreatePipe(&mut read, &mut write, null_mut(), size) - })?; - Ok((read as usize, write as usize)) + ) -> PyResult<(HANDLE, HANDLE)> { + let (read, write) = unsafe { + let mut read = std::mem::MaybeUninit::::uninit(); + let mut write = std::mem::MaybeUninit::::uninit(); + WindowsSysResult(windows_sys::Win32::System::Pipes::CreatePipe( + read.as_mut_ptr(), + write.as_mut_ptr(), + std::ptr::null(), + size, + )) + .to_pyresult(vm)?; + (read.assume_init(), write.assume_init()) + }; + Ok((HANDLE(read), HANDLE(write))) } #[pyfunction] fn DuplicateHandle( - (src_process, src): (usize, usize), - target_process: usize, + (src_process, src): (HANDLE, HANDLE), + target_process: HANDLE, access: u32, - inherit: i32, + inherit: BOOL, options: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { - let mut target = null_mut(); - cvt(vm, unsafe { - handleapi::DuplicateHandle( - src_process as _, - src as _, - target_process as _, - &mut target, + ) -> PyResult { + let target = unsafe { + let mut target = std::mem::MaybeUninit::::uninit(); + WindowsSysResult(windows_sys::Win32::Foundation::DuplicateHandle( + src_process.0, + src.0, + target_process.0, + target.as_mut_ptr(), access, inherit, options.unwrap_or(0), - ) - })?; - Ok(target as usize) + )) + .to_pyresult(vm)?; + target.assume_init() + }; + Ok(HANDLE(target)) } #[pyfunction] - fn GetCurrentProcess() -> usize { - unsafe { processthreadsapi::GetCurrentProcess() as usize } + fn GetCurrentProcess() -> HANDLE { + unsafe { windows::Win32::System::Threading::GetCurrentProcess() } } #[pyfunction] - fn GetFileType(h: usize, vm: &VirtualMachine) -> PyResult { - let ret = unsafe { fileapi::GetFileType(h as _) }; - if ret == 0 && GetLastError() != 0 { + fn GetFileType( + h: HANDLE, + vm: &VirtualMachine, + ) -> PyResult { + let file_type = unsafe { windows_sys::Win32::Storage::FileSystem::GetFileType(h.0) }; + if file_type == 0 && GetLastError() != 0 { Err(errno_err(vm)) } else { - Ok(ret) + Ok(file_type) } } @@ -182,7 +226,7 @@ mod _winapi { fn CreateProcess( args: CreateProcessArgs, vm: &VirtualMachine, - ) -> PyResult<(usize, usize, u32, u32)> { + ) -> PyResult<(HANDLE, HANDLE, u32, u32)> { let mut si = winbase::STARTUPINFOEXW::default(); si.StartupInfo.cb = std::mem::size_of_val(&si) as _; @@ -241,11 +285,11 @@ mod _winapi { let procinfo = unsafe { let mut procinfo = std::mem::MaybeUninit::uninit(); - let ret = processthreadsapi::CreateProcessW( + WindowsSysResult(windows_sys::Win32::System::Threading::CreateProcessW( app_name, command_line, - null_mut(), - null_mut(), + std::ptr::null(), + std::ptr::null(), args.inherit_handles, args.creation_flags | winbase::EXTENDED_STARTUPINFO_PRESENT @@ -254,16 +298,14 @@ mod _winapi { current_dir, &mut si as *mut winbase::STARTUPINFOEXW as _, procinfo.as_mut_ptr(), - ); - if ret == 0 { - return Err(errno_err(vm)); - } + )) + .into_pyresult(vm)?; procinfo.assume_init() }; Ok(( - procinfo.hProcess as usize, - procinfo.hThread as usize, + HANDLE(procinfo.hProcess), + HANDLE(procinfo.hThread), procinfo.dwProcessId, procinfo.dwThreadId, )) @@ -310,7 +352,9 @@ mod _winapi { impl Drop for AttrList { fn drop(&mut self) { unsafe { - processthreadsapi::DeleteProcThreadAttributeList(self.attrlist.as_mut_ptr() as _) + windows_sys::Win32::System::Threading::DeleteProcThreadAttributeList( + self.attrlist.as_mut_ptr() as *mut _, + ) }; } } @@ -333,49 +377,50 @@ mod _winapi { .transpose()?; let attr_count = handlelist.is_some() as u32; - let mut size = 0; - let ret = unsafe { - processthreadsapi::InitializeProcThreadAttributeList( - null_mut(), - attr_count, - 0, - &mut size, - ) + let (result, mut size) = unsafe { + let mut size = std::mem::MaybeUninit::uninit(); + let result = WindowsSysResult( + windows_sys::Win32::System::Threading::InitializeProcThreadAttributeList( + std::ptr::null_mut(), + attr_count, + 0, + size.as_mut_ptr(), + ), + ); + (result, size.assume_init()) }; - if ret != 0 || GetLastError() != winerror::ERROR_INSUFFICIENT_BUFFER { + if !result.is_err() + || GetLastError() != winapi::shared::winerror::ERROR_INSUFFICIENT_BUFFER + { return Err(errno_err(vm)); } let mut attrlist = vec![0u8; size]; - let ret = unsafe { - processthreadsapi::InitializeProcThreadAttributeList( - attrlist.as_mut_ptr() as _, + WindowsSysResult(unsafe { + windows_sys::Win32::System::Threading::InitializeProcThreadAttributeList( + attrlist.as_mut_ptr() as *mut _, attr_count, 0, &mut size, ) - }; - if ret == 0 { - return Err(errno_err(vm)); - } + }) + .into_pyresult(vm)?; let mut attrs = AttrList { handlelist, attrlist, }; if let Some(ref mut handlelist) = attrs.handlelist { - let ret = unsafe { - processthreadsapi::UpdateProcThreadAttribute( + WindowsSysResult(unsafe { + windows_sys::Win32::System::Threading::UpdateProcThreadAttribute( attrs.attrlist.as_mut_ptr() as _, 0, (2 & 0xffff) | 0x20000, // PROC_THREAD_ATTRIBUTE_HANDLE_LIST handlelist.as_mut_ptr() as _, (handlelist.len() * std::mem::size_of::()) as _, - null_mut(), - null_mut(), + std::ptr::null_mut(), + std::ptr::null(), ) - }; - if ret == 0 { - return Err(errno_err(vm)); - } + }) + .into_pyresult(vm)?; } Ok(attrs) }) @@ -383,9 +428,9 @@ mod _winapi { } #[pyfunction] - fn WaitForSingleObject(h: usize, ms: u32, vm: &VirtualMachine) -> PyResult { - let ret = unsafe { synchapi::WaitForSingleObject(h as _, ms) }; - if ret == winbase::WAIT_FAILED { + fn WaitForSingleObject(h: HANDLE, ms: u32, vm: &VirtualMachine) -> PyResult { + let ret = unsafe { windows_sys::Win32::System::Threading::WaitForSingleObject(h.0, ms) }; + if ret == windows_sys::Win32::Foundation::WAIT_FAILED { Err(errno_err(vm)) } else { Ok(ret) @@ -393,27 +438,33 @@ mod _winapi { } #[pyfunction] - fn GetExitCodeProcess(h: usize, vm: &VirtualMachine) -> PyResult { - let mut ec = 0; - cvt(vm, unsafe { - processthreadsapi::GetExitCodeProcess(h as _, &mut ec) - })?; - Ok(ec) + fn GetExitCodeProcess(h: HANDLE, vm: &VirtualMachine) -> PyResult { + unsafe { + let mut ec = std::mem::MaybeUninit::uninit(); + WindowsSysResult(windows_sys::Win32::System::Threading::GetExitCodeProcess( + h.0, + ec.as_mut_ptr(), + )) + .to_pyresult(vm)?; + Ok(ec.assume_init()) + } } #[pyfunction] - fn TerminateProcess(h: usize, exit_code: u32, vm: &VirtualMachine) -> PyResult<()> { - cvt(vm, unsafe { - processthreadsapi::TerminateProcess(h as _, exit_code) + fn TerminateProcess(h: HANDLE, exit_code: u32) -> WindowsSysResult { + WindowsSysResult(unsafe { + windows_sys::Win32::System::Threading::TerminateProcess(h.0, exit_code) }) - .map(drop) } // TODO: ctypes.LibraryLoader.LoadLibrary #[allow(dead_code)] fn LoadLibrary(path: PyStrRef, vm: &VirtualMachine) -> PyResult { let path = path.as_str().to_wides_with_nul(); - let handle = unsafe { LoadLibraryW(PCWSTR::from_raw(path.as_ptr())).unwrap() }; + let handle = unsafe { + windows::Win32::System::LibraryLoader::LoadLibraryW(PCWSTR::from_raw(path.as_ptr())) + .unwrap() + }; if handle.is_invalid() { return Err(vm.new_runtime_error("LoadLibrary failed".to_owned())); } @@ -425,7 +476,8 @@ mod _winapi { let mut path: Vec = vec![0; MAX_PATH as usize]; let handle = HINSTANCE(handle); - let length = unsafe { GetModuleFileNameW(handle, &mut path) }; + let length = + unsafe { windows::Win32::System::LibraryLoader::GetModuleFileNameW(handle, &mut path) }; if length == 0 { return Err(vm.new_runtime_error("GetModuleFileName failed".to_owned())); } From 0f3a9311e0ebd1f0d4ecd4e2e5f196836b117449 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Sat, 7 Oct 2023 02:26:40 +0900 Subject: [PATCH 120/893] port winbase --- vm/src/stdlib/winapi.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vm/src/stdlib/winapi.rs b/vm/src/stdlib/winapi.rs index 34ec941531..1b16b9352b 100644 --- a/vm/src/stdlib/winapi.rs +++ b/vm/src/stdlib/winapi.rs @@ -12,7 +12,6 @@ mod _winapi { PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use std::ptr::{null, null_mut}; - use winapi::um::winbase; use windows::{ core::PCWSTR, Win32::Foundation::{HANDLE, HINSTANCE, MAX_PATH}, @@ -227,7 +226,8 @@ mod _winapi { args: CreateProcessArgs, vm: &VirtualMachine, ) -> PyResult<(HANDLE, HANDLE, u32, u32)> { - let mut si = winbase::STARTUPINFOEXW::default(); + let mut si: windows_sys::Win32::System::Threading::STARTUPINFOEXW = + unsafe { std::mem::zeroed() }; si.StartupInfo.cb = std::mem::size_of_val(&si) as _; macro_rules! si_attr { @@ -292,11 +292,11 @@ mod _winapi { std::ptr::null(), args.inherit_handles, args.creation_flags - | winbase::EXTENDED_STARTUPINFO_PRESENT - | winbase::CREATE_UNICODE_ENVIRONMENT, + | windows_sys::Win32::System::Threading::EXTENDED_STARTUPINFO_PRESENT + | windows_sys::Win32::System::Threading::CREATE_UNICODE_ENVIRONMENT, env as _, current_dir, - &mut si as *mut winbase::STARTUPINFOEXW as _, + &mut si as *mut _ as *mut _, procinfo.as_mut_ptr(), )) .into_pyresult(vm)?; @@ -390,7 +390,7 @@ mod _winapi { (result, size.assume_init()) }; if !result.is_err() - || GetLastError() != winapi::shared::winerror::ERROR_INSUFFICIENT_BUFFER + || GetLastError() != windows_sys::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER { return Err(errno_err(vm)); } From 6313e4c9fb9f0a83d819986942fb09839ed2c2b2 Mon Sep 17 00:00:00 2001 From: Yaminyam Date: Sat, 7 Oct 2023 02:35:44 +0900 Subject: [PATCH 121/893] windows-sys attrs Copied from https://github.com/RustPython/RustPython/pull/4086 --- vm/Cargo.toml | 2 ++ vm/src/stdlib/winapi.rs | 59 ++++++++++++++++++++--------------------- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/vm/Cargo.toml b/vm/Cargo.toml index ec1190d165..ed226016e3 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -130,8 +130,10 @@ features = [ "Win32_System_Console", "Win32_System_LibraryLoader", "Win32_System_Pipes", + "Win32_System_SystemServices", "Win32_System_Threading", "Win32_UI_Shell", + "Win32_UI_WindowsAndMessaging", ] [target.'cfg(windows)'.dependencies.winapi] diff --git a/vm/src/stdlib/winapi.rs b/vm/src/stdlib/winapi.rs index 1b16b9352b..9356690ff7 100644 --- a/vm/src/stdlib/winapi.rs +++ b/vm/src/stdlib/winapi.rs @@ -19,42 +19,41 @@ mod _winapi { use windows_sys::Win32::Foundation::{BOOL, HANDLE as RAW_HANDLE}; #[pyattr] - use winapi::{ - shared::winerror::{ - ERROR_ALREADY_EXISTS, ERROR_BROKEN_PIPE, ERROR_IO_PENDING, ERROR_MORE_DATA, - ERROR_NETNAME_DELETED, ERROR_NO_DATA, ERROR_NO_SYSTEM_RESOURCES, - ERROR_OPERATION_ABORTED, ERROR_PIPE_BUSY, ERROR_PIPE_CONNECTED, ERROR_SEM_TIMEOUT, - WAIT_TIMEOUT, + use windows_sys::Win32::{ + Foundation::{ + DUPLICATE_CLOSE_SOURCE, DUPLICATE_SAME_ACCESS, ERROR_ALREADY_EXISTS, ERROR_BROKEN_PIPE, + ERROR_IO_PENDING, ERROR_MORE_DATA, ERROR_NETNAME_DELETED, ERROR_NO_DATA, + ERROR_NO_SYSTEM_RESOURCES, ERROR_OPERATION_ABORTED, ERROR_PIPE_BUSY, + ERROR_PIPE_CONNECTED, ERROR_SEM_TIMEOUT, GENERIC_READ, GENERIC_WRITE, STILL_ACTIVE, + WAIT_ABANDONED, WAIT_ABANDONED_0, WAIT_OBJECT_0, WAIT_TIMEOUT, }, - um::{ - fileapi::OPEN_EXISTING, - memoryapi::{ - FILE_MAP_ALL_ACCESS, FILE_MAP_COPY, FILE_MAP_EXECUTE, FILE_MAP_READ, FILE_MAP_WRITE, + Storage::FileSystem::{ + FILE_FLAG_FIRST_PIPE_INSTANCE, FILE_FLAG_OVERLAPPED, FILE_GENERIC_READ, + FILE_GENERIC_WRITE, FILE_TYPE_CHAR, FILE_TYPE_DISK, FILE_TYPE_PIPE, FILE_TYPE_REMOTE, + FILE_TYPE_UNKNOWN, OPEN_EXISTING, PIPE_ACCESS_DUPLEX, PIPE_ACCESS_INBOUND, SYNCHRONIZE, + }, + System::{ + Console::{STD_ERROR_HANDLE, STD_INPUT_HANDLE, STD_OUTPUT_HANDLE}, + Memory::{ + FILE_MAP_ALL_ACCESS, MEM_COMMIT, MEM_FREE, MEM_IMAGE, MEM_MAPPED, MEM_PRIVATE, + MEM_RESERVE, PAGE_EXECUTE, PAGE_EXECUTE_READ, PAGE_EXECUTE_READWRITE, + PAGE_EXECUTE_WRITECOPY, PAGE_GUARD, PAGE_NOACCESS, PAGE_NOCACHE, PAGE_READONLY, + PAGE_READWRITE, PAGE_WRITECOMBINE, PAGE_WRITECOPY, SEC_COMMIT, SEC_IMAGE, + SEC_LARGE_PAGES, SEC_NOCACHE, SEC_RESERVE, SEC_WRITECOMBINE, + }, + Pipes::{ + PIPE_READMODE_MESSAGE, PIPE_TYPE_MESSAGE, PIPE_UNLIMITED_INSTANCES, PIPE_WAIT, }, - minwinbase::STILL_ACTIVE, - winbase::{ + SystemServices::LOCALE_NAME_MAX_LENGTH, + Threading::{ ABOVE_NORMAL_PRIORITY_CLASS, BELOW_NORMAL_PRIORITY_CLASS, CREATE_BREAKAWAY_FROM_JOB, CREATE_DEFAULT_ERROR_MODE, CREATE_NEW_CONSOLE, - CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW, DETACHED_PROCESS, - FILE_FLAG_FIRST_PIPE_INSTANCE, FILE_FLAG_OVERLAPPED, FILE_TYPE_CHAR, - FILE_TYPE_DISK, FILE_TYPE_PIPE, FILE_TYPE_REMOTE, FILE_TYPE_UNKNOWN, - HIGH_PRIORITY_CLASS, IDLE_PRIORITY_CLASS, INFINITE, NORMAL_PRIORITY_CLASS, - PIPE_ACCESS_DUPLEX, PIPE_ACCESS_INBOUND, PIPE_READMODE_MESSAGE, PIPE_TYPE_MESSAGE, - PIPE_UNLIMITED_INSTANCES, PIPE_WAIT, REALTIME_PRIORITY_CLASS, STARTF_USESHOWWINDOW, - STARTF_USESTDHANDLES, STD_ERROR_HANDLE, STD_INPUT_HANDLE, STD_OUTPUT_HANDLE, - WAIT_ABANDONED, WAIT_ABANDONED_0, WAIT_OBJECT_0, - }, - winnt::{ - DUPLICATE_CLOSE_SOURCE, DUPLICATE_SAME_ACCESS, FILE_GENERIC_READ, - FILE_GENERIC_WRITE, GENERIC_READ, GENERIC_WRITE, LOCALE_NAME_MAX_LENGTH, - MEM_COMMIT, MEM_FREE, MEM_IMAGE, MEM_MAPPED, MEM_PRIVATE, MEM_RESERVE, - PAGE_EXECUTE, PAGE_EXECUTE_READ, PAGE_EXECUTE_READWRITE, PAGE_EXECUTE_WRITECOPY, - PAGE_GUARD, PAGE_NOACCESS, PAGE_NOCACHE, PAGE_READONLY, PAGE_READWRITE, - PAGE_WRITECOMBINE, PAGE_WRITECOPY, PROCESS_DUP_HANDLE, SEC_COMMIT, SEC_IMAGE, - SEC_LARGE_PAGES, SEC_NOCACHE, SEC_RESERVE, SEC_WRITECOMBINE, SYNCHRONIZE, + CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW, DETACHED_PROCESS, HIGH_PRIORITY_CLASS, + IDLE_PRIORITY_CLASS, INFINITE, NORMAL_PRIORITY_CLASS, PROCESS_DUP_HANDLE, + REALTIME_PRIORITY_CLASS, STARTF_USESHOWWINDOW, STARTF_USESTDHANDLES, }, - winuser::SW_HIDE, }, + UI::WindowsAndMessaging::SW_HIDE, }; fn GetLastError() -> u32 { From 58df09b4922def231c7e5d90f668385b901dc075 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Sat, 7 Oct 2023 02:52:05 +0900 Subject: [PATCH 122/893] GetLastError --- vm/src/stdlib/winapi.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vm/src/stdlib/winapi.rs b/vm/src/stdlib/winapi.rs index 9356690ff7..a83d58b909 100644 --- a/vm/src/stdlib/winapi.rs +++ b/vm/src/stdlib/winapi.rs @@ -56,10 +56,6 @@ mod _winapi { UI::WindowsAndMessaging::SW_HIDE, }; - fn GetLastError() -> u32 { - unsafe { winapi::um::errhandlingapi::GetLastError() } - } - trait WindowsSysResultValue { type Ok: ToPyObject; fn is_err(&self) -> bool; @@ -191,7 +187,7 @@ mod _winapi { vm: &VirtualMachine, ) -> PyResult { let file_type = unsafe { windows_sys::Win32::Storage::FileSystem::GetFileType(h.0) }; - if file_type == 0 && GetLastError() != 0 { + if file_type == 0 && unsafe { windows_sys::Win32::Foundation::GetLastError() } != 0 { Err(errno_err(vm)) } else { Ok(file_type) @@ -389,7 +385,8 @@ mod _winapi { (result, size.assume_init()) }; if !result.is_err() - || GetLastError() != windows_sys::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER + || unsafe { windows_sys::Win32::Foundation::GetLastError() } + != windows_sys::Win32::Foundation::ERROR_INSUFFICIENT_BUFFER { return Err(errno_err(vm)); } From 4a61eba58e5e8f1b7fb8aea5506bb07a16f05a84 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Sat, 7 Oct 2023 03:01:21 +0900 Subject: [PATCH 123/893] rustpython_vm::windows --- vm/src/lib.rs | 2 ++ vm/src/stdlib/winapi.rs | 64 ++------------------------------------ vm/src/windows.rs | 68 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 62 deletions(-) create mode 100644 vm/src/windows.rs diff --git a/vm/src/lib.rs b/vm/src/lib.rs index 71ca05369b..aae29a9a28 100644 --- a/vm/src/lib.rs +++ b/vm/src/lib.rs @@ -79,6 +79,8 @@ pub mod utils; pub mod version; pub mod vm; pub mod warn; +#[cfg(windows)] +pub mod windows; pub use self::compiler::parser::source_code; pub use self::convert::{TryFromBorrowedObject, TryFromObject}; diff --git a/vm/src/stdlib/winapi.rs b/vm/src/stdlib/winapi.rs index a83d58b909..ed88fabceb 100644 --- a/vm/src/stdlib/winapi.rs +++ b/vm/src/stdlib/winapi.rs @@ -6,9 +6,10 @@ mod _winapi { use crate::{ builtins::PyStrRef, common::windows::ToWideString, - convert::{ToPyException, ToPyObject, ToPyResult}, + convert::{ToPyException, ToPyResult}, function::{ArgMapping, ArgSequence, OptionalArg}, stdlib::os::errno_err, + windows::WindowsSysResult, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use std::ptr::{null, null_mut}; @@ -56,67 +57,6 @@ mod _winapi { UI::WindowsAndMessaging::SW_HIDE, }; - trait WindowsSysResultValue { - type Ok: ToPyObject; - fn is_err(&self) -> bool; - fn into_ok(self) -> Self::Ok; - } - - impl WindowsSysResultValue for RAW_HANDLE { - type Ok = HANDLE; - fn is_err(&self) -> bool { - *self == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE - } - fn into_ok(self) -> Self::Ok { - HANDLE(self) - } - } - - impl WindowsSysResultValue for BOOL { - type Ok = (); - fn is_err(&self) -> bool { - *self == 0 - } - fn into_ok(self) -> Self::Ok {} - } - - struct WindowsSysResult(T); - - impl WindowsSysResult { - fn is_err(&self) -> bool { - self.0.is_err() - } - fn into_pyresult(self, vm: &VirtualMachine) -> PyResult { - if self.is_err() { - Err(errno_err(vm)) - } else { - Ok(self.0.into_ok()) - } - } - } - - impl ToPyResult for WindowsSysResult { - fn to_pyresult(self, vm: &VirtualMachine) -> PyResult { - let ok = self.into_pyresult(vm)?; - Ok(ok.to_pyobject(vm)) - } - } - - type HandleInt = usize; // TODO: change to isize when fully ported to windows-rs - - impl TryFromObject for HANDLE { - fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - let handle = HandleInt::try_from_object(vm, obj)?; - Ok(HANDLE(handle as isize)) - } - } - - impl ToPyObject for HANDLE { - fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { - (self.0 as HandleInt).to_pyobject(vm) - } - } - #[pyfunction] fn CloseHandle(handle: HANDLE) -> WindowsSysResult { WindowsSysResult(unsafe { windows_sys::Win32::Foundation::CloseHandle(handle.0) }) diff --git a/vm/src/windows.rs b/vm/src/windows.rs new file mode 100644 index 0000000000..9216f839fe --- /dev/null +++ b/vm/src/windows.rs @@ -0,0 +1,68 @@ +use crate::{ + convert::{ToPyObject, ToPyResult}, + stdlib::os::errno_err, + PyObjectRef, PyResult, TryFromObject, VirtualMachine, +}; +use windows::Win32::Foundation::HANDLE; +use windows_sys::Win32::Foundation::{BOOL, HANDLE as RAW_HANDLE}; + +pub(crate) trait WindowsSysResultValue { + type Ok: ToPyObject; + fn is_err(&self) -> bool; + fn into_ok(self) -> Self::Ok; +} + +impl WindowsSysResultValue for RAW_HANDLE { + type Ok = HANDLE; + fn is_err(&self) -> bool { + *self == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE + } + fn into_ok(self) -> Self::Ok { + HANDLE(self) + } +} + +impl WindowsSysResultValue for BOOL { + type Ok = (); + fn is_err(&self) -> bool { + *self == 0 + } + fn into_ok(self) -> Self::Ok {} +} + +pub(crate) struct WindowsSysResult(pub T); + +impl WindowsSysResult { + pub fn is_err(&self) -> bool { + self.0.is_err() + } + pub fn into_pyresult(self, vm: &VirtualMachine) -> PyResult { + if self.is_err() { + Err(errno_err(vm)) + } else { + Ok(self.0.into_ok()) + } + } +} + +impl ToPyResult for WindowsSysResult { + fn to_pyresult(self, vm: &VirtualMachine) -> PyResult { + let ok = self.into_pyresult(vm)?; + Ok(ok.to_pyobject(vm)) + } +} + +type HandleInt = usize; // TODO: change to isize when fully ported to windows-rs + +impl TryFromObject for HANDLE { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + let handle = HandleInt::try_from_object(vm, obj)?; + Ok(HANDLE(handle as isize)) + } +} + +impl ToPyObject for HANDLE { + fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { + (self.0 as HandleInt).to_pyobject(vm) + } +} From 285ba765a70ab243ee0f881604e60b999c5771af Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sat, 7 Oct 2023 14:40:45 +0900 Subject: [PATCH 124/893] 0.4.2 with dependency update --- Cargo.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 53524f1446..30e403b54c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,15 +1,15 @@ [package] name = "sre-engine" -version = "0.4.1" +version = "0.4.2" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" repository = "https://github.com/RustPython/sre-engine" license = "MIT" -edition = "2018" +edition = "2021" keywords = ["regex"] include = ["LICENSE", "src/**/*.rs"] [dependencies] -num_enum = "0.5" -bitflags = "1.2" +num_enum = "0.5.9" +bitflags = "2" optional = "0.5" From 54d5869457d349636c7356d7e8a77b332ef9573c Mon Sep 17 00:00:00 2001 From: Amuthan Mannar Date: Sat, 7 Oct 2023 12:02:18 +0530 Subject: [PATCH 125/893] Implement binary operations for integers and floating-point numbers, allowing mixed type calculations --- jit/src/instructions.rs | 28 ++++++++++++++++++++ jit/tests/float_tests.rs | 57 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/jit/src/instructions.rs b/jit/src/instructions.rs index 86dc58aca4..ffef92a413 100644 --- a/jit/src/instructions.rs +++ b/jit/src/instructions.rs @@ -381,6 +381,10 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { // the rhs is popped off first let b = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; let a = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; + + let a_type = a.to_jit_type(); + let b_type = b.to_jit_type(); + let val = match (op, a, b) { (BinaryOperator::Add, JitValue::Int(a), JitValue::Int(b)) => { let (out, carry) = self.builder.ins().iadd_ifcout(a, b); @@ -443,6 +447,30 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { (BinaryOperator::Divide, JitValue::Float(a), JitValue::Float(b)) => { JitValue::Float(self.builder.ins().fdiv(a, b)) } + + // Floats and Integers + (_, JitValue::Int(a), JitValue::Float(b)) | + (_, JitValue::Float(a), JitValue::Int(b)) =>{ + + let operand_one = match a_type.unwrap() { + JitType::Int => self.builder.ins().fcvt_from_sint(types::F64, a), + _=> a + }; + + let operand_two = match b_type.unwrap() { + JitType::Int => self.builder.ins().fcvt_from_sint(types::F64, b), + _=> b + }; + + match op{ + BinaryOperator::Add => JitValue::Float(self.builder.ins().fadd(operand_one, operand_two)), + BinaryOperator::Subtract => JitValue::Float(self.builder.ins().fsub(operand_one, operand_two)), + BinaryOperator::Multiply => JitValue::Float(self.builder.ins().fmul(operand_one, operand_two)), + BinaryOperator::Divide => JitValue::Float(self.builder.ins().fdiv(operand_one, operand_two)), + _ => return Err(JitCompileError::NotSupported) + } + + } _ => return Err(JitCompileError::NotSupported), }; self.stack.push(val); diff --git a/jit/tests/float_tests.rs b/jit/tests/float_tests.rs index ae3fcd2c2d..a72433b8d1 100644 --- a/jit/tests/float_tests.rs +++ b/jit/tests/float_tests.rs @@ -32,6 +32,18 @@ fn test_add() { assert_eq!(add(1.0, f64::NEG_INFINITY), Ok(f64::NEG_INFINITY)); } +#[test] +fn test_add_with_integer() { + let add = jit_function! { add(a:f64, b:i64) -> f64 => r##" + def add(a: float, b: int): + return a + b + "## }; + + assert_approx_eq!(add(5.5, 10), Ok(15.5)); + assert_approx_eq!(add(-4.6, 7), Ok(2.4)); + assert_approx_eq!(add(-5.2, -3), Ok(-8.2)); +} + #[test] fn test_sub() { let sub = jit_function! { sub(a:f64, b:f64) -> f64 => r##" @@ -49,6 +61,19 @@ fn test_sub() { assert_eq!(sub(1.0, f64::INFINITY), Ok(f64::NEG_INFINITY)); } +#[test] +fn test_sub_with_integer() { + let sub = jit_function! { sub(a:i64, b:f64) -> f64 => r##" + def sub(a: int, b: float): + return a - b + "## }; + + assert_approx_eq!(sub(5, 3.6), Ok(1.4)); + assert_approx_eq!(sub(3, -4.2), Ok(7.2)); + assert_approx_eq!(sub(-2, 1.3), Ok(-3.3)); + assert_approx_eq!(sub(-3, -1.3), Ok(-1.7)); +} + #[test] fn test_mul() { let mul = jit_function! { mul(a:f64, b:f64) -> f64 => r##" @@ -70,6 +95,21 @@ fn test_mul() { assert_eq!(mul(f64::NEG_INFINITY, f64::INFINITY), Ok(f64::NEG_INFINITY)); } +#[test] +fn test_mul_with_integer() { + let mul = jit_function! { mul(a:f64, b:i64) -> f64 => r##" + def mul(a: float, b: int): + return a * b + "## }; + + assert_approx_eq!(mul(5.2, 2), Ok(10.4)); + assert_approx_eq!(mul(3.4, -1), Ok(-3.4)); + assert_bits_eq!(mul(1.0, 0), Ok(0.0f64)); + assert_bits_eq!(mul(-0.0,1), Ok(-0.0f64)); + assert_bits_eq!(mul(0.0, -1), Ok(-0.0f64)); + assert_bits_eq!(mul(-0.0,-1), Ok(0.0f64)); +} + #[test] fn test_div() { let div = jit_function! { div(a:f64, b:f64) -> f64 => r##" @@ -91,6 +131,23 @@ fn test_div() { assert_bits_eq!(div(-1.0, f64::INFINITY), Ok(-0.0f64)); } +#[test] +fn test_div_with_integer() { + let div = jit_function! { div(a:f64, b:i64) -> f64 => r##" + def div(a: float, b: int): + return a / b + "## }; + + assert_approx_eq!(div(5.2, 2), Ok(2.6)); + assert_approx_eq!(div(3.4, -1), Ok(-3.4)); + assert_eq!(div(1.0, 0), Ok(f64::INFINITY)); + assert_eq!(div(1.0, -0), Ok(f64::INFINITY)); + assert_eq!(div(-1.0, 0), Ok(f64::NEG_INFINITY)); + assert_eq!(div(-1.0, -0), Ok(f64::NEG_INFINITY)); + assert_eq!(div(f64::INFINITY, 2), Ok(f64::INFINITY)); + assert_eq!(div(f64::NEG_INFINITY, 3), Ok(f64::NEG_INFINITY)); +} + #[test] fn test_if_bool() { let if_bool = jit_function! { if_bool(a:f64) -> i64 => r##" From c2f159b24a5e6ec31262f9c482e5b6248db850a1 Mon Sep 17 00:00:00 2001 From: Amuthan Mannar Date: Sat, 7 Oct 2023 12:34:32 +0530 Subject: [PATCH 126/893] Formatted code files --- jit/src/instructions.rs | 30 ++++++++++++++++++------------ jit/tests/float_tests.rs | 4 ++-- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/jit/src/instructions.rs b/jit/src/instructions.rs index ffef92a413..514a9d81cc 100644 --- a/jit/src/instructions.rs +++ b/jit/src/instructions.rs @@ -449,27 +449,33 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { } // Floats and Integers - (_, JitValue::Int(a), JitValue::Float(b)) | - (_, JitValue::Float(a), JitValue::Int(b)) =>{ - + (_, JitValue::Int(a), JitValue::Float(b)) + | (_, JitValue::Float(a), JitValue::Int(b)) => { let operand_one = match a_type.unwrap() { JitType::Int => self.builder.ins().fcvt_from_sint(types::F64, a), - _=> a + _ => a, }; let operand_two = match b_type.unwrap() { JitType::Int => self.builder.ins().fcvt_from_sint(types::F64, b), - _=> b + _ => b, }; - match op{ - BinaryOperator::Add => JitValue::Float(self.builder.ins().fadd(operand_one, operand_two)), - BinaryOperator::Subtract => JitValue::Float(self.builder.ins().fsub(operand_one, operand_two)), - BinaryOperator::Multiply => JitValue::Float(self.builder.ins().fmul(operand_one, operand_two)), - BinaryOperator::Divide => JitValue::Float(self.builder.ins().fdiv(operand_one, operand_two)), - _ => return Err(JitCompileError::NotSupported) + match op { + BinaryOperator::Add => { + JitValue::Float(self.builder.ins().fadd(operand_one, operand_two)) + } + BinaryOperator::Subtract => { + JitValue::Float(self.builder.ins().fsub(operand_one, operand_two)) + } + BinaryOperator::Multiply => { + JitValue::Float(self.builder.ins().fmul(operand_one, operand_two)) + } + BinaryOperator::Divide => { + JitValue::Float(self.builder.ins().fdiv(operand_one, operand_two)) + } + _ => return Err(JitCompileError::NotSupported), } - } _ => return Err(JitCompileError::NotSupported), }; diff --git a/jit/tests/float_tests.rs b/jit/tests/float_tests.rs index a72433b8d1..2ba7dec822 100644 --- a/jit/tests/float_tests.rs +++ b/jit/tests/float_tests.rs @@ -105,9 +105,9 @@ fn test_mul_with_integer() { assert_approx_eq!(mul(5.2, 2), Ok(10.4)); assert_approx_eq!(mul(3.4, -1), Ok(-3.4)); assert_bits_eq!(mul(1.0, 0), Ok(0.0f64)); - assert_bits_eq!(mul(-0.0,1), Ok(-0.0f64)); + assert_bits_eq!(mul(-0.0, 1), Ok(-0.0f64)); assert_bits_eq!(mul(0.0, -1), Ok(-0.0f64)); - assert_bits_eq!(mul(-0.0,-1), Ok(0.0f64)); + assert_bits_eq!(mul(-0.0, -1), Ok(0.0f64)); } #[test] From f8365ca6c3690e63a74b831f211f3087a4f93a80 Mon Sep 17 00:00:00 2001 From: Amuthan Mannar Date: Thu, 5 Oct 2023 21:27:56 +0530 Subject: [PATCH 127/893] Implemented compare operation for boolean types in JIT engine --- jit/src/instructions.rs | 16 +++++++++ jit/tests/bool_tests.rs | 75 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) diff --git a/jit/src/instructions.rs b/jit/src/instructions.rs index 86dc58aca4..c6b5362fb9 100644 --- a/jit/src/instructions.rs +++ b/jit/src/instructions.rs @@ -348,6 +348,22 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { .push(JitValue::Bool(self.builder.ins().bint(types::I8, val))); Ok(()) } + (JitValue::Bool(a), JitValue::Bool(b)) => { + let cond = match op { + ComparisonOperator::Equal => IntCC::Equal, + ComparisonOperator::NotEqual => IntCC::NotEqual, + ComparisonOperator::Less => IntCC::UnsignedLessThan, + ComparisonOperator::LessOrEqual => IntCC::UnsignedLessThanOrEqual, + ComparisonOperator::Greater => IntCC::UnsignedGreaterThan, + ComparisonOperator::GreaterOrEqual => IntCC::UnsignedGreaterThanOrEqual, + }; + + let val = self.builder.ins().icmp(cond, a, b); + // TODO: Remove this `bint` in cranelift 0.90 as icmp now returns i8 + self.stack + .push(JitValue::Bool(self.builder.ins().bint(types::I8, val))); + Ok(()) + } _ => Err(JitCompileError::NotSupported), } } diff --git a/jit/tests/bool_tests.rs b/jit/tests/bool_tests.rs index ed25ddb83f..64b1fd0cd5 100644 --- a/jit/tests/bool_tests.rs +++ b/jit/tests/bool_tests.rs @@ -50,3 +50,78 @@ fn test_if_not() { assert_eq!(if_not(true), Ok(1)); assert_eq!(if_not(false), Ok(0)); } + +#[test] +fn test_eq() { + let eq = jit_function! { eq(a:bool, b:bool) -> i64 => r##" + def eq(a: bool, b: bool): + if a == b: + return 1 + return 0 + "## }; + + assert_eq!(eq(false, false), Ok(1)); + assert_eq!(eq(true, true), Ok(1)); + assert_eq!(eq(false, true), Ok(0)); + assert_eq!(eq(true, false), Ok(0)); +} + +#[test] +fn test_gt() { + let gt = jit_function! { gt(a:bool, b:bool) -> i64 => r##" + def gt(a: bool, b: bool): + if a > b: + return 1 + return 0 + "## }; + + assert_eq!(gt(false, false), Ok(0)); + assert_eq!(gt(true, true), Ok(0)); + assert_eq!(gt(false, true), Ok(0)); + assert_eq!(gt(true, false), Ok(1)); +} + +#[test] +fn test_lt() { + let lt = jit_function! { lt(a:bool, b:bool) -> i64 => r##" + def lt(a: bool, b: bool): + if a < b: + return 1 + return 0 + "## }; + + assert_eq!(lt(false, false), Ok(0)); + assert_eq!(lt(true, true), Ok(0)); + assert_eq!(lt(false, true), Ok(1)); + assert_eq!(lt(true, false), Ok(0)); +} + +#[test] +fn test_gte() { + let gte = jit_function! { gte(a:bool, b:bool) -> i64 => r##" + def gte(a: bool, b: bool): + if a >= b: + return 1 + return 0 + "## }; + + assert_eq!(gte(false, false), Ok(1)); + assert_eq!(gte(true, true), Ok(1)); + assert_eq!(gte(false, true), Ok(0)); + assert_eq!(gte(true, false), Ok(1)); +} + +#[test] +fn test_lte() { + let lte = jit_function! { lte(a:bool, b:bool) -> i64 => r##" + def lte(a: bool, b: bool): + if a <= b: + return 1 + return 0 + "## }; + + assert_eq!(lte(false, false), Ok(1)); + assert_eq!(lte(true, true), Ok(1)); + assert_eq!(lte(false, true), Ok(1)); + assert_eq!(lte(true, false), Ok(0)); +} From b9ee0b6b7a73df3d18980358d4cec2c84d0ebcdb Mon Sep 17 00:00:00 2001 From: Amuthan Mannar Date: Fri, 6 Oct 2023 17:51:54 +0530 Subject: [PATCH 128/893] Implemented compare operation for boolean and int types --- jit/src/instructions.rs | 40 +++++++++++----------- jit/tests/bool_tests.rs | 76 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 19 deletions(-) diff --git a/jit/src/instructions.rs b/jit/src/instructions.rs index c6b5362fb9..2959361a0f 100644 --- a/jit/src/instructions.rs +++ b/jit/src/instructions.rs @@ -315,18 +315,36 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { let b = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; let a = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; + let a_type:Option = a.to_jit_type(); + let b_type:Option = b.to_jit_type(); + match (a, b) { - (JitValue::Int(a), JitValue::Int(b)) => { + (JitValue::Int(a), JitValue::Int(b)) | + (JitValue::Bool(a), JitValue::Bool(b)) | + (JitValue::Bool(a), JitValue::Int(b)) | + (JitValue::Int(a), JitValue::Bool(b)) + => { + + let operand_one = match a_type.unwrap() { + JitType::Bool => self.builder.ins().uextend(types::I64, a), + _=> a + }; + + let operand_two = match b_type.unwrap() { + JitType::Bool => self.builder.ins().uextend(types::I64, b), + _=> b + }; + let cond = match op { ComparisonOperator::Equal => IntCC::Equal, ComparisonOperator::NotEqual => IntCC::NotEqual, ComparisonOperator::Less => IntCC::SignedLessThan, ComparisonOperator::LessOrEqual => IntCC::SignedLessThanOrEqual, ComparisonOperator::Greater => IntCC::SignedGreaterThan, - ComparisonOperator::GreaterOrEqual => IntCC::SignedLessThanOrEqual, + ComparisonOperator::GreaterOrEqual => IntCC::SignedGreaterThanOrEqual, }; - let val = self.builder.ins().icmp(cond, a, b); + let val = self.builder.ins().icmp(cond, operand_one, operand_two); // TODO: Remove this `bint` in cranelift 0.90 as icmp now returns i8 self.stack .push(JitValue::Bool(self.builder.ins().bint(types::I8, val))); @@ -348,22 +366,6 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { .push(JitValue::Bool(self.builder.ins().bint(types::I8, val))); Ok(()) } - (JitValue::Bool(a), JitValue::Bool(b)) => { - let cond = match op { - ComparisonOperator::Equal => IntCC::Equal, - ComparisonOperator::NotEqual => IntCC::NotEqual, - ComparisonOperator::Less => IntCC::UnsignedLessThan, - ComparisonOperator::LessOrEqual => IntCC::UnsignedLessThanOrEqual, - ComparisonOperator::Greater => IntCC::UnsignedGreaterThan, - ComparisonOperator::GreaterOrEqual => IntCC::UnsignedGreaterThanOrEqual, - }; - - let val = self.builder.ins().icmp(cond, a, b); - // TODO: Remove this `bint` in cranelift 0.90 as icmp now returns i8 - self.stack - .push(JitValue::Bool(self.builder.ins().bint(types::I8, val))); - Ok(()) - } _ => Err(JitCompileError::NotSupported), } } diff --git a/jit/tests/bool_tests.rs b/jit/tests/bool_tests.rs index 64b1fd0cd5..26afb97780 100644 --- a/jit/tests/bool_tests.rs +++ b/jit/tests/bool_tests.rs @@ -66,6 +66,21 @@ fn test_eq() { assert_eq!(eq(true, false), Ok(0)); } +#[test] +fn test_eq_with_integers() { + let eq = jit_function! { eq(a:bool, b:i64) -> i64 => r##" + def eq(a: bool, b: int): + if a == b: + return 1 + return 0 + "## }; + + assert_eq!(eq(false, 0), Ok(1)); + assert_eq!(eq(true, 1), Ok(1)); + assert_eq!(eq(false, 1), Ok(0)); + assert_eq!(eq(true, 0), Ok(0)); +} + #[test] fn test_gt() { let gt = jit_function! { gt(a:bool, b:bool) -> i64 => r##" @@ -81,6 +96,21 @@ fn test_gt() { assert_eq!(gt(true, false), Ok(1)); } +#[test] +fn test_gt_with_integers() { + let gt = jit_function! { gt(a:i64, b:bool) -> i64 => r##" + def gt(a: int, b: bool): + if a > b: + return 1 + return 0 + "## }; + + assert_eq!(gt(0, false), Ok(0)); + assert_eq!(gt(1, true), Ok(0)); + assert_eq!(gt(0, true), Ok(0)); + assert_eq!(gt(1, false), Ok(1)); +} + #[test] fn test_lt() { let lt = jit_function! { lt(a:bool, b:bool) -> i64 => r##" @@ -96,6 +126,21 @@ fn test_lt() { assert_eq!(lt(true, false), Ok(0)); } +#[test] +fn test_lt_with_integers() { + let lt = jit_function! { lt(a:i64, b:bool) -> i64 => r##" + def lt(a: int, b: bool): + if a < b: + return 1 + return 0 + "## }; + + assert_eq!(lt(0, false), Ok(0)); + assert_eq!(lt(1, true), Ok(0)); + assert_eq!(lt(0, true), Ok(1)); + assert_eq!(lt(1, false), Ok(0)); +} + #[test] fn test_gte() { let gte = jit_function! { gte(a:bool, b:bool) -> i64 => r##" @@ -111,6 +156,22 @@ fn test_gte() { assert_eq!(gte(true, false), Ok(1)); } + +#[test] +fn test_gte_with_integers() { + let gte = jit_function! { gte(a:bool, b:i64) -> i64 => r##" + def gte(a: bool, b: int): + if a >= b: + return 1 + return 0 + "## }; + + assert_eq!(gte(false, 0), Ok(1)); + assert_eq!(gte(true, 1), Ok(1)); + assert_eq!(gte(false, 1), Ok(0)); + assert_eq!(gte(true, 0), Ok(1)); +} + #[test] fn test_lte() { let lte = jit_function! { lte(a:bool, b:bool) -> i64 => r##" @@ -125,3 +186,18 @@ fn test_lte() { assert_eq!(lte(false, true), Ok(1)); assert_eq!(lte(true, false), Ok(0)); } + +#[test] +fn test_lte_with_integers() { + let lte = jit_function! { lte(a:bool, b:i64) -> i64 => r##" + def lte(a: bool, b: int): + if a <= b: + return 1 + return 0 + "## }; + + assert_eq!(lte(false, 0), Ok(1)); + assert_eq!(lte(true, 1), Ok(1)); + assert_eq!(lte(false, 1), Ok(1)); + assert_eq!(lte(true, 0), Ok(0)); +} From b69e6a910d5ae8be69eef2e5c6496a7c6d3701c9 Mon Sep 17 00:00:00 2001 From: Amuthan Mannar Date: Fri, 6 Oct 2023 18:18:12 +0530 Subject: [PATCH 129/893] Added missed tests in int_tests of jit --- jit/tests/int_tests.rs | 46 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/jit/tests/int_tests.rs b/jit/tests/int_tests.rs index 314849a06e..9ce3f3b4a6 100644 --- a/jit/tests/int_tests.rs +++ b/jit/tests/int_tests.rs @@ -160,6 +160,52 @@ fn test_gt() { assert_eq!(gt(1, -1), Ok(1)); } +#[test] +fn test_lt() { + let lt = jit_function! { lt(a:i64, b:i64) -> i64 => r##" + def lt(a: int, b: int): + if a < b: + return 1 + return 0 + "## }; + + assert_eq!(lt(-1, -5), Ok(0)); + assert_eq!(lt(10, 0), Ok(0)); + assert_eq!(lt(0, 1), Ok(1)); + assert_eq!(lt(-10, -1), Ok(1)); + assert_eq!(lt(100, 100), Ok(0)); +} + +#[test] +fn test_gte() { + let gte = jit_function! { gte(a:i64, b:i64) -> i64 => r##" + def gte(a: int, b: int): + if a >= b: + return 1 + return 0 + "## }; + + assert_eq!(gte(-64, -64), Ok(1)); + assert_eq!(gte(100, -1), Ok(1)); + assert_eq!(gte(1, 2), Ok(0)); + assert_eq!(gte(1, 0), Ok(1)); +} + +#[test] +fn test_lte() { + let lte = jit_function! { lte(a:i64, b:i64) -> i64 => r##" + def lte(a: int, b: int): + if a <= b: + return 1 + return 0 + "## }; + + assert_eq!(lte(-100, -100), Ok(1)); + assert_eq!(lte(-100, 100), Ok(1)); + assert_eq!(lte(10, 1), Ok(0)); + assert_eq!(lte(0, -2), Ok(0)); +} + #[test] fn test_minus() { let minus = jit_function! { minus(a:i64) -> i64 => r##" From eb83b729b2cd1da84af3332e4e509651c9b6dfbf Mon Sep 17 00:00:00 2001 From: Amuthan Mannar Date: Sat, 7 Oct 2023 12:29:24 +0530 Subject: [PATCH 130/893] Formatted code files --- jit/src/instructions.rs | 18 ++++++++---------- jit/tests/bool_tests.rs | 1 - 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/jit/src/instructions.rs b/jit/src/instructions.rs index 2959361a0f..ef099026e8 100644 --- a/jit/src/instructions.rs +++ b/jit/src/instructions.rs @@ -315,24 +315,22 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { let b = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; let a = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; - let a_type:Option = a.to_jit_type(); - let b_type:Option = b.to_jit_type(); + let a_type: Option = a.to_jit_type(); + let b_type: Option = b.to_jit_type(); match (a, b) { - (JitValue::Int(a), JitValue::Int(b)) | - (JitValue::Bool(a), JitValue::Bool(b)) | - (JitValue::Bool(a), JitValue::Int(b)) | - (JitValue::Int(a), JitValue::Bool(b)) - => { - + (JitValue::Int(a), JitValue::Int(b)) + | (JitValue::Bool(a), JitValue::Bool(b)) + | (JitValue::Bool(a), JitValue::Int(b)) + | (JitValue::Int(a), JitValue::Bool(b)) => { let operand_one = match a_type.unwrap() { JitType::Bool => self.builder.ins().uextend(types::I64, a), - _=> a + _ => a, }; let operand_two = match b_type.unwrap() { JitType::Bool => self.builder.ins().uextend(types::I64, b), - _=> b + _ => b, }; let cond = match op { diff --git a/jit/tests/bool_tests.rs b/jit/tests/bool_tests.rs index 26afb97780..191993938d 100644 --- a/jit/tests/bool_tests.rs +++ b/jit/tests/bool_tests.rs @@ -156,7 +156,6 @@ fn test_gte() { assert_eq!(gte(true, false), Ok(1)); } - #[test] fn test_gte_with_integers() { let gte = jit_function! { gte(a:bool, b:i64) -> i64 => r##" From 830389f62c9600f1f51526c85e4ae9f5edfb840d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dan=20N=C3=A4sman?= <30578250+dannasman@users.noreply.github.com> Date: Mon, 9 Oct 2023 09:16:38 +0300 Subject: [PATCH 131/893] implement dir for ByObjectRef (#5088) --- vm/src/builtins/object.rs | 19 ++----------------- vm/src/protocol/object.rs | 20 +++++++++++++++++++- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/vm/src/builtins/object.rs b/vm/src/builtins/object.rs index 1a91595445..351e559df6 100644 --- a/vm/src/builtins/object.rs +++ b/vm/src/builtins/object.rs @@ -1,4 +1,4 @@ -use super::{PyDict, PyDictRef, PyList, PyStr, PyStrRef, PyType, PyTypeRef}; +use super::{PyDictRef, PyList, PyStr, PyStrRef, PyType, PyTypeRef}; use crate::common::hash::PyHash; use crate::{ class::PyClassImpl, @@ -252,22 +252,7 @@ impl PyBaseObject { #[pymethod(magic)] pub fn dir(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let attributes = obj.class().get_attributes(); - - let dict = PyDict::from_attributes(attributes, vm)?.into_ref(&vm.ctx); - - // Get instance attributes: - if let Some(object_dict) = obj.dict() { - vm.call_method( - dict.as_object(), - identifier!(vm, update).as_str(), - (object_dict,), - )?; - } - - let attributes: Vec<_> = dict.into_iter().map(|(k, _v)| k).collect(); - - Ok(PyList::from(attributes)) + obj.dir(vm) } #[pymethod(magic)] diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index bebfd21446..639e24dda5 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -3,7 +3,7 @@ use crate::{ builtins::{ - pystr::AsPyStr, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyStr, PyStrRef, + pystr::AsPyStr, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, }, bytesinner::ByteInnerNewOptions, @@ -11,6 +11,7 @@ use crate::{ convert::{ToPyObject, ToPyResult}, dictdatatype::DictKey, function::{Either, OptionalArg, PyArithmeticValue, PySetterValue}, + object::PyPayload, protocol::{PyIter, PyMapping, PySequence}, types::{Constructor, PyComparisonOp}, AsObject, Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine, @@ -62,6 +63,23 @@ impl PyObjectRef { } // PyObject *PyObject_Dir(PyObject *o) + pub fn dir(self, vm: &VirtualMachine) -> PyResult { + let attributes = self.class().get_attributes(); + + let dict = PyDict::from_attributes(attributes, vm)?.into_ref(&vm.ctx); + + if let Some(object_dict) = self.dict() { + vm.call_method( + dict.as_object(), + identifier!(vm, update).as_str(), + (object_dict,), + )?; + } + + let attributes: Vec<_> = dict.into_iter().map(|(k, _v)| k).collect(); + + Ok(PyList::from(attributes)) + } } impl PyObject { From 9241e2e5d5430f2e6f80c01bf37c000bde233210 Mon Sep 17 00:00:00 2001 From: Amuthan Mannar Date: Thu, 12 Oct 2023 16:42:33 +0530 Subject: [PATCH 132/893] [VM] Object pickling implementation for product object (python itertools) (#5089) * Implemented __reduce__, __setstate__ in product object --- Lib/test/test_itertools.py | 7 ++- vm/src/stdlib/itertools.rs | 93 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 91 insertions(+), 9 deletions(-) diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 142fa04e38..cf1107c45a 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -208,7 +208,7 @@ def test_chain_setstate(self): it = chain() it.__setstate__((iter(['abc', 'def']), iter(['ghi']))) self.assertEqual(list(it), ['ghi', 'a', 'b', 'c', 'd', 'e', 'f']) - + # TODO: RUSTPYTHON @unittest.expectedFailure def test_combinations(self): @@ -1165,8 +1165,7 @@ def test_product_tuple_reuse(self): self.assertEqual(len(set(map(id, product('abc', 'def')))), 1) self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1) - # TODO: RUSTPYTHON - @unittest.expectedFailure + def test_product_pickling(self): # check copy, deepcopy, pickle for args, result in [ @@ -2297,7 +2296,7 @@ def __eq__(self, other): class SubclassWithKwargsTest(unittest.TestCase): - + # TODO: RUSTPYTHON @unittest.expectedFailure def test_keywords_in_subclass(self): diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 526ab61af3..3bc26c0a8f 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -2,6 +2,7 @@ pub(crate) use decl::make_module; #[pymodule(name = "itertools")] mod decl { + use crate::stdlib::itertools::decl::int::get_value; use crate::{ builtins::{ int, tuple::IntoPyTuple, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef, @@ -110,7 +111,9 @@ mod decl { Ok(()) } } + impl SelfIter for PyItertoolsChain {} + impl IterNext for PyItertoolsChain { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let Some(source) = zelf.source.read().clone() else { @@ -201,6 +204,7 @@ mod decl { } impl SelfIter for PyItertoolsCompress {} + impl IterNext for PyItertoolsCompress { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { loop { @@ -268,7 +272,9 @@ mod decl { (zelf.class().to_owned(), (zelf.cur.read().clone(),)) } } + impl SelfIter for PyItertoolsCount {} + impl IterNext for PyItertoolsCount { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let mut cur = zelf.cur.write(); @@ -316,7 +322,9 @@ mod decl { #[pyclass(with(IterNext, Iterable, Constructor), flags(BASETYPE))] impl PyItertoolsCycle {} + impl SelfIter for PyItertoolsCycle {} + impl IterNext for PyItertoolsCycle { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let item = if let PyIterReturn::Return(item) = zelf.iter.next(vm)? { @@ -401,6 +409,7 @@ mod decl { } impl SelfIter for PyItertoolsRepeat {} + impl IterNext for PyItertoolsRepeat { fn next(zelf: &Py, _vm: &VirtualMachine) -> PyResult { if let Some(ref times) = zelf.times { @@ -466,7 +475,9 @@ mod decl { ) } } + impl SelfIter for PyItertoolsStarmap {} + impl IterNext for PyItertoolsStarmap { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let obj = zelf.iterable.next(vm)?; @@ -537,7 +548,9 @@ mod decl { Ok(()) } } + impl SelfIter for PyItertoolsTakewhile {} + impl IterNext for PyItertoolsTakewhile { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { if zelf.stop_flag.load() { @@ -618,7 +631,9 @@ mod decl { Ok(()) } } + impl SelfIter for PyItertoolsDropwhile {} + impl IterNext for PyItertoolsDropwhile { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let predicate = &zelf.predicate; @@ -629,7 +644,7 @@ mod decl { let obj = match iterable.next(vm)? { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)) + return Ok(PyIterReturn::StopIteration(v)); } }; let pred = predicate.clone(); @@ -737,7 +752,9 @@ mod decl { Ok(PyIterReturn::Return((new_value, new_key))) } } + impl SelfIter for PyItertoolsGroupBy {} + impl IterNext for PyItertoolsGroupBy { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let mut state = zelf.state.lock(); @@ -753,7 +770,7 @@ mod decl { let (value, new_key) = match zelf.advance(vm)? { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)) + return Ok(PyIterReturn::StopIteration(v)); } }; if !vm.bool_eq(&new_key, &old_key)? { @@ -764,7 +781,7 @@ mod decl { match zelf.advance(vm)? { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)) + return Ok(PyIterReturn::StopIteration(v)); } } }; @@ -797,7 +814,9 @@ mod decl { #[pyclass(with(IterNext, Iterable))] impl PyItertoolsGrouper {} + impl SelfIter for PyItertoolsGrouper {} + impl IterNext for PyItertoolsGrouper { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let old_key = { @@ -960,6 +979,7 @@ mod decl { } impl SelfIter for PyItertoolsIslice {} + impl IterNext for PyItertoolsIslice { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { while zelf.cur.load() < zelf.next.load() { @@ -1033,7 +1053,9 @@ mod decl { ) } } + impl SelfIter for PyItertoolsFilterFalse {} + impl IterNext for PyItertoolsFilterFalse { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let predicate = &zelf.predicate; @@ -1142,6 +1164,7 @@ mod decl { } impl SelfIter for PyItertoolsAccumulate {} + impl IterNext for PyItertoolsAccumulate { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let iterable = &zelf.iterable; @@ -1153,7 +1176,7 @@ mod decl { None => match iterable.next(vm)? { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)) + return Ok(PyIterReturn::StopIteration(v)); } }, Some(obj) => obj.clone(), @@ -1162,7 +1185,7 @@ mod decl { let obj = match iterable.next(vm)? { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)) + return Ok(PyIterReturn::StopIteration(v)); } }; match &zelf.binop { @@ -1348,7 +1371,60 @@ mod decl { self.cur.store(idxs.len() - 1); } } + + #[pymethod(magic)] + fn setstate(zelf: PyRef, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> { + let args = state.as_slice(); + if args.len() != zelf.pools.len() { + let msg = "Invalid number of arguments".to_string(); + return Err(vm.new_type_error(msg)); + } + let mut idxs: PyRwLockWriteGuard<'_, Vec> = zelf.idxs.write(); + idxs.clear(); + for s in 0..args.len() { + let index = get_value(state.get(s).unwrap()).to_usize().unwrap(); + let pool_size = zelf.pools.get(s).unwrap().len(); + if pool_size == 0 { + zelf.stop.store(true); + return Ok(()); + } + if index >= pool_size { + idxs.push(pool_size - 1); + } else { + idxs.push(index); + } + } + zelf.stop.store(false); + Ok(()) + } + + #[pymethod(magic)] + fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyTupleRef { + let class = zelf.class().to_owned(); + + if zelf.stop.load() { + return vm.new_tuple((class, (vm.ctx.empty_tuple.clone(),))); + } + + let mut pools: Vec = Vec::new(); + for element in zelf.pools.iter() { + pools.push(element.clone().into_pytuple(vm).into()); + } + + let mut indices: Vec = Vec::new(); + + for item in &zelf.idxs.read()[..] { + indices.push(vm.new_pyobj(*item)); + } + + vm.new_tuple(( + class, + pools.clone().into_pytuple(vm), + indices.into_pytuple(vm), + )) + } } + impl SelfIter for PyItertoolsProduct {} impl IterNext for PyItertoolsProduct { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { @@ -1563,6 +1639,7 @@ mod decl { impl PyItertoolsCombinationsWithReplacement {} impl SelfIter for PyItertoolsCombinationsWithReplacement {} + impl IterNext for PyItertoolsCombinationsWithReplacement { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { // stop signal @@ -1679,7 +1756,9 @@ mod decl { )) } } + impl SelfIter for PyItertoolsPermutations {} + impl IterNext for PyItertoolsPermutations { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { // stop signal @@ -1802,7 +1881,9 @@ mod decl { Ok(()) } } + impl SelfIter for PyItertoolsZipLongest {} + impl IterNext for PyItertoolsZipLongest { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { if zelf.iterators.is_empty() { @@ -1851,7 +1932,9 @@ mod decl { #[pyclass(with(IterNext, Iterable, Constructor))] impl PyItertoolsPairwise {} + impl SelfIter for PyItertoolsPairwise {} + impl IterNext for PyItertoolsPairwise { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let old = match zelf.old.read().clone() { From 4e6172b99d64627f0b566abe72fdd6f72a2f2ca0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dan=20N=C3=A4sman?= <30578250+dannasman@users.noreply.github.com> Date: Mon, 16 Oct 2023 20:39:29 +0300 Subject: [PATCH 133/893] Add object protocol correspoinding to PyObject_GetAIter (#5090) --- vm/src/protocol/object.rs | 11 +++++++++-- vm/src/stdlib/builtins.rs | 7 +------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 639e24dda5..0b87bbf9d3 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -3,8 +3,8 @@ use crate::{ builtins::{ - pystr::AsPyStr, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList, PyStr, PyStrRef, - PyTuple, PyTupleRef, PyType, PyTypeRef, + pystr::AsPyStr, PyAsyncGen, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList, + PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, }, bytesinner::ByteInnerNewOptions, common::{hash::PyHash, str::to_ascii}, @@ -92,6 +92,13 @@ impl PyObject { } // PyObject *PyObject_GetAIter(PyObject *o) + pub fn get_aiter(&self, vm: &VirtualMachine) -> PyResult { + if self.payload_is::() { + vm.call_special_method(self, identifier!(vm, __aiter__), ()) + } else { + Err(vm.new_type_error("wrong argument type".to_owned())) + } + } pub fn has_attr<'a>(&self, attr_name: impl AsPyStr<'a>, vm: &VirtualMachine) -> PyResult { self.get_attr(attr_name, vm).map(|o| !vm.is_none(&o)) diff --git a/vm/src/stdlib/builtins.rs b/vm/src/stdlib/builtins.rs index 835f4152ea..4cc3a53634 100644 --- a/vm/src/stdlib/builtins.rs +++ b/vm/src/stdlib/builtins.rs @@ -9,7 +9,6 @@ pub use builtins::{ascii, print, reversed}; mod builtins { use crate::{ builtins::{ - asyncgenerator::PyAsyncGen, enumerate::PyReverseSequenceIterator, function::{PyCellRef, PyFunction}, int::PyIntRef, @@ -459,11 +458,7 @@ mod builtins { #[pyfunction] fn aiter(iter_target: PyObjectRef, vm: &VirtualMachine) -> PyResult { - if iter_target.payload_is::() { - vm.call_special_method(&iter_target, identifier!(vm, __aiter__), ()) - } else { - Err(vm.new_type_error("wrong argument type".to_owned())) - } + iter_target.get_aiter(vm) } #[pyfunction] From 0e72ba88193f9d0a16555dba3da3adf3a6043fb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dan=20N=C3=A4sman?= <30578250+dannasman@users.noreply.github.com> Date: Wed, 18 Oct 2023 02:22:10 +0300 Subject: [PATCH 134/893] implement PyObject_Type and PyObject_TypeCheck (#5091) --- vm/src/protocol/object.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 0b87bbf9d3..28f481fe96 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -559,8 +559,14 @@ impl PyObject { // type protocol // PyObject *PyObject_Type(PyObject *o) + pub fn obj_type(&self) -> PyObjectRef { + self.class().to_owned().into() + } // int PyObject_TypeCheck(PyObject *o, PyTypeObject *type) + pub fn type_check(&self, typ: PyTypeRef) -> bool { + self.class().fast_isinstance(&typ) + } pub fn length_opt(&self, vm: &VirtualMachine) -> Option> { self.to_sequence(vm) From bda7c5cf06e66297a8340e7ac6c0a674fcd48e24 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Wed, 18 Oct 2023 15:27:51 -0700 Subject: [PATCH 135/893] Fix examples/package_embed (#5096) --- examples/package_embed.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/package_embed.rs b/examples/package_embed.rs index bd3d268856..b35d063928 100644 --- a/examples/package_embed.rs +++ b/examples/package_embed.rs @@ -4,8 +4,9 @@ use vm::{builtins::PyStrRef, Interpreter}; fn py_main(interp: &Interpreter) -> vm::PyResult { interp.enter(|vm| { + // Add local library path vm.insert_sys_path(vm.new_pyobj("examples")) - .expect("add path"); + .expect("add examples to sys.path failed"); let module = vm.import("package_embed", None, 0)?; let name_func = module.get_attr("context", vm)?; let result = name_func.call((), vm)?; @@ -15,7 +16,10 @@ fn py_main(interp: &Interpreter) -> vm::PyResult { } fn main() -> ExitCode { - let interp = vm::Interpreter::with_init(Default::default(), |vm| { + // Add standard library path + let mut settings = vm::Settings::default(); + settings.path_list.push("Lib".to_owned()); + let interp = vm::Interpreter::with_init(settings, |vm| { vm.add_native_modules(rustpython_stdlib::get_module_inits()); }); let result = py_main(&interp); From 03576615e4af4e776a42b05ff3df1287b3be0fcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dan=20N=C3=A4sman?= <30578250+dannasman@users.noreply.github.com> Date: Fri, 20 Oct 2023 04:03:34 +0300 Subject: [PATCH 136/893] fix type_check (#5097) --- vm/src/protocol/object.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 28f481fe96..dc7f487e85 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -565,7 +565,7 @@ impl PyObject { // int PyObject_TypeCheck(PyObject *o, PyTypeObject *type) pub fn type_check(&self, typ: PyTypeRef) -> bool { - self.class().fast_isinstance(&typ) + self.fast_isinstance(&typ) } pub fn length_opt(&self, vm: &VirtualMachine) -> Option> { From d32cb7efdd20c828004ba1ad2dff83d44f63c009 Mon Sep 17 00:00:00 2001 From: Jonas Zaddach Date: Fri, 20 Oct 2023 03:04:11 +0200 Subject: [PATCH 137/893] Add missing 'sysinfoapi' feature for winapi (#5094) --- vm/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vm/Cargo.toml b/vm/Cargo.toml index ed226016e3..7ba9a44f14 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -141,7 +141,7 @@ version = "0.3.9" features = [ "winsock2", "handleapi", "ws2def", "std", "winbase", "wincrypt", "fileapi", "processenv", "namedpipeapi", "winnt", "processthreadsapi", "errhandlingapi", "winuser", "synchapi", "wincon", - "impl-default", "vcruntime", "ifdef", "netioapi", "memoryapi", "profileapi", + "impl-default", "vcruntime", "ifdef", "netioapi", "memoryapi", "profileapi", "sysinfoapi" ] [target.'cfg(target_arch = "wasm32")'.dependencies] From a75f26b922fb2b88dcea2d338a9d11d751915541 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Sun, 22 Oct 2023 18:01:06 -0700 Subject: [PATCH 138/893] Fake PEP-0695 with empty __type_params__ (#5098) --- vm/src/builtins/function.rs | 31 +++++++++++++++++++++++++++++-- vm/src/frame.rs | 3 ++- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/vm/src/builtins/function.rs b/vm/src/builtins/function.rs index dc1764f48a..ab069e0e73 100644 --- a/vm/src/builtins/function.rs +++ b/vm/src/builtins/function.rs @@ -35,6 +35,7 @@ pub struct PyFunction { defaults_and_kwdefaults: PyMutex<(Option, Option)>, name: PyMutex, qualname: PyMutex, + type_params: PyMutex, #[cfg(feature = "jit")] jitted_code: OnceCell, } @@ -54,7 +55,8 @@ impl PyFunction { closure: Option>, defaults: Option, kw_only_defaults: Option, - qualname: PyMutex, + qualname: PyStrRef, + type_params: PyTupleRef, ) -> Self { let name = PyMutex::new(code.obj_name.to_owned()); PyFunction { @@ -63,7 +65,8 @@ impl PyFunction { closure, defaults_and_kwdefaults: PyMutex::new((defaults, kw_only_defaults)), name, - qualname, + qualname: PyMutex::new(qualname), + type_params: PyMutex::new(type_params), #[cfg(feature = "jit")] jitted_code: OnceCell::new(), } @@ -428,6 +431,30 @@ impl PyFunction { Ok(()) } + #[pygetset(magic)] + fn type_params(&self) -> PyTupleRef { + self.type_params.lock().clone() + } + + #[pygetset(magic, setter)] + fn set_type_params( + &self, + value: PySetterValue, + vm: &VirtualMachine, + ) -> PyResult<()> { + match value { + PySetterValue::Assign(value) => { + *self.type_params.lock() = value; + } + PySetterValue::Delete => { + return Err( + vm.new_type_error("__type_params__ must be set to a tuple object".to_string()) + ); + } + } + Ok(()) + } + #[cfg(feature = "jit")] #[pymethod(magic)] fn jit(zelf: PyRef, vm: &VirtualMachine) -> PyResult<()> { diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 719aa6288a..a0af23336f 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -1642,7 +1642,8 @@ impl ExecutingFrame<'_> { closure, defaults, kw_only_defaults, - PyMutex::new(qualified_name.clone()), + qualified_name.clone(), + vm.ctx.empty_tuple.clone(), // FIXME: fake implementation ) .into_pyobject(vm); From af884cb284b8ca0367b5ba331acbc0d9fdf00518 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Sun, 22 Oct 2023 19:19:05 -0700 Subject: [PATCH 139/893] First step for Python 3.12 support (#5078) * Mark 3.12 * Update importlib from Python 3.12.0 * Update test_importlib from Python3.12 * Mark failings tests from importlib * Update test.support from Python3.12 * Fix unsupported parser feature * mark failing test * Update functools from Python 3.12 * manual type annotation * slice behavior changed in 3.12 * empty unittest.main returns non-zero * test_decimal from CPython 3.12 * Mark failing tests * Update test_unicode from CPython 3.12 * Update test_functools from Python 3.12 * Update enum from Python 3.12 * enum * Doc format changed * Update test_module from CPython --------- Co-authored-by: CPython developers <> --- .github/workflows/ci.yaml | 2 +- .github/workflows/cron-ci.yaml | 2 +- DEVELOPMENT.md | 2 +- Lib/enum.py | 254 ++--- Lib/functools.py | 183 ++-- Lib/importlib/__init__.py | 46 +- Lib/importlib/_abc.py | 15 - Lib/importlib/_bootstrap.py | 454 ++++++--- Lib/importlib/_bootstrap_external.py | 249 +++-- Lib/importlib/abc.py | 111 +-- Lib/importlib/metadata/__init__.py | 305 ++---- Lib/importlib/metadata/_adapters.py | 21 + Lib/importlib/metadata/_meta.py | 28 +- Lib/importlib/resources/_adapters.py | 4 +- Lib/importlib/resources/_common.py | 148 ++- Lib/importlib/resources/_itertools.py | 69 +- Lib/importlib/resources/_legacy.py | 3 +- Lib/importlib/resources/abc.py | 26 +- Lib/importlib/resources/readers.py | 50 +- Lib/importlib/resources/simple.py | 79 +- Lib/importlib/util.py | 144 +-- Lib/test/support/__init__.py | 435 +++++++-- Lib/test/support/bytecode_helper.py | 101 ++ Lib/test/support/import_helper.py | 28 + Lib/test/support/interpreters.py | 23 +- Lib/test/support/os_helper.py | 48 +- Lib/test/support/socket_helper.py | 77 +- Lib/test/support/testresult.py | 10 +- Lib/test/support/threading_helper.py | 27 +- Lib/test/support/warnings_helper.py | 2 +- Lib/test/test_decimal.py | 443 ++++++--- Lib/test/test_enum.py | 714 +++++++++----- Lib/test/test_functools.py | 682 ++++++++++++-- Lib/test/test_importlib/_context.py | 13 + Lib/test/test_importlib/_path.py | 109 +++ .../test_importlib/builtin/test_finder.py | 50 - .../extension/test_case_sensitivity.py | 2 +- .../test_importlib/extension/test_loader.py | 151 ++- .../extension/test_path_hook.py | 2 +- Lib/test/test_importlib/fixtures.py | 158 +++- Lib/test/test_importlib/frozen/test_finder.py | 48 - Lib/test/test_importlib/frozen/test_loader.py | 105 +-- .../test_importlib/import_/test___loader__.py | 43 - .../import_/test___package__.py | 40 +- Lib/test/test_importlib/import_/test_api.py | 5 - .../test_importlib/import_/test_caching.py | 25 +- .../test_importlib/import_/test_helpers.py | 192 ++++ .../test_importlib/import_/test_meta_path.py | 10 - Lib/test/test_importlib/import_/test_path.py | 73 +- Lib/test/test_importlib/resources/_path.py | 56 ++ .../{ => resources}/data01/__init__.py | 0 .../resources/data01/binary.file | Bin 0 -> 4 bytes .../data01/subdirectory/__init__.py | 0 .../resources/data01/subdirectory/binary.file | Bin 0 -> 4 bytes .../resources/data01/utf-16.file | Bin 0 -> 44 bytes .../{ => resources}/data01/utf-8.file | 0 .../{ => resources}/data02/__init__.py | 0 .../{ => resources}/data02/one/__init__.py | 0 .../{ => resources}/data02/one/resource1.txt | 0 .../subdirectory/subsubdir/resource.txt | 1 + .../{ => resources}/data02/two/__init__.py | 0 .../{ => resources}/data02/two/resource2.txt | 0 .../{ => resources}/data03/__init__.py | 0 .../data03/namespace/portion1/__init__.py | 0 .../data03/namespace/portion2/__init__.py | 0 .../data03/namespace/resource1.txt | 0 .../resources/namespacedata01/binary.file | Bin 0 -> 4 bytes .../resources/namespacedata01/utf-16.file | Bin 0 -> 44 bytes .../namespacedata01/utf-8.file | 0 .../test_compatibilty_files.py | 8 +- .../{ => resources}/test_contents.py | 2 +- .../test_importlib/resources/test_custom.py | 46 + .../test_importlib/resources/test_files.py | 113 +++ .../{ => resources}/test_open.py | 22 +- .../{ => resources}/test_path.py | 17 +- .../{ => resources}/test_read.py | 14 +- .../{ => resources}/test_reader.py | 16 + .../{ => resources}/test_resource.py | 110 +-- .../{ => resources}/update-zips.py | 0 Lib/test/test_importlib/resources/util.py | 51 +- .../{ => resources}/zipdata01/__init__.py | 0 .../resources/zipdata01/ziptestdata.zip | Bin 0 -> 876 bytes .../{ => resources}/zipdata02/__init__.py | 0 .../resources/zipdata02/ziptestdata.zip | Bin 0 -> 698 bytes .../source/test_case_sensitivity.py | 13 - .../test_importlib/source/test_file_loader.py | 1 - Lib/test/test_importlib/source/test_finder.py | 29 +- .../test_importlib/source/test_path_hook.py | 9 - Lib/test/test_importlib/test_abc.py | 118 +-- Lib/test/test_importlib/test_api.py | 71 +- Lib/test/test_importlib/test_files.py | 46 - Lib/test/test_importlib/test_locks.py | 5 + Lib/test/test_importlib/test_main.py | 119 ++- Lib/test/test_importlib/test_metadata_api.py | 112 +-- .../test_importlib/test_namespace_pkgs.py | 7 +- Lib/test/test_importlib/test_spec.py | 143 --- .../test_importlib/test_threaded_import.py | 11 +- Lib/test/test_importlib/test_util.py | 364 +++----- Lib/test/test_importlib/test_windows.py | 18 +- Lib/test/test_importlib/util.py | 34 +- .../__init__.py} | 55 +- Lib/test/test_module/bad_getattr.py | 4 + Lib/test/test_module/bad_getattr2.py | 7 + Lib/test/test_module/bad_getattr3.py | 5 + Lib/test/test_module/good_getattr.py | 11 + Lib/test/test_support.py | 93 +- Lib/test/test_unicode.py | 878 ++++++------------ README.md | 2 +- extra_tests/snippets/builtin_none.py | 3 +- extra_tests/snippets/builtin_slice.py | 17 +- extra_tests/snippets/syntax_async.py | 4 +- vm/src/version.rs | 4 +- whats_left.py | 4 +- 113 files changed, 4828 insertions(+), 3626 deletions(-) create mode 100644 Lib/test/test_importlib/_context.py create mode 100644 Lib/test/test_importlib/_path.py create mode 100644 Lib/test/test_importlib/import_/test_helpers.py create mode 100644 Lib/test/test_importlib/resources/_path.py rename Lib/test/test_importlib/{ => resources}/data01/__init__.py (100%) create mode 100644 Lib/test/test_importlib/resources/data01/binary.file rename Lib/test/test_importlib/{ => resources}/data01/subdirectory/__init__.py (100%) create mode 100644 Lib/test/test_importlib/resources/data01/subdirectory/binary.file create mode 100644 Lib/test/test_importlib/resources/data01/utf-16.file rename Lib/test/test_importlib/{ => resources}/data01/utf-8.file (100%) rename Lib/test/test_importlib/{ => resources}/data02/__init__.py (100%) rename Lib/test/test_importlib/{ => resources}/data02/one/__init__.py (100%) rename Lib/test/test_importlib/{ => resources}/data02/one/resource1.txt (100%) create mode 100644 Lib/test/test_importlib/resources/data02/subdirectory/subsubdir/resource.txt rename Lib/test/test_importlib/{ => resources}/data02/two/__init__.py (100%) rename Lib/test/test_importlib/{ => resources}/data02/two/resource2.txt (100%) rename Lib/test/test_importlib/{ => resources}/data03/__init__.py (100%) rename Lib/test/test_importlib/{ => resources}/data03/namespace/portion1/__init__.py (100%) rename Lib/test/test_importlib/{ => resources}/data03/namespace/portion2/__init__.py (100%) rename Lib/test/test_importlib/{ => resources}/data03/namespace/resource1.txt (100%) create mode 100644 Lib/test/test_importlib/resources/namespacedata01/binary.file create mode 100644 Lib/test/test_importlib/resources/namespacedata01/utf-16.file rename Lib/test/test_importlib/{ => resources}/namespacedata01/utf-8.file (100%) rename Lib/test/test_importlib/{ => resources}/test_compatibilty_files.py (93%) rename Lib/test/test_importlib/{ => resources}/test_contents.py (97%) create mode 100644 Lib/test/test_importlib/resources/test_custom.py create mode 100644 Lib/test/test_importlib/resources/test_files.py rename Lib/test/test_importlib/{ => resources}/test_open.py (82%) rename Lib/test/test_importlib/{ => resources}/test_path.py (84%) rename Lib/test/test_importlib/{ => resources}/test_read.py (86%) rename Lib/test/test_importlib/{ => resources}/test_reader.py (85%) rename Lib/test/test_importlib/{ => resources}/test_resource.py (74%) rename Lib/test/test_importlib/{ => resources}/update-zips.py (100%) rename Lib/test/test_importlib/{ => resources}/zipdata01/__init__.py (100%) create mode 100644 Lib/test/test_importlib/resources/zipdata01/ziptestdata.zip rename Lib/test/test_importlib/{ => resources}/zipdata02/__init__.py (100%) create mode 100644 Lib/test/test_importlib/resources/zipdata02/ziptestdata.zip delete mode 100644 Lib/test/test_importlib/test_files.py rename Lib/test/{test_module.py => test_module/__init__.py} (90%) create mode 100644 Lib/test/test_module/bad_getattr.py create mode 100644 Lib/test/test_module/bad_getattr2.py create mode 100644 Lib/test/test_module/bad_getattr3.py create mode 100644 Lib/test/test_module/good_getattr.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b4e1eb1932..e616f6d10c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -105,7 +105,7 @@ env: test_weakref test_yield_from # Python version targeted by the CI. - PYTHON_VERSION: "3.11.4" + PYTHON_VERSION: "3.12.0" jobs: rust_tests: diff --git a/.github/workflows/cron-ci.yaml b/.github/workflows/cron-ci.yaml index ee90aac4a9..9176f232c7 100644 --- a/.github/workflows/cron-ci.yaml +++ b/.github/workflows/cron-ci.yaml @@ -7,7 +7,7 @@ name: Periodic checks/tasks env: CARGO_ARGS: --no-default-features --features stdlib,zlib,importlib,encodings,ssl,jit - PYTHON_VERSION: "3.11.4" + PYTHON_VERSION: "3.12.0" jobs: # codecov collects code coverage data from the rust tests, python snippets and python test suite. diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 4ef49abe94..7c79a011ba 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -25,7 +25,7 @@ RustPython requires the following: stable version: `rustup update stable` - If you do not have Rust installed, use [rustup](https://rustup.rs/) to do so. -- CPython version 3.11 or higher +- CPython version 3.12 or higher - CPython can be installed by your operating system's package manager, from the [Python website](https://www.python.org/downloads/), or using a third-party distribution, such as diff --git a/Lib/enum.py b/Lib/enum.py index 625e9ea56a..7cffb71863 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -190,41 +190,48 @@ class property(DynamicClassAttribute): a corresponding enum member. """ + member = None + _attr_type = None + _cls_type = None + def __get__(self, instance, ownerclass=None): if instance is None: - try: - return ownerclass._member_map_[self.name] - except KeyError: + if self.member is not None: + return self.member + else: raise AttributeError( '%r has no attribute %r' % (ownerclass, self.name) ) - else: - if self.fget is None: - # look for a member by this name. - try: - return ownerclass._member_map_[self.name] - except KeyError: - raise AttributeError( - '%r has no attribute %r' % (ownerclass, self.name) - ) from None - else: - return self.fget(instance) + if self.fget is not None: + # use previous enum.property + return self.fget(instance) + elif self._attr_type == 'attr': + # look up previous attibute + return getattr(self._cls_type, self.name) + elif self._attr_type == 'desc': + # use previous descriptor + return getattr(instance._value_, self.name) + # look for a member by this name. + try: + return ownerclass._member_map_[self.name] + except KeyError: + raise AttributeError( + '%r has no attribute %r' % (ownerclass, self.name) + ) from None def __set__(self, instance, value): - if self.fset is None: - raise AttributeError( - " cannot set attribute %r" % (self.clsname, self.name) - ) - else: + if self.fset is not None: return self.fset(instance, value) + raise AttributeError( + " cannot set attribute %r" % (self.clsname, self.name) + ) def __delete__(self, instance): - if self.fdel is None: - raise AttributeError( - " cannot delete attribute %r" % (self.clsname, self.name) - ) - else: + if self.fdel is not None: return self.fdel(instance) + raise AttributeError( + " cannot delete attribute %r" % (self.clsname, self.name) + ) def __set_name__(self, ownerclass, name): self.name = name @@ -312,27 +319,38 @@ def __set_name__(self, enum_class, member_name): enum_class._member_names_.append(member_name) # if necessary, get redirect in place and then add it to _member_map_ found_descriptor = None + descriptor_type = None + class_type = None for base in enum_class.__mro__[1:]: - descriptor = base.__dict__.get(member_name) - if descriptor is not None: - if isinstance(descriptor, (property, DynamicClassAttribute)): - found_descriptor = descriptor + attr = base.__dict__.get(member_name) + if attr is not None: + if isinstance(attr, (property, DynamicClassAttribute)): + found_descriptor = attr + class_type = base + descriptor_type = 'enum' break - elif ( - hasattr(descriptor, 'fget') and - hasattr(descriptor, 'fset') and - hasattr(descriptor, 'fdel') - ): - found_descriptor = descriptor + elif _is_descriptor(attr): + found_descriptor = attr + descriptor_type = descriptor_type or 'desc' + class_type = class_type or base continue + else: + descriptor_type = 'attr' + class_type = base if found_descriptor: redirect = property() redirect.member = enum_member redirect.__set_name__(enum_class, member_name) - # earlier descriptor found; copy fget, fset, fdel to this one. - redirect.fget = found_descriptor.fget - redirect.fset = found_descriptor.fset - redirect.fdel = found_descriptor.fdel + if descriptor_type in ('enum','desc'): + # earlier descriptor found; copy fget, fset, fdel to this one. + redirect.fget = getattr(found_descriptor, 'fget', None) + redirect._get = getattr(found_descriptor, '__get__', None) + redirect.fset = getattr(found_descriptor, 'fset', None) + redirect._set = getattr(found_descriptor, '__set__', None) + redirect.fdel = getattr(found_descriptor, 'fdel', None) + redirect._del = getattr(found_descriptor, '__delete__', None) + redirect._attr_type = descriptor_type + redirect._cls_type = class_type setattr(enum_class, member_name, redirect) else: setattr(enum_class, member_name, enum_member) @@ -521,8 +539,13 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k # # adjust the sunders _order_ = classdict.pop('_order_', None) + _gnv = classdict.get('_generate_next_value_') + if _gnv is not None and type(_gnv) is not staticmethod: + _gnv = staticmethod(_gnv) # convert to normal dict classdict = dict(classdict.items()) + if _gnv is not None: + classdict['_generate_next_value_'] = _gnv # # data type of member and the controlling Enum class member_type, first_enum = metacls._get_mixins_(cls, bases) @@ -674,7 +697,7 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k 'member order does not match _order_:\n %r\n %r' % (enum_class._member_names_, _order_) ) - + # return enum_class def __bool__(cls): @@ -683,7 +706,7 @@ def __bool__(cls): """ return True - def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None): + def __call__(cls, value, names=None, *values, module=None, qualname=None, type=None, start=1, boundary=None): """ Either returns an existing member, or creates a new enum class. @@ -691,6 +714,8 @@ def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, s to an enumeration member (i.e. Color(3)) and for the functional API (i.e. Color = Enum('Color', names='RED GREEN BLUE')). + The value lookup branch is chosen if the enum is final. + When used for the functional API: `value` will be the name of the new class. @@ -708,12 +733,20 @@ def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, s `type`, if set, will be mixed in as the first base class. """ - if names is None: # simple value lookup + if cls._member_map_: + # simple value lookup if members exist + if names: + value = (value, names) + values return cls.__new__(cls, value) # otherwise, functional API: we're creating a new Enum type + if names is None and type is None: + # no body? no data-type? possibly wrong usage + raise TypeError( + f"{cls} has no members; specify `names=()` if you meant to create a new, empty, enum" + ) return cls._create_( - value, - names, + class_name=value, + names=names, module=module, qualname=qualname, type=type, @@ -721,26 +754,16 @@ def __call__(cls, value, names=None, *, module=None, qualname=None, type=None, s boundary=boundary, ) - def __contains__(cls, member): - """ - Return True if member is a member of this enum - raises TypeError if member is not an enum member + def __contains__(cls, value): + """Return True if `value` is in `cls`. - note: in 3.12 TypeError will no longer be raised, and True will also be - returned if member is the value of a member in this enum + `value` is in `cls` if: + 1) `value` is a member of `cls`, or + 2) `value` is the value of one of the `cls`'s members. """ - if not isinstance(member, Enum): - import warnings - warnings.warn( - "in 3.12 __contains__ will no longer raise TypeError, but will return True or\n" - "False depending on whether the value is a member or the value of a member", - DeprecationWarning, - stacklevel=2, - ) - raise TypeError( - "unsupported operand type(s) for 'in': '%s' and '%s'" % ( - type(member).__qualname__, cls.__class__.__qualname__)) - return isinstance(member, cls) and member._name_ in cls._member_map_ + if isinstance(value, cls): + return True + return value in cls._value2member_map_ or value in cls._unhashable_values_ def __delattr__(cls, attr): # nicer error message when someone tries to delete an attribute @@ -767,22 +790,6 @@ def __dir__(cls): # return whatever mixed-in data type has return sorted(set(dir(cls._member_type_)) | interesting) - def __getattr__(cls, name): - """ - Return the enum member matching `name` - - We use __getattr__ instead of descriptors or inserting into the enum - class' __dict__ in order to support `name` and `value` being both - properties for enum members (which live in the class' __dict__) and - enum members themselves. - """ - if _is_dunder(name): - raise AttributeError(name) - try: - return cls._member_map_[name] - except KeyError: - raise AttributeError(name) from None - def __getitem__(cls, name): """ Return the member matching `name`. @@ -863,6 +870,8 @@ def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, s value = first_enum._generate_next_value_(name, start, count, last_values[:]) last_values.append(value) names.append((name, value)) + if names is None: + names = () # Here, names is either an iterable of (name, value) or a mapping. for item in names: @@ -872,13 +881,15 @@ def _create_(cls, class_name, names, *, module=None, qualname=None, type=None, s member_name, member_value = item classdict[member_name] = member_value - # TODO: replace the frame hack if a blessed way to know the calling - # module is ever developed if module is None: try: - module = sys._getframe(2).f_globals['__name__'] - except (AttributeError, ValueError, KeyError): - pass + module = sys._getframemodulename(2) + except AttributeError: + # Fall back on _getframe if _getframemodulename is missing + try: + module = sys._getframe(2).f_globals['__name__'] + except (AttributeError, ValueError, KeyError): + pass if module is None: _make_class_unpicklable(classdict) else: @@ -946,9 +957,6 @@ def _get_mixins_(mcls, class_name, bases): """ if not bases: return object, Enum - - mcls._check_for_existing_members_(class_name, bases) - # ensure final parent class is an Enum derivative, find any concrete # data type, and check that Enum has no members first_enum = bases[-1] @@ -969,12 +977,20 @@ def _find_data_repr_(mcls, class_name, bases): return base._value_repr_ elif '__repr__' in base.__dict__: # this is our data repr - return base.__dict__['__repr__'] + # double-check if a dataclass with a default __repr__ + if ( + '__dataclass_fields__' in base.__dict__ + and '__dataclass_params__' in base.__dict__ + and base.__dict__['__dataclass_params__'].repr + ): + return _dataclass_repr + else: + return base.__dict__['__repr__'] return None @classmethod def _find_data_type_(mcls, class_name, bases): - # a datatype has a __new__ method + # a datatype has a __new__ method, or a __dataclass_fields__ attribute data_types = set() base_chain = set() for chain in bases: @@ -988,8 +1004,6 @@ def _find_data_type_(mcls, class_name, bases): data_types.add(base._member_type_) break elif '__new__' in base.__dict__ or '__dataclass_fields__' in base.__dict__: - if isinstance(base, EnumType): - continue data_types.add(candidate or base) break else: @@ -1061,20 +1075,20 @@ class Enum(metaclass=EnumType): Access them by: - - attribute access:: + - attribute access: - >>> Color.RED - + >>> Color.RED + - value lookup: - >>> Color(1) - + >>> Color(1) + - name lookup: - >>> Color['RED'] - + >>> Color['RED'] + Enumerations can be iterated over, and know how many members they have: @@ -1088,6 +1102,13 @@ class Enum(metaclass=EnumType): attributes -- see the documentation for details. """ + @classmethod + def __signature__(cls): + if cls._member_names_: + return '(*values)' + else: + return '(new_class_name, /, names, *, module=None, qualname=None, type=None, start=1, boundary=None)' + def __new__(cls, value): # all enum instances are actually created during class construction # without calling this method; this method is called by the metaclass' @@ -1107,6 +1128,11 @@ def __new__(cls, value): for member in cls._member_map_.values(): if member._value_ == value: return member + # still not found -- verify that members exist, in-case somebody got here mistakenly + # (such as via super when trying to override __new__) + if not cls._member_map_: + raise TypeError("%r has no members defined" % cls) + # # still not found -- try _missing_ hook try: exc = None @@ -1142,6 +1168,7 @@ def __new__(cls, value): def __init__(self, *args, **kwds): pass + @staticmethod def _generate_next_value_(name, start, count, last_values): """ Generate the next value when not given. @@ -1236,10 +1263,10 @@ def __copy__(self): # enum.property is used to provide access to the `name` and # `value` attributes of enum members while keeping some measure of # protection from modification, while still allowing for an enumeration - # to have members named `name` and `value`. This works because enumeration - # members are not set directly on the enum class; they are kept in a - # separate structure, _member_map_, which is where enum.property looks for - # them + # to have members named `name` and `value`. This works because each + # instance of enum.property saves its companion member, which it returns + # on class lookup; on instance lookup it either executes a provided function + # or raises an AttributeError. @property def name(self): @@ -1290,6 +1317,7 @@ def __new__(cls, *values): member._value_ = value return member + @staticmethod def _generate_next_value_(name, start, count, last_values): """ Return the lower-cased version of the member name. @@ -1328,6 +1356,7 @@ class Flag(Enum, boundary=STRICT): _numeric_repr_ = repr + @staticmethod def _generate_next_value_(name, start, count, last_values): """ Generate the next value when not given. @@ -1566,10 +1595,13 @@ def unique(enumeration): (enumeration, alias_details)) return enumeration -def _power_of_two(value): - if value < 1: - return False - return value == 2 ** _high_bit(value) +def _dataclass_repr(self): + dcf = self.__dataclass_fields__ + return ', '.join( + '%s=%r' % (k, getattr(self, k)) + for k in dcf.keys() + if dcf[k].repr + ) def global_enum_repr(self): """ @@ -1713,10 +1745,12 @@ def convert_class(cls): value = gnv(name, 1, len(member_names), gnv_last_values) if value in value2member_map: # an alias to an existing member + member = value2member_map[value] redirect = property() + redirect.member = member redirect.__set_name__(enum_class, name) setattr(enum_class, name, redirect) - member_map[name] = value2member_map[value] + member_map[name] = member else: # create the member if use_args: @@ -1732,6 +1766,7 @@ def convert_class(cls): member.__objclass__ = enum_class member.__init__(value) redirect = property() + redirect.member = member redirect.__set_name__(enum_class, name) setattr(enum_class, name, redirect) member_map[name] = member @@ -1760,10 +1795,12 @@ def convert_class(cls): value = value.value if value in value2member_map: # an alias to an existing member + member = value2member_map[value] redirect = property() + redirect.member = member redirect.__set_name__(enum_class, name) setattr(enum_class, name, redirect) - member_map[name] = value2member_map[value] + member_map[name] = member else: # create the member if use_args: @@ -1780,6 +1817,7 @@ def convert_class(cls): member.__init__(value) member._sort_order_ = len(member_names) redirect = property() + redirect.member = member redirect.__set_name__(enum_class, name) setattr(enum_class, name, redirect) member_map[name] = member @@ -1903,8 +1941,8 @@ def _test_simple_enum(checked_enum, simple_enum): ... RED = auto() ... GREEN = auto() ... BLUE = auto() - >>> # TODO: RUSTPYTHON - >>> # _test_simple_enum(CheckedColor, Color) + ... # TODO: RUSTPYTHON + >>> _test_simple_enum(CheckedColor, Color) # doctest: +SKIP If differences are found, a :exc:`TypeError` is raised. """ diff --git a/Lib/functools.py b/Lib/functools.py index 8decc874e1..2ae4290f98 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -10,9 +10,9 @@ # See C source code for _functools credits/copyright __all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES', - 'total_ordering', 'cmp_to_key', 'lru_cache', 'reduce', 'partial', - 'partialmethod', 'singledispatch', 'singledispatchmethod', - "cached_property"] + 'total_ordering', 'cache', 'cmp_to_key', 'lru_cache', 'reduce', + 'partial', 'partialmethod', 'singledispatch', 'singledispatchmethod', + 'cached_property'] from abc import get_cache_token from collections import namedtuple @@ -30,7 +30,7 @@ # wrapper functions that can handle naive introspection WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__', - '__annotations__') + '__annotations__', '__type_params__') WRAPPER_UPDATES = ('__dict__',) def update_wrapper(wrapper, wrapped, @@ -86,82 +86,86 @@ def wraps(wrapped, # infinite recursion that could occur when the operator dispatch logic # detects a NotImplemented result and then calls a reflected method. -def _gt_from_lt(self, other, NotImplemented=NotImplemented): +def _gt_from_lt(self, other): 'Return a > b. Computed by @total_ordering from (not a < b) and (a != b).' - op_result = self.__lt__(other) + op_result = type(self).__lt__(self, other) if op_result is NotImplemented: return op_result return not op_result and self != other -def _le_from_lt(self, other, NotImplemented=NotImplemented): +def _le_from_lt(self, other): 'Return a <= b. Computed by @total_ordering from (a < b) or (a == b).' - op_result = self.__lt__(other) + op_result = type(self).__lt__(self, other) + if op_result is NotImplemented: + return op_result return op_result or self == other -def _ge_from_lt(self, other, NotImplemented=NotImplemented): +def _ge_from_lt(self, other): 'Return a >= b. Computed by @total_ordering from (not a < b).' - op_result = self.__lt__(other) + op_result = type(self).__lt__(self, other) if op_result is NotImplemented: return op_result return not op_result -def _ge_from_le(self, other, NotImplemented=NotImplemented): +def _ge_from_le(self, other): 'Return a >= b. Computed by @total_ordering from (not a <= b) or (a == b).' - op_result = self.__le__(other) + op_result = type(self).__le__(self, other) if op_result is NotImplemented: return op_result return not op_result or self == other -def _lt_from_le(self, other, NotImplemented=NotImplemented): +def _lt_from_le(self, other): 'Return a < b. Computed by @total_ordering from (a <= b) and (a != b).' - op_result = self.__le__(other) + op_result = type(self).__le__(self, other) if op_result is NotImplemented: return op_result return op_result and self != other -def _gt_from_le(self, other, NotImplemented=NotImplemented): +def _gt_from_le(self, other): 'Return a > b. Computed by @total_ordering from (not a <= b).' - op_result = self.__le__(other) + op_result = type(self).__le__(self, other) if op_result is NotImplemented: return op_result return not op_result -def _lt_from_gt(self, other, NotImplemented=NotImplemented): +def _lt_from_gt(self, other): 'Return a < b. Computed by @total_ordering from (not a > b) and (a != b).' - op_result = self.__gt__(other) + op_result = type(self).__gt__(self, other) if op_result is NotImplemented: return op_result return not op_result and self != other -def _ge_from_gt(self, other, NotImplemented=NotImplemented): +def _ge_from_gt(self, other): 'Return a >= b. Computed by @total_ordering from (a > b) or (a == b).' - op_result = self.__gt__(other) + op_result = type(self).__gt__(self, other) + if op_result is NotImplemented: + return op_result return op_result or self == other -def _le_from_gt(self, other, NotImplemented=NotImplemented): +def _le_from_gt(self, other): 'Return a <= b. Computed by @total_ordering from (not a > b).' - op_result = self.__gt__(other) + op_result = type(self).__gt__(self, other) if op_result is NotImplemented: return op_result return not op_result -def _le_from_ge(self, other, NotImplemented=NotImplemented): +def _le_from_ge(self, other): 'Return a <= b. Computed by @total_ordering from (not a >= b) or (a == b).' - op_result = self.__ge__(other) + op_result = type(self).__ge__(self, other) if op_result is NotImplemented: return op_result return not op_result or self == other -def _gt_from_ge(self, other, NotImplemented=NotImplemented): +def _gt_from_ge(self, other): 'Return a > b. Computed by @total_ordering from (a >= b) and (a != b).' - op_result = self.__ge__(other) + op_result = type(self).__ge__(self, other) if op_result is NotImplemented: return op_result return op_result and self != other -def _lt_from_ge(self, other, NotImplemented=NotImplemented): +def _lt_from_ge(self, other): 'Return a < b. Computed by @total_ordering from (not a >= b).' - op_result = self.__ge__(other) + op_result = type(self).__ge__(self, other) if op_result is NotImplemented: return op_result return not op_result @@ -232,14 +236,14 @@ def __ge__(self, other): def reduce(function, sequence, initial=_initial_missing): """ - reduce(function, sequence[, initial]) -> value + reduce(function, iterable[, initial]) -> value - Apply a function of two arguments cumulatively to the items of a sequence, - from left to right, so as to reduce the sequence to a single value. - For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates + Apply a function of two arguments cumulatively to the items of a sequence + or iterable, from left to right, so as to reduce the iterable to a single + value. For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates ((((1+2)+3)+4)+5). If initial is present, it is placed before the items - of the sequence in the calculation, and serves as a default when the - sequence is empty. + of the iterable in the calculation, and serves as a default when the + iterable is empty. """ it = iter(sequence) @@ -248,7 +252,8 @@ def reduce(function, sequence, initial=_initial_missing): try: value = next(it) except StopIteration: - raise TypeError("reduce() of empty sequence with no initial value") from None + raise TypeError( + "reduce() of empty iterable with no initial value") from None else: value = initial @@ -347,23 +352,7 @@ class partialmethod(object): callables as instance methods. """ - def __init__(*args, **keywords): - if len(args) >= 2: - self, func, *args = args - elif not args: - raise TypeError("descriptor '__init__' of partialmethod " - "needs an argument") - elif 'func' in keywords: - func = keywords.pop('func') - self, *args = args - import warnings - warnings.warn("Passing 'func' as keyword argument is deprecated", - DeprecationWarning, stacklevel=2) - else: - raise TypeError("type 'partialmethod' takes at least one argument, " - "got %d" % (len(args)-1)) - args = tuple(args) - + def __init__(self, func, /, *args, **keywords): if not callable(func) and not hasattr(func, "__get__"): raise TypeError("{!r} is not callable or a descriptor" .format(func)) @@ -381,7 +370,6 @@ def __init__(*args, **keywords): self.func = func self.args = args self.keywords = keywords - __init__.__text_signature__ = '($self, func, /, *args, **keywords)' def __repr__(self): args = ", ".join(map(repr, self.args)) @@ -427,6 +415,7 @@ def __isabstractmethod__(self): __class_getitem__ = classmethod(GenericAlias) + # Helper functions def _unwrap_partial(func): @@ -503,7 +492,7 @@ def lru_cache(maxsize=128, typed=False): with f.cache_info(). Clear the cache and statistics with f.cache_clear(). Access the underlying function with f.__wrapped__. - See: http://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU) + See: https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU) """ @@ -520,6 +509,7 @@ def lru_cache(maxsize=128, typed=False): # The user_function was passed in directly via the maxsize argument user_function, maxsize = maxsize, 128 wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo) + wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed} return update_wrapper(wrapper, user_function) elif maxsize is not None: raise TypeError( @@ -527,6 +517,7 @@ def lru_cache(maxsize=128, typed=False): def decorating_function(user_function): wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo) + wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed} return update_wrapper(wrapper, user_function) return decorating_function @@ -653,6 +644,15 @@ def cache_clear(): pass +################################################################################ +### cache -- simplified access to the infinity cache +################################################################################ + +def cache(user_function, /): + 'Simple lightweight unbounded cache. Sometimes called "memoize".' + return lru_cache(maxsize=None)(user_function) + + ################################################################################ ### singledispatch() - single-dispatch generic function decorator ################################################################################ @@ -660,7 +660,7 @@ def cache_clear(): def _c3_merge(sequences): """Merges MROs in *sequences* to a single MRO using the C3 algorithm. - Adapted from http://www.python.org/download/releases/2.3/mro/. + Adapted from https://www.python.org/download/releases/2.3/mro/. """ result = [] @@ -740,6 +740,7 @@ def _compose_mro(cls, types): # Remove entries which are already present in the __mro__ or unrelated. def is_related(typ): return (typ not in bases and hasattr(typ, '__mro__') + and not isinstance(typ, GenericAlias) and issubclass(cls, typ)) types = [n for n in types if is_related(n)] # Remove entries which are strict bases of other entries (they will end up @@ -837,6 +838,17 @@ def dispatch(cls): dispatch_cache[cls] = impl return impl + def _is_union_type(cls): + from typing import get_origin, Union + return get_origin(cls) in {Union, types.UnionType} + + def _is_valid_dispatch_type(cls): + if isinstance(cls, type): + return True + from typing import get_args + return (_is_union_type(cls) and + all(isinstance(arg, type) for arg in get_args(cls))) + def register(cls, func=None): """generic_func.register(cls, func) -> func @@ -844,9 +856,15 @@ def register(cls, func=None): """ nonlocal cache_token - if func is None: - if isinstance(cls, type): + if _is_valid_dispatch_type(cls): + if func is None: return lambda f: register(cls, f) + else: + if func is not None: + raise TypeError( + f"Invalid first argument to `register()`. " + f"{cls!r} is not a class or union type." + ) ann = getattr(cls, '__annotations__', {}) if not ann: raise TypeError( @@ -859,12 +877,25 @@ def register(cls, func=None): # only import typing if annotation parsing is necessary from typing import get_type_hints argname, cls = next(iter(get_type_hints(func).items())) - if not isinstance(cls, type): - raise TypeError( - f"Invalid annotation for {argname!r}. " - f"{cls!r} is not a class." - ) - registry[cls] = func + if not _is_valid_dispatch_type(cls): + if _is_union_type(cls): + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} not all arguments are classes." + ) + else: + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} is not a class." + ) + + if _is_union_type(cls): + from typing import get_args + + for arg in get_args(cls): + registry[arg] = func + else: + registry[cls] = func if cache_token is None and hasattr(cls, '__abstractmethods__'): cache_token = get_cache_token() dispatch_cache.clear() @@ -925,18 +956,16 @@ def __isabstractmethod__(self): ################################################################################ -### cached_property() - computed once per instance, cached as attribute +### cached_property() - property result cached as instance attribute ################################################################################ _NOT_FOUND = object() - class cached_property: def __init__(self, func): self.func = func self.attrname = None self.__doc__ = func.__doc__ - self.lock = RLock() def __set_name__(self, owner, name): if self.attrname is None: @@ -963,19 +992,15 @@ def __get__(self, instance, owner=None): raise TypeError(msg) from None val = cache.get(self.attrname, _NOT_FOUND) if val is _NOT_FOUND: - with self.lock: - # check if another thread filled cache while we awaited lock - val = cache.get(self.attrname, _NOT_FOUND) - if val is _NOT_FOUND: - val = self.func(instance) - try: - cache[self.attrname] = val - except TypeError: - msg = ( - f"The '__dict__' attribute on {type(instance).__name__!r} instance " - f"does not support item assignment for caching {self.attrname!r} property." - ) - raise TypeError(msg) from None + val = self.func(instance) + try: + cache[self.attrname] = val + except TypeError: + msg = ( + f"The '__dict__' attribute on {type(instance).__name__!r} instance " + f"does not support item assignment for caching {self.attrname!r} property." + ) + raise TypeError(msg) from None return val __class_getitem__ = classmethod(GenericAlias) diff --git a/Lib/importlib/__init__.py b/Lib/importlib/__init__.py index ce61883288..707c081cb2 100644 --- a/Lib/importlib/__init__.py +++ b/Lib/importlib/__init__.py @@ -70,41 +70,6 @@ def invalidate_caches(): finder.invalidate_caches() -def find_loader(name, path=None): - """Return the loader for the specified module. - - This is a backward-compatible wrapper around find_spec(). - - This function is deprecated in favor of importlib.util.find_spec(). - - """ - warnings.warn('Deprecated since Python 3.4 and slated for removal in ' - 'Python 3.12; use importlib.util.find_spec() instead', - DeprecationWarning, stacklevel=2) - try: - loader = sys.modules[name].__loader__ - if loader is None: - raise ValueError('{}.__loader__ is None'.format(name)) - else: - return loader - except KeyError: - pass - except AttributeError: - raise ValueError('{}.__loader__ is not set'.format(name)) from None - - spec = _bootstrap._find_spec(name, path) - # We won't worry about malformed specs (missing attributes). - if spec is None: - return None - if spec.loader is None: - if spec.submodule_search_locations is None: - raise ImportError('spec for {} missing loader'.format(name), - name=name) - raise ImportError('namespace packages do not have loaders', - name=name) - return spec.loader - - def import_module(name, package=None): """Import a module. @@ -116,9 +81,8 @@ def import_module(name, package=None): level = 0 if name.startswith('.'): if not package: - msg = ("the 'package' argument is required to perform a relative " - "import for {!r}") - raise TypeError(msg.format(name)) + raise TypeError("the 'package' argument is required to perform a " + f"relative import for {name!r}") for character in name: if character != '.': break @@ -144,8 +108,7 @@ def reload(module): raise TypeError("reload() argument must be a module") if sys.modules.get(name) is not module: - msg = "module {} not in sys.modules" - raise ImportError(msg.format(name), name=name) + raise ImportError(f"module {name} not in sys.modules", name=name) if name in _RELOADING: return _RELOADING[name] _RELOADING[name] = module @@ -155,8 +118,7 @@ def reload(module): try: parent = sys.modules[parent_name] except KeyError: - msg = "parent {!r} not in sys.modules" - raise ImportError(msg.format(parent_name), + raise ImportError(f"parent {parent_name!r} not in sys.modules", name=parent_name) from None else: pkgpath = parent.__path__ diff --git a/Lib/importlib/_abc.py b/Lib/importlib/_abc.py index f80348fc7f..693b466112 100644 --- a/Lib/importlib/_abc.py +++ b/Lib/importlib/_abc.py @@ -1,7 +1,6 @@ """Subset of importlib.abc used to reduce importlib.util imports.""" from . import _bootstrap import abc -import warnings class Loader(metaclass=abc.ABCMeta): @@ -38,17 +37,3 @@ def load_module(self, fullname): raise ImportError # Warning implemented in _load_module_shim(). return _bootstrap._load_module_shim(self, fullname) - - def module_repr(self, module): - """Return a module's repr. - - Used by the module type when the method does not raise - NotImplementedError. - - This method is deprecated. - - """ - warnings.warn("importlib.abc.Loader.module_repr() is deprecated and " - "slated for removal in Python 3.12", DeprecationWarning) - # The exception will cause ModuleType.__repr__ to ignore this method. - raise NotImplementedError diff --git a/Lib/importlib/_bootstrap.py b/Lib/importlib/_bootstrap.py index b1fdad8e6d..093a0b8245 100644 --- a/Lib/importlib/_bootstrap.py +++ b/Lib/importlib/_bootstrap.py @@ -51,17 +51,178 @@ def _new_module(name): # Module-level locking ######################################################## -# A dict mapping module names to weakrefs of _ModuleLock instances -# Dictionary protected by the global import lock +# For a list that can have a weakref to it. +class _List(list): + pass + + +# Copied from weakref.py with some simplifications and modifications unique to +# bootstrapping importlib. Many methods were simply deleting for simplicity, so if they +# are needed in the future they may work if simply copied back in. +class _WeakValueDictionary: + + def __init__(self): + self_weakref = _weakref.ref(self) + + # Inlined to avoid issues with inheriting from _weakref.ref before _weakref is + # set by _setup(). Since there's only one instance of this class, this is + # not expensive. + class KeyedRef(_weakref.ref): + + __slots__ = "key", + + def __new__(type, ob, key): + self = super().__new__(type, ob, type.remove) + self.key = key + return self + + def __init__(self, ob, key): + super().__init__(ob, self.remove) + + @staticmethod + def remove(wr): + nonlocal self_weakref + + self = self_weakref() + if self is not None: + if self._iterating: + self._pending_removals.append(wr.key) + else: + _weakref._remove_dead_weakref(self.data, wr.key) + + self._KeyedRef = KeyedRef + self.clear() + + def clear(self): + self._pending_removals = [] + self._iterating = set() + self.data = {} + + def _commit_removals(self): + pop = self._pending_removals.pop + d = self.data + while True: + try: + key = pop() + except IndexError: + return + _weakref._remove_dead_weakref(d, key) + + def get(self, key, default=None): + if self._pending_removals: + self._commit_removals() + try: + wr = self.data[key] + except KeyError: + return default + else: + if (o := wr()) is None: + return default + else: + return o + + def setdefault(self, key, default=None): + try: + o = self.data[key]() + except KeyError: + o = None + if o is None: + if self._pending_removals: + self._commit_removals() + self.data[key] = self._KeyedRef(default, key) + return default + else: + return o + + +# A dict mapping module names to weakrefs of _ModuleLock instances. +# Dictionary protected by the global import lock. _module_locks = {} -# A dict mapping thread ids to _ModuleLock instances -_blocking_on = {} + +# A dict mapping thread IDs to weakref'ed lists of _ModuleLock instances. +# This maps a thread to the module locks it is blocking on acquiring. The +# values are lists because a single thread could perform a re-entrant import +# and be "in the process" of blocking on locks for more than one module. A +# thread can be "in the process" because a thread cannot actually block on +# acquiring more than one lock but it can have set up bookkeeping that reflects +# that it intends to block on acquiring more than one lock. +# +# The dictionary uses a WeakValueDictionary to avoid keeping unnecessary +# lists around, regardless of GC runs. This way there's no memory leak if +# the list is no longer needed (GH-106176). +_blocking_on = None + + +class _BlockingOnManager: + """A context manager responsible to updating ``_blocking_on``.""" + def __init__(self, thread_id, lock): + self.thread_id = thread_id + self.lock = lock + + def __enter__(self): + """Mark the running thread as waiting for self.lock. via _blocking_on.""" + # Interactions with _blocking_on are *not* protected by the global + # import lock here because each thread only touches the state that it + # owns (state keyed on its thread id). The global import lock is + # re-entrant (i.e., a single thread may take it more than once) so it + # wouldn't help us be correct in the face of re-entrancy either. + + self.blocked_on = _blocking_on.setdefault(self.thread_id, _List()) + self.blocked_on.append(self.lock) + + def __exit__(self, *args, **kwargs): + """Remove self.lock from this thread's _blocking_on list.""" + self.blocked_on.remove(self.lock) class _DeadlockError(RuntimeError): pass + +def _has_deadlocked(target_id, *, seen_ids, candidate_ids, blocking_on): + """Check if 'target_id' is holding the same lock as another thread(s). + + The search within 'blocking_on' starts with the threads listed in + 'candidate_ids'. 'seen_ids' contains any threads that are considered + already traversed in the search. + + Keyword arguments: + target_id -- The thread id to try to reach. + seen_ids -- A set of threads that have already been visited. + candidate_ids -- The thread ids from which to begin. + blocking_on -- A dict representing the thread/blocking-on graph. This may + be the same object as the global '_blocking_on' but it is + a parameter to reduce the impact that global mutable + state has on the result of this function. + """ + if target_id in candidate_ids: + # If we have already reached the target_id, we're done - signal that it + # is reachable. + return True + + # Otherwise, try to reach the target_id from each of the given candidate_ids. + for tid in candidate_ids: + if not (candidate_blocking_on := blocking_on.get(tid)): + # There are no edges out from this node, skip it. + continue + elif tid in seen_ids: + # bpo 38091: the chain of tid's we encounter here eventually leads + # to a fixed point or a cycle, but does not reach target_id. + # This means we would not actually deadlock. This can happen if + # other threads are at the beginning of acquire() below. + return False + seen_ids.add(tid) + + # Follow the edges out from this thread. + edges = [lock.owner for lock in candidate_blocking_on] + if _has_deadlocked(target_id, seen_ids=seen_ids, candidate_ids=edges, + blocking_on=blocking_on): + return True + + return False + + class _ModuleLock: """A recursive lock implementation which is able to detect deadlocks (e.g. thread 1 trying to take locks A then B, and thread 2 trying to @@ -69,33 +230,76 @@ class _ModuleLock: """ def __init__(self, name): - self.lock = _thread.allocate_lock() + # Create an RLock for protecting the import process for the + # corresponding module. Since it is an RLock, a single thread will be + # able to take it more than once. This is necessary to support + # re-entrancy in the import system that arises from (at least) signal + # handlers and the garbage collector. Consider the case of: + # + # import foo + # -> ... + # -> importlib._bootstrap._ModuleLock.acquire + # -> ... + # -> + # -> __del__ + # -> import foo + # -> ... + # -> importlib._bootstrap._ModuleLock.acquire + # -> _BlockingOnManager.__enter__ + # + # If a different thread than the running one holds the lock then the + # thread will have to block on taking the lock, which is what we want + # for thread safety. + self.lock = _thread.RLock() self.wakeup = _thread.allocate_lock() + + # The name of the module for which this is a lock. self.name = name + + # Can end up being set to None if this lock is not owned by any thread + # or the thread identifier for the owning thread. self.owner = None - self.count = 0 - self.waiters = 0 + + # Represent the number of times the owning thread has acquired this lock + # via a list of True. This supports RLock-like ("re-entrant lock") + # behavior, necessary in case a single thread is following a circular + # import dependency and needs to take the lock for a single module + # more than once. + # + # Counts are represented as a list of True because list.append(True) + # and list.pop() are both atomic and thread-safe in CPython and it's hard + # to find another primitive with the same properties. + self.count = [] + + # This is a count of the number of threads that are blocking on + # self.wakeup.acquire() awaiting to get their turn holding this module + # lock. When the module lock is released, if this is greater than + # zero, it is decremented and `self.wakeup` is released one time. The + # intent is that this will let one other thread make more progress on + # acquiring this module lock. This repeats until all the threads have + # gotten a turn. + # + # This is incremented in self.acquire() when a thread notices it is + # going to have to wait for another thread to finish. + # + # See the comment above count for explanation of the representation. + self.waiters = [] def has_deadlock(self): - # Deadlock avoidance for concurrent circular imports. - me = _thread.get_ident() - tid = self.owner - seen = set() - while True: - lock = _blocking_on.get(tid) - if lock is None: - return False - tid = lock.owner - if tid == me: - return True - if tid in seen: - # bpo 38091: the chain of tid's we encounter here - # eventually leads to a fixpoint or a cycle, but - # does not reach 'me'. This means we would not - # actually deadlock. This can happen if other - # threads are at the beginning of acquire() below. - return False - seen.add(tid) + # To avoid deadlocks for concurrent or re-entrant circular imports, + # look at _blocking_on to see if any threads are blocking + # on getting the import lock for any module for which the import lock + # is held by this thread. + return _has_deadlocked( + # Try to find this thread. + target_id=_thread.get_ident(), + seen_ids=set(), + # Start from the thread that holds the import lock for this + # module. + candidate_ids=[self.owner], + # Use the global "blocking on" state. + blocking_on=_blocking_on, + ) def acquire(self): """ @@ -104,39 +308,82 @@ def acquire(self): Otherwise, the lock is always acquired and True is returned. """ tid = _thread.get_ident() - _blocking_on[tid] = self - try: + with _BlockingOnManager(tid, self): while True: + # Protect interaction with state on self with a per-module + # lock. This makes it safe for more than one thread to try to + # acquire the lock for a single module at the same time. with self.lock: - if self.count == 0 or self.owner == tid: + if self.count == [] or self.owner == tid: + # If the lock for this module is unowned then we can + # take the lock immediately and succeed. If the lock + # for this module is owned by the running thread then + # we can also allow the acquire to succeed. This + # supports circular imports (thread T imports module A + # which imports module B which imports module A). self.owner = tid - self.count += 1 + self.count.append(True) return True + + # At this point we know the lock is held (because count != + # 0) by another thread (because owner != tid). We'll have + # to get in line to take the module lock. + + # But first, check to see if this thread would create a + # deadlock by acquiring this module lock. If it would + # then just stop with an error. + # + # It's not clear who is expected to handle this error. + # There is one handler in _lock_unlock_module but many + # times this method is called when entering the context + # manager _ModuleLockManager instead - so _DeadlockError + # will just propagate up to application code. + # + # This seems to be more than just a hypothetical - + # https://stackoverflow.com/questions/59509154 + # https://github.com/encode/django-rest-framework/issues/7078 if self.has_deadlock(): - raise _DeadlockError('deadlock detected by %r' % self) + raise _DeadlockError(f'deadlock detected by {self!r}') + + # Check to see if we're going to be able to acquire the + # lock. If we are going to have to wait then increment + # the waiters so `self.release` will know to unblock us + # later on. We do this part non-blockingly so we don't + # get stuck here before we increment waiters. We have + # this extra acquire call (in addition to the one below, + # outside the self.lock context manager) to make sure + # self.wakeup is held when the next acquire is called (so + # we block). This is probably needlessly complex and we + # should just take self.wakeup in the return codepath + # above. if self.wakeup.acquire(False): - self.waiters += 1 - # Wait for a release() call + self.waiters.append(None) + + # Now take the lock in a blocking fashion. This won't + # complete until the thread holding this lock + # (self.owner) calls self.release. self.wakeup.acquire() + + # Taking the lock has served its purpose (making us wait), so we can + # give it up now. We'll take it w/o blocking again on the + # next iteration around this 'while' loop. self.wakeup.release() - finally: - del _blocking_on[tid] def release(self): tid = _thread.get_ident() with self.lock: if self.owner != tid: raise RuntimeError('cannot release un-acquired lock') - assert self.count > 0 - self.count -= 1 - if self.count == 0: + assert len(self.count) > 0 + self.count.pop() + if not len(self.count): self.owner = None - if self.waiters: - self.waiters -= 1 + if len(self.waiters) > 0: + self.waiters.pop() self.wakeup.release() def __repr__(self): - return '_ModuleLock({!r}) at {}'.format(self.name, id(self)) + return f'_ModuleLock({self.name!r}) at {id(self)}' class _DummyModuleLock: @@ -157,7 +404,7 @@ def release(self): self.count -= 1 def __repr__(self): - return '_DummyModuleLock({!r}) at {}'.format(self.name, id(self)) + return f'_DummyModuleLock({self.name!r}) at {id(self)}' class _ModuleLockManager: @@ -254,7 +501,7 @@ def _requires_builtin(fxn): """Decorator to verify the named module is built-in.""" def _requires_builtin_wrapper(self, fullname): if fullname not in sys.builtin_module_names: - raise ImportError('{!r} is not a built-in module'.format(fullname), + raise ImportError(f'{fullname!r} is not a built-in module', name=fullname) return fxn(self, fullname) _wrap(_requires_builtin_wrapper, fxn) @@ -265,7 +512,7 @@ def _requires_frozen(fxn): """Decorator to verify the named module is frozen.""" def _requires_frozen_wrapper(self, fullname): if not _imp.is_frozen(fullname): - raise ImportError('{!r} is not a frozen module'.format(fullname), + raise ImportError(f'{fullname!r} is not a frozen module', name=fullname) return fxn(self, fullname) _wrap(_requires_frozen_wrapper, fxn) @@ -297,11 +544,6 @@ def _module_repr(module): loader = getattr(module, '__loader__', None) if spec := getattr(module, "__spec__", None): return _module_repr_from_spec(spec) - elif hasattr(loader, 'module_repr'): - try: - return loader.module_repr(module) - except Exception: - pass # Fall through to a catch-all which always succeeds. try: name = module.__name__ @@ -311,11 +553,11 @@ def _module_repr(module): filename = module.__file__ except AttributeError: if loader is None: - return ''.format(name) + return f'' else: - return ''.format(name, loader) + return f'' else: - return ''.format(name, filename) + return f'' class ModuleSpec: @@ -369,14 +611,12 @@ def __init__(self, name, loader, *, origin=None, loader_state=None, self._cached = None def __repr__(self): - args = ['name={!r}'.format(self.name), - 'loader={!r}'.format(self.loader)] + args = [f'name={self.name!r}', f'loader={self.loader!r}'] if self.origin is not None: - args.append('origin={!r}'.format(self.origin)) + args.append(f'origin={self.origin!r}') if self.submodule_search_locations is not None: - args.append('submodule_search_locations={}' - .format(self.submodule_search_locations)) - return '{}({})'.format(self.__class__.__name__, ', '.join(args)) + args.append(f'submodule_search_locations={self.submodule_search_locations}') + return f'{self.__class__.__name__}({", ".join(args)})' def __eq__(self, other): smsl = self.submodule_search_locations @@ -583,18 +823,17 @@ def module_from_spec(spec): def _module_repr_from_spec(spec): """Return the repr to use for the module.""" - # We mostly replicate _module_repr() using the spec attributes. name = '?' if spec.name is None else spec.name if spec.origin is None: if spec.loader is None: - return ''.format(name) + return f'' else: - return ''.format(name, spec.loader) + return f'' else: if spec.has_location: - return ''.format(name, spec.origin) + return f'' else: - return ''.format(spec.name, spec.origin) + return f'' # Used by importlib.reload() and _load_module_shim(). @@ -603,7 +842,7 @@ def _exec(spec, module): name = spec.name with _ModuleLockManager(name): if sys.modules.get(name) is not module: - msg = 'module {!r} not in sys.modules'.format(name) + msg = f'module {name!r} not in sys.modules' raise ImportError(msg, name=name) try: if spec.loader is None: @@ -735,46 +974,18 @@ class BuiltinImporter: _ORIGIN = "built-in" - @staticmethod - def module_repr(module): - """Return repr for the module. - - The method is deprecated. The import machinery does the job itself. - - """ - _warnings.warn("BuiltinImporter.module_repr() is deprecated and " - "slated for removal in Python 3.12", DeprecationWarning) - return f'' - @classmethod def find_spec(cls, fullname, path=None, target=None): - if path is not None: - return None if _imp.is_builtin(fullname): return spec_from_loader(fullname, cls, origin=cls._ORIGIN) else: return None - @classmethod - def find_module(cls, fullname, path=None): - """Find the built-in module. - - If 'path' is ever specified then the search is considered a failure. - - This method is deprecated. Use find_spec() instead. - - """ - _warnings.warn("BuiltinImporter.find_module() is deprecated and " - "slated for removal in Python 3.12; use find_spec() instead", - DeprecationWarning) - spec = cls.find_spec(fullname, path) - return spec.loader if spec is not None else None - @staticmethod def create_module(spec): """Create a built-in module""" if spec.name not in sys.builtin_module_names: - raise ImportError('{!r} is not a built-in module'.format(spec.name), + raise ImportError(f'{spec.name!r} is not a built-in module', name=spec.name) return _call_with_frames_removed(_imp.create_builtin, spec) @@ -815,17 +1026,6 @@ class FrozenImporter: _ORIGIN = "frozen" - @staticmethod - def module_repr(m): - """Return repr for the module. - - The method is deprecated. The import machinery does the job itself. - - """ - _warnings.warn("FrozenImporter.module_repr() is deprecated and " - "slated for removal in Python 3.12", DeprecationWarning) - return ''.format(m.__name__, FrozenImporter._ORIGIN) - @classmethod def _fix_up_module(cls, module): spec = module.__spec__ @@ -950,18 +1150,6 @@ def find_spec(cls, fullname, path=None, target=None): spec.submodule_search_locations.insert(0, pkgdir) return spec - @classmethod - def find_module(cls, fullname, path=None): - """Find a frozen module. - - This method is deprecated. Use find_spec() instead. - - """ - _warnings.warn("FrozenImporter.find_module() is deprecated and " - "slated for removal in Python 3.12; use find_spec() instead", - DeprecationWarning) - return cls if _imp.is_frozen(fullname) else None - @staticmethod def create_module(spec): """Set __file__, if able.""" @@ -1041,17 +1229,7 @@ def _resolve_name(name, package, level): if len(bits) < level: raise ImportError('attempted relative import beyond top-level package') base = bits[0] - return '{}.{}'.format(base, name) if name else base - - -def _find_spec_legacy(finder, name, path): - msg = (f"{_object_name(finder)}.find_spec() not found; " - "falling back to find_module()") - _warnings.warn(msg, ImportWarning) - loader = finder.find_module(name, path) - if loader is None: - return None - return spec_from_loader(name, loader) + return f'{base}.{name}' if name else base def _find_spec(name, path, target=None): @@ -1074,9 +1252,7 @@ def _find_spec(name, path, target=None): try: find_spec = finder.find_spec except AttributeError: - spec = _find_spec_legacy(finder, name, path) - if spec is None: - continue + continue else: spec = find_spec(name, path, target) if spec is not None: @@ -1104,7 +1280,7 @@ def _find_spec(name, path, target=None): def _sanity_check(name, package, level): """Verify arguments are "sane".""" if not isinstance(name, str): - raise TypeError('module name must be str, not {}'.format(type(name))) + raise TypeError(f'module name must be str, not {type(name)}') if level < 0: raise ValueError('level must be >= 0') if level > 0: @@ -1134,13 +1310,13 @@ def _find_and_load_unlocked(name, import_): try: path = parent_module.__path__ except AttributeError: - msg = (_ERR_MSG + '; {!r} is not a package').format(name, parent) + msg = f'{_ERR_MSG_PREFIX}{name!r}; {parent!r} is not a package' raise ModuleNotFoundError(msg, name=name) from None parent_spec = parent_module.__spec__ child = name.rpartition('.')[2] spec = _find_spec(name, path) if spec is None: - raise ModuleNotFoundError(_ERR_MSG.format(name), name=name) + raise ModuleNotFoundError(f'{_ERR_MSG_PREFIX}{name!r}', name=name) else: if parent_spec: # Temporarily add child we are currently importing to parent's @@ -1185,8 +1361,7 @@ def _find_and_load(name, import_): _lock_unlock_module(name) if module is None: - message = ('import of {} halted; ' - 'None in sys.modules'.format(name)) + message = f'import of {name} halted; None in sys.modules' raise ModuleNotFoundError(message, name=name) return module @@ -1230,7 +1405,7 @@ def _handle_fromlist(module, fromlist, import_, *, recursive=False): _handle_fromlist(module, module.__all__, import_, recursive=True) elif not hasattr(module, x): - from_name = '{}.{}'.format(module.__name__, x) + from_name = f'{module.__name__}.{x}' try: _call_with_frames_removed(import_, from_name) except ModuleNotFoundError as exc: @@ -1257,7 +1432,7 @@ def _calc___package__(globals): if spec is not None and package != spec.parent: _warnings.warn("__package__ != __spec__.parent " f"({package!r} != {spec.parent!r})", - ImportWarning, stacklevel=3) + DeprecationWarning, stacklevel=3) return package elif spec is not None: return spec.parent @@ -1323,7 +1498,7 @@ def _setup(sys_module, _imp_module): modules, those two modules must be explicitly passed in. """ - global _imp, sys + global _imp, sys, _blocking_on _imp = _imp_module sys = sys_module @@ -1351,6 +1526,9 @@ def _setup(sys_module, _imp_module): builtin_module = sys.modules[builtin_name] setattr(self_module, builtin_name, builtin_module) + # Instantiation requires _weakref to have been set. + _blocking_on = _WeakValueDictionary() + def _install(sys_module, _imp_module): """Install importers for builtin and frozen modules""" diff --git a/Lib/importlib/_bootstrap_external.py b/Lib/importlib/_bootstrap_external.py index f603a89f7f..73ac4405cb 100644 --- a/Lib/importlib/_bootstrap_external.py +++ b/Lib/importlib/_bootstrap_external.py @@ -182,12 +182,22 @@ def _path_isabs(path): return path.startswith(path_separators) +def _path_abspath(path): + """Replacement for os.path.abspath.""" + if not _path_isabs(path): + for sep in path_separators: + path = path.removeprefix(f".{sep}") + return _path_join(_os.getcwd(), path) + else: + return path + + def _write_atomic(path, data, mode=0o666): """Best-effort function to write data to a path atomically. Be prepared to handle a FileExistsError if concurrent writing of the temporary file is attempted.""" # id() is used to generate a pseudo-random filename. - path_tmp = '{}.{}'.format(path, id(path)) + path_tmp = f'{path}.{id(path)}' fd = _os.open(path_tmp, _os.O_EXCL | _os.O_CREAT | _os.O_WRONLY, mode & 0o666) try: @@ -403,11 +413,45 @@ def _write_atomic(path, data, mode=0o666): # Python 3.11a7 3492 (make POP_JUMP_IF_NONE/NOT_NONE/TRUE/FALSE relative) # Python 3.11a7 3493 (Make JUMP_IF_TRUE_OR_POP/JUMP_IF_FALSE_OR_POP relative) # Python 3.11a7 3494 (New location info table) -# Python 3.11b4 3495 (Set line number of module's RESUME instr to 0 per PEP 626) -# Python 3.12 will start with magic number 3500 - +# Python 3.12a1 3500 (Remove PRECALL opcode) +# Python 3.12a1 3501 (YIELD_VALUE oparg == stack_depth) +# Python 3.12a1 3502 (LOAD_FAST_CHECK, no NULL-check in LOAD_FAST) +# Python 3.12a1 3503 (Shrink LOAD_METHOD cache) +# Python 3.12a1 3504 (Merge LOAD_METHOD back into LOAD_ATTR) +# Python 3.12a1 3505 (Specialization/Cache for FOR_ITER) +# Python 3.12a1 3506 (Add BINARY_SLICE and STORE_SLICE instructions) +# Python 3.12a1 3507 (Set lineno of module's RESUME to 0) +# Python 3.12a1 3508 (Add CLEANUP_THROW) +# Python 3.12a1 3509 (Conditional jumps only jump forward) +# Python 3.12a2 3510 (FOR_ITER leaves iterator on the stack) +# Python 3.12a2 3511 (Add STOPITERATION_ERROR instruction) +# Python 3.12a2 3512 (Remove all unused consts from code objects) +# Python 3.12a4 3513 (Add CALL_INTRINSIC_1 instruction, removed STOPITERATION_ERROR, PRINT_EXPR, IMPORT_STAR) +# Python 3.12a4 3514 (Remove ASYNC_GEN_WRAP, LIST_TO_TUPLE, and UNARY_POSITIVE) +# Python 3.12a5 3515 (Embed jump mask in COMPARE_OP oparg) +# Python 3.12a5 3516 (Add COMPARE_AND_BRANCH instruction) +# Python 3.12a5 3517 (Change YIELD_VALUE oparg to exception block depth) +# Python 3.12a6 3518 (Add RETURN_CONST instruction) +# Python 3.12a6 3519 (Modify SEND instruction) +# Python 3.12a6 3520 (Remove PREP_RERAISE_STAR, add CALL_INTRINSIC_2) +# Python 3.12a7 3521 (Shrink the LOAD_GLOBAL caches) +# Python 3.12a7 3522 (Removed JUMP_IF_FALSE_OR_POP/JUMP_IF_TRUE_OR_POP) +# Python 3.12a7 3523 (Convert COMPARE_AND_BRANCH back to COMPARE_OP) +# Python 3.12a7 3524 (Shrink the BINARY_SUBSCR caches) +# Python 3.12b1 3525 (Shrink the CALL caches) +# Python 3.12b1 3526 (Add instrumentation support) +# Python 3.12b1 3527 (Add LOAD_SUPER_ATTR) +# Python 3.12b1 3528 (Add LOAD_SUPER_ATTR_METHOD specialization) +# Python 3.12b1 3529 (Inline list/dict/set comprehensions) +# Python 3.12b1 3530 (Shrink the LOAD_SUPER_ATTR caches) +# Python 3.12b1 3531 (Add PEP 695 changes) + +# Python 3.13 will start with 3550 + +# Please don't copy-paste the same pre-release tag for new entries above!!! +# You should always use the *upcoming* tag. For example, if 3.12a6 came out +# a week ago, I should put "Python 3.12a7" next to my new magic number. -# # MAGIC must change whenever the bytecode emitted by the compiler may no # longer be understood by older implementations of the eval loop (usually # due to the addition of new opcodes). @@ -417,7 +461,7 @@ def _write_atomic(path, data, mode=0o666): # Whenever MAGIC_NUMBER is changed, the ranges in the magic_values array # in PC/launcher.c must also be updated. -MAGIC_NUMBER = (3495).to_bytes(2, 'little') + b'\r\n' +MAGIC_NUMBER = (3531).to_bytes(2, 'little') + b'\r\n' _RAW_MAGIC_NUMBER = int.from_bytes(MAGIC_NUMBER, 'little') # For import.c @@ -474,8 +518,8 @@ def cache_from_source(path, debug_override=None, *, optimization=None): optimization = str(optimization) if optimization != '': if not optimization.isalnum(): - raise ValueError('{!r} is not alphanumeric'.format(optimization)) - almost_filename = '{}.{}{}'.format(almost_filename, _OPT, optimization) + raise ValueError(f'{optimization!r} is not alphanumeric') + almost_filename = f'{almost_filename}.{_OPT}{optimization}' filename = almost_filename + BYTECODE_SUFFIXES[0] if sys.pycache_prefix is not None: # We need an absolute path to the py file to avoid the possibility of @@ -486,8 +530,7 @@ def cache_from_source(path, debug_override=None, *, optimization=None): # make it absolute (`C:\Somewhere\Foo\Bar`), then make it root-relative # (`Somewhere\Foo\Bar`), so we end up placing the bytecode file in an # unambiguous `C:\Bytecode\Somewhere\Foo\Bar\`. - if not _path_isabs(head): - head = _path_join(_os.getcwd(), head) + head = _path_abspath(head) # Strip initial drive from a Windows path. We know we have an absolute # path here, so the second part of the check rules out a POSIX path that @@ -619,26 +662,6 @@ def _wrap(new, old): return _check_name_wrapper -def _find_module_shim(self, fullname): - """Try to find a loader for the specified module by delegating to - self.find_loader(). - - This method is deprecated in favor of finder.find_spec(). - - """ - _warnings.warn("find_module() is deprecated and " - "slated for removal in Python 3.12; use find_spec() instead", - DeprecationWarning) - # Call find_loader(). If it returns a string (indicating this - # is a namespace package portion), generate a warning and - # return None. - loader, portions = self.find_loader(fullname) - if loader is None and len(portions): - msg = 'Not importing directory {}: missing __init__' - _warnings.warn(msg.format(portions[0]), ImportWarning) - return loader - - def _classify_pyc(data, name, exc_details): """Perform basic validity checking of a pyc header and return the flags field, which determines how the pyc should be further validated against the source. @@ -733,7 +756,7 @@ def _compile_bytecode(data, name=None, bytecode_path=None, source_path=None): _imp._fix_co_filename(code, source_path) return code else: - raise ImportError('Non-code object in {!r}'.format(bytecode_path), + raise ImportError(f'Non-code object in {bytecode_path!r}', name=name, path=bytecode_path) @@ -800,11 +823,10 @@ def spec_from_file_location(name, location=None, *, loader=None, pass else: location = _os.fspath(location) - if not _path_isabs(location): - try: - location = _path_join(_os.getcwd(), location) - except OSError: - pass + try: + location = _path_abspath(location) + except OSError: + pass # If the location is on the filesystem, but doesn't actually exist, # we could return None here, indicating that the location is not @@ -846,6 +868,54 @@ def spec_from_file_location(name, location=None, *, loader=None, return spec +def _bless_my_loader(module_globals): + """Helper function for _warnings.c + + See GH#97850 for details. + """ + # 2022-10-06(warsaw): For now, this helper is only used in _warnings.c and + # that use case only has the module globals. This function could be + # extended to accept either that or a module object. However, in the + # latter case, it would be better to raise certain exceptions when looking + # at a module, which should have either a __loader__ or __spec__.loader. + # For backward compatibility, it is possible that we'll get an empty + # dictionary for the module globals, and that cannot raise an exception. + if not isinstance(module_globals, dict): + return None + + missing = object() + loader = module_globals.get('__loader__', None) + spec = module_globals.get('__spec__', missing) + + if loader is None: + if spec is missing: + # If working with a module: + # raise AttributeError('Module globals is missing a __spec__') + return None + elif spec is None: + raise ValueError('Module globals is missing a __spec__.loader') + + spec_loader = getattr(spec, 'loader', missing) + + if spec_loader in (missing, None): + if loader is None: + exc = AttributeError if spec_loader is missing else ValueError + raise exc('Module globals is missing a __spec__.loader') + _warnings.warn( + 'Module globals is missing a __spec__.loader', + DeprecationWarning) + spec_loader = loader + + assert spec_loader is not None + if loader is not None and loader != spec_loader: + _warnings.warn( + 'Module globals; __loader__ != __spec__.loader', + DeprecationWarning) + return loader + + return spec_loader + + # Loaders ##################################################################### class WindowsRegistryFinder: @@ -898,22 +968,6 @@ def find_spec(cls, fullname, path=None, target=None): origin=filepath) return spec - @classmethod - def find_module(cls, fullname, path=None): - """Find module named in the registry. - - This method is deprecated. Use find_spec() instead. - - """ - _warnings.warn("WindowsRegistryFinder.find_module() is deprecated and " - "slated for removal in Python 3.12; use find_spec() instead", - DeprecationWarning) - spec = cls.find_spec(fullname, path) - if spec is not None: - return spec.loader - else: - return None - class _LoaderBasics: @@ -935,8 +989,8 @@ def exec_module(self, module): """Execute the module.""" code = self.get_code(module.__name__) if code is None: - raise ImportError('cannot load module {!r} when get_code() ' - 'returns None'.format(module.__name__)) + raise ImportError(f'cannot load module {module.__name__!r} when ' + 'get_code() returns None') _bootstrap._call_with_frames_removed(exec, code, module.__dict__) def load_module(self, fullname): @@ -1077,7 +1131,8 @@ def get_code(self, fullname): source_mtime is not None): if hash_based: if source_hash is None: - source_hash = _imp.source_hash(source_bytes) + source_hash = _imp.source_hash(_RAW_MAGIC_NUMBER, + source_bytes) data = _code_to_hash_pyc(code_object, source_hash, check_source) else: data = _code_to_timestamp_pyc(code_object, source_mtime, @@ -1321,7 +1376,7 @@ def __len__(self): return len(self._recalculate()) def __repr__(self): - return '_NamespacePath({!r})'.format(self._path) + return f'_NamespacePath({self._path!r})' def __contains__(self, item): return item in self._recalculate() @@ -1332,22 +1387,11 @@ def append(self, item): # This class is actually exposed publicly in a namespace package's __loader__ # attribute, so it should be available through a non-private name. -# https://bugs.python.org/issue35673 +# https://github.com/python/cpython/issues/92054 class NamespaceLoader: def __init__(self, name, path, path_finder): self._path = _NamespacePath(name, path, path_finder) - @staticmethod - def module_repr(module): - """Return repr for the module. - - The method is deprecated. The import machinery does the job itself. - - """ - _warnings.warn("NamespaceLoader.module_repr() is deprecated and " - "slated for removal in Python 3.12", DeprecationWarning) - return ''.format(module.__name__) - def is_package(self, fullname): return True @@ -1440,27 +1484,6 @@ def _path_importer_cache(cls, path): sys.path_importer_cache[path] = finder return finder - @classmethod - def _legacy_get_spec(cls, fullname, finder): - # This would be a good place for a DeprecationWarning if - # we ended up going that route. - if hasattr(finder, 'find_loader'): - msg = (f"{_bootstrap._object_name(finder)}.find_spec() not found; " - "falling back to find_loader()") - _warnings.warn(msg, ImportWarning) - loader, portions = finder.find_loader(fullname) - else: - msg = (f"{_bootstrap._object_name(finder)}.find_spec() not found; " - "falling back to find_module()") - _warnings.warn(msg, ImportWarning) - loader = finder.find_module(fullname) - portions = [] - if loader is not None: - return _bootstrap.spec_from_loader(fullname, loader) - spec = _bootstrap.ModuleSpec(fullname, None) - spec.submodule_search_locations = portions - return spec - @classmethod def _get_spec(cls, fullname, path, target=None): """Find the loader or namespace_path for this module/package name.""" @@ -1472,10 +1495,7 @@ def _get_spec(cls, fullname, path, target=None): continue finder = cls._path_importer_cache(entry) if finder is not None: - if hasattr(finder, 'find_spec'): - spec = finder.find_spec(fullname, target) - else: - spec = cls._legacy_get_spec(fullname, finder) + spec = finder.find_spec(fullname, target) if spec is None: continue if spec.loader is not None: @@ -1517,22 +1537,6 @@ def find_spec(cls, fullname, path=None, target=None): else: return spec - @classmethod - def find_module(cls, fullname, path=None): - """find the module on sys.path or 'path' based on sys.path_hooks and - sys.path_importer_cache. - - This method is deprecated. Use find_spec() instead. - - """ - _warnings.warn("PathFinder.find_module() is deprecated and " - "slated for removal in Python 3.12; use find_spec() instead", - DeprecationWarning) - spec = cls.find_spec(fullname, path) - if spec is None: - return None - return spec.loader - @staticmethod def find_distributions(*args, **kwargs): """ @@ -1567,10 +1571,8 @@ def __init__(self, path, *loader_details): # Base (directory) path if not path or path == '.': self.path = _os.getcwd() - elif not _path_isabs(path): - self.path = _path_join(_os.getcwd(), path) else: - self.path = path + self.path = _path_abspath(path) self._path_mtime = -1 self._path_cache = set() self._relaxed_path_cache = set() @@ -1579,23 +1581,6 @@ def invalidate_caches(self): """Invalidate the directory mtime.""" self._path_mtime = -1 - find_module = _find_module_shim - - def find_loader(self, fullname): - """Try to find a loader for the specified module, or the namespace - package portions. Returns (loader, list-of-portions). - - This method is deprecated. Use find_spec() instead. - - """ - _warnings.warn("FileFinder.find_loader() is deprecated and " - "slated for removal in Python 3.12; use find_spec() instead", - DeprecationWarning) - spec = self.find_spec(fullname) - if spec is None: - return None, [] - return spec.loader, spec.submodule_search_locations or [] - def _get_spec(self, loader_class, fullname, path, smsl, target): loader = loader_class(fullname, path) return spec_from_file_location(fullname, path, loader=loader, @@ -1675,7 +1660,7 @@ def _fill_cache(self): for item in contents: name, dot, suffix = item.partition('.') if dot: - new_name = '{}.{}'.format(name, suffix.lower()) + new_name = f'{name}.{suffix.lower()}' else: new_name = name lower_suffix_contents.add(new_name) @@ -1702,7 +1687,7 @@ def path_hook_for_FileFinder(path): return path_hook_for_FileFinder def __repr__(self): - return 'FileFinder({!r})'.format(self.path) + return f'FileFinder({self.path!r})' # Import setup ############################################################### @@ -1720,6 +1705,8 @@ def _fix_up_module(ns, name, pathname, cpathname=None): loader = SourceFileLoader(name, pathname) if not spec: spec = spec_from_file_location(name, pathname, loader=loader) + if cpathname: + spec.cached = _path_abspath(cpathname) try: ns['__spec__'] = spec ns['__loader__'] = loader diff --git a/Lib/importlib/abc.py b/Lib/importlib/abc.py index 3fa151f390..b56fa94eb9 100644 --- a/Lib/importlib/abc.py +++ b/Lib/importlib/abc.py @@ -15,20 +15,29 @@ import abc import warnings -# for compatibility with Python 3.10 -from .resources.abc import ResourceReader, Traversable, TraversableResources +from .resources import abc as _resources_abc __all__ = [ - 'Loader', 'Finder', 'MetaPathFinder', 'PathEntryFinder', + 'Loader', 'MetaPathFinder', 'PathEntryFinder', 'ResourceLoader', 'InspectLoader', 'ExecutionLoader', 'FileLoader', 'SourceLoader', - - # for compatibility with Python 3.10 - 'ResourceReader', 'Traversable', 'TraversableResources', ] +def __getattr__(name): + """ + For backwards compatibility, continue to make names + from _resources_abc available through this module. #93963 + """ + if name in _resources_abc.__all__: + obj = getattr(_resources_abc, name) + warnings._deprecated(f"{__name__}.{name}", remove=(3, 14)) + globals()[name] = obj + return obj + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') + + def _register(abstract_cls, *classes): for cls in classes: abstract_cls.register(cls) @@ -40,38 +49,6 @@ def _register(abstract_cls, *classes): abstract_cls.register(frozen_cls) -class Finder(metaclass=abc.ABCMeta): - - """Legacy abstract base class for import finders. - - It may be subclassed for compatibility with legacy third party - reimplementations of the import system. Otherwise, finder - implementations should derive from the more specific MetaPathFinder - or PathEntryFinder ABCs. - - Deprecated since Python 3.3 - """ - - def __init__(self): - warnings.warn("the Finder ABC is deprecated and " - "slated for removal in Python 3.12; use MetaPathFinder " - "or PathEntryFinder instead", - DeprecationWarning) - - @abc.abstractmethod - def find_module(self, fullname, path=None): - """An abstract method that should find a module. - The fullname is a str and the optional path is a str or None. - Returns a Loader object or None. - """ - warnings.warn("importlib.abc.Finder along with its find_module() " - "method are deprecated and " - "slated for removal in Python 3.12; use " - "MetaPathFinder.find_spec() or " - "PathEntryFinder.find_spec() instead", - DeprecationWarning) - - class MetaPathFinder(metaclass=abc.ABCMeta): """Abstract base class for import finders on sys.meta_path.""" @@ -79,27 +56,6 @@ class MetaPathFinder(metaclass=abc.ABCMeta): # We don't define find_spec() here since that would break # hasattr checks we do to support backward compatibility. - def find_module(self, fullname, path): - """Return a loader for the module. - - If no module is found, return None. The fullname is a str and - the path is a list of strings or None. - - This method is deprecated since Python 3.4 in favor of - finder.find_spec(). If find_spec() exists then backwards-compatible - functionality is provided for this method. - - """ - warnings.warn("MetaPathFinder.find_module() is deprecated since Python " - "3.4 in favor of MetaPathFinder.find_spec() and is " - "slated for removal in Python 3.12", - DeprecationWarning, - stacklevel=2) - if not hasattr(self, 'find_spec'): - return None - found = self.find_spec(fullname, path) - return found.loader if found is not None else None - def invalidate_caches(self): """An optional method for clearing the finder's cache, if any. This method is used by importlib.invalidate_caches(). @@ -113,43 +69,6 @@ class PathEntryFinder(metaclass=abc.ABCMeta): """Abstract base class for path entry finders used by PathFinder.""" - # We don't define find_spec() here since that would break - # hasattr checks we do to support backward compatibility. - - def find_loader(self, fullname): - """Return (loader, namespace portion) for the path entry. - - The fullname is a str. The namespace portion is a sequence of - path entries contributing to part of a namespace package. The - sequence may be empty. If loader is not None, the portion will - be ignored. - - The portion will be discarded if another path entry finder - locates the module as a normal module or package. - - This method is deprecated since Python 3.4 in favor of - finder.find_spec(). If find_spec() is provided than backwards-compatible - functionality is provided. - """ - warnings.warn("PathEntryFinder.find_loader() is deprecated since Python " - "3.4 in favor of PathEntryFinder.find_spec() " - "(available since 3.4)", - DeprecationWarning, - stacklevel=2) - if not hasattr(self, 'find_spec'): - return None, [] - found = self.find_spec(fullname) - if found is not None: - if not found.submodule_search_locations: - portions = [] - else: - portions = found.submodule_search_locations - return found.loader, portions - else: - return None, [] - - find_module = _bootstrap_external._find_module_shim - def invalidate_caches(self): """An optional method for clearing the finder's cache, if any. This method is used by PathFinder.invalidate_caches(). diff --git a/Lib/importlib/metadata/__init__.py b/Lib/importlib/metadata/__init__.py index 68828269fc..56ee403832 100644 --- a/Lib/importlib/metadata/__init__.py +++ b/Lib/importlib/metadata/__init__.py @@ -12,7 +12,9 @@ import functools import itertools import posixpath +import contextlib import collections +import inspect from . import _adapters, _meta from ._collections import FreezableDefaultDict, Pair @@ -24,7 +26,7 @@ from importlib import import_module from importlib.abc import MetaPathFinder from itertools import starmap -from typing import List, Mapping, Optional, Union +from typing import List, Mapping, Optional, cast __all__ = [ @@ -140,6 +142,7 @@ class DeprecatedTuple: 1 """ + # Do not remove prior to 2023-05-01 or Python 3.13 _warn = functools.partial( warnings.warn, "EntryPoint tuple interface is deprecated. Access members by name.", @@ -228,17 +231,6 @@ def _for(self, dist): vars(self).update(dist=dist) return self - def __iter__(self): - """ - Supply iter so one may construct dicts of EntryPoints by name. - """ - msg = ( - "Construction of dict of EntryPoints is deprecated in " - "favor of EntryPoints." - ) - warnings.warn(msg, DeprecationWarning) - return iter((self.name, self)) - def matches(self, **params): """ EntryPoint matches the given parameters. @@ -284,77 +276,7 @@ def __hash__(self): return hash(self._key()) -class DeprecatedList(list): - """ - Allow an otherwise immutable object to implement mutability - for compatibility. - - >>> recwarn = getfixture('recwarn') - >>> dl = DeprecatedList(range(3)) - >>> dl[0] = 1 - >>> dl.append(3) - >>> del dl[3] - >>> dl.reverse() - >>> dl.sort() - >>> dl.extend([4]) - >>> dl.pop(-1) - 4 - >>> dl.remove(1) - >>> dl += [5] - >>> dl + [6] - [1, 2, 5, 6] - >>> dl + (6,) - [1, 2, 5, 6] - >>> dl.insert(0, 0) - >>> dl - [0, 1, 2, 5] - >>> dl == [0, 1, 2, 5] - True - >>> dl == (0, 1, 2, 5) - True - >>> len(recwarn) - 1 - """ - - __slots__ = () - - _warn = functools.partial( - warnings.warn, - "EntryPoints list interface is deprecated. Cast to list if needed.", - DeprecationWarning, - stacklevel=2, - ) - - def _wrap_deprecated_method(method_name: str): # type: ignore - def wrapped(self, *args, **kwargs): - self._warn() - return getattr(super(), method_name)(*args, **kwargs) - - return method_name, wrapped - - locals().update( - map( - _wrap_deprecated_method, - '__setitem__ __delitem__ append reverse extend pop remove ' - '__iadd__ insert sort'.split(), - ) - ) - - def __add__(self, other): - if not isinstance(other, tuple): - self._warn() - other = tuple(other) - return self.__class__(tuple(self) + other) - - def __eq__(self, other): - if not isinstance(other, tuple): - self._warn() - other = tuple(other) - - return tuple(self).__eq__(other) - - -class EntryPoints(DeprecatedList): +class EntryPoints(tuple): """ An immutable collection of selectable EntryPoint objects. """ @@ -365,14 +287,6 @@ def __getitem__(self, name): # -> EntryPoint: """ Get the EntryPoint in self matching name. """ - if isinstance(name, int): - warnings.warn( - "Accessing entry points by index is deprecated. " - "Cast to tuple if needed.", - DeprecationWarning, - stacklevel=2, - ) - return super().__getitem__(name) try: return next(iter(self.select(name=name))) except StopIteration: @@ -396,10 +310,6 @@ def names(self): def groups(self): """ Return the set of all groups of all entry points. - - For coverage while SelectableGroups is present. - >>> EntryPoints().groups - set() """ return {ep.group for ep in self} @@ -415,101 +325,6 @@ def _from_text(text): ) -class Deprecated: - """ - Compatibility add-in for mapping to indicate that - mapping behavior is deprecated. - - >>> recwarn = getfixture('recwarn') - >>> class DeprecatedDict(Deprecated, dict): pass - >>> dd = DeprecatedDict(foo='bar') - >>> dd.get('baz', None) - >>> dd['foo'] - 'bar' - >>> list(dd) - ['foo'] - >>> list(dd.keys()) - ['foo'] - >>> 'foo' in dd - True - >>> list(dd.values()) - ['bar'] - >>> len(recwarn) - 1 - """ - - _warn = functools.partial( - warnings.warn, - "SelectableGroups dict interface is deprecated. Use select.", - DeprecationWarning, - stacklevel=2, - ) - - def __getitem__(self, name): - self._warn() - return super().__getitem__(name) - - def get(self, name, default=None): - self._warn() - return super().get(name, default) - - def __iter__(self): - self._warn() - return super().__iter__() - - def __contains__(self, *args): - self._warn() - return super().__contains__(*args) - - def keys(self): - self._warn() - return super().keys() - - def values(self): - self._warn() - return super().values() - - -class SelectableGroups(Deprecated, dict): - """ - A backward- and forward-compatible result from - entry_points that fully implements the dict interface. - """ - - @classmethod - def load(cls, eps): - by_group = operator.attrgetter('group') - ordered = sorted(eps, key=by_group) - grouped = itertools.groupby(ordered, by_group) - return cls((group, EntryPoints(eps)) for group, eps in grouped) - - @property - def _all(self): - """ - Reconstruct a list of all entrypoints from the groups. - """ - groups = super(Deprecated, self).values() - return EntryPoints(itertools.chain.from_iterable(groups)) - - @property - def groups(self): - return self._all.groups - - @property - def names(self): - """ - for coverage: - >>> SelectableGroups().names - set() - """ - return self._all.names - - def select(self, **params): - if not params: - return self - return self._all.select(**params) - - class PackagePath(pathlib.PurePosixPath): """A reference to a path in a package""" @@ -534,11 +349,30 @@ def __repr__(self): return f'' -class Distribution: +class DeprecatedNonAbstract: + def __new__(cls, *args, **kwargs): + all_names = { + name for subclass in inspect.getmro(cls) for name in vars(subclass) + } + abstract = { + name + for name in all_names + if getattr(getattr(cls, name), '__isabstractmethod__', False) + } + if abstract: + warnings.warn( + f"Unimplemented abstract methods {abstract}", + DeprecationWarning, + stacklevel=2, + ) + return super().__new__(cls) + + +class Distribution(DeprecatedNonAbstract): """A Python distribution package.""" @abc.abstractmethod - def read_text(self, filename): + def read_text(self, filename) -> Optional[str]: """Attempt to load metadata file given by the name. :param filename: The name of the file in the distribution info. @@ -612,7 +446,7 @@ def metadata(self) -> _meta.PackageMetadata: The returned object will have keys that name the various bits of metadata. See PEP 566 for details. """ - text = ( + opt_text = ( self.read_text('METADATA') or self.read_text('PKG-INFO') # This last clause is here to support old egg-info files. Its @@ -620,6 +454,7 @@ def metadata(self) -> _meta.PackageMetadata: # (which points to the egg-info file) attribute unchanged. or self.read_text('') ) + text = cast(str, opt_text) return _adapters.Message(email.message_from_string(text)) @property @@ -648,8 +483,8 @@ def files(self): :return: List of PackagePath for this distribution or None Result is `None` if the metadata file that enumerates files - (i.e. RECORD for dist-info or SOURCES.txt for egg-info) is - missing. + (i.e. RECORD for dist-info, or installed-files.txt or + SOURCES.txt for egg-info) is missing. Result may be empty if the metadata exists but is empty. """ @@ -662,9 +497,19 @@ def make_file(name, hash=None, size_str=None): @pass_none def make_files(lines): - return list(starmap(make_file, csv.reader(lines))) + return starmap(make_file, csv.reader(lines)) - return make_files(self._read_files_distinfo() or self._read_files_egginfo()) + @pass_none + def skip_missing_files(package_paths): + return list(filter(lambda path: path.locate().exists(), package_paths)) + + return skip_missing_files( + make_files( + self._read_files_distinfo() + or self._read_files_egginfo_installed() + or self._read_files_egginfo_sources() + ) + ) def _read_files_distinfo(self): """ @@ -673,10 +518,45 @@ def _read_files_distinfo(self): text = self.read_text('RECORD') return text and text.splitlines() - def _read_files_egginfo(self): + def _read_files_egginfo_installed(self): + """ + Read installed-files.txt and return lines in a similar + CSV-parsable format as RECORD: each file must be placed + relative to the site-packages directory and must also be + quoted (since file names can contain literal commas). + + This file is written when the package is installed by pip, + but it might not be written for other installation methods. + Assume the file is accurate if it exists. """ - SOURCES.txt might contain literal commas, so wrap each line - in quotes. + text = self.read_text('installed-files.txt') + # Prepend the .egg-info/ subdir to the lines in this file. + # But this subdir is only available from PathDistribution's + # self._path. + subdir = getattr(self, '_path', None) + if not text or not subdir: + return + + paths = ( + (subdir / name) + .resolve() + .relative_to(self.locate_file('').resolve()) + .as_posix() + for name in text.splitlines() + ) + return map('"{}"'.format, paths) + + def _read_files_egginfo_sources(self): + """ + Read SOURCES.txt and return lines in a similar CSV-parsable + format as RECORD: each file name must be quoted (since it + might contain literal commas). + + Note that SOURCES.txt is not a reliable source for what + files are installed by a package. This file is generated + for a source archive, and the files that are present + there (e.g. setup.py) may not correctly reflect the files + that are present after the package has been installed. """ text = self.read_text('SOURCES.txt') return text and map('"{}"'.format, text.splitlines()) @@ -1023,27 +903,19 @@ def version(distribution_name): """ -def entry_points(**params) -> Union[EntryPoints, SelectableGroups]: +def entry_points(**params) -> EntryPoints: """Return EntryPoint objects for all installed packages. Pass selection parameters (group or name) to filter the result to entry points matching those properties (see EntryPoints.select()). - For compatibility, returns ``SelectableGroups`` object unless - selection parameters are supplied. In the future, this function - will return ``EntryPoints`` instead of ``SelectableGroups`` - even when no selection parameters are supplied. - - For maximum future compatibility, pass selection parameters - or invoke ``.select`` with parameters on the result. - - :return: EntryPoints or SelectableGroups for all installed packages. + :return: EntryPoints for all installed packages. """ eps = itertools.chain.from_iterable( dist.entry_points for dist in _unique(distributions()) ) - return SelectableGroups.load(eps).select(**params) + return EntryPoints(eps).select(**params) def files(distribution_name): @@ -1087,8 +959,13 @@ def _top_level_declared(dist): def _top_level_inferred(dist): - return { - f.parts[0] if len(f.parts) > 1 else f.with_suffix('').name + opt_names = { + f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in always_iterable(dist.files) - if f.suffix == ".py" } + + @pass_none + def importable_name(name): + return '.' not in name + + return filter(importable_name, opt_names) diff --git a/Lib/importlib/metadata/_adapters.py b/Lib/importlib/metadata/_adapters.py index aa460d3eda..6aed69a308 100644 --- a/Lib/importlib/metadata/_adapters.py +++ b/Lib/importlib/metadata/_adapters.py @@ -1,3 +1,5 @@ +import functools +import warnings import re import textwrap import email.message @@ -5,6 +7,15 @@ from ._text import FoldedCase +# Do not remove prior to 2024-01-01 or Python 3.14 +_warn = functools.partial( + warnings.warn, + "Implicit None on return values is deprecated and will raise KeyErrors.", + DeprecationWarning, + stacklevel=2, +) + + class Message(email.message.Message): multiple_use_keys = set( map( @@ -39,6 +50,16 @@ def __init__(self, *args, **kwargs): def __iter__(self): return super().__iter__() + def __getitem__(self, item): + """ + Warn users that a ``KeyError`` can be expected when a + mising key is supplied. Ref python/importlib_metadata#371. + """ + res = super().__getitem__(item) + if res is None: + _warn() + return res + def _repair_headers(self): def redent(value): "Correct for RFC822 indentation" diff --git a/Lib/importlib/metadata/_meta.py b/Lib/importlib/metadata/_meta.py index d5c0576194..c9a7ef906a 100644 --- a/Lib/importlib/metadata/_meta.py +++ b/Lib/importlib/metadata/_meta.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Iterator, List, Protocol, TypeVar, Union +from typing import Protocol +from typing import Any, Dict, Iterator, List, Optional, TypeVar, Union, overload _T = TypeVar("_T") @@ -17,7 +18,21 @@ def __getitem__(self, key: str) -> str: def __iter__(self) -> Iterator[str]: ... # pragma: no cover - def get_all(self, name: str, failobj: _T = ...) -> Union[List[Any], _T]: + @overload + def get(self, name: str, failobj: None = None) -> Optional[str]: + ... # pragma: no cover + + @overload + def get(self, name: str, failobj: _T) -> Union[str, _T]: + ... # pragma: no cover + + # overload per python/importlib_metadata#435 + @overload + def get_all(self, name: str, failobj: None = None) -> Optional[List[Any]]: + ... # pragma: no cover + + @overload + def get_all(self, name: str, failobj: _T) -> Union[List[Any], _T]: """ Return all values associated with a possibly multi-valued key. """ @@ -29,18 +44,19 @@ def json(self) -> Dict[str, Union[str, List[str]]]: """ -class SimplePath(Protocol): +class SimplePath(Protocol[_T]): """ A minimal subset of pathlib.Path required by PathDistribution. """ - def joinpath(self) -> 'SimplePath': + def joinpath(self) -> _T: ... # pragma: no cover - def __truediv__(self) -> 'SimplePath': + def __truediv__(self, other: Union[str, _T]) -> _T: ... # pragma: no cover - def parent(self) -> 'SimplePath': + @property + def parent(self) -> _T: ... # pragma: no cover def read_text(self) -> str: diff --git a/Lib/importlib/resources/_adapters.py b/Lib/importlib/resources/_adapters.py index ea363d86a5..50688fbb66 100644 --- a/Lib/importlib/resources/_adapters.py +++ b/Lib/importlib/resources/_adapters.py @@ -34,9 +34,7 @@ def _io_wrapper(file, mode='r', *args, **kwargs): return TextIOWrapper(file, *args, **kwargs) elif mode == 'rb': return file - raise ValueError( - "Invalid mode value '{}', only 'r' and 'rb' are supported".format(mode) - ) + raise ValueError(f"Invalid mode value '{mode}', only 'r' and 'rb' are supported") class CompatibilityFiles: diff --git a/Lib/importlib/resources/_common.py b/Lib/importlib/resources/_common.py index ca1fa8ab2f..b402e05116 100644 --- a/Lib/importlib/resources/_common.py +++ b/Lib/importlib/resources/_common.py @@ -5,25 +5,58 @@ import contextlib import types import importlib +import inspect +import warnings +import itertools -from typing import Union, Optional +from typing import Union, Optional, cast from .abc import ResourceReader, Traversable from ._adapters import wrap_spec Package = Union[types.ModuleType, str] +Anchor = Package -def files(package): - # type: (Package) -> Traversable +def package_to_anchor(func): """ - Get a Traversable resource from a package + Replace 'package' parameter as 'anchor' and warn about the change. + + Other errors should fall through. + + >>> files('a', 'b') + Traceback (most recent call last): + TypeError: files() takes from 0 to 1 positional arguments but 2 were given + """ + undefined = object() + + @functools.wraps(func) + def wrapper(anchor=undefined, package=undefined): + if package is not undefined: + if anchor is not undefined: + return func(anchor, package) + warnings.warn( + "First parameter to files is renamed to 'anchor'", + DeprecationWarning, + stacklevel=2, + ) + return func(package) + elif anchor is undefined: + return func() + return func(anchor) + + return wrapper + + +@package_to_anchor +def files(anchor: Optional[Anchor] = None) -> Traversable: + """ + Get a Traversable resource for an anchor. """ - return from_package(get_package(package)) + return from_package(resolve(anchor)) -def get_resource_reader(package): - # type: (types.ModuleType) -> Optional[ResourceReader] +def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]: """ Return the package's loader if it's a ResourceReader. """ @@ -39,24 +72,39 @@ def get_resource_reader(package): return reader(spec.name) # type: ignore -def resolve(cand): - # type: (Package) -> types.ModuleType - return cand if isinstance(cand, types.ModuleType) else importlib.import_module(cand) +@functools.singledispatch +def resolve(cand: Optional[Anchor]) -> types.ModuleType: + return cast(types.ModuleType, cand) + + +@resolve.register(str) # TODO: RUSTPYTHON; manual type annotation +def _(cand: str) -> types.ModuleType: + return importlib.import_module(cand) + +@resolve.register(type(None)) # TODO: RUSTPYTHON; manual type annotation +def _(cand: None) -> types.ModuleType: + return resolve(_infer_caller().f_globals['__name__']) -def get_package(package): - # type: (Package) -> types.ModuleType - """Take a package name or module object and return the module. - Raise an exception if the resolved module is not a package. +def _infer_caller(): """ - resolved = resolve(package) - if wrap_spec(resolved).submodule_search_locations is None: - raise TypeError(f'{package!r} is not a package') - return resolved + Walk the stack and find the frame of the first caller not in this module. + """ + + def is_this_file(frame_info): + return frame_info.filename == __file__ + + def is_wrapper(frame_info): + return frame_info.function == 'wrapper' + + not_this_file = itertools.filterfalse(is_this_file, inspect.stack()) + # also exclude 'wrapper' due to singledispatch in the call stack + callers = itertools.filterfalse(is_wrapper, not_this_file) + return next(callers).frame -def from_package(package): +def from_package(package: types.ModuleType): """ Return a Traversable object for the given package. @@ -67,10 +115,14 @@ def from_package(package): @contextlib.contextmanager -def _tempfile(reader, suffix='', - # gh-93353: Keep a reference to call os.remove() in late Python - # finalization. - *, _os_remove=os.remove): +def _tempfile( + reader, + suffix='', + # gh-93353: Keep a reference to call os.remove() in late Python + # finalization. + *, + _os_remove=os.remove, +): # Not using tempfile.NamedTemporaryFile as it leads to deeper 'try' # blocks due to the need to close the temporary file to work on Windows # properly. @@ -89,13 +141,30 @@ def _tempfile(reader, suffix='', pass +def _temp_file(path): + return _tempfile(path.read_bytes, suffix=path.name) + + +def _is_present_dir(path: Traversable) -> bool: + """ + Some Traversables implement ``is_dir()`` to raise an + exception (i.e. ``FileNotFoundError``) when the + directory doesn't exist. This function wraps that call + to always return a boolean and only return True + if there's a dir and it exists. + """ + with contextlib.suppress(FileNotFoundError): + return path.is_dir() + return False + + @functools.singledispatch def as_file(path): """ Given a Traversable object, return that object as a path on the local file system in a context manager. """ - return _tempfile(path.read_bytes, suffix=path.name) + return _temp_dir(path) if _is_present_dir(path) else _temp_file(path) @as_file.register(pathlib.Path) @@ -105,3 +174,34 @@ def _(path): Degenerate behavior for pathlib.Path objects. """ yield path + + +@contextlib.contextmanager +def _temp_path(dir: tempfile.TemporaryDirectory): + """ + Wrap tempfile.TemporyDirectory to return a pathlib object. + """ + with dir as result: + yield pathlib.Path(result) + + +@contextlib.contextmanager +def _temp_dir(path): + """ + Given a traversable dir, recursively replicate the whole tree + to the file system in a context manager. + """ + assert path.is_dir() + with _temp_path(tempfile.TemporaryDirectory()) as temp_dir: + yield _write_contents(temp_dir, path) + + +def _write_contents(target, source): + child = target.joinpath(source.name) + if source.is_dir(): + child.mkdir() + for item in source.iterdir(): + _write_contents(child, item) + else: + child.write_bytes(source.read_bytes()) + return child diff --git a/Lib/importlib/resources/_itertools.py b/Lib/importlib/resources/_itertools.py index cce05582ff..7b775ef5ae 100644 --- a/Lib/importlib/resources/_itertools.py +++ b/Lib/importlib/resources/_itertools.py @@ -1,35 +1,38 @@ -from itertools import filterfalse +# from more_itertools 9.0 +def only(iterable, default=None, too_long=None): + """If *iterable* has only one item, return it. + If it has zero items, return *default*. + If it has more than one item, raise the exception given by *too_long*, + which is ``ValueError`` by default. + >>> only([], default='missing') + 'missing' + >>> only([1]) + 1 + >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Expected exactly one item in iterable, but got 1, 2, + and perhaps more.' + >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError + Note that :func:`only` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check + iterable contents less destructively. + """ + it = iter(iterable) + first_value = next(it, default) -from typing import ( - Callable, - Iterable, - Iterator, - Optional, - Set, - TypeVar, - Union, -) - -# Type and type variable definitions -_T = TypeVar('_T') -_U = TypeVar('_U') - - -def unique_everseen( - iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = None -) -> Iterator[_T]: - "List unique elements, preserving order. Remember all elements ever seen." - # unique_everseen('AAAABBBCCDAABBB') --> A B C D - # unique_everseen('ABBCcAD', str.lower) --> A B C D - seen: Set[Union[_T, _U]] = set() - seen_add = seen.add - if key is None: - for element in filterfalse(seen.__contains__, iterable): - seen_add(element) - yield element + try: + second_value = next(it) + except StopIteration: + pass else: - for element in iterable: - k = key(element) - if k not in seen: - seen_add(k) - yield element + msg = ( + 'Expected exactly one item in iterable, but got {!r}, {!r}, ' + 'and perhaps more.'.format(first_value, second_value) + ) + raise too_long or ValueError(msg) + + return first_value diff --git a/Lib/importlib/resources/_legacy.py b/Lib/importlib/resources/_legacy.py index 1d5d3f1fbb..b1ea8105da 100644 --- a/Lib/importlib/resources/_legacy.py +++ b/Lib/importlib/resources/_legacy.py @@ -27,8 +27,7 @@ def wrapper(*args, **kwargs): return wrapper -def normalize_path(path): - # type: (Any) -> str +def normalize_path(path: Any) -> str: """Normalize a path by ensuring it is a string. If the resulting string contains path separators, an exception is raised. diff --git a/Lib/importlib/resources/abc.py b/Lib/importlib/resources/abc.py index 0b7bfdc415..6750a7aaf1 100644 --- a/Lib/importlib/resources/abc.py +++ b/Lib/importlib/resources/abc.py @@ -1,6 +1,8 @@ import abc import io +import itertools import os +import pathlib from typing import Any, BinaryIO, Iterable, Iterator, NoReturn, Text, Optional from typing import runtime_checkable, Protocol from typing import Union @@ -53,6 +55,10 @@ def contents(self) -> Iterable[str]: raise FileNotFoundError +class TraversalError(Exception): + pass + + @runtime_checkable class Traversable(Protocol): """ @@ -95,7 +101,6 @@ def is_file(self) -> bool: Return True if self is a file """ - @abc.abstractmethod def joinpath(self, *descendants: StrPath) -> "Traversable": """ Return Traversable resolved with any descendants applied. @@ -104,6 +109,22 @@ def joinpath(self, *descendants: StrPath) -> "Traversable": and each may contain multiple levels separated by ``posixpath.sep`` (``/``). """ + if not descendants: + return self + names = itertools.chain.from_iterable( + path.parts for path in map(pathlib.PurePosixPath, descendants) + ) + target = next(names) + matches = ( + traversable for traversable in self.iterdir() if traversable.name == target + ) + try: + match = next(matches) + except StopIteration: + raise TraversalError( + "Target not found during traversal.", target, list(names) + ) + return match.joinpath(*names) def __truediv__(self, child: StrPath) -> "Traversable": """ @@ -121,7 +142,8 @@ def open(self, mode='r', *args, **kwargs): accepted by io.TextIOWrapper. """ - @abc.abstractproperty + @property + @abc.abstractmethod def name(self) -> str: """ The base name of this object without any parent references. diff --git a/Lib/importlib/resources/readers.py b/Lib/importlib/resources/readers.py index b470a2062b..c3cdf769cb 100644 --- a/Lib/importlib/resources/readers.py +++ b/Lib/importlib/resources/readers.py @@ -1,11 +1,12 @@ import collections -import operator +import itertools import pathlib +import operator import zipfile from . import abc -from ._itertools import unique_everseen +from ._itertools import only def remove_duplicates(items): @@ -41,8 +42,10 @@ def open_resource(self, resource): raise FileNotFoundError(exc.args[0]) def is_resource(self, path): - # workaround for `zipfile.Path.is_file` returning true - # for non-existent paths. + """ + Workaround for `zipfile.Path.is_file` returning true + for non-existent paths. + """ target = self.files().joinpath(path) return target.is_file() and target.exists() @@ -67,8 +70,10 @@ def __init__(self, *paths): raise NotADirectoryError('MultiplexedPath only supports directories') def iterdir(self): - files = (file for path in self._paths for file in path.iterdir()) - return unique_everseen(files, key=operator.attrgetter('name')) + children = (child for path in self._paths for child in path.iterdir()) + by_name = operator.attrgetter('name') + groups = itertools.groupby(sorted(children, key=by_name), key=by_name) + return map(self._follow, (locs for name, locs in groups)) def read_bytes(self): raise FileNotFoundError(f'{self} is not a file') @@ -82,15 +87,32 @@ def is_dir(self): def is_file(self): return False - def joinpath(self, child): - # first try to find child in current paths - for file in self.iterdir(): - if file.name == child: - return file - # if it does not exist, construct it with the first path - return self._paths[0] / child + def joinpath(self, *descendants): + try: + return super().joinpath(*descendants) + except abc.TraversalError: + # One of the paths did not resolve (a directory does not exist). + # Just return something that will not exist. + return self._paths[0].joinpath(*descendants) + + @classmethod + def _follow(cls, children): + """ + Construct a MultiplexedPath if needed. + + If children contains a sole element, return it. + Otherwise, return a MultiplexedPath of the items. + Unless one of the items is not a Directory, then return the first. + """ + subdirs, one_dir, one_file = itertools.tee(children, 3) - __truediv__ = joinpath + try: + return only(one_dir) + except ValueError: + try: + return cls(*subdirs) + except NotADirectoryError: + return next(one_file) def open(self, *args, **kwargs): raise FileNotFoundError(f'{self} is not a file') diff --git a/Lib/importlib/resources/simple.py b/Lib/importlib/resources/simple.py index d0fbf23776..7770c922c8 100644 --- a/Lib/importlib/resources/simple.py +++ b/Lib/importlib/resources/simple.py @@ -16,31 +16,28 @@ class SimpleReader(abc.ABC): provider. """ - @abc.abstractproperty - def package(self): - # type: () -> str + @property + @abc.abstractmethod + def package(self) -> str: """ The name of the package for which this reader loads resources. """ @abc.abstractmethod - def children(self): - # type: () -> List['SimpleReader'] + def children(self) -> List['SimpleReader']: """ Obtain an iterable of SimpleReader for available child containers (e.g. directories). """ @abc.abstractmethod - def resources(self): - # type: () -> List[str] + def resources(self) -> List[str]: """ Obtain available named resources for this virtual package. """ @abc.abstractmethod - def open_binary(self, resource): - # type: (str) -> BinaryIO + def open_binary(self, resource: str) -> BinaryIO: """ Obtain a File-like for a named resource. """ @@ -50,13 +47,35 @@ def name(self): return self.package.split('.')[-1] +class ResourceContainer(Traversable): + """ + Traversable container for a package's resources via its reader. + """ + + def __init__(self, reader: SimpleReader): + self.reader = reader + + def is_dir(self): + return True + + def is_file(self): + return False + + def iterdir(self): + files = (ResourceHandle(self, name) for name in self.reader.resources) + dirs = map(ResourceContainer, self.reader.children()) + return itertools.chain(files, dirs) + + def open(self, *args, **kwargs): + raise IsADirectoryError() + + class ResourceHandle(Traversable): """ Handle to a named resource in a ResourceReader. """ - def __init__(self, parent, name): - # type: (ResourceContainer, str) -> None + def __init__(self, parent: ResourceContainer, name: str): self.parent = parent self.name = name # type: ignore @@ -76,44 +95,6 @@ def joinpath(self, name): raise RuntimeError("Cannot traverse into a resource") -class ResourceContainer(Traversable): - """ - Traversable container for a package's resources via its reader. - """ - - def __init__(self, reader): - # type: (SimpleReader) -> None - self.reader = reader - - def is_dir(self): - return True - - def is_file(self): - return False - - def iterdir(self): - files = (ResourceHandle(self, name) for name in self.reader.resources) - dirs = map(ResourceContainer, self.reader.children()) - return itertools.chain(files, dirs) - - def open(self, *args, **kwargs): - raise IsADirectoryError() - - @staticmethod - def _flatten(compound_names): - for name in compound_names: - yield from name.split('/') - - def joinpath(self, *descendants): - if not descendants: - return self - names = self._flatten(descendants) - target = next(names) - return next( - traversable for traversable in self.iterdir() if traversable.name == target - ).joinpath(*names) - - class TraversableReader(TraversableResources, SimpleReader): """ A TraversableResources based on SimpleReader. Resource providers diff --git a/Lib/importlib/util.py b/Lib/importlib/util.py index 8623c89840..f4d6e82331 100644 --- a/Lib/importlib/util.py +++ b/Lib/importlib/util.py @@ -11,12 +11,9 @@ from ._bootstrap_external import source_from_cache from ._bootstrap_external import spec_from_file_location -from contextlib import contextmanager import _imp -import functools import sys import types -import warnings def source_hash(source_bytes): @@ -63,10 +60,10 @@ def _find_spec_from_path(name, path=None): try: spec = module.__spec__ except AttributeError: - raise ValueError('{}.__spec__ is not set'.format(name)) from None + raise ValueError(f'{name}.__spec__ is not set') from None else: if spec is None: - raise ValueError('{}.__spec__ is None'.format(name)) + raise ValueError(f'{name}.__spec__ is None') return spec @@ -108,115 +105,64 @@ def find_spec(name, package=None): try: spec = module.__spec__ except AttributeError: - raise ValueError('{}.__spec__ is not set'.format(name)) from None + raise ValueError(f'{name}.__spec__ is not set') from None else: if spec is None: - raise ValueError('{}.__spec__ is None'.format(name)) + raise ValueError(f'{name}.__spec__ is None') return spec -@contextmanager -def _module_to_load(name): - is_reload = name in sys.modules - - module = sys.modules.get(name) - if not is_reload: - # This must be done before open() is called as the 'io' module - # implicitly imports 'locale' and would otherwise trigger an - # infinite loop. - module = type(sys)(name) - # This must be done before putting the module in sys.modules - # (otherwise an optimization shortcut in import.c becomes wrong) - module.__initializing__ = True - sys.modules[name] = module - try: - yield module - except Exception: - if not is_reload: - try: - del sys.modules[name] - except KeyError: - pass - finally: - module.__initializing__ = False +# Normally we would use contextlib.contextmanager. However, this module +# is imported by runpy, which means we want to avoid any unnecessary +# dependencies. Thus we use a class. +class _incompatible_extension_module_restrictions: + """A context manager that can temporarily skip the compatibility check. -def set_package(fxn): - """Set __package__ on the returned module. + NOTE: This function is meant to accommodate an unusual case; one + which is likely to eventually go away. There's is a pretty good + chance this is not what you were looking for. - This function is deprecated. + WARNING: Using this function to disable the check can lead to + unexpected behavior and even crashes. It should only be used during + extension module development. - """ - @functools.wraps(fxn) - def set_package_wrapper(*args, **kwargs): - warnings.warn('The import system now takes care of this automatically; ' - 'this decorator is slated for removal in Python 3.12', - DeprecationWarning, stacklevel=2) - module = fxn(*args, **kwargs) - if getattr(module, '__package__', None) is None: - module.__package__ = module.__name__ - if not hasattr(module, '__path__'): - module.__package__ = module.__package__.rpartition('.')[0] - return module - return set_package_wrapper + If "disable_check" is True then the compatibility check will not + happen while the context manager is active. Otherwise the check + *will* happen. + Normally, extensions that do not support multiple interpreters + may not be imported in a subinterpreter. That implies modules + that do not implement multi-phase init or that explicitly of out. -def set_loader(fxn): - """Set __loader__ on the returned module. + Likewise for modules import in a subinterpeter with its own GIL + when the extension does not support a per-interpreter GIL. This + implies the module does not have a Py_mod_multiple_interpreters slot + set to Py_MOD_PER_INTERPRETER_GIL_SUPPORTED. - This function is deprecated. + In both cases, this context manager may be used to temporarily + disable the check for compatible extension modules. + You can get the same effect as this function by implementing the + basic interface of multi-phase init (PEP 489) and lying about + support for mulitple interpreters (or per-interpreter GIL). """ - @functools.wraps(fxn) - def set_loader_wrapper(self, *args, **kwargs): - warnings.warn('The import system now takes care of this automatically; ' - 'this decorator is slated for removal in Python 3.12', - DeprecationWarning, stacklevel=2) - module = fxn(self, *args, **kwargs) - if getattr(module, '__loader__', None) is None: - module.__loader__ = self - return module - return set_loader_wrapper - - -def module_for_loader(fxn): - """Decorator to handle selecting the proper module for loaders. - - The decorated function is passed the module to use instead of the module - name. The module passed in to the function is either from sys.modules if - it already exists or is a new module. If the module is new, then __name__ - is set the first argument to the method, __loader__ is set to self, and - __package__ is set accordingly (if self.is_package() is defined) will be set - before it is passed to the decorated function (if self.is_package() does - not work for the module it will be set post-load). - - If an exception is raised and the decorator created the module it is - subsequently removed from sys.modules. - - The decorator assumes that the decorated function takes the module name as - the second argument. - """ - warnings.warn('The import system now takes care of this automatically; ' - 'this decorator is slated for removal in Python 3.12', - DeprecationWarning, stacklevel=2) - @functools.wraps(fxn) - def module_for_loader_wrapper(self, fullname, *args, **kwargs): - with _module_to_load(fullname) as module: - module.__loader__ = self - try: - is_package = self.is_package(fullname) - except (ImportError, AttributeError): - pass - else: - if is_package: - module.__package__ = fullname - else: - module.__package__ = fullname.rpartition('.')[0] - # If __package__ was not set above, __import__() will do it later. - return fxn(self, module, *args, **kwargs) - - return module_for_loader_wrapper + def __init__(self, *, disable_check): + self.disable_check = bool(disable_check) + + def __enter__(self): + self.old = _imp._override_multi_interp_extensions_check(self.override) + return self + + def __exit__(self, *args): + old = self.old + del self.old + _imp._override_multi_interp_extensions_check(old) + + @property + def override(self): + return -1 if self.disable_check else 1 class _LazyModule(types.ModuleType): diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index e9736cd5ba..975ff21101 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -4,13 +4,16 @@ raise ImportError('support must be imported from the test package') import contextlib +import dataclasses import functools import getpass +import opcode import os import re import stat import sys import sysconfig +import textwrap import time import types import unittest @@ -19,11 +22,6 @@ from .testresult import get_test_runner -try: - from _testcapi import unicode_legacy_string -except ImportError: - unicode_legacy_string = None - __all__ = [ # globals "PIPE_MAX_SIZE", "verbose", "max_memuse", "use_resources", "failfast", @@ -36,7 +34,7 @@ "is_resource_enabled", "requires", "requires_freebsd_version", "requires_linux_version", "requires_mac_ver", "check_syntax_error", - "BasicTestRunner", "run_unittest", "run_doctest", + "run_unittest", "run_doctest", "requires_gzip", "requires_bz2", "requires_lzma", "bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute", "requires_IEEE_754", "requires_zlib", @@ -46,9 +44,12 @@ "anticipate_failure", "load_package_tests", "detect_api_mismatch", "check__all__", "skip_if_buggy_ucrt_strfptime", "check_disallow_instantiation", "check_sanitizer", "skip_if_sanitizer", + "requires_limited_api", "requires_specialization", # sys "is_jython", "is_android", "is_emscripten", "is_wasi", "check_impl_detail", "unix_shell", "setswitchinterval", + # os + "get_pagesize", # network "open_urlresource", # processes @@ -59,6 +60,8 @@ "run_with_tz", "PGO", "missing_compiler_executable", "ALWAYS_EQ", "NEVER_EQ", "LARGEST", "SMALLEST", "LOOPBACK_TIMEOUT", "INTERNET_TIMEOUT", "SHORT_TIMEOUT", "LONG_TIMEOUT", + "Py_DEBUG", "EXCEEDS_RECURSION_LIMIT", "C_RECURSION_LIMIT", + "skip_on_s390x", ] @@ -116,17 +119,20 @@ class Error(Exception): class TestFailed(Error): """Test failed.""" + def __init__(self, msg, *args, stats=None): + self.msg = msg + self.stats = stats + super().__init__(msg, *args) + + def __str__(self): + return self.msg class TestFailedWithDetails(TestFailed): """Test failed.""" - def __init__(self, msg, errors, failures): - self.msg = msg + def __init__(self, msg, errors, failures, stats): self.errors = errors self.failures = failures - super().__init__(msg, errors, failures) - - def __str__(self): - return self.msg + super().__init__(msg, errors, failures, stats=stats) class TestDidNotRun(Error): """Test did not run any subtests.""" @@ -408,7 +414,7 @@ def check_sanitizer(*, address=False, memory=False, ub=False): ) address_sanitizer = ( '-fsanitize=address' in _cflags or - '--with-memory-sanitizer' in _config_args + '--with-address-sanitizer' in _config_args ) ub_sanitizer = ( '-fsanitize=undefined' in _cflags or @@ -500,9 +506,16 @@ def has_no_debug_ranges(): def requires_debug_ranges(reason='requires co_positions / debug_ranges'): return unittest.skipIf(has_no_debug_ranges(), reason) -requires_legacy_unicode_capi = unittest.skipUnless(unicode_legacy_string, - 'requires legacy Unicode C API') +def requires_legacy_unicode_capi(): + try: + from _testcapi import unicode_legacy_string + except ImportError: + unicode_legacy_string = None + + return unittest.skipUnless(unicode_legacy_string, + 'requires legacy Unicode C API') +# Is not actually used in tests, but is kept for compatibility. is_jython = sys.platform.startswith('java') is_android = hasattr(sys, 'getandroidapilevel') @@ -578,7 +591,8 @@ def darwin_malloc_err_warning(test_name): msg = ' NOTICE ' detail = (f'{test_name} may generate "malloc can\'t allocate region"\n' 'warnings on macOS systems. This behavior is known. Do not\n' - 'report a bug unless tests are also failing. See bpo-40928.') + 'report a bug unless tests are also failing.\n' + 'See https://github.com/python/cpython/issues/85100') padding, _ = shutil.get_terminal_size() print(msg.center(padding, '-')) @@ -612,6 +626,14 @@ def sortdict(dict): withcommas = ", ".join(reprpairs) return "{%s}" % withcommas + +def run_code(code: str) -> dict[str, object]: + """Run a piece of code after dedenting it, and return its global namespace.""" + ns = {} + exec(textwrap.dedent(code), ns) + return ns + + def check_syntax_error(testcase, statement, errtext='', *, lineno=None, offset=None): with testcase.assertRaisesRegex(SyntaxError, errtext) as cm: compile(statement, '', 'exec') @@ -994,12 +1016,6 @@ def wrapper(self): #======================================================================= # unittest integration. -class BasicTestRunner: - def run(self, test): - result = unittest.TestResult() - test(result) - return result - def _id(obj): return obj @@ -1078,6 +1094,18 @@ def refcount_test(test): return no_tracing(cpython_only(test)) +def requires_limited_api(test): + try: + import _testcapi + except ImportError: + return unittest.skip('needs _testcapi module')(test) + return unittest.skipUnless( + _testcapi.LIMITED_API_AVAILABLE, 'needs Limited API support')(test) + +def requires_specialization(test): + return unittest.skipUnless( + opcode.ENABLE_SPECIALIZATION, "requires specialization")(test) + def _filter_suite(suite, pred): """Recursively filter test cases in a suite based on a predicate.""" newtests = [] @@ -1090,6 +1118,29 @@ def _filter_suite(suite, pred): newtests.append(test) suite._tests = newtests +@dataclasses.dataclass(slots=True) +class TestStats: + tests_run: int = 0 + failures: int = 0 + skipped: int = 0 + + @staticmethod + def from_unittest(result): + return TestStats(result.testsRun, + len(result.failures), + len(result.skipped)) + + @staticmethod + def from_doctest(results): + return TestStats(results.attempted, + results.failed) + + def accumulate(self, stats): + self.tests_run += stats.tests_run + self.failures += stats.failures + self.skipped += stats.skipped + + def _run_suite(suite): """Run tests from a unittest.TestSuite-derived class.""" runner = get_test_runner(sys.stdout, @@ -1101,9 +1152,10 @@ def _run_suite(suite): if junit_xml_list is not None: junit_xml_list.append(result.get_xml_element()) - if not result.testsRun and not result.skipped: + if not result.testsRun and not result.skipped and not result.errors: raise TestDidNotRun if not result.wasSuccessful(): + stats = TestStats.from_unittest(result) if len(result.errors) == 1 and not result.failures: err = result.errors[0][1] elif len(result.failures) == 1 and not result.errors: @@ -1113,7 +1165,8 @@ def _run_suite(suite): if not verbose: err += "; run in verbose mode for details" errors = [(str(tc), exc_str) for tc, exc_str in result.errors] failures = [(str(tc), exc_str) for tc, exc_str in result.failures] - raise TestFailedWithDetails(err, errors, failures) + raise TestFailedWithDetails(err, errors, failures, stats=stats) + return result # By default, don't filter tests @@ -1144,7 +1197,6 @@ def _is_full_match_test(pattern): def set_match_tests(accept_patterns=None, ignore_patterns=None): global _match_test_func, _accept_test_patterns, _ignore_test_patterns - if accept_patterns is None: accept_patterns = () if ignore_patterns is None: @@ -1222,7 +1274,7 @@ def run_unittest(*classes): else: suite.addTest(loader.loadTestsFromTestCase(cls)) _filter_suite(suite, match_test) - _run_suite(suite) + return _run_suite(suite) #======================================================================= # Check for the presence of docstrings. @@ -1262,13 +1314,18 @@ def run_doctest(module, verbosity=None, optionflags=0): else: verbosity = None - f, t = doctest.testmod(module, verbose=verbosity, optionflags=optionflags) - if f: - raise TestFailed("%d of %d doctests failed" % (f, t)) + results = doctest.testmod(module, + verbose=verbosity, + optionflags=optionflags) + if results.failed: + stats = TestStats.from_doctest(results) + raise TestFailed(f"{results.failed} of {results.attempted} " + f"doctests failed", + stats=stats) if verbose: print('doctest (%s) ... %d tests with zero failures' % - (module.__name__, t)) - return f, t + (module.__name__, results.attempted)) + return results #======================================================================= @@ -1792,6 +1849,25 @@ def run_in_subinterp(code): Run code in a subinterpreter. Raise unittest.SkipTest if the tracemalloc module is enabled. """ + _check_tracemalloc() + import _testcapi + return _testcapi.run_in_subinterp(code) + + +def run_in_subinterp_with_config(code, *, own_gil=None, **config): + """ + Run code in a subinterpreter. Raise unittest.SkipTest if the tracemalloc + module is enabled. + """ + _check_tracemalloc() + import _testcapi + if own_gil is not None: + assert 'gil' not in config, (own_gil, config) + config['gil'] = 2 if own_gil else 1 + return _testcapi.run_in_subinterp_with_config(code, **config) + + +def _check_tracemalloc(): # Issue #10915, #15751: PyGILState_*() functions don't work with # sub-interpreters, the tracemalloc module uses these functions internally try: @@ -1803,8 +1879,6 @@ def run_in_subinterp(code): raise unittest.SkipTest("run_in_subinterp() cannot be used " "if tracemalloc module is tracing " "memory allocations") - import _testcapi - return _testcapi.run_in_subinterp(code) # TODO: RUSTPYTHON (comment out before) @@ -1836,15 +1910,16 @@ def missing_compiler_executable(cmd_names=[]): missing. """ - # TODO (PEP 632): alternate check without using distutils - from distutils import ccompiler, sysconfig, spawn, errors + from setuptools._distutils import ccompiler, sysconfig, spawn + from setuptools import errors + compiler = ccompiler.new_compiler() sysconfig.customize_compiler(compiler) if compiler.compiler_type == "msvc": # MSVC has no executables, so check whether initialization succeeds try: compiler.initialize() - except errors.DistutilsPlatformError: + except errors.PlatformError: return "msvc" for name in compiler.executables: if cmd_names and name not in cmd_names: @@ -1875,6 +1950,18 @@ def setswitchinterval(interval): return sys.setswitchinterval(interval) +def get_pagesize(): + """Get size of a page in bytes.""" + try: + page_size = os.sysconf('SC_PAGESIZE') + except (ValueError, AttributeError): + try: + page_size = os.sysconf('SC_PAGE_SIZE') + except (ValueError, AttributeError): + page_size = 4096 + return page_size + + @contextlib.contextmanager def disable_faulthandler(): import faulthandler @@ -2092,31 +2179,26 @@ def wait_process(pid, *, exitcode, timeout=None): if timeout is None: timeout = LONG_TIMEOUT - t0 = time.monotonic() - sleep = 0.001 - max_sleep = 0.1 - while True: + + start_time = time.monotonic() + for _ in sleeping_retry(timeout, error=False): pid2, status = os.waitpid(pid, os.WNOHANG) if pid2 != 0: break - # process is still running - - dt = time.monotonic() - t0 - if dt > timeout: - try: - os.kill(pid, signal.SIGKILL) - os.waitpid(pid, 0) - except OSError: - # Ignore errors like ChildProcessError or PermissionError - pass - - raise AssertionError(f"process {pid} is still running " - f"after {dt:.1f} seconds") + # rety: the process is still running + else: + try: + os.kill(pid, signal.SIGKILL) + os.waitpid(pid, 0) + except OSError: + # Ignore errors like ChildProcessError or PermissionError + pass - sleep = min(sleep * 2, max_sleep) - time.sleep(sleep) + dt = time.monotonic() - start_time + raise AssertionError(f"process {pid} is still running " + f"after {dt:.1f} seconds") else: - # Windows implementation + # Windows implementation: don't support timeout :-( pid2, status = os.waitpid(pid, 0) exitcode2 = os.waitstatus_to_exitcode(status) @@ -2168,20 +2250,61 @@ def check_disallow_instantiation(testcase, tp, *args, **kwds): msg = f"cannot create '{re.escape(qualname)}' instances" testcase.assertRaisesRegex(TypeError, msg, tp, *args, **kwds) +def get_recursion_depth(): + """Get the recursion depth of the caller function. + + In the __main__ module, at the module level, it should be 1. + """ + try: + import _testinternalcapi + depth = _testinternalcapi.get_recursion_depth() + except (ImportError, RecursionError) as exc: + # sys._getframe() + frame.f_back implementation. + try: + depth = 0 + frame = sys._getframe() + while frame is not None: + depth += 1 + frame = frame.f_back + finally: + # Break any reference cycles. + frame = None + + # Ignore get_recursion_depth() frame. + return max(depth - 1, 1) + +def get_recursion_available(): + """Get the number of available frames before RecursionError. + + It depends on the current recursion depth of the caller function and + sys.getrecursionlimit(). + """ + limit = sys.getrecursionlimit() + depth = get_recursion_depth() + return limit - depth + @contextlib.contextmanager -def infinite_recursion(max_depth=75): +def set_recursion_limit(limit): + """Temporarily change the recursion limit.""" + original_limit = sys.getrecursionlimit() + try: + sys.setrecursionlimit(limit) + yield + finally: + sys.setrecursionlimit(original_limit) + +def infinite_recursion(max_depth=100): """Set a lower limit for tests that interact with infinite recursions (e.g test_ast.ASTHelpers_Test.test_recursion_direct) since on some debug windows builds, due to not enough functions being inlined the stack size might not handle the default recursion limit (1000). See bpo-11105 for details.""" - - original_depth = sys.getrecursionlimit() - try: - sys.setrecursionlimit(max_depth) - yield - finally: - sys.setrecursionlimit(original_depth) + if max_depth < 3: + raise ValueError("max_depth must be at least 3, got {max_depth}") + depth = get_recursion_depth() + depth = max(depth - 1, 1) # Ignore infinite_recursion() frame. + limit = depth + max_depth + return set_recursion_limit(limit) def ignore_deprecations_from(module: str, *, like: str) -> object: token = object() @@ -2230,6 +2353,180 @@ def requires_venv_with_pip(): return unittest.skipUnless(ctypes, 'venv: pip requires ctypes') +@functools.cache +def _findwheel(pkgname): + """Try to find a wheel with the package specified as pkgname. + + If set, the wheels are searched for in WHEEL_PKG_DIR (see ensurepip). + Otherwise, they are searched for in the test directory. + """ + wheel_dir = sysconfig.get_config_var('WHEEL_PKG_DIR') or TEST_HOME_DIR + filenames = os.listdir(wheel_dir) + filenames = sorted(filenames, reverse=True) # approximate "newest" first + for filename in filenames: + # filename is like 'setuptools-67.6.1-py3-none-any.whl' + if not filename.endswith(".whl"): + continue + prefix = pkgname + '-' + if filename.startswith(prefix): + return os.path.join(wheel_dir, filename) + raise FileNotFoundError(f"No wheel for {pkgname} found in {wheel_dir}") + + +# Context manager that creates a virtual environment, install setuptools and wheel in it +# and returns the path to the venv directory and the path to the python executable +@contextlib.contextmanager +def setup_venv_with_pip_setuptools_wheel(venv_dir): + import subprocess + from .os_helper import temp_cwd + + with temp_cwd() as temp_dir: + # Create virtual environment to get setuptools + cmd = [sys.executable, '-X', 'dev', '-m', 'venv', venv_dir] + if verbose: + print() + print('Run:', ' '.join(cmd)) + subprocess.run(cmd, check=True) + + venv = os.path.join(temp_dir, venv_dir) + + # Get the Python executable of the venv + python_exe = os.path.basename(sys.executable) + if sys.platform == 'win32': + python = os.path.join(venv, 'Scripts', python_exe) + else: + python = os.path.join(venv, 'bin', python_exe) + + cmd = [python, '-X', 'dev', + '-m', 'pip', 'install', + _findwheel('setuptools'), + _findwheel('wheel')] + if verbose: + print() + print('Run:', ' '.join(cmd)) + subprocess.run(cmd, check=True) + + yield python + + +# True if Python is built with the Py_DEBUG macro defined: if +# Python is built in debug mode (./configure --with-pydebug). +Py_DEBUG = hasattr(sys, 'gettotalrefcount') + + +def late_deletion(obj): + """ + Keep a Python alive as long as possible. + + Create a reference cycle and store the cycle in an object deleted late in + Python finalization. Try to keep the object alive until the very last + garbage collection. + + The function keeps a strong reference by design. It should be called in a + subprocess to not mark a test as "leaking a reference". + """ + + # Late CPython finalization: + # - finalize_interp_clear() + # - _PyInterpreterState_Clear(): Clear PyInterpreterState members + # (ex: codec_search_path, before_forkers) + # - clear os.register_at_fork() callbacks + # - clear codecs.register() callbacks + + ref_cycle = [obj] + ref_cycle.append(ref_cycle) + + # Store a reference in PyInterpreterState.codec_search_path + import codecs + def search_func(encoding): + return None + search_func.reference = ref_cycle + codecs.register(search_func) + + if hasattr(os, 'register_at_fork'): + # Store a reference in PyInterpreterState.before_forkers + def atfork_func(): + pass + atfork_func.reference = ref_cycle + os.register_at_fork(before=atfork_func) + + +def busy_retry(timeout, err_msg=None, /, *, error=True): + """ + Run the loop body until "break" stops the loop. + + After *timeout* seconds, raise an AssertionError if *error* is true, + or just stop if *error is false. + + Example: + + for _ in support.busy_retry(support.SHORT_TIMEOUT): + if check(): + break + + Example of error=False usage: + + for _ in support.busy_retry(support.SHORT_TIMEOUT, error=False): + if check(): + break + else: + raise RuntimeError('my custom error') + + """ + if timeout <= 0: + raise ValueError("timeout must be greater than zero") + + start_time = time.monotonic() + deadline = start_time + timeout + + while True: + yield + + if time.monotonic() >= deadline: + break + + if error: + dt = time.monotonic() - start_time + msg = f"timeout ({dt:.1f} seconds)" + if err_msg: + msg = f"{msg}: {err_msg}" + raise AssertionError(msg) + + +def sleeping_retry(timeout, err_msg=None, /, + *, init_delay=0.010, max_delay=1.0, error=True): + """ + Wait strategy that applies exponential backoff. + + Run the loop body until "break" stops the loop. Sleep at each loop + iteration, but not at the first iteration. The sleep delay is doubled at + each iteration (up to *max_delay* seconds). + + See busy_retry() documentation for the parameters usage. + + Example raising an exception after SHORT_TIMEOUT seconds: + + for _ in support.sleeping_retry(support.SHORT_TIMEOUT): + if check(): + break + + Example of error=False usage: + + for _ in support.sleeping_retry(support.SHORT_TIMEOUT, error=False): + if check(): + break + else: + raise RuntimeError('my custom error') + """ + + delay = init_delay + for _ in busy_retry(timeout, err_msg, error=error): + yield + + time.sleep(delay) + delay = min(delay * 2, max_delay) + + @contextlib.contextmanager def adjust_int_max_str_digits(max_digits): """Temporarily change the integer string conversion length limit.""" @@ -2239,3 +2536,13 @@ def adjust_int_max_str_digits(max_digits): yield finally: sys.set_int_max_str_digits(current) + +#For recursion tests, easily exceeds default recursion limit +EXCEEDS_RECURSION_LIMIT = 5000 + +# The default C recursion limit (from Include/cpython/pystate.h). +C_RECURSION_LIMIT = 1500 + +#Windows doesn't have os.uname() but it doesn't support s390x. +skip_on_s390x = unittest.skipIf(hasattr(os, 'uname') and os.uname().machine == 's390x', + 'skipped on s390x') diff --git a/Lib/test/support/bytecode_helper.py b/Lib/test/support/bytecode_helper.py index 471d4a68f9..388d126677 100644 --- a/Lib/test/support/bytecode_helper.py +++ b/Lib/test/support/bytecode_helper.py @@ -3,6 +3,7 @@ import unittest import dis import io +from _testinternalcapi import compiler_codegen, optimize_cfg, assemble_code_object _UNSPECIFIED = object() @@ -16,6 +17,7 @@ def get_disassembly_as_string(self, co): def assertInBytecode(self, x, opname, argval=_UNSPECIFIED): """Returns instr if opname is found, otherwise throws AssertionError""" + self.assertIn(opname, dis.opmap) for instr in dis.get_instructions(x): if instr.opname == opname: if argval is _UNSPECIFIED or instr.argval == argval: @@ -30,6 +32,7 @@ def assertInBytecode(self, x, opname, argval=_UNSPECIFIED): def assertNotInBytecode(self, x, opname, argval=_UNSPECIFIED): """Throws AssertionError if opname is found""" + self.assertIn(opname, dis.opmap) for instr in dis.get_instructions(x): if instr.opname == opname: disassembly = self.get_disassembly_as_string(x) @@ -40,3 +43,101 @@ def assertNotInBytecode(self, x, opname, argval=_UNSPECIFIED): msg = '(%s,%r) occurs in bytecode:\n%s' msg = msg % (opname, argval, disassembly) self.fail(msg) + +class CompilationStepTestCase(unittest.TestCase): + + HAS_ARG = set(dis.hasarg) + HAS_TARGET = set(dis.hasjrel + dis.hasjabs + dis.hasexc) + HAS_ARG_OR_TARGET = HAS_ARG.union(HAS_TARGET) + + class Label: + pass + + def assertInstructionsMatch(self, actual_, expected_): + # get two lists where each entry is a label or + # an instruction tuple. Normalize the labels to the + # instruction count of the target, and compare the lists. + + self.assertIsInstance(actual_, list) + self.assertIsInstance(expected_, list) + + actual = self.normalize_insts(actual_) + expected = self.normalize_insts(expected_) + self.assertEqual(len(actual), len(expected)) + + # compare instructions + for act, exp in zip(actual, expected): + if isinstance(act, int): + self.assertEqual(exp, act) + continue + self.assertIsInstance(exp, tuple) + self.assertIsInstance(act, tuple) + # crop comparison to the provided expected values + if len(act) > len(exp): + act = act[:len(exp)] + self.assertEqual(exp, act) + + def resolveAndRemoveLabels(self, insts): + idx = 0 + res = [] + for item in insts: + assert isinstance(item, (self.Label, tuple)) + if isinstance(item, self.Label): + item.value = idx + else: + idx += 1 + res.append(item) + + return res + + def normalize_insts(self, insts): + """ Map labels to instruction index. + Map opcodes to opnames. + """ + insts = self.resolveAndRemoveLabels(insts) + res = [] + for item in insts: + assert isinstance(item, tuple) + opcode, oparg, *loc = item + opcode = dis.opmap.get(opcode, opcode) + if isinstance(oparg, self.Label): + arg = oparg.value + else: + arg = oparg if opcode in self.HAS_ARG else None + opcode = dis.opname[opcode] + res.append((opcode, arg, *loc)) + return res + + def complete_insts_info(self, insts): + # fill in omitted fields in location, and oparg 0 for ops with no arg. + res = [] + for item in insts: + assert isinstance(item, tuple) + inst = list(item) + opcode = dis.opmap[inst[0]] + oparg = inst[1] + loc = inst[2:] + [-1] * (6 - len(inst)) + res.append((opcode, oparg, *loc)) + return res + + +class CodegenTestCase(CompilationStepTestCase): + + def generate_code(self, ast): + insts, _ = compiler_codegen(ast, "my_file.py", 0) + return insts + + +class CfgOptimizationTestCase(CompilationStepTestCase): + + def get_optimized(self, insts, consts, nlocals=0): + insts = self.normalize_insts(insts) + insts = self.complete_insts_info(insts) + insts = optimize_cfg(insts, consts, nlocals) + return insts, consts + +class AssemblerTestCase(CompilationStepTestCase): + + def get_code_object(self, filename, insts, metadata): + co = assemble_code_object(filename, insts, metadata) + return co diff --git a/Lib/test/support/import_helper.py b/Lib/test/support/import_helper.py index 5201dc84cf..67f18e530e 100644 --- a/Lib/test/support/import_helper.py +++ b/Lib/test/support/import_helper.py @@ -105,6 +105,26 @@ def frozen_modules(enabled=True): _imp._override_frozen_modules_for_tests(0) +@contextlib.contextmanager +def multi_interp_extensions_check(enabled=True): + """Force legacy modules to be allowed in subinterpreters (or not). + + ("legacy" == single-phase init) + + This only applies to modules that haven't been imported yet. + It overrides the PyInterpreterConfig.check_multi_interp_extensions + setting (see support.run_in_subinterp_with_config() and + _xxsubinterpreters.create()). + + Also see importlib.utils.allowing_all_extensions(). + """ + old = _imp._override_multi_interp_extensions_check(1 if enabled else -1) + try: + yield + finally: + _imp._override_multi_interp_extensions_check(old) + + def import_fresh_module(name, fresh=(), blocked=(), *, deprecated=False, usefrozen=False, @@ -246,3 +266,11 @@ def modules_cleanup(oldmodules): # do currently). Implicitly imported *real* modules should be left alone # (see issue 10556). sys.modules.update(oldmodules) + + +def mock_register_at_fork(func): + # bpo-30599: Mock os.register_at_fork() when importing the random module, + # since this function doesn't allow to unregister callbacks and would leak + # memory. + from unittest import mock + return mock.patch('os.register_at_fork', create=True)(func) diff --git a/Lib/test/support/interpreters.py b/Lib/test/support/interpreters.py index 2935708f9d..5c484d1170 100644 --- a/Lib/test/support/interpreters.py +++ b/Lib/test/support/interpreters.py @@ -2,11 +2,12 @@ import time import _xxsubinterpreters as _interpreters +import _xxinterpchannels as _channels # aliases: -from _xxsubinterpreters import ( +from _xxsubinterpreters import is_shareable, RunFailedError +from _xxinterpchannels import ( ChannelError, ChannelNotFoundError, ChannelEmptyError, - is_shareable, ) @@ -102,7 +103,7 @@ def create_channel(): The channel may be used to pass data safely between interpreters. """ - cid = _interpreters.channel_create() + cid = _channels.create() recv, send = RecvChannel(cid), SendChannel(cid) return recv, send @@ -110,14 +111,14 @@ def create_channel(): def list_all_channels(): """Return a list of (recv, send) for all open channels.""" return [(RecvChannel(cid), SendChannel(cid)) - for cid in _interpreters.channel_list_all()] + for cid in _channels.list_all()] class _ChannelEnd: """The base class for RecvChannel and SendChannel.""" def __init__(self, id): - if not isinstance(id, (int, _interpreters.ChannelID)): + if not isinstance(id, (int, _channels.ChannelID)): raise TypeError(f'id must be an int, got {id!r}') self._id = id @@ -152,10 +153,10 @@ def recv(self, *, _sentinel=object(), _delay=10 / 1000): # 10 milliseconds This blocks until an object has been sent, if none have been sent already. """ - obj = _interpreters.channel_recv(self._id, _sentinel) + obj = _channels.recv(self._id, _sentinel) while obj is _sentinel: time.sleep(_delay) - obj = _interpreters.channel_recv(self._id, _sentinel) + obj = _channels.recv(self._id, _sentinel) return obj def recv_nowait(self, default=_NOT_SET): @@ -166,9 +167,9 @@ def recv_nowait(self, default=_NOT_SET): is the same as recv(). """ if default is _NOT_SET: - return _interpreters.channel_recv(self._id) + return _channels.recv(self._id) else: - return _interpreters.channel_recv(self._id, default) + return _channels.recv(self._id, default) class SendChannel(_ChannelEnd): @@ -179,7 +180,7 @@ def send(self, obj): This blocks until the object is received. """ - _interpreters.channel_send(self._id, obj) + _channels.send(self._id, obj) # XXX We are missing a low-level channel_send_wait(). # See bpo-32604 and gh-19829. # Until that shows up we fake it: @@ -194,4 +195,4 @@ def send_nowait(self, obj): # XXX Note that at the moment channel_send() only ever returns # None. This should be fixed when channel_send_wait() is added. # See bpo-32604 and gh-19829. - return _interpreters.channel_send(self._id, obj) + return _channels.send(self._id, obj) diff --git a/Lib/test/support/os_helper.py b/Lib/test/support/os_helper.py index f599cc7521..821a4b1ffd 100644 --- a/Lib/test/support/os_helper.py +++ b/Lib/test/support/os_helper.py @@ -4,6 +4,7 @@ import os import re import stat +import string import sys import time import unittest @@ -11,11 +12,7 @@ # Filename used for testing -if os.name == 'java': - # Jython disallows @ in module names - TESTFN_ASCII = '$test' -else: - TESTFN_ASCII = '@test' +TESTFN_ASCII = '@test' # Disambiguate TESTFN for parallel testing, while letting it remain a valid # module name. @@ -141,6 +138,11 @@ try: name.decode(sys.getfilesystemencoding()) except UnicodeDecodeError: + try: + name.decode(sys.getfilesystemencoding(), + sys.getfilesystemencodeerrors()) + except UnicodeDecodeError: + continue TESTFN_UNDECODABLE = os.fsencode(TESTFN_ASCII) + name break @@ -567,7 +569,7 @@ def fs_is_case_insensitive(directory): class FakePath: - """Simple implementing of the path protocol. + """Simple implementation of the path protocol. """ def __init__(self, path): self.path = path @@ -715,3 +717,37 @@ def __exit__(self, *ignore_exc): else: self._environ[k] = v os.environ = self._environ + + +try: + import ctypes + kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) + + ERROR_FILE_NOT_FOUND = 2 + DDD_REMOVE_DEFINITION = 2 + DDD_EXACT_MATCH_ON_REMOVE = 4 + DDD_NO_BROADCAST_SYSTEM = 8 +except (ImportError, AttributeError): + def subst_drive(path): + raise unittest.SkipTest('ctypes or kernel32 is not available') +else: + @contextlib.contextmanager + def subst_drive(path): + """Temporarily yield a substitute drive for a given path.""" + for c in reversed(string.ascii_uppercase): + drive = f'{c}:' + if (not kernel32.QueryDosDeviceW(drive, None, 0) and + ctypes.get_last_error() == ERROR_FILE_NOT_FOUND): + break + else: + raise unittest.SkipTest('no available logical drive') + if not kernel32.DefineDosDeviceW( + DDD_NO_BROADCAST_SYSTEM, drive, path): + raise ctypes.WinError(ctypes.get_last_error()) + try: + yield drive + finally: + if not kernel32.DefineDosDeviceW( + DDD_REMOVE_DEFINITION | DDD_EXACT_MATCH_ON_REMOVE, + drive, path): + raise ctypes.WinError(ctypes.get_last_error()) diff --git a/Lib/test/support/socket_helper.py b/Lib/test/support/socket_helper.py index 42b2a93398..d9c087c251 100644 --- a/Lib/test/support/socket_helper.py +++ b/Lib/test/support/socket_helper.py @@ -1,8 +1,11 @@ import contextlib import errno +import os.path import socket -import unittest import sys +import subprocess +import tempfile +import unittest from .. import support from . import warnings_helper @@ -61,7 +64,7 @@ def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): http://bugs.python.org/issue2550 for more info. The following site also has a very thorough description about the implications of both REUSEADDR and EXCLUSIVEADDRUSE on Windows: - http://msdn2.microsoft.com/en-us/library/ms740621(VS.85).aspx) + https://learn.microsoft.com/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse XXX: although this approach is a vast improvement on previous attempts to elicit unused ports, it rests heavily on the assumption that the ephemeral @@ -270,3 +273,73 @@ def filter_error(err): # __cause__ or __context__? finally: socket.setdefaulttimeout(old_timeout) + + +def create_unix_domain_name(): + """ + Create a UNIX domain name: socket.bind() argument of a AF_UNIX socket. + + Return a path relative to the current directory to get a short path + (around 27 ASCII characters). + """ + return tempfile.mktemp(prefix="test_python_", suffix='.sock', + dir=os.path.curdir) + + +# consider that sysctl values should not change while tests are running +_sysctl_cache = {} + +def _get_sysctl(name): + """Get a sysctl value as an integer.""" + try: + return _sysctl_cache[name] + except KeyError: + pass + + # At least Linux and FreeBSD support the "-n" option + cmd = ['sysctl', '-n', name] + proc = subprocess.run(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True) + if proc.returncode: + support.print_warning(f'{" ".join(cmd)!r} command failed with ' + f'exit code {proc.returncode}') + # cache the error to only log the warning once + _sysctl_cache[name] = None + return None + output = proc.stdout + + # Parse '0\n' to get '0' + try: + value = int(output.strip()) + except Exception as exc: + support.print_warning(f'Failed to parse {" ".join(cmd)!r} ' + f'command output {output!r}: {exc!r}') + # cache the error to only log the warning once + _sysctl_cache[name] = None + return None + + _sysctl_cache[name] = value + return value + + +def tcp_blackhole(): + if not sys.platform.startswith('freebsd'): + return False + + # gh-109015: test if FreeBSD TCP blackhole is enabled + value = _get_sysctl('net.inet.tcp.blackhole') + if value is None: + # don't skip if we fail to get the sysctl value + return False + return (value != 0) + + +def skip_if_tcp_blackhole(test): + """Decorator skipping test if TCP blackhole is enabled.""" + skip_if = unittest.skipIf( + tcp_blackhole(), + "TCP blackhole is enabled (sysctl net.inet.tcp.blackhole)" + ) + return skip_if(test) diff --git a/Lib/test/support/testresult.py b/Lib/test/support/testresult.py index 2cd1366cd8..de23fdd59d 100644 --- a/Lib/test/support/testresult.py +++ b/Lib/test/support/testresult.py @@ -8,6 +8,7 @@ import time import traceback import unittest +from test import support class RegressionTestResult(unittest.TextTestResult): USE_XML = False @@ -18,10 +19,13 @@ def __init__(self, stream, descriptions, verbosity): self.buffer = True if self.USE_XML: from xml.etree import ElementTree as ET - from datetime import datetime + from datetime import datetime, UTC self.__ET = ET self.__suite = ET.Element('testsuite') - self.__suite.set('start', datetime.utcnow().isoformat(' ')) + self.__suite.set('start', + datetime.now(UTC) + .replace(tzinfo=None) + .isoformat(' ')) self.__e = None self.__start_time = None @@ -109,6 +113,8 @@ def addExpectedFailure(self, test, err): def addFailure(self, test, err): self._add_result(test, True, failure=self.__makeErrorDict(*err)) super().addFailure(test, err) + if support.failfast: + self.stop() def addSkip(self, test, reason): self._add_result(test, skipped=reason) diff --git a/Lib/test/support/threading_helper.py b/Lib/test/support/threading_helper.py index 26cbc6f4d2..7f16050f32 100644 --- a/Lib/test/support/threading_helper.py +++ b/Lib/test/support/threading_helper.py @@ -88,19 +88,17 @@ def wait_threads_exit(timeout=None): yield finally: start_time = time.monotonic() - deadline = start_time + timeout - while True: + for _ in support.sleeping_retry(timeout, error=False): + support.gc_collect() count = _thread._count() if count <= old_count: break - if time.monotonic() > deadline: - dt = time.monotonic() - start_time - msg = (f"wait_threads() failed to cleanup {count - old_count} " - f"threads after {dt:.1f} seconds " - f"(count: {count}, old count: {old_count})") - raise AssertionError(msg) - time.sleep(0.010) - support.gc_collect() + else: + dt = time.monotonic() - start_time + msg = (f"wait_threads() failed to cleanup {count - old_count} " + f"threads after {dt:.1f} seconds " + f"(count: {count}, old count: {old_count})") + raise AssertionError(msg) def join_thread(thread, timeout=None): @@ -117,7 +115,11 @@ def join_thread(thread, timeout=None): @contextlib.contextmanager def start_threads(threads, unlock=None): - import faulthandler + try: + import faulthandler + except ImportError: + # It isn't supported on subinterpreters yet. + faulthandler = None threads = list(threads) started = [] try: @@ -149,7 +151,8 @@ def start_threads(threads, unlock=None): finally: started = [t for t in started if t.is_alive()] if started: - faulthandler.dump_traceback(sys.stdout) + if faulthandler is not None: + faulthandler.dump_traceback(sys.stdout) raise AssertionError('Unable to join %d threads' % len(started)) diff --git a/Lib/test/support/warnings_helper.py b/Lib/test/support/warnings_helper.py index 28e96f88b2..c1bf056230 100644 --- a/Lib/test/support/warnings_helper.py +++ b/Lib/test/support/warnings_helper.py @@ -44,7 +44,7 @@ def check_syntax_warning(testcase, statement, errtext='', def ignore_warnings(*, category): - """Decorator to suppress deprecation warnings. + """Decorator to suppress warnings. Use of context managers to hide warnings make diffs more noisy and tools like 'git blame' less useful. diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 47e10bf2a6..472397e8df 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -20,7 +20,7 @@ This test module can be called from command line with one parameter (Arithmetic or Behaviour) to test each part, or without parameter to test both parts. If -you're working through IDLE, you can import this test module and call test_main() +you're working through IDLE, you can import this test module and call test() with the corresponding argument. """ @@ -32,13 +32,14 @@ import unittest import numbers import locale -from test.support import (run_unittest, run_doctest, is_resource_enabled, +from test.support import (is_resource_enabled, requires_IEEE_754, requires_docstrings, requires_legacy_unicode_capi, check_sanitizer) from test.support import (TestFailed, run_with_locale, cpython_only, - darwin_malloc_err_warning) + darwin_malloc_err_warning, is_emscripten) from test.support.import_helper import import_fresh_module +from test.support import threading_helper from test.support import warnings_helper import random import inspect @@ -61,6 +62,7 @@ fractions = {C:cfractions, P:pfractions} sys.modules['decimal'] = orig_sys_decimal +requires_cdecimal = unittest.skipUnless(C, "test requires C version") # Useful Test Constant Signals = { @@ -98,7 +100,7 @@ def assert_signals(cls, context, attr, expected): ] # Tests are built around these assumed context defaults. -# test_main() restores the original context. +# test() restores the original context. ORIGINAL_CONTEXT = { C: C.getcontext().copy() if C else None, P: P.getcontext().copy() @@ -132,7 +134,7 @@ def init(m): EXTRA_FUNCTIONALITY, "test requires regular build") -class IBMTestCases(unittest.TestCase): +class IBMTestCases: """Class which tests the Decimal class against the IBM test cases.""" def setUp(self): @@ -487,14 +489,10 @@ def change_max_exponent(self, exp): def change_clamp(self, clamp): self.context.clamp = clamp -class CIBMTestCases(IBMTestCases): - decimal = C -class PyIBMTestCases(IBMTestCases): - decimal = P # The following classes test the behaviour of Decimal according to PEP 327 -class ExplicitConstructionTest(unittest.TestCase): +class ExplicitConstructionTest: '''Unit tests for Explicit Construction cases of Decimal.''' def test_explicit_empty(self): @@ -589,7 +587,7 @@ def test_explicit_from_string(self): self.assertRaises(InvalidOperation, Decimal, "1_2_\u00003") @cpython_only - @requires_legacy_unicode_capi + @requires_legacy_unicode_capi() @warnings_helper.ignore_warnings(category=DeprecationWarning) def test_from_legacy_strings(self): import _testcapi @@ -839,12 +837,13 @@ def test_unicode_digits(self): for input, expected in test_values.items(): self.assertEqual(str(Decimal(input)), expected) -class CExplicitConstructionTest(ExplicitConstructionTest): +@requires_cdecimal +class CExplicitConstructionTest(ExplicitConstructionTest, unittest.TestCase): decimal = C -class PyExplicitConstructionTest(ExplicitConstructionTest): +class PyExplicitConstructionTest(ExplicitConstructionTest, unittest.TestCase): decimal = P -class ImplicitConstructionTest(unittest.TestCase): +class ImplicitConstructionTest: '''Unit tests for Implicit Construction cases of Decimal.''' def test_implicit_from_None(self): @@ -921,13 +920,16 @@ def __ne__(self, other): self.assertEqual(eval('Decimal(10)' + sym + 'E()'), '10' + rop + 'str') -class CImplicitConstructionTest(ImplicitConstructionTest): +@requires_cdecimal +class CImplicitConstructionTest(ImplicitConstructionTest, unittest.TestCase): decimal = C -class PyImplicitConstructionTest(ImplicitConstructionTest): +class PyImplicitConstructionTest(ImplicitConstructionTest, unittest.TestCase): decimal = P -class FormatTest(unittest.TestCase): +class FormatTest: '''Unit tests for the format function.''' + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_formatting(self): Decimal = self.decimal.Decimal @@ -1073,6 +1075,57 @@ def test_formatting(self): (',e', '123456', '1.23456e+5'), (',E', '123456', '1.23456E+5'), + # negative zero: default behavior + ('.1f', '-0', '-0.0'), + ('.1f', '-.0', '-0.0'), + ('.1f', '-.01', '-0.0'), + + # negative zero: z option + ('z.1f', '0.', '0.0'), + ('z6.1f', '0.', ' 0.0'), + ('z6.1f', '-1.', ' -1.0'), + ('z.1f', '-0.', '0.0'), + ('z.1f', '.01', '0.0'), + ('z.1f', '-.01', '0.0'), + ('z.2f', '0.', '0.00'), + ('z.2f', '-0.', '0.00'), + ('z.2f', '.001', '0.00'), + ('z.2f', '-.001', '0.00'), + + ('z.1e', '0.', '0.0e+1'), + ('z.1e', '-0.', '0.0e+1'), + ('z.1E', '0.', '0.0E+1'), + ('z.1E', '-0.', '0.0E+1'), + + ('z.2e', '-0.001', '-1.00e-3'), # tests for mishandled rounding + ('z.2g', '-0.001', '-0.001'), + ('z.2%', '-0.001', '-0.10%'), + + ('zf', '-0.0000', '0.0000'), # non-normalized form is preserved + + ('z.1f', '-00000.000001', '0.0'), + ('z.1f', '-00000.', '0.0'), + ('z.1f', '-.0000000000', '0.0'), + + ('z.2f', '-00000.000001', '0.00'), + ('z.2f', '-00000.', '0.00'), + ('z.2f', '-.0000000000', '0.00'), + + ('z.1f', '.09', '0.1'), + ('z.1f', '-.09', '-0.1'), + + (' z.0f', '-0.', ' 0'), + ('+z.0f', '-0.', '+0'), + ('-z.0f', '-0.', '0'), + (' z.0f', '-1.', '-1'), + ('+z.0f', '-1.', '-1'), + ('-z.0f', '-1.', '-1'), + + ('z>6.1f', '-0.', 'zz-0.0'), + ('z>z6.1f', '-0.', 'zzz0.0'), + ('x>z6.1f', '-0.', 'xxx0.0'), + ('🖤>z6.1f', '-0.', '🖤🖤🖤0.0'), # multi-byte fill char + # issue 6850 ('a=-7.0', '0.12345', 'aaaa0.1'), @@ -1087,6 +1140,17 @@ def test_formatting(self): # bytes format argument self.assertRaises(TypeError, Decimal(1).__format__, b'-020') + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_negative_zero_format_directed_rounding(self): + with self.decimal.localcontext() as ctx: + ctx.rounding = ROUND_CEILING + self.assertEqual(format(self.decimal.Decimal('-0.001'), 'z.2f'), + '0.00') + + def test_negative_zero_bad_format(self): + self.assertRaises(ValueError, format, self.decimal.Decimal('1.23'), 'fz') + def test_n_format(self): Decimal = self.decimal.Decimal @@ -1205,12 +1269,13 @@ def __init__(self, a): a = A.from_float(42) self.assertEqual(self.decimal.Decimal, a.a_type) -class CFormatTest(FormatTest): +@requires_cdecimal +class CFormatTest(FormatTest, unittest.TestCase): decimal = C -class PyFormatTest(FormatTest): +class PyFormatTest(FormatTest, unittest.TestCase): decimal = P -class ArithmeticOperatorsTest(unittest.TestCase): +class ArithmeticOperatorsTest: '''Unit tests for all arithmetic operators, binary and unary.''' def test_addition(self): @@ -1466,14 +1531,17 @@ def test_nan_comparisons(self): equality_ops = operator.eq, operator.ne # results when InvalidOperation is not trapped - for x, y in qnan_pairs + snan_pairs: - for op in order_ops + equality_ops: - got = op(x, y) - expected = True if op is operator.ne else False - self.assertIs(expected, got, - "expected {0!r} for operator.{1}({2!r}, {3!r}); " - "got {4!r}".format( - expected, op.__name__, x, y, got)) + with localcontext() as ctx: + ctx.traps[InvalidOperation] = 0 + + for x, y in qnan_pairs + snan_pairs: + for op in order_ops + equality_ops: + got = op(x, y) + expected = True if op is operator.ne else False + self.assertIs(expected, got, + "expected {0!r} for operator.{1}({2!r}, {3!r}); " + "got {4!r}".format( + expected, op.__name__, x, y, got)) # repeat the above, but this time trap the InvalidOperation with localcontext() as ctx: @@ -1505,9 +1573,10 @@ def test_copy_sign(self): self.assertEqual(Decimal(1).copy_sign(-2), d) self.assertRaises(TypeError, Decimal(1).copy_sign, '-2') -class CArithmeticOperatorsTest(ArithmeticOperatorsTest): +@requires_cdecimal +class CArithmeticOperatorsTest(ArithmeticOperatorsTest, unittest.TestCase): decimal = C -class PyArithmeticOperatorsTest(ArithmeticOperatorsTest): +class PyArithmeticOperatorsTest(ArithmeticOperatorsTest, unittest.TestCase): decimal = P # The following are two functions used to test threading in the next class @@ -1595,7 +1664,9 @@ def thfunc2(cls): for sig in Overflow, Underflow, DivisionByZero, InvalidOperation: cls.assertFalse(thiscontext.flags[sig]) -class ThreadingTest(unittest.TestCase): + +@threading_helper.requires_working_threading() +class ThreadingTest: '''Unit tests for thread local contexts in Decimal.''' # Take care executing this test from IDLE, there's an issue in threading @@ -1640,13 +1711,14 @@ def test_threading(self): DefaultContext.Emin = save_emin -class CThreadingTest(ThreadingTest): +@requires_cdecimal +class CThreadingTest(ThreadingTest, unittest.TestCase): decimal = C -class PyThreadingTest(ThreadingTest): +class PyThreadingTest(ThreadingTest, unittest.TestCase): decimal = P -class UsabilityTest(unittest.TestCase): +class UsabilityTest: '''Unit tests for Usability cases of Decimal.''' def test_comparison_operators(self): @@ -2466,12 +2538,22 @@ def test_conversions_from_int(self): self.assertEqual(Decimal(-12).fma(45, Decimal(67)), Decimal(-12).fma(Decimal(45), Decimal(67))) -class CUsabilityTest(UsabilityTest): +@requires_cdecimal +class CUsabilityTest(UsabilityTest, unittest.TestCase): decimal = C -class PyUsabilityTest(UsabilityTest): +class PyUsabilityTest(UsabilityTest, unittest.TestCase): decimal = P -class PythonAPItests(unittest.TestCase): + def setUp(self): + super().setUp() + self._previous_int_limit = sys.get_int_max_str_digits() + sys.set_int_max_str_digits(7000) + + def tearDown(self): + sys.set_int_max_str_digits(self._previous_int_limit) + super().tearDown() + +class PythonAPItests: def test_abc(self): Decimal = self.decimal.Decimal @@ -2549,6 +2631,13 @@ def test_int(self): self.assertRaises(OverflowError, int, Decimal('inf')) self.assertRaises(OverflowError, int, Decimal('-inf')) + @cpython_only + def test_small_ints(self): + Decimal = self.decimal.Decimal + # bpo-46361 + for x in range(-5, 257): + self.assertIs(int(Decimal(x)), x) + def test_trunc(self): Decimal = self.decimal.Decimal @@ -2815,12 +2904,13 @@ def test_exception_hierarchy(self): self.assertTrue(issubclass(decimal.DivisionUndefined, ZeroDivisionError)) self.assertTrue(issubclass(decimal.InvalidContext, InvalidOperation)) -class CPythonAPItests(PythonAPItests): +@requires_cdecimal +class CPythonAPItests(PythonAPItests, unittest.TestCase): decimal = C -class PyPythonAPItests(PythonAPItests): +class PyPythonAPItests(PythonAPItests, unittest.TestCase): decimal = P -class ContextAPItests(unittest.TestCase): +class ContextAPItests: def test_none_args(self): Context = self.decimal.Context @@ -2843,7 +2933,7 @@ def test_none_args(self): Overflow]) @cpython_only - @requires_legacy_unicode_capi + @requires_legacy_unicode_capi() @warnings_helper.ignore_warnings(category=DeprecationWarning) def test_from_legacy_strings(self): import _testcapi @@ -3566,12 +3656,13 @@ def test_to_integral_value(self): self.assertRaises(TypeError, c.to_integral_value, '10') self.assertRaises(TypeError, c.to_integral_value, 10, 'x') -class CContextAPItests(ContextAPItests): +@requires_cdecimal +class CContextAPItests(ContextAPItests, unittest.TestCase): decimal = C -class PyContextAPItests(ContextAPItests): +class PyContextAPItests(ContextAPItests, unittest.TestCase): decimal = P -class ContextWithStatement(unittest.TestCase): +class ContextWithStatement: # Can't do these as docstrings until Python 2.6 # as doctest can't handle __future__ statements @@ -3605,6 +3696,48 @@ def test_localcontextarg(self): self.assertIsNot(new_ctx, set_ctx, 'did not copy the context') self.assertIs(set_ctx, enter_ctx, '__enter__ returned wrong context') + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_localcontext_kwargs(self): + with self.decimal.localcontext( + prec=10, rounding=ROUND_HALF_DOWN, + Emin=-20, Emax=20, capitals=0, + clamp=1 + ) as ctx: + self.assertEqual(ctx.prec, 10) + self.assertEqual(ctx.rounding, self.decimal.ROUND_HALF_DOWN) + self.assertEqual(ctx.Emin, -20) + self.assertEqual(ctx.Emax, 20) + self.assertEqual(ctx.capitals, 0) + self.assertEqual(ctx.clamp, 1) + + self.assertRaises(TypeError, self.decimal.localcontext, precision=10) + + self.assertRaises(ValueError, self.decimal.localcontext, Emin=1) + self.assertRaises(ValueError, self.decimal.localcontext, Emax=-1) + self.assertRaises(ValueError, self.decimal.localcontext, capitals=2) + self.assertRaises(ValueError, self.decimal.localcontext, clamp=2) + + self.assertRaises(TypeError, self.decimal.localcontext, rounding="") + self.assertRaises(TypeError, self.decimal.localcontext, rounding=1) + + self.assertRaises(TypeError, self.decimal.localcontext, flags="") + self.assertRaises(TypeError, self.decimal.localcontext, traps="") + self.assertRaises(TypeError, self.decimal.localcontext, Emin="") + self.assertRaises(TypeError, self.decimal.localcontext, Emax="") + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_local_context_kwargs_does_not_overwrite_existing_argument(self): + ctx = self.decimal.getcontext() + orig_prec = ctx.prec + with self.decimal.localcontext(prec=10) as ctx2: + self.assertEqual(ctx2.prec, 10) + self.assertEqual(ctx.prec, orig_prec) + with self.decimal.localcontext(prec=20) as ctx2: + self.assertEqual(ctx2.prec, 20) + self.assertEqual(ctx.prec, orig_prec) + def test_nested_with_statements(self): # Use a copy of the supplied context in the block Decimal = self.decimal.Decimal @@ -3697,12 +3830,13 @@ def test_with_statements_gc3(self): self.assertEqual(c4.prec, 4) del c4 -class CContextWithStatement(ContextWithStatement): +@requires_cdecimal +class CContextWithStatement(ContextWithStatement, unittest.TestCase): decimal = C -class PyContextWithStatement(ContextWithStatement): +class PyContextWithStatement(ContextWithStatement, unittest.TestCase): decimal = P -class ContextFlags(unittest.TestCase): +class ContextFlags: def test_flags_irrelevant(self): # check that the result (numeric result + flags raised) of an @@ -3969,12 +4103,13 @@ def test_float_operation_default(self): self.assertTrue(context.traps[FloatOperation]) self.assertTrue(context.traps[Inexact]) -class CContextFlags(ContextFlags): +@requires_cdecimal +class CContextFlags(ContextFlags, unittest.TestCase): decimal = C -class PyContextFlags(ContextFlags): +class PyContextFlags(ContextFlags, unittest.TestCase): decimal = P -class SpecialContexts(unittest.TestCase): +class SpecialContexts: """Test the context templates.""" def test_context_templates(self): @@ -4054,12 +4189,13 @@ def test_default_context(self): if ex: raise ex -class CSpecialContexts(SpecialContexts): +@requires_cdecimal +class CSpecialContexts(SpecialContexts, unittest.TestCase): decimal = C -class PySpecialContexts(SpecialContexts): +class PySpecialContexts(SpecialContexts, unittest.TestCase): decimal = P -class ContextInputValidation(unittest.TestCase): +class ContextInputValidation: def test_invalid_context(self): Context = self.decimal.Context @@ -4121,12 +4257,13 @@ def test_invalid_context(self): self.assertRaises(TypeError, Context, flags=(0,1)) self.assertRaises(TypeError, Context, traps=(1,0)) -class CContextInputValidation(ContextInputValidation): +@requires_cdecimal +class CContextInputValidation(ContextInputValidation, unittest.TestCase): decimal = C -class PyContextInputValidation(ContextInputValidation): +class PyContextInputValidation(ContextInputValidation, unittest.TestCase): decimal = P -class ContextSubclassing(unittest.TestCase): +class ContextSubclassing: def test_context_subclassing(self): decimal = self.decimal @@ -4235,12 +4372,14 @@ def __init__(self, prec=None, rounding=None, Emin=None, Emax=None, for signal in OrderedSignals[decimal]: self.assertFalse(c.traps[signal]) -class CContextSubclassing(ContextSubclassing): +@requires_cdecimal +class CContextSubclassing(ContextSubclassing, unittest.TestCase): decimal = C -class PyContextSubclassing(ContextSubclassing): +class PyContextSubclassing(ContextSubclassing, unittest.TestCase): decimal = P @skip_if_extra_functionality +@requires_cdecimal class CheckAttributes(unittest.TestCase): def test_module_attributes(self): @@ -4270,7 +4409,7 @@ def test_decimal_attributes(self): y = [s for s in dir(C.Decimal(9)) if '__' in s or not s.startswith('_')] self.assertEqual(set(x) - set(y), set()) -class Coverage(unittest.TestCase): +class Coverage: def test_adjusted(self): Decimal = self.decimal.Decimal @@ -4527,11 +4666,21 @@ def test_copy(self): y = c.copy_sign(x, 1) self.assertEqual(y, -x) -class CCoverage(Coverage): +@requires_cdecimal +class CCoverage(Coverage, unittest.TestCase): decimal = C -class PyCoverage(Coverage): +class PyCoverage(Coverage, unittest.TestCase): decimal = P + def setUp(self): + super().setUp() + self._previous_int_limit = sys.get_int_max_str_digits() + sys.set_int_max_str_digits(7000) + + def tearDown(self): + sys.set_int_max_str_digits(self._previous_int_limit) + super().tearDown() + class PyFunctionality(unittest.TestCase): """Extra functionality in decimal.py""" @@ -4773,6 +4922,7 @@ def test_constants(self): self.assertEqual(C.DecTraps, C.DecErrors|C.DecOverflow|C.DecUnderflow) +@requires_cdecimal class CWhitebox(unittest.TestCase): """Whitebox testing for _decimal""" @@ -5426,6 +5576,7 @@ def test_from_tuple(self): with localcontext() as c: + c.prec = 9 c.traps[InvalidOperation] = True c.traps[Overflow] = True c.traps[Underflow] = True @@ -5510,6 +5661,7 @@ def __abs__(self): # Issue 41540: @unittest.skipIf(sys.platform.startswith("aix"), "AIX: default ulimit: test is flaky because of extreme over-allocation") + @unittest.skipIf(is_emscripten, "Test is unstable on Emscripten") @unittest.skipIf(check_sanitizer(address=True, memory=True), "ASAN/MSAN sanitizer defaults to crashing " "instead of returning NULL for malloc failure.") @@ -5548,8 +5700,38 @@ def test_maxcontext_exact_arith(self): self.assertEqual(Decimal(400) ** -1, Decimal('0.0025')) + def test_c_signaldict_segfault(self): + # See gh-106263 for details. + SignalDict = type(C.Context().flags) + sd = SignalDict() + err_msg = "invalid signal dict" + + with self.assertRaisesRegex(ValueError, err_msg): + len(sd) + + with self.assertRaisesRegex(ValueError, err_msg): + iter(sd) + + with self.assertRaisesRegex(ValueError, err_msg): + repr(sd) + + with self.assertRaisesRegex(ValueError, err_msg): + sd[C.InvalidOperation] = True + + with self.assertRaisesRegex(ValueError, err_msg): + sd[C.InvalidOperation] + + with self.assertRaisesRegex(ValueError, err_msg): + sd == C.Context().flags + + with self.assertRaisesRegex(ValueError, err_msg): + C.Context().flags == sd + + with self.assertRaisesRegex(ValueError, err_msg): + sd.copy() + @requires_docstrings -@unittest.skipUnless(C, "test requires C version") +@requires_cdecimal class SignatureTest(unittest.TestCase): """Function signatures""" @@ -5685,52 +5867,10 @@ def doit(ty): doit('Context') -all_tests = [ - CExplicitConstructionTest, PyExplicitConstructionTest, - CImplicitConstructionTest, PyImplicitConstructionTest, - CFormatTest, PyFormatTest, - CArithmeticOperatorsTest, PyArithmeticOperatorsTest, - CThreadingTest, PyThreadingTest, - CUsabilityTest, PyUsabilityTest, - CPythonAPItests, PyPythonAPItests, - CContextAPItests, PyContextAPItests, - CContextWithStatement, PyContextWithStatement, - CContextFlags, PyContextFlags, - CSpecialContexts, PySpecialContexts, - CContextInputValidation, PyContextInputValidation, - CContextSubclassing, PyContextSubclassing, - CCoverage, PyCoverage, - CFunctionality, PyFunctionality, - CWhitebox, PyWhitebox, - CIBMTestCases, PyIBMTestCases, -] - -# Delete C tests if _decimal.so is not present. -if not C: - all_tests = all_tests[1::2] -else: - all_tests.insert(0, CheckAttributes) - all_tests.insert(1, SignatureTest) - - -def test_main(arith=None, verbose=None, todo_tests=None, debug=None): - """ Execute the tests. - - Runs all arithmetic tests if arith is True or if the "decimal" resource - is enabled in regrtest.py - """ - - init(C) - init(P) - global TEST_ALL, DEBUG - TEST_ALL = arith if arith is not None else is_resource_enabled('decimal') - DEBUG = debug - - if todo_tests is None: - test_classes = all_tests - else: - test_classes = [CIBMTestCases, PyIBMTestCases] - +def load_tests(loader, tests, pattern): + if TODO_TESTS is not None: + # Run only Arithmetic tests + tests = loader.suiteClass() # Dynamically build custom test definition for each file in the test # directory and add the definitions to the DecimalTest class. This # procedure insures that new files do not get skipped. @@ -5738,34 +5878,69 @@ def test_main(arith=None, verbose=None, todo_tests=None, debug=None): if '.decTest' not in filename or filename.startswith("."): continue head, tail = filename.split('.') - if todo_tests is not None and head not in todo_tests: + if TODO_TESTS is not None and head not in TODO_TESTS: continue tester = lambda self, f=filename: self.eval_file(directory + f) - setattr(CIBMTestCases, 'test_' + head, tester) - setattr(PyIBMTestCases, 'test_' + head, tester) + setattr(IBMTestCases, 'test_' + head, tester) del filename, head, tail, tester + for prefix, mod in ('C', C), ('Py', P): + if not mod: + continue + test_class = type(prefix + 'IBMTestCases', + (IBMTestCases, unittest.TestCase), + {'decimal': mod}) + tests.addTest(loader.loadTestsFromTestCase(test_class)) + + if TODO_TESTS is None: + from doctest import DocTestSuite, IGNORE_EXCEPTION_DETAIL + for mod in C, P: + if not mod: + continue + def setUp(slf, mod=mod): + sys.modules['decimal'] = mod + def tearDown(slf): + sys.modules['decimal'] = orig_sys_decimal + optionflags = IGNORE_EXCEPTION_DETAIL if mod is C else 0 + sys.modules['decimal'] = mod + tests.addTest(DocTestSuite(mod, setUp=setUp, tearDown=tearDown, + optionflags=optionflags)) + sys.modules['decimal'] = orig_sys_decimal + return tests + +def setUpModule(): + init(C) + init(P) + global TEST_ALL + TEST_ALL = ARITH if ARITH is not None else is_resource_enabled('decimal') + +def tearDownModule(): + if C: C.setcontext(ORIGINAL_CONTEXT[C]) + P.setcontext(ORIGINAL_CONTEXT[P]) + if not C: + warnings.warn('C tests skipped: no module named _decimal.', + UserWarning) + if not orig_sys_decimal is sys.modules['decimal']: + raise TestFailed("Internal error: unbalanced number of changes to " + "sys.modules['decimal'].") + + +ARITH = None +TEST_ALL = True +TODO_TESTS = None +DEBUG = False + +def test(arith=None, verbose=None, todo_tests=None, debug=None): + """ Execute the tests. + Runs all arithmetic tests if arith is True or if the "decimal" resource + is enabled in regrtest.py + """ - try: - run_unittest(*test_classes) - if todo_tests is None: - from doctest import IGNORE_EXCEPTION_DETAIL - savedecimal = sys.modules['decimal'] - if C: - sys.modules['decimal'] = C - run_doctest(C, verbose, optionflags=IGNORE_EXCEPTION_DETAIL) - sys.modules['decimal'] = P - run_doctest(P, verbose) - sys.modules['decimal'] = savedecimal - finally: - if C: C.setcontext(ORIGINAL_CONTEXT[C]) - P.setcontext(ORIGINAL_CONTEXT[P]) - if not C: - warnings.warn('C tests skipped: no module named _decimal.', - UserWarning) - if not orig_sys_decimal is sys.modules['decimal']: - raise TestFailed("Internal error: unbalanced number of changes to " - "sys.modules['decimal'].") + global ARITH, TODO_TESTS, DEBUG + ARITH = arith + TODO_TESTS = todo_tests + DEBUG = debug + unittest.main(__name__, verbosity=2 if verbose else 1, exit=False, argv=[__name__]) if __name__ == '__main__': @@ -5776,8 +5951,8 @@ def test_main(arith=None, verbose=None, todo_tests=None, debug=None): (opt, args) = p.parse_args() if opt.skip: - test_main(arith=False, verbose=True) + test(arith=False, verbose=True) elif args: - test_main(arith=True, verbose=True, todo_tests=args, debug=opt.debug) + test(arith=True, verbose=True, todo_tests=args, debug=opt.debug) else: - test_main(arith=True, verbose=True) + test(arith=True, verbose=True) diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index 1c307e75ee..3989b7d674 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -20,7 +20,6 @@ from test import support from test.support import ALWAYS_EQ from test.support import threading_helper -from textwrap import dedent from datetime import timedelta python_version = sys.version_info[:2] @@ -237,11 +236,83 @@ class _EnumTests: values = None def setUp(self): - class BaseEnum(self.enum_type): + if self.__class__.__name__[-5:] == 'Class': + class BaseEnum(self.enum_type): + @enum.property + def first(self): + return '%s is first!' % self.name + class MainEnum(BaseEnum): + first = auto() + second = auto() + third = auto() + if issubclass(self.enum_type, Flag): + dupe = 3 + else: + dupe = third + self.MainEnum = MainEnum + # + class NewStrEnum(self.enum_type): + def __str__(self): + return self.name.upper() + first = auto() + self.NewStrEnum = NewStrEnum + # + class NewFormatEnum(self.enum_type): + def __format__(self, spec): + return self.name.upper() + first = auto() + self.NewFormatEnum = NewFormatEnum + # + class NewStrFormatEnum(self.enum_type): + def __str__(self): + return self.name.title() + def __format__(self, spec): + return ''.join(reversed(self.name)) + first = auto() + self.NewStrFormatEnum = NewStrFormatEnum + # + class NewBaseEnum(self.enum_type): + def __str__(self): + return self.name.title() + def __format__(self, spec): + return ''.join(reversed(self.name)) + self.NewBaseEnum = NewBaseEnum + class NewSubEnum(NewBaseEnum): + first = auto() + self.NewSubEnum = NewSubEnum + # + class LazyGNV(self.enum_type): + def _generate_next_value_(name, start, last, values): + pass + self.LazyGNV = LazyGNV + # + class BusyGNV(self.enum_type): + @staticmethod + def _generate_next_value_(name, start, last, values): + pass + self.BusyGNV = BusyGNV + # + self.is_flag = False + self.names = ['first', 'second', 'third'] + if issubclass(MainEnum, StrEnum): + self.values = self.names + elif MainEnum._member_type_ is str: + self.values = ['1', '2', '3'] + elif issubclass(self.enum_type, Flag): + self.values = [1, 2, 4] + self.is_flag = True + self.dupe2 = MainEnum(5) + else: + self.values = self.values or [1, 2, 3] + # + if not getattr(self, 'source_values', False): + self.source_values = self.values + elif self.__class__.__name__[-8:] == 'Function': @enum.property def first(self): return '%s is first!' % self.name - class MainEnum(BaseEnum): + BaseEnum = self.enum_type('BaseEnum', {'first':first}) + # first = auto() second = auto() third = auto() @@ -249,52 +320,58 @@ class MainEnum(BaseEnum): dupe = 3 else: dupe = third - self.MainEnum = MainEnum - # - class NewStrEnum(self.enum_type): + self.MainEnum = MainEnum = BaseEnum('MainEnum', dict(first=first, second=second, third=third, dupe=dupe)) + # def __str__(self): return self.name.upper() first = auto() - self.NewStrEnum = NewStrEnum - # - class NewFormatEnum(self.enum_type): + self.NewStrEnum = self.enum_type('NewStrEnum', (('first',first),('__str__',__str__))) + # def __format__(self, spec): return self.name.upper() first = auto() - self.NewFormatEnum = NewFormatEnum - # - class NewStrFormatEnum(self.enum_type): + self.NewFormatEnum = self.enum_type('NewFormatEnum', [('first',first),('__format__',__format__)]) + # def __str__(self): return self.name.title() def __format__(self, spec): return ''.join(reversed(self.name)) first = auto() - self.NewStrFormatEnum = NewStrFormatEnum - # - class NewBaseEnum(self.enum_type): + self.NewStrFormatEnum = self.enum_type('NewStrFormatEnum', dict(first=first, __format__=__format__, __str__=__str__)) + # def __str__(self): return self.name.title() def __format__(self, spec): return ''.join(reversed(self.name)) - class NewSubEnum(NewBaseEnum): - first = auto() - self.NewSubEnum = NewSubEnum - # - self.is_flag = False - self.names = ['first', 'second', 'third'] - if issubclass(MainEnum, StrEnum): - self.values = self.names - elif MainEnum._member_type_ is str: - self.values = ['1', '2', '3'] - elif issubclass(self.enum_type, Flag): - self.values = [1, 2, 4] - self.is_flag = True - self.dupe2 = MainEnum(5) + self.NewBaseEnum = self.enum_type('NewBaseEnum', dict(__format__=__format__, __str__=__str__)) + self.NewSubEnum = self.NewBaseEnum('NewSubEnum', 'first') + # + def _generate_next_value_(name, start, last, values): + pass + self.LazyGNV = self.enum_type('LazyGNV', {'_generate_next_value_':_generate_next_value_}) + # + @staticmethod + def _generate_next_value_(name, start, last, values): + pass + self.BusyGNV = self.enum_type('BusyGNV', {'_generate_next_value_':_generate_next_value_}) + # + self.is_flag = False + self.names = ['first', 'second', 'third'] + if issubclass(MainEnum, StrEnum): + self.values = self.names + elif MainEnum._member_type_ is str: + self.values = ['1', '2', '3'] + elif issubclass(self.enum_type, Flag): + self.values = [1, 2, 4] + self.is_flag = True + self.dupe2 = MainEnum(5) + else: + self.values = self.values or [1, 2, 3] + # + if not getattr(self, 'source_values', False): + self.source_values = self.values else: - self.values = self.values or [1, 2, 3] - # - if not getattr(self, 'source_values', False): - self.source_values = self.values + raise ValueError('unknown enum style: %r' % self.__class__.__name__) def assertFormatIsValue(self, spec, member): self.assertEqual(spec.format(member), spec.format(member.value)) @@ -322,6 +399,17 @@ def spam(cls): with self.assertRaises(AttributeError): del Season.SPRING.name + def test_bad_new_super(self): + with self.assertRaisesRegex( + TypeError, + 'has no members defined', + ): + class BadSuper(self.enum_type): + def __new__(cls, value): + obj = super().__new__(cls, value) + return obj + failed = 1 + def test_basics(self): TE = self.MainEnum if self.is_flag: @@ -373,19 +461,12 @@ def test_changing_member_fails(self): with self.assertRaises(AttributeError): self.MainEnum.second = 'really first' - @unittest.skipIf( - python_version >= (3, 12), - '__contains__ now returns True/False for all inputs', - ) - def test_contains_er(self): + def test_contains_tf(self): MainEnum = self.MainEnum - self.assertIn(MainEnum.third, MainEnum) - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - self.source_values[1] in MainEnum - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 'first' in MainEnum + self.assertIn(MainEnum.first, MainEnum) + self.assertTrue(self.values[0] in MainEnum) + if type(self) not in (TestStrEnumClass, TestStrEnumFunction): + self.assertFalse('first' in MainEnum) val = MainEnum.dupe self.assertIn(val, MainEnum) # @@ -393,23 +474,43 @@ class OtherEnum(Enum): one = auto() two = auto() self.assertNotIn(OtherEnum.two, MainEnum) - - @unittest.skipIf( - python_version < (3, 12), - '__contains__ works only with enum memmbers before 3.12', - ) - def test_contains_tf(self): + # + if MainEnum._member_type_ is object: + # enums without mixed data types will always be False + class NotEqualEnum(self.enum_type): + this = self.source_values[0] + that = self.source_values[1] + self.assertNotIn(NotEqualEnum.this, MainEnum) + self.assertNotIn(NotEqualEnum.that, MainEnum) + else: + # enums with mixed data types may be True + class EqualEnum(self.enum_type): + this = self.source_values[0] + that = self.source_values[1] + self.assertIn(EqualEnum.this, MainEnum) + self.assertIn(EqualEnum.that, MainEnum) + + def test_contains_same_name_diff_enum_diff_values(self): MainEnum = self.MainEnum - self.assertIn(MainEnum.first, MainEnum) - self.assertTrue(self.source_values[0] in MainEnum) - self.assertFalse('first' in MainEnum) - val = MainEnum.dupe - self.assertIn(val, MainEnum) # class OtherEnum(Enum): - one = auto() - two = auto() - self.assertNotIn(OtherEnum.two, MainEnum) + first = "brand" + second = "new" + third = "values" + # + self.assertIn(MainEnum.first, MainEnum) + self.assertIn(MainEnum.second, MainEnum) + self.assertIn(MainEnum.third, MainEnum) + self.assertNotIn(MainEnum.first, OtherEnum) + self.assertNotIn(MainEnum.second, OtherEnum) + self.assertNotIn(MainEnum.third, OtherEnum) + # + self.assertIn(OtherEnum.first, OtherEnum) + self.assertIn(OtherEnum.second, OtherEnum) + self.assertIn(OtherEnum.third, OtherEnum) + self.assertNotIn(OtherEnum.first, MainEnum) + self.assertNotIn(OtherEnum.second, MainEnum) + self.assertNotIn(OtherEnum.third, MainEnum) def test_dir_on_class(self): TE = self.MainEnum @@ -459,10 +560,20 @@ class SubEnum(SuperEnum): self.assertTrue('description' not in dir(SubEnum)) self.assertTrue('description' in dir(SubEnum.sample), dir(SubEnum.sample)) + def test_empty_enum_has_no_values(self): + with self.assertRaisesRegex(TypeError, "<.... 'NewBaseEnum'> has no members"): + self.NewBaseEnum(7) + def test_enum_in_enum_out(self): Main = self.MainEnum self.assertIs(Main(Main.first), Main.first) + def test_gnv_is_static(self): + lazy = self.LazyGNV + busy = self.BusyGNV + self.assertTrue(type(lazy.__dict__['_generate_next_value_']) is staticmethod) + self.assertTrue(type(busy.__dict__['_generate_next_value_']) is staticmethod) + def test_hash(self): MainEnum = self.MainEnum mapping = {} @@ -499,7 +610,7 @@ def __repr__(self): def test_overridden_str(self): # TODO: RUSTPYTHON, format(NS.first) does not use __str__ - if isinstance(self, TestIntFlag) or isinstance(self, TestIntEnum) or isinstance(self, TestMinimalFloat): + if self.__class__ in (TestIntFlagFunction, TestIntFlagClass, TestIntEnumFunction, TestIntEnumClass, TestMinimalFloatFunction, TestMinimalFloatClass): self.skipTest("format(NS.first) does not use __str__") NS = self.NewStrEnum self.assertEqual(str(NS.first), NS.first.name.upper()) @@ -883,80 +994,192 @@ class OpenXYZ(self.enum_type): self.assertTrue(~OpenXYZ(0), (X|Y|Z)) -class TestPlainEnum(_EnumTests, _PlainOutputTests, unittest.TestCase): +class TestPlainEnumClass(_EnumTests, _PlainOutputTests, unittest.TestCase): enum_type = Enum -class TestPlainFlag(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase): +class TestPlainEnumFunction(_EnumTests, _PlainOutputTests, unittest.TestCase): + enum_type = Enum + + +class TestPlainFlagClass(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase): enum_type = Flag -class TestIntEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase): +class TestPlainFlagFunction(_EnumTests, _PlainOutputTests, _FlagTests, unittest.TestCase): + enum_type = Flag + + +class TestIntEnumClass(_EnumTests, _MinimalOutputTests, unittest.TestCase): + enum_type = IntEnum + # + def test_shadowed_attr(self): + class Number(IntEnum): + divisor = 1 + numerator = 2 + # + self.assertEqual(Number.divisor.numerator, 1) + self.assertIs(Number.numerator.divisor, Number.divisor) + + +class TestIntEnumFunction(_EnumTests, _MinimalOutputTests, unittest.TestCase): enum_type = IntEnum + # + def test_shadowed_attr(self): + Number = IntEnum('Number', ('divisor', 'numerator')) + # + self.assertEqual(Number.divisor.numerator, 1) + self.assertIs(Number.numerator.divisor, Number.divisor) -class TestStrEnum(_EnumTests, _MinimalOutputTests, unittest.TestCase): +class TestStrEnumClass(_EnumTests, _MinimalOutputTests, unittest.TestCase): enum_type = StrEnum + # + def test_shadowed_attr(self): + class Book(StrEnum): + author = 'author' + title = 'title' + # + self.assertEqual(Book.author.title(), 'Author') + self.assertEqual(Book.title.title(), 'Title') + self.assertIs(Book.title.author, Book.author) + + +class TestStrEnumFunction(_EnumTests, _MinimalOutputTests, unittest.TestCase): + enum_type = StrEnum + # + def test_shadowed_attr(self): + Book = StrEnum('Book', ('author', 'title')) + # + self.assertEqual(Book.author.title(), 'Author') + self.assertEqual(Book.title.title(), 'Title') + self.assertIs(Book.title.author, Book.author) -class TestIntFlag(_EnumTests, _MinimalOutputTests, _FlagTests, unittest.TestCase): +class TestIntFlagClass(_EnumTests, _MinimalOutputTests, _FlagTests, unittest.TestCase): enum_type = IntFlag -class TestMixedInt(_EnumTests, _MixedOutputTests, unittest.TestCase): +class TestIntFlagFunction(_EnumTests, _MinimalOutputTests, _FlagTests, unittest.TestCase): + enum_type = IntFlag + + +class TestMixedIntClass(_EnumTests, _MixedOutputTests, unittest.TestCase): class enum_type(int, Enum): pass -class TestMixedStr(_EnumTests, _MixedOutputTests, unittest.TestCase): +class TestMixedIntFunction(_EnumTests, _MixedOutputTests, unittest.TestCase): + enum_type = Enum('enum_type', type=int) + + +class TestMixedStrClass(_EnumTests, _MixedOutputTests, unittest.TestCase): class enum_type(str, Enum): pass -class TestMixedIntFlag(_EnumTests, _MixedOutputTests, _FlagTests, unittest.TestCase): +class TestMixedStrFunction(_EnumTests, _MixedOutputTests, unittest.TestCase): + enum_type = Enum('enum_type', type=str) + + +class TestMixedIntFlagClass(_EnumTests, _MixedOutputTests, _FlagTests, unittest.TestCase): class enum_type(int, Flag): pass -class TestMixedDate(_EnumTests, _MixedOutputTests, unittest.TestCase): +class TestMixedIntFlagFunction(_EnumTests, _MixedOutputTests, _FlagTests, unittest.TestCase): + enum_type = Flag('enum_type', type=int) + +class TestMixedDateClass(_EnumTests, _MixedOutputTests, unittest.TestCase): + # values = [date(2021, 12, 25), date(2020, 3, 15), date(2019, 11, 27)] source_values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)] - + # class enum_type(date, Enum): + @staticmethod def _generate_next_value_(name, start, count, last_values): values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)] return values[count] -class TestMinimalDate(_EnumTests, _MinimalOutputTests, unittest.TestCase): +class TestMixedDateFunction(_EnumTests, _MixedOutputTests, unittest.TestCase): + # + values = [date(2021, 12, 25), date(2020, 3, 15), date(2019, 11, 27)] + source_values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)] + # + # staticmethod decorator will be added by EnumType if not present + def _generate_next_value_(name, start, count, last_values): + values = [(2021, 12, 25), (2020, 3, 15), (2019, 11, 27)] + return values[count] + # + enum_type = Enum('enum_type', {'_generate_next_value_':_generate_next_value_}, type=date) + +class TestMinimalDateClass(_EnumTests, _MinimalOutputTests, unittest.TestCase): + # values = [date(2023, 12, 1), date(2016, 2, 29), date(2009, 1, 1)] source_values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)] - + # class enum_type(date, ReprEnum): + # staticmethod decorator will be added by EnumType if absent def _generate_next_value_(name, start, count, last_values): values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)] return values[count] -class TestMixedFloat(_EnumTests, _MixedOutputTests, unittest.TestCase): +class TestMinimalDateFunction(_EnumTests, _MinimalOutputTests, unittest.TestCase): + # + values = [date(2023, 12, 1), date(2016, 2, 29), date(2009, 1, 1)] + source_values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)] + # + @staticmethod + def _generate_next_value_(name, start, count, last_values): + values = [(2023, 12, 1), (2016, 2, 29), (2009, 1, 1)] + return values[count] + # + enum_type = ReprEnum('enum_type', {'_generate_next_value_':_generate_next_value_}, type=date) - values = [1.1, 2.2, 3.3] +class TestMixedFloatClass(_EnumTests, _MixedOutputTests, unittest.TestCase): + # + values = [1.1, 2.2, 3.3] + # class enum_type(float, Enum): def _generate_next_value_(name, start, count, last_values): values = [1.1, 2.2, 3.3] return values[count] -class TestMinimalFloat(_EnumTests, _MinimalOutputTests, unittest.TestCase): +class TestMixedFloatFunction(_EnumTests, _MixedOutputTests, unittest.TestCase): + # + values = [1.1, 2.2, 3.3] + # + def _generate_next_value_(name, start, count, last_values): + values = [1.1, 2.2, 3.3] + return values[count] + # + enum_type = Enum('enum_type', {'_generate_next_value_':_generate_next_value_}, type=float) - values = [4.4, 5.5, 6.6] +class TestMinimalFloatClass(_EnumTests, _MinimalOutputTests, unittest.TestCase): + # + values = [4.4, 5.5, 6.6] + # class enum_type(float, ReprEnum): def _generate_next_value_(name, start, count, last_values): values = [4.4, 5.5, 6.6] return values[count] +class TestMinimalFloatFunction(_EnumTests, _MinimalOutputTests, unittest.TestCase): + # + values = [4.4, 5.5, 6.6] + # + def _generate_next_value_(name, start, count, last_values): + values = [4.4, 5.5, 6.6] + return values[count] + # + enum_type = ReprEnum('enum_type', {'_generate_next_value_':_generate_next_value_}, type=float) + + class TestSpecial(unittest.TestCase): """ various operations that are not attributable to every possible enum @@ -1244,6 +1467,28 @@ class Huh(Enum): self.assertEqual(Huh.name.name, 'name') self.assertEqual(Huh.name.value, 1) + def test_contains_name_and_value_overlap(self): + class IntEnum1(IntEnum): + X = 1 + class IntEnum2(IntEnum): + X = 1 + class IntEnum3(IntEnum): + X = 2 + class IntEnum4(IntEnum): + Y = 1 + self.assertIn(IntEnum1.X, IntEnum1) + self.assertIn(IntEnum1.X, IntEnum2) + self.assertNotIn(IntEnum1.X, IntEnum3) + self.assertIn(IntEnum1.X, IntEnum4) + + def test_contains_different_types_same_members(self): + class IntEnum1(IntEnum): + X = 1 + class IntFlag1(IntFlag): + X = 1 + self.assertIn(IntEnum1.X, IntFlag1) + self.assertIn(IntFlag1.X, IntEnum1) + def test_inherited_data_type(self): class HexInt(int): __qualname__ = 'HexInt' @@ -1262,7 +1507,6 @@ class MyEnum(HexInt, enum.Enum): # class SillyInt(HexInt): __qualname__ = 'SillyInt' - pass class MyOtherEnum(SillyInt, enum.Enum): __qualname__ = 'MyOtherEnum' D = 4 @@ -1396,6 +1640,21 @@ def test_programmatic_function_type_from_subclass_with_start(self): self.assertIn(e, MinorEnum) self.assertIs(type(e), MinorEnum) + def test_programmatic_function_is_value_call(self): + class TwoPart(Enum): + ONE = 1, 1.0 + TWO = 2, 2.0 + THREE = 3, 3.0 + self.assertRaisesRegex(ValueError, '1 is not a valid .*TwoPart', TwoPart, 1) + self.assertIs(TwoPart((1, 1.0)), TwoPart.ONE) + self.assertIs(TwoPart(1, 1.0), TwoPart.ONE) + class ThreePart(Enum): + ONE = 1, 1.0, 'one' + TWO = 2, 2.0, 'two' + THREE = 3, 3.0, 'three' + self.assertIs(ThreePart((3, 3.0, 'three')), ThreePart.THREE) + self.assertIs(ThreePart(3, 3.0, 'three'), ThreePart.THREE) + # TODO: RUSTPYTHON, AssertionError: is not @unittest.expectedFailure def test_intenum_from_bytes(self): @@ -1539,7 +1798,7 @@ class MoreColor(Color): class EvenMoreColor(Color, IntEnum): chartruese = 7 # - with self.assertRaisesRegex(TypeError, " cannot extend "): + with self.assertRaisesRegex(ValueError, r"\(.Foo., \(.pink., .black.\)\) is not a valid .*Color"): Color('Foo', ('pink', 'black')) def test_exclude_methods(self): @@ -2733,14 +2992,15 @@ class Private(Enum): self.assertEqual(Private._Private__corporal, 'Radar') self.assertEqual(Private._Private__major_, 'Hoolihan') - @unittest.skip("Accessing all values retained for performance reasons, see GH-93910") - def test_exception_for_member_from_member_access(self): - with self.assertRaisesRegex(AttributeError, " member has no attribute .NO."): - class Di(Enum): - YES = 1 - NO = 0 - nope = Di.YES.NO - + def test_member_from_member_access(self): + class Di(Enum): + YES = 1 + NO = 0 + name = 3 + warn = Di.YES.NO + self.assertIs(warn, Di.NO) + self.assertIs(Di.name, Di['name']) + self.assertEqual(Di.name.name, 'name') def test_dynamic_members_with_static_methods(self): # @@ -2771,20 +3031,69 @@ def upper(self): def test_repr_with_dataclass(self): "ensure dataclass-mixin has correct repr()" - from dataclasses import dataclass - @dataclass + # + # check overridden dataclass __repr__ is used + # + from dataclasses import dataclass, field + @dataclass(repr=False) class Foo: __qualname__ = 'Foo' a: int + def __repr__(self): + return 'ha hah!' class Entries(Foo, Enum): ENTRY1 = 1 + self.assertEqual(repr(Entries.ENTRY1), '') + self.assertTrue(Entries.ENTRY1.value == Foo(1), Entries.ENTRY1.value) self.assertTrue(isinstance(Entries.ENTRY1, Foo)) self.assertTrue(Entries._member_type_ is Foo, Entries._member_type_) - self.assertTrue(Entries.ENTRY1.value == Foo(1), Entries.ENTRY1.value) - self.assertEqual(repr(Entries.ENTRY1), '') - - def test_repr_with_init_data_type_mixin(self): - # non-data_type is a mixin that doesn't define __new__ + # + # check auto-generated dataclass __repr__ is not used + # + @dataclass + class CreatureDataMixin: + __qualname__ = 'CreatureDataMixin' + size: str + legs: int + tail: bool = field(repr=False, default=True) + class Creature(CreatureDataMixin, Enum): + __qualname__ = 'Creature' + BEETLE = ('small', 6) + DOG = ('medium', 4) + self.assertEqual(repr(Creature.DOG), "") + # + # check inherited repr used + # + class Huh: + def __repr__(self): + return 'inherited' + @dataclass(repr=False) + class CreatureDataMixin(Huh): + __qualname__ = 'CreatureDataMixin' + size: str + legs: int + tail: bool = field(repr=False, default=True) + class Creature(CreatureDataMixin, Enum): + __qualname__ = 'Creature' + BEETLE = ('small', 6) + DOG = ('medium', 4) + self.assertEqual(repr(Creature.DOG), "") + # + # check default object.__repr__ used if nothing provided + # + @dataclass(repr=False) + class CreatureDataMixin: + __qualname__ = 'CreatureDataMixin' + size: str + legs: int + tail: bool = field(repr=False, default=True) + class Creature(CreatureDataMixin, Enum): + __qualname__ = 'Creature' + BEETLE = ('small', 6) + DOG = ('medium', 4) + self.assertRegex(repr(Creature.DOG), "") + + def test_repr_with_init_mixin(self): class Foo: def __init__(self, a): self.a = a @@ -2795,7 +3104,7 @@ class Entries(Foo, Enum): # self.assertEqual(repr(Entries.ENTRY1), 'Foo(a=1)') - def test_repr_and_str_with_non_data_type_mixin(self): + def test_repr_and_str_with_no_init_mixin(self): # non-data_type is a mixin that doesn't define __new__ class Foo: def __repr__(self): @@ -3250,32 +3559,6 @@ def test_pickle(self): test_pickle_dump_load(self.assertEqual, IntFlagStoogesWithZero.CURLY|IntFlagStoogesWithZero.BIG) - @unittest.skipIf( - python_version >= (3, 12), - '__contains__ now returns True/False for all inputs', - ) - def test_contains_er(self): - Open = self.Open - Color = self.Color - self.assertFalse(Color.BLACK in Open) - self.assertFalse(Open.RO in Color) - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 'BLACK' in Color - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 'RO' in Open - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 1 in Color - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 1 in Open - - @unittest.skipIf( - python_version < (3, 12), - '__contains__ only works with enum memmbers before 3.12', - ) def test_contains_tf(self): Open = self.Open Color = self.Color @@ -3283,6 +3566,8 @@ def test_contains_tf(self): self.assertFalse(Open.RO in Color) self.assertFalse('BLACK' in Color) self.assertFalse('RO' in Open) + self.assertTrue(Color.BLACK in Color) + self.assertTrue(Open.RO in Open) self.assertTrue(1 in Color) self.assertTrue(1 in Open) @@ -3449,9 +3734,8 @@ def cycle_enum(): threading.Thread(target=cycle_enum) for _ in range(8) ] - with threading_helper.wait_threads_exit(): - with threading_helper.start_threads(threads): - pass + with threading_helper.start_threads(threads): + pass # check that only 248 members were created self.assertFalse( failed, @@ -3827,41 +4111,11 @@ def test_programatic_function_from_empty_tuple(self): self.assertEqual(len(lst), len(Thing)) self.assertEqual(len(Thing), 0, Thing) - @unittest.skipIf( - python_version >= (3, 12), - '__contains__ now returns True/False for all inputs', - ) - def test_contains_er(self): - Open = self.Open - Color = self.Color - self.assertTrue(Color.GREEN in Color) - self.assertTrue(Open.RW in Open) - self.assertFalse(Color.GREEN in Open) - self.assertFalse(Open.RW in Color) - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 'GREEN' in Color - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 'RW' in Open - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 2 in Color - with self.assertRaises(TypeError): - with self.assertWarns(DeprecationWarning): - 2 in Open - - @unittest.skipIf( - python_version < (3, 12), - '__contains__ only works with enum memmbers before 3.12', - ) def test_contains_tf(self): Open = self.Open Color = self.Color self.assertTrue(Color.GREEN in Color) self.assertTrue(Open.RW in Open) - self.assertTrue(Color.GREEN in Open) - self.assertTrue(Open.RW in Color) self.assertFalse('GREEN' in Color) self.assertFalse('RW' in Open) self.assertTrue(2 in Color) @@ -3967,6 +4221,7 @@ class Color(StrMixin, AllMixin, IntFlag): self.assertEqual(Color.ALL.value, 7) self.assertEqual(str(Color.BLUE), 'blue') + @unittest.skip("TODO: RUSTPYTHON; flaky test") @threading_helper.reap_threads @threading_helper.requires_working_threading() def test_unique_composite(self): @@ -3998,9 +4253,8 @@ def cycle_enum(): threading.Thread(target=cycle_enum) for _ in range(8) ] - with threading_helper.wait_threads_exit(): - with threading_helper.start_threads(threads): - pass + with threading_helper.start_threads(threads): + pass # check that only 248 members were created self.assertFalse( failed, @@ -4417,118 +4671,87 @@ class TestEnumTypeSubclassing(unittest.TestCase): Help on class Color in module %s: class Color(enum.Enum) - | Create a collection of name/value pairs. - |\x20\x20 - | Example enumeration: - |\x20\x20 - | >>> class Color(Enum): - | ... RED = 1 - | ... BLUE = 2 - | ... GREEN = 3 - |\x20\x20 - | Access them by: - |\x20\x20 - | - attribute access:: - |\x20\x20 - | >>> Color.RED - | - |\x20\x20 - | - value lookup: - |\x20\x20 - | >>> Color(1) - | - |\x20\x20 - | - name lookup: - |\x20\x20 - | >>> Color['RED'] - | - |\x20\x20 - | Enumerations can be iterated over, and know how many members they have: - |\x20\x20 - | >>> len(Color) - | 3 - |\x20\x20 - | >>> list(Color) - | [, , ] - |\x20\x20 - | Methods can be added to enumerations, and members can have their own - | attributes -- see the documentation for details. - |\x20\x20 + | Color(*values) + | | Method resolution order: | Color | enum.Enum | builtins.object - |\x20\x20 + | | Data and other attributes defined here: - |\x20\x20 + | | CYAN = - |\x20\x20 + | | MAGENTA = - |\x20\x20 + | | YELLOW = - |\x20\x20 + | | ---------------------------------------------------------------------- | Data descriptors inherited from enum.Enum: - |\x20\x20 + | | name | The name of the Enum member. - |\x20\x20 + | | value | The value of the Enum member. - |\x20\x20 + | | ---------------------------------------------------------------------- | Methods inherited from enum.EnumType: - |\x20\x20 - | __contains__(member) from enum.EnumType - | Return True if member is a member of this enum - | raises TypeError if member is not an enum member - |\x20\x20\x20\x20\x20\x20 - | note: in 3.12 TypeError will no longer be raised, and True will also be - | returned if member is the value of a member in this enum - |\x20\x20 + | + | __contains__(value) from enum.EnumType + | Return True if `value` is in `cls`. + | + | `value` is in `cls` if: + | 1) `value` is a member of `cls`, or + | 2) `value` is the value of one of the `cls`'s members. + | | __getitem__(name) from enum.EnumType | Return the member matching `name`. - |\x20\x20 + | | __iter__() from enum.EnumType | Return members in definition order. - |\x20\x20 + | | __len__() from enum.EnumType | Return the number of members (no aliases) - |\x20\x20 + | | ---------------------------------------------------------------------- - | Data descriptors inherited from enum.EnumType: - |\x20\x20 - | __members__""" + | Readonly properties inherited from enum.EnumType: + | + | __members__ + | Returns a mapping of member name->value. + | + | This mapping lists all enum members, including aliases. Note that this + | is a read-only view of the internal mapping.""" expected_help_output_without_docs = """\ Help on class Color in module %s: class Color(enum.Enum) - | Color(value, names=None, *, module=None, qualname=None, type=None, start=1) - |\x20\x20 + | Color(*values) + | | Method resolution order: | Color | enum.Enum | builtins.object - |\x20\x20 + | | Data and other attributes defined here: - |\x20\x20 + | | YELLOW = - |\x20\x20 + | | MAGENTA = - |\x20\x20 + | | CYAN = - |\x20\x20 + | | ---------------------------------------------------------------------- | Data descriptors inherited from enum.Enum: - |\x20\x20 + | | name - |\x20\x20 + | | value - |\x20\x20 + | | ---------------------------------------------------------------------- | Data descriptors inherited from enum.EnumType: - |\x20\x20 + | | __members__""" class TestStdLib(unittest.TestCase): @@ -4540,6 +4763,8 @@ class Color(Enum): MAGENTA = 2 YELLOW = 3 + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_pydoc(self): # indirectly test __objclass__ if StrEnum.__doc__ is None: @@ -4651,6 +4876,29 @@ def test_inspect_classify_class_attrs(self): if failed: self.fail("result does not equal expected, see print above") + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_inspect_signatures(self): + from inspect import signature, Signature, Parameter + self.assertEqual( + signature(Enum), + Signature([ + Parameter('new_class_name', Parameter.POSITIONAL_ONLY), + Parameter('names', Parameter.POSITIONAL_OR_KEYWORD), + Parameter('module', Parameter.KEYWORD_ONLY, default=None), + Parameter('qualname', Parameter.KEYWORD_ONLY, default=None), + Parameter('type', Parameter.KEYWORD_ONLY, default=None), + Parameter('start', Parameter.KEYWORD_ONLY, default=1), + Parameter('boundary', Parameter.KEYWORD_ONLY, default=None), + ]), + ) + self.assertEqual( + signature(enum.FlagBoundary), + Signature([ + Parameter('values', Parameter.VAR_POSITIONAL), + ]), + ) + # TODO: RUSTPYTHON, len is often/always > 256 @unittest.expectedFailure def test_test_simple_enum(self): @@ -4756,11 +5004,6 @@ class Quadruple(Enum): COMPLEX_A = 2j COMPLEX_B = 3j -class _ModuleWrapper: - """We use this class as a namespace for swapping modules.""" - def __init__(self, module): - self.__dict__.update(module.__dict__) - class TestConvert(unittest.TestCase): def tearDown(self): # Reset the module-level test variables to their original integer @@ -4800,12 +5043,6 @@ def test_convert_int(self): self.assertEqual(test_type.CONVERT_TEST_NAME_D, 5) self.assertEqual(test_type.CONVERT_TEST_NAME_E, 5) # Ensure that test_type only picked up names matching the filter. - int_dir = dir(int) + [ - 'CONVERT_TEST_NAME_A', 'CONVERT_TEST_NAME_B', 'CONVERT_TEST_NAME_C', - 'CONVERT_TEST_NAME_D', 'CONVERT_TEST_NAME_E', 'CONVERT_TEST_NAME_F', - 'CONVERT_TEST_SIGABRT', 'CONVERT_TEST_SIGIOT', - 'CONVERT_TEST_EIO', 'CONVERT_TEST_EBUS', - ] extra = [name for name in dir(test_type) if name not in enum_dir(test_type)] missing = [name for name in enum_dir(test_type) if name not in dir(test_type)] self.assertEqual( @@ -4847,7 +5084,6 @@ def test_convert_str(self): self.assertEqual(test_type.CONVERT_STR_TEST_1, 'hello') self.assertEqual(test_type.CONVERT_STR_TEST_2, 'goodbye') # Ensure that test_type only picked up names matching the filter. - str_dir = dir(str) + ['CONVERT_STR_TEST_1', 'CONVERT_STR_TEST_2'] extra = [name for name in dir(test_type) if name not in enum_dir(test_type)] missing = [name for name in enum_dir(test_type) if name not in dir(test_type)] self.assertEqual( @@ -4915,8 +5151,6 @@ def member_dir(member): allowed.add(name) return sorted(allowed) -missing = object() - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 71748df8ed..fb2dcf7a51 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -13,19 +13,29 @@ import typing import unittest import unittest.mock +import weakref +import gc from weakref import proxy import contextlib +from inspect import Signature from test.support import import_helper from test.support import threading_helper import functools -py_functools = import_helper.import_fresh_module('functools', blocked=['_functools']) -c_functools = import_helper.import_fresh_module('functools', fresh=['_functools']) +py_functools = import_helper.import_fresh_module('functools', + blocked=['_functools']) +c_functools = import_helper.import_fresh_module('functools', + fresh=['_functools']) decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal']) +_partial_types = [py_functools.partial] +if c_functools: + _partial_types.append(c_functools.partial) + + @contextlib.contextmanager def replaced_module(name, replacement): original_module = sys.modules[name] @@ -162,6 +172,7 @@ def test_weakref(self): p = proxy(f) self.assertEqual(f.func, p.func) f = None + support.gc_collect() # For PyPy or other GCs. self.assertRaises(ReferenceError, getattr, p, 'func') def test_with_bound_and_unbound_methods(self): @@ -196,7 +207,7 @@ def test_repr(self): kwargs = {'a': object(), 'b': object()} kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 'b={b!r}, a={a!r}'.format_map(kwargs)] - if self.partial in (c_functools.partial, py_functools.partial): + if self.partial in _partial_types: name = 'functools.partial' else: name = self.partial.__name__ @@ -218,7 +229,7 @@ def test_repr(self): for kwargs_repr in kwargs_reprs]) def test_recursive_repr(self): - if self.partial in (c_functools.partial, py_functools.partial): + if self.partial in _partial_types: name = 'functools.partial' else: name = self.partial.__name__ @@ -245,7 +256,7 @@ def test_recursive_repr(self): f.__setstate__((capture, (), {}, {})) def test_pickle(self): - with self.AllowPickle(): + with replaced_module('functools', self.module): f = self.partial(signature, ['asdf'], bar=[True]) f.attr = [] for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -328,7 +339,7 @@ def test_setstate_subclasses(self): self.assertIs(type(r[0]), tuple) def test_recursive_pickle(self): - with self.AllowPickle(): + with replaced_module('functools', self.module): f = self.partial(capture) f.__setstate__((f, (), {}, {})) try: @@ -382,24 +393,9 @@ def __getitem__(self, key): @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestPartialC(TestPartial, unittest.TestCase): if c_functools: + module = c_functools partial = c_functools.partial - class AllowPickle: - def __enter__(self): - return self - def __exit__(self, type, value, tb): - return False - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_pickle(self): - super().test_pickle() - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_recursive_pickle(self): - super().test_recursive_pickle() - # TODO: RUSTPYTHON @unittest.expectedFailure def test_attributes_unwritable(self): @@ -444,15 +440,9 @@ def __str__(self): class TestPartialPy(TestPartial, unittest.TestCase): + module = py_functools partial = py_functools.partial - class AllowPickle: - def __init__(self): - self._cm = replaced_module("functools", py_functools) - def __enter__(self): - return self._cm.__enter__() - def __exit__(self, type, value, tb): - return self._cm.__exit__(type, value, tb) if c_functools: class CPartialSubclass(c_functools.partial): @@ -579,11 +569,9 @@ class B(object): with self.assertRaises(TypeError): class B: method = functools.partialmethod() - with self.assertWarns(DeprecationWarning): + with self.assertRaises(TypeError): class B: method = functools.partialmethod(func=capture, a=1) - b = B() - self.assertEqual(b.method(2, x=3), ((b, 2), {'a': 1, 'x': 3})) def test_repr(self): self.assertEqual(repr(vars(self.A)['both']), @@ -634,6 +622,8 @@ def check_wrapper(self, wrapper, wrapped, def _default_update(self): + # XXX: RUSTPYTHON; f[T] is not supported yet + # def f[T](a:'This is a new annotation'): def f(a:'This is a new annotation'): """This is a test""" pass @@ -644,15 +634,19 @@ def wrapper(b:'This is the prior annotation'): functools.update_wrapper(wrapper, f) return wrapper, f + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_default_update(self): wrapper, f = self._default_update() self.check_wrapper(wrapper, f) + T, = f.__type_params__ self.assertIs(wrapper.__wrapped__, f) self.assertEqual(wrapper.__name__, 'f') self.assertEqual(wrapper.__qualname__, f.__qualname__) self.assertEqual(wrapper.attr, 'This is also a test') self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') self.assertNotIn('b', wrapper.__annotations__) + self.assertEqual(wrapper.__type_params__, (T,)) @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") @@ -959,6 +953,10 @@ def mycmp(x, y): self.assertRaises(TypeError, hash, k) self.assertNotIsInstance(k, collections.abc.Hashable) + def test_cmp_to_signature(self): + self.assertEqual(str(Signature.from_callable(self.cmp_to_key)), + '(mycmp)') + @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): @@ -1000,6 +998,18 @@ def test_sort_int(self): def test_sort_int_str(self): super().test_sort_int_str() + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_cmp_to_signature(self): + super().test_cmp_to_signature() + + @support.cpython_only + def test_disallow_instantiation(self): + # Ensure that the type disallows instantiation (bpo-43916) + support.check_disallow_instantiation( + self, type(c_functools.cmp_to_key(None)) + ) + class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): cmp_to_key = staticmethod(py_functools.cmp_to_key) @@ -1093,6 +1103,73 @@ def test_no_operations_defined(self): class A: pass + def test_notimplemented(self): + # Verify NotImplemented results are correctly handled + @functools.total_ordering + class ImplementsLessThan: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsLessThan): + return self.value == other.value + return False + def __lt__(self, other): + if isinstance(other, ImplementsLessThan): + return self.value < other.value + return NotImplemented + + @functools.total_ordering + class ImplementsLessThanEqualTo: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsLessThanEqualTo): + return self.value == other.value + return False + def __le__(self, other): + if isinstance(other, ImplementsLessThanEqualTo): + return self.value <= other.value + return NotImplemented + + @functools.total_ordering + class ImplementsGreaterThan: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsGreaterThan): + return self.value == other.value + return False + def __gt__(self, other): + if isinstance(other, ImplementsGreaterThan): + return self.value > other.value + return NotImplemented + + @functools.total_ordering + class ImplementsGreaterThanEqualTo: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsGreaterThanEqualTo): + return self.value == other.value + return False + def __ge__(self, other): + if isinstance(other, ImplementsGreaterThanEqualTo): + return self.value >= other.value + return NotImplemented + + self.assertIs(ImplementsLessThan(1).__le__(1), NotImplemented) + self.assertIs(ImplementsLessThan(1).__gt__(1), NotImplemented) + self.assertIs(ImplementsLessThan(1).__ge__(1), NotImplemented) + self.assertIs(ImplementsLessThanEqualTo(1).__lt__(1), NotImplemented) + self.assertIs(ImplementsLessThanEqualTo(1).__gt__(1), NotImplemented) + self.assertIs(ImplementsLessThanEqualTo(1).__ge__(1), NotImplemented) + self.assertIs(ImplementsGreaterThan(1).__lt__(1), NotImplemented) + self.assertIs(ImplementsGreaterThan(1).__gt__(1), NotImplemented) + self.assertIs(ImplementsGreaterThan(1).__ge__(1), NotImplemented) + self.assertIs(ImplementsGreaterThanEqualTo(1).__lt__(1), NotImplemented) + self.assertIs(ImplementsGreaterThanEqualTo(1).__le__(1), NotImplemented) + self.assertIs(ImplementsGreaterThanEqualTo(1).__gt__(1), NotImplemented) + def test_type_error_when_not_implemented(self): # bug 10042; ensure stack overflow does not occur # when decorated types return NotImplemented @@ -1208,6 +1285,34 @@ def test_pickle(self): method_copy = pickle.loads(pickle.dumps(method, proto)) self.assertIs(method_copy, method) + + def test_total_ordering_for_metaclasses_issue_44605(self): + + @functools.total_ordering + class SortableMeta(type): + def __new__(cls, name, bases, ns): + return super().__new__(cls, name, bases, ns) + + def __lt__(self, other): + if not isinstance(other, SortableMeta): + pass + return self.__name__ < other.__name__ + + def __eq__(self, other): + if not isinstance(other, SortableMeta): + pass + return self.__name__ == other.__name__ + + class B(metaclass=SortableMeta): + pass + + class A(metaclass=SortableMeta): + pass + + self.assertTrue(A < B) + self.assertFalse(A > B) + + @functools.total_ordering class Orderable_LT: def __init__(self, value): @@ -1218,6 +1323,25 @@ def __eq__(self, other): return self.value == other.value +class TestCache: + # This tests that the pass-through is working as designed. + # The underlying functionality is tested in TestLRU. + + def test_cache(self): + @self.module.cache + def fib(n): + if n < 2: + return n + return fib(n-1) + fib(n-2) + self.assertEqual([fib(n) for n in range(16)], + [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) + self.assertEqual(fib.cache_info(), + self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) + fib.cache_clear() + self.assertEqual(fib.cache_info(), + self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) + + class TestLRU: def test_lru(self): @@ -1411,7 +1535,7 @@ def test_lru_reentrancy_with_len(self): def test_lru_star_arg_handling(self): # Test regression that arose in ea064ff3c10f - @functools.lru_cache() + @self.module.lru_cache() def f(*args): return args @@ -1423,11 +1547,11 @@ def test_lru_type_error(self): # lru_cache was leaking when one of the arguments # wasn't cacheable. - @functools.lru_cache(maxsize=None) + @self.module.lru_cache(maxsize=None) def infinite_cache(o): pass - @functools.lru_cache(maxsize=10) + @self.module.lru_cache(maxsize=10) def limited_cache(o): pass @@ -1492,6 +1616,33 @@ def square(x): self.assertEqual(square.cache_info().hits, 4) self.assertEqual(square.cache_info().misses, 4) + def test_lru_cache_typed_is_not_recursive(self): + cached = self.module.lru_cache(typed=True)(repr) + + self.assertEqual(cached(1), '1') + self.assertEqual(cached(True), 'True') + self.assertEqual(cached(1.0), '1.0') + self.assertEqual(cached(0), '0') + self.assertEqual(cached(False), 'False') + self.assertEqual(cached(0.0), '0.0') + + self.assertEqual(cached((1,)), '(1,)') + self.assertEqual(cached((True,)), '(1,)') + self.assertEqual(cached((1.0,)), '(1,)') + self.assertEqual(cached((0,)), '(0,)') + self.assertEqual(cached((False,)), '(0,)') + self.assertEqual(cached((0.0,)), '(0,)') + + class T(tuple): + pass + + self.assertEqual(cached(T((1,))), '(1,)') + self.assertEqual(cached(T((True,))), '(1,)') + self.assertEqual(cached(T((1.0,))), '(1,)') + self.assertEqual(cached(T((0,))), '(0,)') + self.assertEqual(cached(T((False,))), '(0,)') + self.assertEqual(cached(T((0.0,))), '(0,)') + def test_lru_with_keyword_args(self): @self.module.lru_cache() def fib(n): @@ -1542,6 +1693,7 @@ def f(zomg: 'zomg_annotation'): # TODO: RUSTPYTHON @unittest.expectedFailure + @threading_helper.requires_working_threading() def test_lru_cache_threaded(self): n, m = 5, 11 def orig(x, y): @@ -1590,6 +1742,7 @@ def clear(): finally: sys.setswitchinterval(orig_si) + @threading_helper.requires_working_threading() def test_lru_cache_threaded2(self): # Simultaneous call with the same arguments n, m = 5, 7 @@ -1617,6 +1770,7 @@ def test(): pause.reset() self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) + @threading_helper.requires_working_threading() def test_lru_cache_threaded3(self): @self.module.lru_cache(maxsize=2) def f(x): @@ -1717,14 +1871,62 @@ def orig(x, y): f_copy = copy.deepcopy(f) self.assertIs(f_copy, f) + def test_lru_cache_parameters(self): + @self.module.lru_cache(maxsize=2) + def f(): + return 1 + self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False}) + + @self.module.lru_cache(maxsize=1000, typed=True) + def f(): + return 1 + self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True}) + + def test_lru_cache_weakrefable(self): + @self.module.lru_cache + def test_function(x): + return x + + class A: + @self.module.lru_cache + def test_method(self, x): + return (self, x) + + @staticmethod + @self.module.lru_cache + def test_staticmethod(x): + return (self, x) + + refs = [weakref.ref(test_function), + weakref.ref(A.test_method), + weakref.ref(A.test_staticmethod)] + + for ref in refs: + self.assertIsNotNone(ref()) + + del A + del test_function + gc.collect() + + for ref in refs: + self.assertIsNone(ref()) + + def test_common_signatures(self): + def orig(): ... + lru = self.module.lru_cache(1)(orig) + + self.assertEqual(str(Signature.from_callable(lru.cache_info)), '()') + self.assertEqual(str(Signature.from_callable(lru.cache_clear)), '()') + @py_functools.lru_cache() def py_cached_func(x, y): return 3 * x + y -@c_functools.lru_cache() -def c_cached_func(x, y): - return 3 * x + y +if c_functools: + @c_functools.lru_cache() + def c_cached_func(x, y): + return 3 * x + y class TestLRUPy(TestLRU, unittest.TestCase): @@ -1741,18 +1943,20 @@ def cached_staticmeth(x, y): return 3 * x + y +@unittest.skipUnless(c_functools, 'requires the C _functools module') class TestLRUC(TestLRU, unittest.TestCase): - module = c_functools - cached_func = c_cached_func, + if c_functools: + module = c_functools + cached_func = c_cached_func, - @module.lru_cache() - def cached_meth(self, x, y): - return 3 * x + y + @module.lru_cache() + def cached_meth(self, x, y): + return 3 * x + y - @staticmethod - @module.lru_cache() - def cached_staticmeth(x, y): - return 3 * x + y + @staticmethod + @module.lru_cache() + def cached_staticmeth(x, y): + return 3 * x + y class TestSingleDispatch(unittest.TestCase): @@ -1867,7 +2071,7 @@ class D(collections.defaultdict): c.MutableSequence.register(D) bases = [c.MutableSequence, c.MutableMapping] for haystack in permutations(bases): - m = mro(D, bases) + m = mro(D, haystack) self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, collections.defaultdict, dict, c.MutableMapping, c.Mapping, c.Collection, c.Sized, c.Iterable, c.Container, @@ -2370,7 +2574,7 @@ def _(cls, arg): self.assertEqual(A.t(0.0).arg, "base") def test_abstractmethod_register(self): - class Abstract(abc.ABCMeta): + class Abstract(metaclass=abc.ABCMeta): @functools.singledispatchmethod @abc.abstractmethod @@ -2378,6 +2582,10 @@ def add(self, x, y): pass self.assertTrue(Abstract.add.__isabstractmethod__) + self.assertTrue(Abstract.__dict__['add'].__isabstractmethod__) + + with self.assertRaises(TypeError): + Abstract() def test_type_ann_register(self): class A: @@ -2396,6 +2604,183 @@ def _(self, arg: str): self.assertEqual(a.t(''), "str") self.assertEqual(a.t(0.0), "base") + def test_staticmethod_type_ann_register(self): + class A: + @functools.singledispatchmethod + @staticmethod + def t(arg): + return arg + @t.register + @staticmethod + def _(arg: int): + return isinstance(arg, int) + @t.register + @staticmethod + def _(arg: str): + return isinstance(arg, str) + a = A() + + self.assertTrue(A.t(0)) + self.assertTrue(A.t('')) + self.assertEqual(A.t(0.0), 0.0) + + def test_classmethod_type_ann_register(self): + class A: + def __init__(self, arg): + self.arg = arg + + @functools.singledispatchmethod + @classmethod + def t(cls, arg): + return cls("base") + @t.register + @classmethod + def _(cls, arg: int): + return cls("int") + @t.register + @classmethod + def _(cls, arg: str): + return cls("str") + + self.assertEqual(A.t(0).arg, "int") + self.assertEqual(A.t('').arg, "str") + self.assertEqual(A.t(0.0).arg, "base") + + def test_method_wrapping_attributes(self): + class A: + @functools.singledispatchmethod + def func(self, arg: int) -> str: + """My function docstring""" + return str(arg) + @functools.singledispatchmethod + @classmethod + def cls_func(cls, arg: int) -> str: + """My function docstring""" + return str(arg) + @functools.singledispatchmethod + @staticmethod + def static_func(arg: int) -> str: + """My function docstring""" + return str(arg) + + for meth in ( + A.func, + A().func, + A.cls_func, + A().cls_func, + A.static_func, + A().static_func + ): + with self.subTest(meth=meth): + self.assertEqual(meth.__doc__, 'My function docstring') + self.assertEqual(meth.__annotations__['arg'], int) + + self.assertEqual(A.func.__name__, 'func') + self.assertEqual(A().func.__name__, 'func') + self.assertEqual(A.cls_func.__name__, 'cls_func') + self.assertEqual(A().cls_func.__name__, 'cls_func') + self.assertEqual(A.static_func.__name__, 'static_func') + self.assertEqual(A().static_func.__name__, 'static_func') + + def test_double_wrapped_methods(self): + def classmethod_friendly_decorator(func): + wrapped = func.__func__ + @classmethod + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + return wrapped(*args, **kwargs) + return wrapper + + class WithoutSingleDispatch: + @classmethod + @contextlib.contextmanager + def cls_context_manager(cls, arg: int) -> str: + try: + yield str(arg) + finally: + return 'Done' + + @classmethod_friendly_decorator + @classmethod + def decorated_classmethod(cls, arg: int) -> str: + return str(arg) + + class WithSingleDispatch: + @functools.singledispatchmethod + @classmethod + @contextlib.contextmanager + def cls_context_manager(cls, arg: int) -> str: + """My function docstring""" + try: + yield str(arg) + finally: + return 'Done' + + @functools.singledispatchmethod + @classmethod_friendly_decorator + @classmethod + def decorated_classmethod(cls, arg: int) -> str: + """My function docstring""" + return str(arg) + + # These are sanity checks + # to test the test itself is working as expected + with WithoutSingleDispatch.cls_context_manager(5) as foo: + without_single_dispatch_foo = foo + + with WithSingleDispatch.cls_context_manager(5) as foo: + single_dispatch_foo = foo + + self.assertEqual(without_single_dispatch_foo, single_dispatch_foo) + self.assertEqual(single_dispatch_foo, '5') + + self.assertEqual( + WithoutSingleDispatch.decorated_classmethod(5), + WithSingleDispatch.decorated_classmethod(5) + ) + + self.assertEqual(WithSingleDispatch.decorated_classmethod(5), '5') + + # Behavioural checks now follow + for method_name in ('cls_context_manager', 'decorated_classmethod'): + with self.subTest(method=method_name): + self.assertEqual( + getattr(WithSingleDispatch, method_name).__name__, + getattr(WithoutSingleDispatch, method_name).__name__ + ) + + self.assertEqual( + getattr(WithSingleDispatch(), method_name).__name__, + getattr(WithoutSingleDispatch(), method_name).__name__ + ) + + for meth in ( + WithSingleDispatch.cls_context_manager, + WithSingleDispatch().cls_context_manager, + WithSingleDispatch.decorated_classmethod, + WithSingleDispatch().decorated_classmethod + ): + with self.subTest(meth=meth): + self.assertEqual(meth.__doc__, 'My function docstring') + self.assertEqual(meth.__annotations__['arg'], int) + + self.assertEqual( + WithSingleDispatch.cls_context_manager.__name__, + 'cls_context_manager' + ) + self.assertEqual( + WithSingleDispatch().cls_context_manager.__name__, + 'cls_context_manager' + ) + self.assertEqual( + WithSingleDispatch.decorated_classmethod.__name__, + 'decorated_classmethod' + ) + self.assertEqual( + WithSingleDispatch().decorated_classmethod.__name__, + 'decorated_classmethod' + ) + def test_invalid_registrations(self): msg_prefix = "Invalid first argument to `register()`: " msg_suffix = ( @@ -2435,6 +2820,17 @@ def _(arg: typing.Iterable[str]): 'typing.Iterable[str] is not a class.' )) + with self.assertRaises(TypeError) as exc: + @i.register + def _(arg: typing.Union[int, typing.Iterable[str]]): + return "Invalid Union" + self.assertTrue(str(exc.exception).startswith( + "Invalid annotation for 'arg'." + )) + self.assertTrue(str(exc.exception).endswith( + 'typing.Union[int, typing.Iterable[str]] not all arguments are classes.' + )) + def test_invalid_positional_argument(self): @functools.singledispatch def f(*args): @@ -2443,6 +2839,134 @@ def f(*args): with self.assertRaisesRegex(TypeError, msg): f() + def test_union(self): + @functools.singledispatch + def f(arg): + return "default" + + @f.register + def _(arg: typing.Union[str, bytes]): + return "typing.Union" + + @f.register + def _(arg: int | float): + return "types.UnionType" + + self.assertEqual(f([]), "default") + self.assertEqual(f(""), "typing.Union") + self.assertEqual(f(b""), "typing.Union") + self.assertEqual(f(1), "types.UnionType") + self.assertEqual(f(1.0), "types.UnionType") + + def test_union_conflict(self): + @functools.singledispatch + def f(arg): + return "default" + + @f.register + def _(arg: typing.Union[str, bytes]): + return "typing.Union" + + @f.register + def _(arg: int | str): + return "types.UnionType" + + self.assertEqual(f([]), "default") + self.assertEqual(f(""), "types.UnionType") # last one wins + self.assertEqual(f(b""), "typing.Union") + self.assertEqual(f(1), "types.UnionType") + + def test_union_None(self): + @functools.singledispatch + def typing_union(arg): + return "default" + + @typing_union.register + def _(arg: typing.Union[str, None]): + return "typing.Union" + + self.assertEqual(typing_union(1), "default") + self.assertEqual(typing_union(""), "typing.Union") + self.assertEqual(typing_union(None), "typing.Union") + + @functools.singledispatch + def types_union(arg): + return "default" + + @types_union.register + def _(arg: int | None): + return "types.UnionType" + + self.assertEqual(types_union(""), "default") + self.assertEqual(types_union(1), "types.UnionType") + self.assertEqual(types_union(None), "types.UnionType") + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_register_genericalias(self): + @functools.singledispatch + def f(arg): + return "default" + + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(list[int], lambda arg: "types.GenericAlias") + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(typing.List[int], lambda arg: "typing.GenericAlias") + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)") + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]") + + self.assertEqual(f([1]), "default") + self.assertEqual(f([1.0]), "default") + self.assertEqual(f(""), "default") + self.assertEqual(f(b""), "default") + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_register_genericalias_decorator(self): + @functools.singledispatch + def f(arg): + return "default" + + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(list[int]) + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(typing.List[int]) + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(list[int] | str) + with self.assertRaisesRegex(TypeError, "Invalid first argument to "): + f.register(typing.List[int] | str) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_register_genericalias_annotation(self): + @functools.singledispatch + def f(arg): + return "default" + + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + @f.register + def _(arg: list[int]): + return "types.GenericAlias" + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + @f.register + def _(arg: typing.List[float]): + return "typing.GenericAlias" + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + @f.register + def _(arg: list[int] | str): + return "types.UnionType(types.GenericAlias)" + with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"): + @f.register + def _(arg: typing.List[float] | bytes): + return "typing.Union[typing.GenericAlias]" + + self.assertEqual(f([1]), "default") + self.assertEqual(f([1.0]), "default") + self.assertEqual(f(""), "default") + self.assertEqual(f(b""), "default") + class CachedCostItem: _cost = 1 @@ -2469,21 +2993,6 @@ def get_cost(self): cached_cost = py_functools.cached_property(get_cost) -class CachedCostItemWait: - - def __init__(self, event): - self._cost = 1 - self.lock = py_functools.RLock() - self.event = event - - @py_functools.cached_property - def cost(self): - self.event.wait(1) - with self.lock: - self._cost += 1 - return self._cost - - class CachedCostItemWithSlots: __slots__ = ('_cost') @@ -2508,28 +3017,6 @@ def test_cached_attribute_name_differs_from_func_name(self): self.assertEqual(item.get_cost(), 4) self.assertEqual(item.cached_cost, 3) - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_threaded(self): - go = threading.Event() - item = CachedCostItemWait(go) - - num_threads = 3 - - orig_si = sys.getswitchinterval() - sys.setswitchinterval(1e-6) - try: - threads = [ - threading.Thread(target=lambda: item.cost) - for k in range(num_threads) - ] - with threading_helper.start_threads(threads): - go.set() - finally: - sys.setswitchinterval(orig_si) - - self.assertEqual(item.cost, 2) - # TODO: RUSTPYTHON @unittest.expectedFailure def test_object_with_slots(self): @@ -2559,7 +3046,7 @@ class MyClass(metaclass=MyMeta): @unittest.expectedFailure def test_reuse_different_names(self): """Disallow this case because decorated function a would not be cached.""" - with self.assertRaises(RuntimeError) as ctx: + with self.assertRaises(TypeError) as ctx: class ReusedCachedProperty: @py_functools.cached_property def a(self): @@ -2568,7 +3055,7 @@ def a(self): b = a self.assertEqual( - str(ctx.exception.__context__), + str(ctx.exception), str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b').")) ) @@ -2614,6 +3101,25 @@ def test_access_from_class(self): def test_doc(self): self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.") + def test_subclass_with___set__(self): + """Caching still works for a subclass defining __set__.""" + class readonly_cached_property(py_functools.cached_property): + def __set__(self, obj, value): + raise AttributeError("read only property") + + class Test: + def __init__(self, prop): + self._prop = prop + + @readonly_cached_property + def prop(self): + return self._prop + + t = Test(1) + self.assertEqual(t.prop, 1) + t._prop = 999 + self.assertEqual(t.prop, 1) + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_importlib/_context.py b/Lib/test/test_importlib/_context.py new file mode 100644 index 0000000000..8a53eb55d1 --- /dev/null +++ b/Lib/test/test_importlib/_context.py @@ -0,0 +1,13 @@ +import contextlib + + +# from jaraco.context 4.3 +class suppress(contextlib.suppress, contextlib.ContextDecorator): + """ + A version of contextlib.suppress with decorator support. + + >>> @suppress(KeyError) + ... def key_error(): + ... {}[''] + >>> key_error() + """ diff --git a/Lib/test/test_importlib/_path.py b/Lib/test/test_importlib/_path.py new file mode 100644 index 0000000000..71a704389b --- /dev/null +++ b/Lib/test/test_importlib/_path.py @@ -0,0 +1,109 @@ +# from jaraco.path 3.5 + +import functools +import pathlib +from typing import Dict, Union + +try: + from typing import Protocol, runtime_checkable +except ImportError: # pragma: no cover + # Python 3.7 + from typing_extensions import Protocol, runtime_checkable # type: ignore + + +FilesSpec = Dict[str, Union[str, bytes, 'FilesSpec']] # type: ignore + + +@runtime_checkable +class TreeMaker(Protocol): + def __truediv__(self, *args, **kwargs): + ... # pragma: no cover + + def mkdir(self, **kwargs): + ... # pragma: no cover + + def write_text(self, content, **kwargs): + ... # pragma: no cover + + def write_bytes(self, content): + ... # pragma: no cover + + +def _ensure_tree_maker(obj: Union[str, TreeMaker]) -> TreeMaker: + return obj if isinstance(obj, TreeMaker) else pathlib.Path(obj) # type: ignore + + +def build( + spec: FilesSpec, + prefix: Union[str, TreeMaker] = pathlib.Path(), # type: ignore +): + """ + Build a set of files/directories, as described by the spec. + + Each key represents a pathname, and the value represents + the content. Content may be a nested directory. + + >>> spec = { + ... 'README.txt': "A README file", + ... "foo": { + ... "__init__.py": "", + ... "bar": { + ... "__init__.py": "", + ... }, + ... "baz.py": "# Some code", + ... } + ... } + >>> target = getfixture('tmp_path') + >>> build(spec, target) + >>> target.joinpath('foo/baz.py').read_text(encoding='utf-8') + '# Some code' + """ + for name, contents in spec.items(): + create(contents, _ensure_tree_maker(prefix) / name) + + +@functools.singledispatch +def create(content: Union[str, bytes, FilesSpec], path): + path.mkdir(exist_ok=True) + build(content, prefix=path) # type: ignore + + +@create.register +def _(content: bytes, path): + path.write_bytes(content) + + +@create.register +def _(content: str, path): + path.write_text(content, encoding='utf-8') + + +@create.register +def _(content: str, path): + path.write_text(content, encoding='utf-8') + + +class Recording: + """ + A TreeMaker object that records everything that would be written. + + >>> r = Recording() + >>> build({'foo': {'foo1.txt': 'yes'}, 'bar.txt': 'abc'}, r) + >>> r.record + ['foo/foo1.txt', 'bar.txt'] + """ + + def __init__(self, loc=pathlib.PurePosixPath(), record=None): + self.loc = loc + self.record = record if record is not None else [] + + def __truediv__(self, other): + return Recording(self.loc / other, self.record) + + def write_text(self, content, **kwargs): + self.record.append(str(self.loc)) + + write_bytes = write_text + + def mkdir(self, **kwargs): + return diff --git a/Lib/test/test_importlib/builtin/test_finder.py b/Lib/test/test_importlib/builtin/test_finder.py index a4869e07b9..111c4af1ea 100644 --- a/Lib/test/test_importlib/builtin/test_finder.py +++ b/Lib/test/test_importlib/builtin/test_finder.py @@ -37,61 +37,11 @@ def test_failure(self): spec = self.machinery.BuiltinImporter.find_spec(name) self.assertIsNone(spec) - def test_ignore_path(self): - # The value for 'path' should always trigger a failed import. - with util.uncache(util.BUILTINS.good_name): - spec = self.machinery.BuiltinImporter.find_spec(util.BUILTINS.good_name, - ['pkg']) - self.assertIsNone(spec) - (Frozen_FindSpecTests, Source_FindSpecTests ) = util.test_both(FindSpecTests, machinery=machinery) -@unittest.skipIf(util.BUILTINS.good_name is None, 'no reasonable builtin module') -class FinderTests(abc.FinderTests): - - """Test find_module() for built-in modules.""" - - def test_module(self): - # Common case. - with util.uncache(util.BUILTINS.good_name): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - found = self.machinery.BuiltinImporter.find_module(util.BUILTINS.good_name) - self.assertTrue(found) - self.assertTrue(hasattr(found, 'load_module')) - - # Built-in modules cannot be a package. - test_package = test_package_in_package = test_package_over_module = None - - # Built-in modules cannot be in a package. - test_module_in_package = None - - def test_failure(self): - assert 'importlib' not in sys.builtin_module_names - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - loader = self.machinery.BuiltinImporter.find_module('importlib') - self.assertIsNone(loader) - - def test_ignore_path(self): - # The value for 'path' should always trigger a failed import. - with util.uncache(util.BUILTINS.good_name): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - loader = self.machinery.BuiltinImporter.find_module( - util.BUILTINS.good_name, - ['pkg']) - self.assertIsNone(loader) - - -(Frozen_FinderTests, - Source_FinderTests - ) = util.test_both(FinderTests, machinery=machinery) - - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_importlib/extension/test_case_sensitivity.py b/Lib/test/test_importlib/extension/test_case_sensitivity.py index 366e565cf4..0bb74fff5f 100644 --- a/Lib/test/test_importlib/extension/test_case_sensitivity.py +++ b/Lib/test/test_importlib/extension/test_case_sensitivity.py @@ -8,7 +8,7 @@ machinery = util.import_importlib('importlib.machinery') -@unittest.skipIf(util.EXTENSIONS.filename is None, '_testcapi not available') +@unittest.skipIf(util.EXTENSIONS.filename is None, f'{util.EXTENSIONS.name} not available') @util.case_insensitive_tests class ExtensionModuleCaseSensitivityTest(util.CASEOKTestBase): diff --git a/Lib/test/test_importlib/extension/test_loader.py b/Lib/test/test_importlib/extension/test_loader.py index 6c5cd577c1..d06558f2ad 100644 --- a/Lib/test/test_importlib/extension/test_loader.py +++ b/Lib/test/test_importlib/extension/test_loader.py @@ -13,9 +13,9 @@ from test.support.script_helper import assert_python_failure -class LoaderTests(abc.LoaderTests): +class LoaderTests: - """Test load_module() for extension modules.""" + """Test ExtensionFileLoader.""" def setUp(self): if not self.machinery.EXTENSION_SUFFIXES: @@ -32,17 +32,6 @@ def load_module(self, fullname): warnings.simplefilter("ignore", DeprecationWarning) return self.loader.load_module(fullname) - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_load_module_API(self): - # Test the default argument for load_module(). - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - self.loader.load_module() - self.loader.load_module(None) - with self.assertRaises(ImportError): - self.load_module('XXX') - def test_equality(self): other = self.machinery.ExtensionFileLoader(util.EXTENSIONS.name, util.EXTENSIONS.file_path) @@ -53,6 +42,15 @@ def test_inequality(self): util.EXTENSIONS.file_path) self.assertNotEqual(self.loader, other) + def test_load_module_API(self): + # Test the default argument for load_module(). + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + self.loader.load_module() + self.loader.load_module(None) + with self.assertRaises(ImportError): + self.load_module('XXX') + # TODO: RUSTPYTHON @unittest.expectedFailure def test_module(self): @@ -72,14 +70,6 @@ def test_module(self): # No extension module in a package available for testing. test_lacking_parent = None - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_module_reuse(self): - with util.uncache(util.EXTENSIONS.name): - module1 = self.load_module(util.EXTENSIONS.name) - module2 = self.load_module(util.EXTENSIONS.name) - self.assertIs(module1, module2) - # No easy way to trigger a failure after a successful import. test_state_after_failure = None @@ -89,6 +79,12 @@ def test_unloadable(self): self.load_module(name) self.assertEqual(cm.exception.name, name) + def test_module_reuse(self): + with util.uncache(util.EXTENSIONS.name): + module1 = self.load_module(util.EXTENSIONS.name) + module2 = self.load_module(util.EXTENSIONS.name) + self.assertIs(module1, module2) + # TODO: RUSTPYTHON @unittest.expectedFailure def test_is_package(self): @@ -98,11 +94,94 @@ def test_is_package(self): loader = self.machinery.ExtensionFileLoader('pkg', path) self.assertTrue(loader.is_package('pkg')) + (Frozen_LoaderTests, Source_LoaderTests ) = util.test_both(LoaderTests, machinery=machinery) -@unittest.skip("TODO: RUSTPYTHON, AssertionError") + +class SinglePhaseExtensionModuleTests(abc.LoaderTests): + # Test loading extension modules without multi-phase initialization. + + def setUp(self): + if not self.machinery.EXTENSION_SUFFIXES: + raise unittest.SkipTest("Requires dynamic loading support.") + self.name = '_testsinglephase' + if self.name in sys.builtin_module_names: + raise unittest.SkipTest( + f"{self.name} is a builtin module" + ) + finder = self.machinery.FileFinder(None) + self.spec = importlib.util.find_spec(self.name) + assert self.spec + self.loader = self.machinery.ExtensionFileLoader( + self.name, self.spec.origin) + + def load_module(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return self.loader.load_module(self.name) + + def load_module_by_name(self, fullname): + # Load a module from the test extension by name. + origin = self.spec.origin + loader = self.machinery.ExtensionFileLoader(fullname, origin) + spec = importlib.util.spec_from_loader(fullname, loader) + module = importlib.util.module_from_spec(spec) + loader.exec_module(module) + return module + + def test_module(self): + # Test loading an extension module. + with util.uncache(self.name): + module = self.load_module() + for attr, value in [('__name__', self.name), + ('__file__', self.spec.origin), + ('__package__', '')]: + self.assertEqual(getattr(module, attr), value) + with self.assertRaises(AttributeError): + module.__path__ + self.assertIs(module, sys.modules[self.name]) + self.assertIsInstance(module.__loader__, + self.machinery.ExtensionFileLoader) + + # No extension module as __init__ available for testing. + test_package = None + + # No extension module in a package available for testing. + test_lacking_parent = None + + # No easy way to trigger a failure after a successful import. + test_state_after_failure = None + + def test_unloadable(self): + name = 'asdfjkl;' + with self.assertRaises(ImportError) as cm: + self.load_module_by_name(name) + self.assertEqual(cm.exception.name, name) + + def test_unloadable_nonascii(self): + # Test behavior with nonexistent module with non-ASCII name. + name = 'fo\xf3' + with self.assertRaises(ImportError) as cm: + self.load_module_by_name(name) + self.assertEqual(cm.exception.name, name) + + # It may make sense to add the equivalent to + # the following MultiPhaseExtensionModuleTests tests: + # + # * test_nonmodule + # * test_nonmodule_with_methods + # * test_bad_modules + # * test_nonascii + + +(Frozen_SinglePhaseExtensionModuleTests, + Source_SinglePhaseExtensionModuleTests + ) = util.test_both(SinglePhaseExtensionModuleTests, machinery=machinery) + + +# @unittest.skip("TODO: RUSTPYTHON, AssertionError") class MultiPhaseExtensionModuleTests(abc.LoaderTests): # Test loading extension modules with multi-phase initialization (PEP 489). @@ -188,15 +267,16 @@ def test_reload(self): def test_try_registration(self): # Assert that the PyState_{Find,Add,Remove}Module C API doesn't work. - module = self.load_module() - with self.subTest('PyState_FindModule'): - self.assertEqual(module.call_state_registration_func(0), None) - with self.subTest('PyState_AddModule'): - with self.assertRaises(SystemError): - module.call_state_registration_func(1) - with self.subTest('PyState_RemoveModule'): - with self.assertRaises(SystemError): - module.call_state_registration_func(2) + with util.uncache(self.name): + module = self.load_module() + with self.subTest('PyState_FindModule'): + self.assertEqual(module.call_state_registration_func(0), None) + with self.subTest('PyState_AddModule'): + with self.assertRaises(SystemError): + module.call_state_registration_func(1) + with self.subTest('PyState_RemoveModule'): + with self.assertRaises(SystemError): + module.call_state_registration_func(2) def test_load_submodule(self): # Test loading a simulated submodule. @@ -274,12 +354,19 @@ def test_bad_modules(self): 'exec_err', 'exec_raise', 'exec_unreported_exception', + 'multiple_create_slots', + 'multiple_multiple_interpreters_slots', ]: with self.subTest(name_base): name = self.name + '_' + name_base - with self.assertRaises(SystemError): + with self.assertRaises(SystemError) as cm: self.load_module_by_name(name) + # If there is an unreported exception, it should be chained + # with the `SystemError`. + if "unreported_exception" in name_base: + self.assertIsNotNone(cm.exception.__cause__) + def test_nonascii(self): # Test that modules with non-ASCII names can be loaded. # punycode behaves slightly differently in some-ASCII and no-ASCII diff --git a/Lib/test/test_importlib/extension/test_path_hook.py b/Lib/test/test_importlib/extension/test_path_hook.py index a0adc70ad1..ec9644dc52 100644 --- a/Lib/test/test_importlib/extension/test_path_hook.py +++ b/Lib/test/test_importlib/extension/test_path_hook.py @@ -19,7 +19,7 @@ def hook(self, entry): def test_success(self): # Path hook should handle a directory where a known extension module # exists. - self.assertTrue(hasattr(self.hook(util.EXTENSIONS.path), 'find_module')) + self.assertTrue(hasattr(self.hook(util.EXTENSIONS.path), 'find_spec')) (Frozen_PathHooksTests, diff --git a/Lib/test/test_importlib/fixtures.py b/Lib/test/test_importlib/fixtures.py index e7be77b395..73e5da2ba9 100644 --- a/Lib/test/test_importlib/fixtures.py +++ b/Lib/test/test_importlib/fixtures.py @@ -10,7 +10,10 @@ from test.support.os_helper import FS_NONASCII from test.support import requires_zlib -from typing import Dict, Union + +from . import _path +from ._path import FilesSpec + try: from importlib import resources # type: ignore @@ -83,13 +86,8 @@ def setUp(self): self.fixtures.enter_context(self.add_sys_path(self.site_dir)) -# Except for python/mypy#731, prefer to define -# FilesDef = Dict[str, Union['FilesDef', str]] -FilesDef = Dict[str, Union[Dict[str, Union[Dict[str, str], str]], str]] - - class DistInfoPkg(OnSysPath, SiteDir): - files: FilesDef = { + files: FilesSpec = { "distinfo_pkg-1.0.0.dist-info": { "METADATA": """ Name: distinfo-pkg @@ -131,7 +129,7 @@ def make_uppercase(self): class DistInfoPkgWithDot(OnSysPath, SiteDir): - files: FilesDef = { + files: FilesSpec = { "pkg_dot-1.0.0.dist-info": { "METADATA": """ Name: pkg.dot @@ -146,7 +144,7 @@ def setUp(self): class DistInfoPkgWithDotLegacy(OnSysPath, SiteDir): - files: FilesDef = { + files: FilesSpec = { "pkg.dot-1.0.0.dist-info": { "METADATA": """ Name: pkg.dot @@ -173,7 +171,7 @@ def setUp(self): class EggInfoPkg(OnSysPath, SiteDir): - files: FilesDef = { + files: FilesSpec = { "egginfo_pkg.egg-info": { "PKG-INFO": """ Name: egginfo-pkg @@ -212,8 +210,99 @@ def setUp(self): build_files(EggInfoPkg.files, prefix=self.site_dir) +class EggInfoPkgPipInstalledNoToplevel(OnSysPath, SiteDir): + files: FilesSpec = { + "egg_with_module_pkg.egg-info": { + "PKG-INFO": "Name: egg_with_module-pkg", + # SOURCES.txt is made from the source archive, and contains files + # (setup.py) that are not present after installation. + "SOURCES.txt": """ + egg_with_module.py + setup.py + egg_with_module_pkg.egg-info/PKG-INFO + egg_with_module_pkg.egg-info/SOURCES.txt + egg_with_module_pkg.egg-info/top_level.txt + """, + # installed-files.txt is written by pip, and is a strictly more + # accurate source than SOURCES.txt as to the installed contents of + # the package. + "installed-files.txt": """ + ../egg_with_module.py + PKG-INFO + SOURCES.txt + top_level.txt + """, + # missing top_level.txt (to trigger fallback to installed-files.txt) + }, + "egg_with_module.py": """ + def main(): + print("hello world") + """, + } + + def setUp(self): + super().setUp() + build_files(EggInfoPkgPipInstalledNoToplevel.files, prefix=self.site_dir) + + +class EggInfoPkgPipInstalledNoModules(OnSysPath, SiteDir): + files: FilesSpec = { + "egg_with_no_modules_pkg.egg-info": { + "PKG-INFO": "Name: egg_with_no_modules-pkg", + # SOURCES.txt is made from the source archive, and contains files + # (setup.py) that are not present after installation. + "SOURCES.txt": """ + setup.py + egg_with_no_modules_pkg.egg-info/PKG-INFO + egg_with_no_modules_pkg.egg-info/SOURCES.txt + egg_with_no_modules_pkg.egg-info/top_level.txt + """, + # installed-files.txt is written by pip, and is a strictly more + # accurate source than SOURCES.txt as to the installed contents of + # the package. + "installed-files.txt": """ + PKG-INFO + SOURCES.txt + top_level.txt + """, + # top_level.txt correctly reflects that no modules are installed + "top_level.txt": b"\n", + }, + } + + def setUp(self): + super().setUp() + build_files(EggInfoPkgPipInstalledNoModules.files, prefix=self.site_dir) + + +class EggInfoPkgSourcesFallback(OnSysPath, SiteDir): + files: FilesSpec = { + "sources_fallback_pkg.egg-info": { + "PKG-INFO": "Name: sources_fallback-pkg", + # SOURCES.txt is made from the source archive, and contains files + # (setup.py) that are not present after installation. + "SOURCES.txt": """ + sources_fallback.py + setup.py + sources_fallback_pkg.egg-info/PKG-INFO + sources_fallback_pkg.egg-info/SOURCES.txt + """, + # missing installed-files.txt (i.e. not installed by pip) and + # missing top_level.txt (to trigger fallback to SOURCES.txt) + }, + "sources_fallback.py": """ + def main(): + print("hello world") + """, + } + + def setUp(self): + super().setUp() + build_files(EggInfoPkgSourcesFallback.files, prefix=self.site_dir) + + class EggInfoFile(OnSysPath, SiteDir): - files: FilesDef = { + files: FilesSpec = { "egginfo_file.egg-info": """ Metadata-Version: 1.0 Name: egginfo_file @@ -233,38 +322,22 @@ def setUp(self): build_files(EggInfoFile.files, prefix=self.site_dir) -def build_files(file_defs, prefix=pathlib.Path()): - """Build a set of files/directories, as described by the +# dedent all text strings before writing +orig = _path.create.registry[str] +_path.create.register(str, lambda content, path: orig(DALS(content), path)) - file_defs dictionary. Each key/value pair in the dictionary is - interpreted as a filename/contents pair. If the contents value is a - dictionary, a directory is created, and the dictionary interpreted - as the files within it, recursively. - For example: +build_files = _path.build - {"README.txt": "A README file", - "foo": { - "__init__.py": "", - "bar": { - "__init__.py": "", - }, - "baz.py": "# Some code", - } - } - """ - for name, contents in file_defs.items(): - full_name = prefix / name - if isinstance(contents, dict): - full_name.mkdir() - build_files(contents, prefix=full_name) - else: - if isinstance(contents, bytes): - with full_name.open('wb') as f: - f.write(contents) - else: - with full_name.open('w', encoding='utf-8') as f: - f.write(DALS(contents)) + +def build_record(file_defs): + return ''.join(f'{name},,\n' for name in record_names(file_defs)) + + +def record_names(file_defs): + recording = _path.Recording() + _path.build(file_defs, recording) + return recording.record class FileBuilder: @@ -277,11 +350,6 @@ def DALS(str): return textwrap.dedent(str).lstrip() -class NullFinder: - def find_module(self, name): - pass - - @requires_zlib() class ZipFixtures: root = 'test.test_importlib.data' diff --git a/Lib/test/test_importlib/frozen/test_finder.py b/Lib/test/test_importlib/frozen/test_finder.py index a82148f865..5bb075f377 100644 --- a/Lib/test/test_importlib/frozen/test_finder.py +++ b/Lib/test/test_importlib/frozen/test_finder.py @@ -70,14 +70,6 @@ def check_search_locations(self, spec): expected = [os.path.dirname(filename)] self.assertListEqual(spec.submodule_search_locations, expected) - def test_package(self): - spec = self.find('__phello__') - self.assertIsNotNone(spec) - - def test_module_in_package(self): - spec = self.find('__phello__.spam', ['__phello__']) - self.assertIsNotNone(spec) - # TODO: RUSTPYTHON @unittest.expectedFailure def test_module(self): @@ -196,45 +188,5 @@ def test_not_using_frozen(self): ) = util.test_both(FindSpecTests, machinery=machinery) -class FinderTests(abc.FinderTests): - - """Test finding frozen modules.""" - - def find(self, name, path=None): - finder = self.machinery.FrozenImporter - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - with import_helper.frozen_modules(): - return finder.find_module(name, path) - - def test_module(self): - name = '__hello__' - loader = self.find(name) - self.assertTrue(hasattr(loader, 'load_module')) - - def test_package(self): - loader = self.find('__phello__') - self.assertTrue(hasattr(loader, 'load_module')) - - def test_module_in_package(self): - loader = self.find('__phello__.spam', ['__phello__']) - self.assertTrue(hasattr(loader, 'load_module')) - - # No frozen package within another package to test with. - test_package_in_package = None - - # No easy way to test. - test_package_over_module = None - - def test_failure(self): - loader = self.find('') - self.assertIsNone(loader) - - -(Frozen_FinderTests, - Source_FinderTests - ) = util.test_both(FinderTests, machinery=machinery) - - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_importlib/frozen/test_loader.py b/Lib/test/test_importlib/frozen/test_loader.py index db256ff0fb..4f1af454b5 100644 --- a/Lib/test/test_importlib/frozen/test_loader.py +++ b/Lib/test/test_importlib/frozen/test_loader.py @@ -103,15 +103,7 @@ def test_lacking_parent(self): expected=value)) self.assertEqual(output, 'Hello world!\n') - def test_module_repr(self): - name = '__hello__' - module, output = self.exec_module(name) - with deprecated(): - repr_str = self.machinery.FrozenImporter.module_repr(module) - self.assertEqual(repr_str, - "") - - def test_module_repr_indirect(self): + def test_module_repr_indirect_through_spec(self): name = '__hello__' module, output = self.exec_module(name) self.assertEqual(repr(module), @@ -133,101 +125,6 @@ def test_unloadable(self): ) = util.test_both(ExecModuleTests, machinery=machinery) -class LoaderTests(abc.LoaderTests): - - def load_module(self, name): - with fresh(name, oldapi=True): - module = self.machinery.FrozenImporter.load_module(name) - with captured_stdout() as stdout: - module.main() - return module, stdout - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_module(self): - module, stdout = self.load_module('__hello__') - filename = resolve_stdlib_file('__hello__') - check = {'__name__': '__hello__', - '__package__': '', - '__loader__': self.machinery.FrozenImporter, - '__file__': filename, - } - for attr, value in check.items(): - self.assertEqual(getattr(module, attr, None), value) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_package(self): - module, stdout = self.load_module('__phello__') - filename = resolve_stdlib_file('__phello__', ispkg=True) - pkgdir = os.path.dirname(filename) - check = {'__name__': '__phello__', - '__package__': '__phello__', - '__path__': [pkgdir], - '__loader__': self.machinery.FrozenImporter, - '__file__': filename, - } - for attr, value in check.items(): - attr_value = getattr(module, attr, None) - self.assertEqual(attr_value, value, - "for __phello__.%s, %r != %r" % - (attr, attr_value, value)) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_lacking_parent(self): - with util.uncache('__phello__'): - module, stdout = self.load_module('__phello__.spam') - filename = resolve_stdlib_file('__phello__.spam') - check = {'__name__': '__phello__.spam', - '__package__': '__phello__', - '__loader__': self.machinery.FrozenImporter, - '__file__': filename, - } - for attr, value in check.items(): - attr_value = getattr(module, attr) - self.assertEqual(attr_value, value, - "for __phello__.spam.%s, %r != %r" % - (attr, attr_value, value)) - self.assertEqual(stdout.getvalue(), 'Hello world!\n') - - def test_module_reuse(self): - with fresh('__hello__', oldapi=True): - module1 = self.machinery.FrozenImporter.load_module('__hello__') - module2 = self.machinery.FrozenImporter.load_module('__hello__') - with captured_stdout() as stdout: - module1.main() - module2.main() - self.assertIs(module1, module2) - self.assertEqual(stdout.getvalue(), - 'Hello world!\nHello world!\n') - - def test_module_repr(self): - with fresh('__hello__', oldapi=True): - module = self.machinery.FrozenImporter.load_module('__hello__') - repr_str = self.machinery.FrozenImporter.module_repr(module) - self.assertEqual(repr_str, - "") - - # No way to trigger an error in a frozen module. - test_state_after_failure = None - - def test_unloadable(self): - with import_helper.frozen_modules(): - with deprecated(): - assert self.machinery.FrozenImporter.find_module('_not_real') is None - with self.assertRaises(ImportError) as cm: - self.load_module('_not_real') - self.assertEqual(cm.exception.name, '_not_real') - - -(Frozen_LoaderTests, - Source_LoaderTests - ) = util.test_both(LoaderTests, machinery=machinery) - - class InspectLoaderTests: """Tests for the InspectLoader methods for FrozenImporter.""" diff --git a/Lib/test/test_importlib/import_/test___loader__.py b/Lib/test/test_importlib/import_/test___loader__.py index eaf665a6f5..a14163919a 100644 --- a/Lib/test/test_importlib/import_/test___loader__.py +++ b/Lib/test/test_importlib/import_/test___loader__.py @@ -33,48 +33,5 @@ def test___loader__(self): ) = util.test_both(SpecLoaderAttributeTests, __import__=util.__import__) -class LoaderMock: - - def find_module(self, fullname, path=None): - return self - - def load_module(self, fullname): - sys.modules[fullname] = self.module - return self.module - - -class LoaderAttributeTests: - - def test___loader___missing(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - module = types.ModuleType('blah') - try: - del module.__loader__ - except AttributeError: - pass - loader = LoaderMock() - loader.module = module - with util.uncache('blah'), util.import_state(meta_path=[loader]): - module = self.__import__('blah') - self.assertEqual(loader, module.__loader__) - - def test___loader___is_None(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - module = types.ModuleType('blah') - module.__loader__ = None - loader = LoaderMock() - loader.module = module - with util.uncache('blah'), util.import_state(meta_path=[loader]): - returned_module = self.__import__('blah') - self.assertEqual(loader, module.__loader__) - - -(Frozen_Tests, - Source_Tests - ) = util.test_both(LoaderAttributeTests, __import__=util.__import__) - - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_importlib/import_/test___package__.py b/Lib/test/test_importlib/import_/test___package__.py index cc2fa0f459..431faea5b4 100644 --- a/Lib/test/test_importlib/import_/test___package__.py +++ b/Lib/test/test_importlib/import_/test___package__.py @@ -78,8 +78,8 @@ def test_spec_fallback(self): # TODO: RUSTPYTHON @unittest.expectedFailure def test_warn_when_package_and_spec_disagree(self): - # Raise an ImportWarning if __package__ != __spec__.parent. - with self.assertWarns(ImportWarning): + # Raise a DeprecationWarning if __package__ != __spec__.parent. + with self.assertWarns(DeprecationWarning): self.import_module({'__package__': 'pkg.fake', '__spec__': FakeSpec('pkg.fakefake')}) @@ -99,25 +99,6 @@ def __init__(self, parent): self.parent = parent -class Using__package__PEP302(Using__package__): - mock_modules = util.mock_modules - - def test_using___package__(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - super().test_using___package__() - - def test_spec_fallback(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - super().test_spec_fallback() - - -(Frozen_UsingPackagePEP302, - Source_UsingPackagePEP302 - ) = util.test_both(Using__package__PEP302, __import__=util.__import__) - - class Using__package__PEP451(Using__package__): mock_modules = util.mock_spec @@ -166,23 +147,6 @@ def test_submodule(self): module = getattr(pkg, 'mod') self.assertEqual(module.__package__, 'pkg') -class Setting__package__PEP302(Setting__package__, unittest.TestCase): - mock_modules = util.mock_modules - - def test_top_level(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - super().test_top_level() - - def test_package(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - super().test_package() - - def test_submodule(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - super().test_submodule() class Setting__package__PEP451(Setting__package__, unittest.TestCase): mock_modules = util.mock_spec diff --git a/Lib/test/test_importlib/import_/test_api.py b/Lib/test/test_importlib/import_/test_api.py index 0ee032b020..d6ad590b3d 100644 --- a/Lib/test/test_importlib/import_/test_api.py +++ b/Lib/test/test_importlib/import_/test_api.py @@ -28,11 +28,6 @@ def exec_module(module): class BadLoaderFinder: - @classmethod - def find_module(cls, fullname, path): - if fullname == SUBMOD_NAME: - return cls - @classmethod def load_module(cls, fullname): if fullname == SUBMOD_NAME: diff --git a/Lib/test/test_importlib/import_/test_caching.py b/Lib/test/test_importlib/import_/test_caching.py index 3ca765fb4a..aedf0fd4f9 100644 --- a/Lib/test/test_importlib/import_/test_caching.py +++ b/Lib/test/test_importlib/import_/test_caching.py @@ -52,12 +52,11 @@ class ImportlibUseCache(UseCache, unittest.TestCase): __import__ = util.__import__['Source'] def create_mock(self, *names, return_=None): - mock = util.mock_modules(*names) - original_load = mock.load_module - def load_module(self, fullname): - original_load(fullname) - return return_ - mock.load_module = MethodType(load_module, mock) + mock = util.mock_spec(*names) + original_spec = mock.find_spec + def find_spec(self, fullname, path, target=None): + return original_spec(fullname) + mock.find_spec = MethodType(find_spec, mock) return mock # __import__ inconsistent between loaders and built-in import when it comes @@ -86,14 +85,12 @@ def test_using_cache_for_assigning_to_attribute(self): # See test_using_cache_after_loader() for reasoning. def test_using_cache_for_fromlist(self): # [from cache for fromlist] - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - with self.create_mock('pkg.__init__', 'pkg.module') as importer: - with util.import_state(meta_path=[importer]): - module = self.__import__('pkg', fromlist=['module']) - self.assertTrue(hasattr(module, 'module')) - self.assertEqual(id(module.module), - id(sys.modules['pkg.module'])) + with self.create_mock('pkg.__init__', 'pkg.module') as importer: + with util.import_state(meta_path=[importer]): + module = self.__import__('pkg', fromlist=['module']) + self.assertTrue(hasattr(module, 'module')) + self.assertEqual(id(module.module), + id(sys.modules['pkg.module'])) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/import_/test_helpers.py b/Lib/test/test_importlib/import_/test_helpers.py new file mode 100644 index 0000000000..28cdc0e526 --- /dev/null +++ b/Lib/test/test_importlib/import_/test_helpers.py @@ -0,0 +1,192 @@ +"""Tests for helper functions used by import.c .""" + +from importlib import _bootstrap_external, machinery +import os.path +from types import ModuleType, SimpleNamespace +import unittest +import warnings + +from .. import util + + +class FixUpModuleTests: + + def test_no_loader_but_spec(self): + loader = object() + name = "hello" + path = "hello.py" + spec = machinery.ModuleSpec(name, loader) + ns = {"__spec__": spec} + _bootstrap_external._fix_up_module(ns, name, path) + + expected = {"__spec__": spec, "__loader__": loader, "__file__": path, + "__cached__": None} + self.assertEqual(ns, expected) + + def test_no_loader_no_spec_but_sourceless(self): + name = "hello" + path = "hello.py" + ns = {} + _bootstrap_external._fix_up_module(ns, name, path, path) + + expected = {"__file__": path, "__cached__": path} + + for key, val in expected.items(): + with self.subTest(f"{key}: {val}"): + self.assertEqual(ns[key], val) + + spec = ns["__spec__"] + self.assertIsInstance(spec, machinery.ModuleSpec) + self.assertEqual(spec.name, name) + self.assertEqual(spec.origin, os.path.abspath(path)) + self.assertEqual(spec.cached, os.path.abspath(path)) + self.assertIsInstance(spec.loader, machinery.SourcelessFileLoader) + self.assertEqual(spec.loader.name, name) + self.assertEqual(spec.loader.path, path) + self.assertEqual(spec.loader, ns["__loader__"]) + + def test_no_loader_no_spec_but_source(self): + name = "hello" + path = "hello.py" + ns = {} + _bootstrap_external._fix_up_module(ns, name, path) + + expected = {"__file__": path, "__cached__": None} + + for key, val in expected.items(): + with self.subTest(f"{key}: {val}"): + self.assertEqual(ns[key], val) + + spec = ns["__spec__"] + self.assertIsInstance(spec, machinery.ModuleSpec) + self.assertEqual(spec.name, name) + self.assertEqual(spec.origin, os.path.abspath(path)) + self.assertIsInstance(spec.loader, machinery.SourceFileLoader) + self.assertEqual(spec.loader.name, name) + self.assertEqual(spec.loader.path, path) + self.assertEqual(spec.loader, ns["__loader__"]) + + +FrozenFixUpModuleTests, SourceFixUpModuleTests = util.test_both(FixUpModuleTests) + + +class TestBlessMyLoader(unittest.TestCase): + # GH#86298 is part of the migration away from module attributes and toward + # __spec__ attributes. There are several cases to test here. This will + # have to change in Python 3.14 when we actually remove/ignore __loader__ + # in favor of requiring __spec__.loader. + + def test_gh86298_no_loader_and_no_spec(self): + bar = ModuleType('bar') + del bar.__loader__ + del bar.__spec__ + # 2022-10-06(warsaw): For backward compatibility with the + # implementation in _warnings.c, this can't raise an + # AttributeError. See _bless_my_loader() in _bootstrap_external.py + # If working with a module: + ## self.assertRaises( + ## AttributeError, _bootstrap_external._bless_my_loader, + ## bar.__dict__) + self.assertIsNone(_bootstrap_external._bless_my_loader(bar.__dict__)) + + def test_gh86298_loader_is_none_and_no_spec(self): + bar = ModuleType('bar') + bar.__loader__ = None + del bar.__spec__ + # 2022-10-06(warsaw): For backward compatibility with the + # implementation in _warnings.c, this can't raise an + # AttributeError. See _bless_my_loader() in _bootstrap_external.py + # If working with a module: + ## self.assertRaises( + ## AttributeError, _bootstrap_external._bless_my_loader, + ## bar.__dict__) + self.assertIsNone(_bootstrap_external._bless_my_loader(bar.__dict__)) + + def test_gh86298_no_loader_and_spec_is_none(self): + bar = ModuleType('bar') + del bar.__loader__ + bar.__spec__ = None + self.assertRaises( + ValueError, + _bootstrap_external._bless_my_loader, bar.__dict__) + + def test_gh86298_loader_is_none_and_spec_is_none(self): + bar = ModuleType('bar') + bar.__loader__ = None + bar.__spec__ = None + self.assertRaises( + ValueError, + _bootstrap_external._bless_my_loader, bar.__dict__) + + def test_gh86298_loader_is_none_and_spec_loader_is_none(self): + bar = ModuleType('bar') + bar.__loader__ = None + bar.__spec__ = SimpleNamespace(loader=None) + self.assertRaises( + ValueError, + _bootstrap_external._bless_my_loader, bar.__dict__) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_gh86298_no_spec(self): + bar = ModuleType('bar') + bar.__loader__ = object() + del bar.__spec__ + with warnings.catch_warnings(): + self.assertWarns( + DeprecationWarning, + _bootstrap_external._bless_my_loader, bar.__dict__) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_gh86298_spec_is_none(self): + bar = ModuleType('bar') + bar.__loader__ = object() + bar.__spec__ = None + with warnings.catch_warnings(): + self.assertWarns( + DeprecationWarning, + _bootstrap_external._bless_my_loader, bar.__dict__) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_gh86298_no_spec_loader(self): + bar = ModuleType('bar') + bar.__loader__ = object() + bar.__spec__ = SimpleNamespace() + with warnings.catch_warnings(): + self.assertWarns( + DeprecationWarning, + _bootstrap_external._bless_my_loader, bar.__dict__) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_gh86298_loader_and_spec_loader_disagree(self): + bar = ModuleType('bar') + bar.__loader__ = object() + bar.__spec__ = SimpleNamespace(loader=object()) + with warnings.catch_warnings(): + self.assertWarns( + DeprecationWarning, + _bootstrap_external._bless_my_loader, bar.__dict__) + + def test_gh86298_no_loader_and_no_spec_loader(self): + bar = ModuleType('bar') + del bar.__loader__ + bar.__spec__ = SimpleNamespace() + self.assertRaises( + AttributeError, + _bootstrap_external._bless_my_loader, bar.__dict__) + + def test_gh86298_no_loader_with_spec_loader_okay(self): + bar = ModuleType('bar') + del bar.__loader__ + loader = object() + bar.__spec__ = SimpleNamespace(loader=loader) + self.assertEqual( + _bootstrap_external._bless_my_loader(bar.__dict__), + loader) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_importlib/import_/test_meta_path.py b/Lib/test/test_importlib/import_/test_meta_path.py index c52fc57065..26e7b070b9 100644 --- a/Lib/test/test_importlib/import_/test_meta_path.py +++ b/Lib/test/test_importlib/import_/test_meta_path.py @@ -115,16 +115,6 @@ def test_with_path(self): super().test_no_path() -class CallSignaturePEP302(CallSignoreSuppressImportWarning): - mock_modules = util.mock_modules - finder_name = 'find_module' - - -(Frozen_CallSignaturePEP302, - Source_CallSignaturePEP302 - ) = util.test_both(CallSignaturePEP302, __import__=util.__import__) - - class CallSignaturePEP451(CallSignature): mock_modules = util.mock_spec finder_name = 'find_spec' diff --git a/Lib/test/test_importlib/import_/test_path.py b/Lib/test/test_importlib/import_/test_path.py index 3873d9f3ed..9cf3a77cb8 100644 --- a/Lib/test/test_importlib/import_/test_path.py +++ b/Lib/test/test_importlib/import_/test_path.py @@ -118,46 +118,6 @@ def test_None_on_sys_path(self): if email is not missing: sys.modules['email'] = email - def test_finder_with_find_module(self): - class TestFinder: - def find_module(self, fullname): - return self.to_return - failing_finder = TestFinder() - failing_finder.to_return = None - path = 'testing path' - with util.import_state(path_importer_cache={path: failing_finder}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - self.assertIsNone( - self.machinery.PathFinder.find_spec('whatever', [path])) - success_finder = TestFinder() - success_finder.to_return = __loader__ - with util.import_state(path_importer_cache={path: success_finder}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - spec = self.machinery.PathFinder.find_spec('whatever', [path]) - self.assertEqual(spec.loader, __loader__) - - def test_finder_with_find_loader(self): - class TestFinder: - loader = None - portions = [] - def find_loader(self, fullname): - return self.loader, self.portions - path = 'testing path' - with util.import_state(path_importer_cache={path: TestFinder()}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - self.assertIsNone( - self.machinery.PathFinder.find_spec('whatever', [path])) - success_finder = TestFinder() - success_finder.loader = __loader__ - with util.import_state(path_importer_cache={path: success_finder}): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - spec = self.machinery.PathFinder.find_spec('whatever', [path]) - self.assertEqual(spec.loader, __loader__) - def test_finder_with_find_spec(self): class TestFinder: spec = None @@ -230,9 +190,9 @@ def invalidate_caches(self): class FindModuleTests(FinderTests): def find(self, *args, **kwargs): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - return self.machinery.PathFinder.find_module(*args, **kwargs) + spec = self.machinery.PathFinder.find_spec(*args, **kwargs) + return None if spec is None else spec.loader + def check_found(self, found, importer): self.assertIs(found, importer) @@ -257,16 +217,14 @@ def check_found(self, found, importer): class PathEntryFinderTests: def test_finder_with_failing_find_spec(self): - # PathEntryFinder with find_module() defined should work. - # Issue #20763. class Finder: - path_location = 'test_finder_with_find_module' + path_location = 'test_finder_with_find_spec' def __init__(self, path): if path != self.path_location: raise ImportError @staticmethod - def find_module(fullname): + def find_spec(fullname, target=None): return None @@ -276,27 +234,6 @@ def find_module(fullname): warnings.simplefilter("ignore", ImportWarning) self.machinery.PathFinder.find_spec('importlib') - def test_finder_with_failing_find_module(self): - # PathEntryFinder with find_module() defined should work. - # Issue #20763. - class Finder: - path_location = 'test_finder_with_find_module' - def __init__(self, path): - if path != self.path_location: - raise ImportError - - @staticmethod - def find_module(fullname): - return None - - - with util.import_state(path=[Finder.path_location]+sys.path[:], - path_hooks=[Finder]): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ImportWarning) - warnings.simplefilter("ignore", DeprecationWarning) - self.machinery.PathFinder.find_module('importlib') - (Frozen_PEFTests, Source_PEFTests diff --git a/Lib/test/test_importlib/resources/_path.py b/Lib/test/test_importlib/resources/_path.py new file mode 100644 index 0000000000..1f97c96146 --- /dev/null +++ b/Lib/test/test_importlib/resources/_path.py @@ -0,0 +1,56 @@ +import pathlib +import functools + +from typing import Dict, Union + + +#### +# from jaraco.path 3.4.1 + +FilesSpec = Dict[str, Union[str, bytes, 'FilesSpec']] # type: ignore + + +def build(spec: FilesSpec, prefix=pathlib.Path()): + """ + Build a set of files/directories, as described by the spec. + + Each key represents a pathname, and the value represents + the content. Content may be a nested directory. + + >>> spec = { + ... 'README.txt': "A README file", + ... "foo": { + ... "__init__.py": "", + ... "bar": { + ... "__init__.py": "", + ... }, + ... "baz.py": "# Some code", + ... } + ... } + >>> target = getfixture('tmp_path') + >>> build(spec, target) + >>> target.joinpath('foo/baz.py').read_text(encoding='utf-8') + '# Some code' + """ + for name, contents in spec.items(): + create(contents, pathlib.Path(prefix) / name) + + +@functools.singledispatch +def create(content: Union[str, bytes, FilesSpec], path): + path.mkdir(exist_ok=True) + build(content, prefix=path) # type: ignore + + +@create.register +def _(content: bytes, path): + path.write_bytes(content) + + +@create.register +def _(content: str, path): + path.write_text(content, encoding='utf-8') + + +# end from jaraco.path +#### diff --git a/Lib/test/test_importlib/data01/__init__.py b/Lib/test/test_importlib/resources/data01/__init__.py similarity index 100% rename from Lib/test/test_importlib/data01/__init__.py rename to Lib/test/test_importlib/resources/data01/__init__.py diff --git a/Lib/test/test_importlib/resources/data01/binary.file b/Lib/test/test_importlib/resources/data01/binary.file new file mode 100644 index 0000000000000000000000000000000000000000..eaf36c1daccfdf325514461cd1a2ffbc139b5464 GIT binary patch literal 4 LcmZQzWMT#Y01f~L literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/data01/subdirectory/__init__.py b/Lib/test/test_importlib/resources/data01/subdirectory/__init__.py similarity index 100% rename from Lib/test/test_importlib/data01/subdirectory/__init__.py rename to Lib/test/test_importlib/resources/data01/subdirectory/__init__.py diff --git a/Lib/test/test_importlib/resources/data01/subdirectory/binary.file b/Lib/test/test_importlib/resources/data01/subdirectory/binary.file new file mode 100644 index 0000000000000000000000000000000000000000..eaf36c1daccfdf325514461cd1a2ffbc139b5464 GIT binary patch literal 4 LcmZQzWMT#Y01f~L literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/resources/data01/utf-16.file b/Lib/test/test_importlib/resources/data01/utf-16.file new file mode 100644 index 0000000000000000000000000000000000000000..2cb772295ef4b480a8d83725bd5006a0236d8f68 GIT binary patch literal 44 ucmezW&x0YAAqNQa8FUyF7(y9B7~B|i84MZBfV^^`Xc15@g+Y;liva-T)Ce>H literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/data01/utf-8.file b/Lib/test/test_importlib/resources/data01/utf-8.file similarity index 100% rename from Lib/test/test_importlib/data01/utf-8.file rename to Lib/test/test_importlib/resources/data01/utf-8.file diff --git a/Lib/test/test_importlib/data02/__init__.py b/Lib/test/test_importlib/resources/data02/__init__.py similarity index 100% rename from Lib/test/test_importlib/data02/__init__.py rename to Lib/test/test_importlib/resources/data02/__init__.py diff --git a/Lib/test/test_importlib/data02/one/__init__.py b/Lib/test/test_importlib/resources/data02/one/__init__.py similarity index 100% rename from Lib/test/test_importlib/data02/one/__init__.py rename to Lib/test/test_importlib/resources/data02/one/__init__.py diff --git a/Lib/test/test_importlib/data02/one/resource1.txt b/Lib/test/test_importlib/resources/data02/one/resource1.txt similarity index 100% rename from Lib/test/test_importlib/data02/one/resource1.txt rename to Lib/test/test_importlib/resources/data02/one/resource1.txt diff --git a/Lib/test/test_importlib/resources/data02/subdirectory/subsubdir/resource.txt b/Lib/test/test_importlib/resources/data02/subdirectory/subsubdir/resource.txt new file mode 100644 index 0000000000..48f587a2d0 --- /dev/null +++ b/Lib/test/test_importlib/resources/data02/subdirectory/subsubdir/resource.txt @@ -0,0 +1 @@ +a resource \ No newline at end of file diff --git a/Lib/test/test_importlib/data02/two/__init__.py b/Lib/test/test_importlib/resources/data02/two/__init__.py similarity index 100% rename from Lib/test/test_importlib/data02/two/__init__.py rename to Lib/test/test_importlib/resources/data02/two/__init__.py diff --git a/Lib/test/test_importlib/data02/two/resource2.txt b/Lib/test/test_importlib/resources/data02/two/resource2.txt similarity index 100% rename from Lib/test/test_importlib/data02/two/resource2.txt rename to Lib/test/test_importlib/resources/data02/two/resource2.txt diff --git a/Lib/test/test_importlib/data03/__init__.py b/Lib/test/test_importlib/resources/data03/__init__.py similarity index 100% rename from Lib/test/test_importlib/data03/__init__.py rename to Lib/test/test_importlib/resources/data03/__init__.py diff --git a/Lib/test/test_importlib/data03/namespace/portion1/__init__.py b/Lib/test/test_importlib/resources/data03/namespace/portion1/__init__.py similarity index 100% rename from Lib/test/test_importlib/data03/namespace/portion1/__init__.py rename to Lib/test/test_importlib/resources/data03/namespace/portion1/__init__.py diff --git a/Lib/test/test_importlib/data03/namespace/portion2/__init__.py b/Lib/test/test_importlib/resources/data03/namespace/portion2/__init__.py similarity index 100% rename from Lib/test/test_importlib/data03/namespace/portion2/__init__.py rename to Lib/test/test_importlib/resources/data03/namespace/portion2/__init__.py diff --git a/Lib/test/test_importlib/data03/namespace/resource1.txt b/Lib/test/test_importlib/resources/data03/namespace/resource1.txt similarity index 100% rename from Lib/test/test_importlib/data03/namespace/resource1.txt rename to Lib/test/test_importlib/resources/data03/namespace/resource1.txt diff --git a/Lib/test/test_importlib/resources/namespacedata01/binary.file b/Lib/test/test_importlib/resources/namespacedata01/binary.file new file mode 100644 index 0000000000000000000000000000000000000000..eaf36c1daccfdf325514461cd1a2ffbc139b5464 GIT binary patch literal 4 LcmZQzWMT#Y01f~L literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/resources/namespacedata01/utf-16.file b/Lib/test/test_importlib/resources/namespacedata01/utf-16.file new file mode 100644 index 0000000000000000000000000000000000000000..2cb772295ef4b480a8d83725bd5006a0236d8f68 GIT binary patch literal 44 ucmezW&x0YAAqNQa8FUyF7(y9B7~B|i84MZBfV^^`Xc15@g+Y;liva-T)Ce>H literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/namespacedata01/utf-8.file b/Lib/test/test_importlib/resources/namespacedata01/utf-8.file similarity index 100% rename from Lib/test/test_importlib/namespacedata01/utf-8.file rename to Lib/test/test_importlib/resources/namespacedata01/utf-8.file diff --git a/Lib/test/test_importlib/test_compatibilty_files.py b/Lib/test/test_importlib/resources/test_compatibilty_files.py similarity index 93% rename from Lib/test/test_importlib/test_compatibilty_files.py rename to Lib/test/test_importlib/resources/test_compatibilty_files.py index 9a823f2d93..bcf608d9e2 100644 --- a/Lib/test/test_importlib/test_compatibilty_files.py +++ b/Lib/test/test_importlib/resources/test_compatibilty_files.py @@ -8,7 +8,7 @@ wrap_spec, ) -from .resources import util +from . import util class CompatibilityFilesTests(unittest.TestCase): @@ -64,11 +64,13 @@ def test_orphan_path_name(self): def test_spec_path_open(self): self.assertEqual(self.files.read_bytes(), b'Hello, world!') - self.assertEqual(self.files.read_text(), 'Hello, world!') + self.assertEqual(self.files.read_text(encoding='utf-8'), 'Hello, world!') def test_child_path_open(self): self.assertEqual((self.files / 'a').read_bytes(), b'Hello, world!') - self.assertEqual((self.files / 'a').read_text(), 'Hello, world!') + self.assertEqual( + (self.files / 'a').read_text(encoding='utf-8'), 'Hello, world!' + ) def test_orphan_path_open(self): with self.assertRaises(FileNotFoundError): diff --git a/Lib/test/test_importlib/test_contents.py b/Lib/test/test_importlib/resources/test_contents.py similarity index 97% rename from Lib/test/test_importlib/test_contents.py rename to Lib/test/test_importlib/resources/test_contents.py index 3323bf5b5c..1a13f043a8 100644 --- a/Lib/test/test_importlib/test_contents.py +++ b/Lib/test/test_importlib/resources/test_contents.py @@ -2,7 +2,7 @@ from importlib import resources from . import data01 -from .resources import util +from . import util class ContentsTests: diff --git a/Lib/test/test_importlib/resources/test_custom.py b/Lib/test/test_importlib/resources/test_custom.py new file mode 100644 index 0000000000..73127209a2 --- /dev/null +++ b/Lib/test/test_importlib/resources/test_custom.py @@ -0,0 +1,46 @@ +import unittest +import contextlib +import pathlib + +from test.support import os_helper + +from importlib import resources +from importlib.resources.abc import TraversableResources, ResourceReader +from . import util + + +class SimpleLoader: + """ + A simple loader that only implements a resource reader. + """ + + def __init__(self, reader: ResourceReader): + self.reader = reader + + def get_resource_reader(self, package): + return self.reader + + +class MagicResources(TraversableResources): + """ + Magically returns the resources at path. + """ + + def __init__(self, path: pathlib.Path): + self.path = path + + def files(self): + return self.path + + +class CustomTraversableResourcesTests(unittest.TestCase): + def setUp(self): + self.fixtures = contextlib.ExitStack() + self.addCleanup(self.fixtures.close) + + def test_custom_loader(self): + temp_dir = self.fixtures.enter_context(os_helper.temp_dir()) + loader = SimpleLoader(MagicResources(temp_dir)) + pkg = util.create_package_from_loader(loader) + files = resources.files(pkg) + assert files is temp_dir diff --git a/Lib/test/test_importlib/resources/test_files.py b/Lib/test/test_importlib/resources/test_files.py new file mode 100644 index 0000000000..1450cfb310 --- /dev/null +++ b/Lib/test/test_importlib/resources/test_files.py @@ -0,0 +1,113 @@ +import typing +import textwrap +import unittest +import warnings +import importlib +import contextlib + +from importlib import resources +from importlib.resources.abc import Traversable +from . import data01 +from . import util +from . import _path +from test.support import os_helper +from test.support import import_helper + + +@contextlib.contextmanager +def suppress_known_deprecation(): + with warnings.catch_warnings(record=True) as ctx: + warnings.simplefilter('default', category=DeprecationWarning) + yield ctx + + +class FilesTests: + def test_read_bytes(self): + files = resources.files(self.data) + actual = files.joinpath('utf-8.file').read_bytes() + assert actual == b'Hello, UTF-8 world!\n' + + def test_read_text(self): + files = resources.files(self.data) + actual = files.joinpath('utf-8.file').read_text(encoding='utf-8') + assert actual == 'Hello, UTF-8 world!\n' + + @unittest.skipUnless( + hasattr(typing, 'runtime_checkable'), + "Only suitable when typing supports runtime_checkable", + ) + def test_traversable(self): + assert isinstance(resources.files(self.data), Traversable) + + def test_old_parameter(self): + """ + Files used to take a 'package' parameter. Make sure anyone + passing by name is still supported. + """ + with suppress_known_deprecation(): + resources.files(package=self.data) + + +class OpenDiskTests(FilesTests, unittest.TestCase): + def setUp(self): + self.data = data01 + + +class OpenZipTests(FilesTests, util.ZipSetup, unittest.TestCase): + pass + + +class OpenNamespaceTests(FilesTests, unittest.TestCase): + def setUp(self): + from . import namespacedata01 + + self.data = namespacedata01 + + +class SiteDir: + def setUp(self): + self.fixtures = contextlib.ExitStack() + self.addCleanup(self.fixtures.close) + self.site_dir = self.fixtures.enter_context(os_helper.temp_dir()) + self.fixtures.enter_context(import_helper.DirsOnSysPath(self.site_dir)) + self.fixtures.enter_context(import_helper.CleanImport()) + + +class ModulesFilesTests(SiteDir, unittest.TestCase): + def test_module_resources(self): + """ + A module can have resources found adjacent to the module. + """ + spec = { + 'mod.py': '', + 'res.txt': 'resources are the best', + } + _path.build(spec, self.site_dir) + import mod + + actual = resources.files(mod).joinpath('res.txt').read_text(encoding='utf-8') + assert actual == spec['res.txt'] + + +class ImplicitContextFilesTests(SiteDir, unittest.TestCase): + def test_implicit_files(self): + """ + Without any parameter, files() will infer the location as the caller. + """ + spec = { + 'somepkg': { + '__init__.py': textwrap.dedent( + """ + import importlib.resources as res + val = res.files().joinpath('res.txt').read_text(encoding='utf-8') + """ + ), + 'res.txt': 'resources are the best', + }, + } + _path.build(spec, self.site_dir) + assert importlib.import_module('somepkg').val == 'resources are the best' + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_importlib/test_open.py b/Lib/test/test_importlib/resources/test_open.py similarity index 82% rename from Lib/test/test_importlib/test_open.py rename to Lib/test/test_importlib/resources/test_open.py index fc0136e865..86becb4bfa 100644 --- a/Lib/test/test_importlib/test_open.py +++ b/Lib/test/test_importlib/resources/test_open.py @@ -2,7 +2,7 @@ from importlib import resources from . import data01 -from .resources import util +from . import util class CommonBinaryTests(util.CommonTests, unittest.TestCase): @@ -15,7 +15,7 @@ def execute(self, package, path): class CommonTextTests(util.CommonTests, unittest.TestCase): def execute(self, package, path): target = resources.files(package).joinpath(path) - with target.open(): + with target.open(encoding='utf-8'): pass @@ -28,7 +28,7 @@ def test_open_binary(self): def test_open_text_default_encoding(self): target = resources.files(self.data) / 'utf-8.file' - with target.open() as fp: + with target.open(encoding='utf-8') as fp: result = fp.read() self.assertEqual(result, 'Hello, UTF-8 world!\n') @@ -39,7 +39,9 @@ def test_open_text_given_encoding(self): self.assertEqual(result, 'Hello, UTF-16 world!\n') def test_open_text_with_errors(self): - # Raises UnicodeError without the 'errors' argument. + """ + Raises UnicodeError without the 'errors' argument. + """ target = resources.files(self.data) / 'utf-16.file' with target.open(encoding='utf-8', errors='strict') as fp: self.assertRaises(UnicodeError, fp.read) @@ -54,11 +56,13 @@ def test_open_text_with_errors(self): def test_open_binary_FileNotFoundError(self): target = resources.files(self.data) / 'does-not-exist' - self.assertRaises(FileNotFoundError, target.open, 'rb') + with self.assertRaises(FileNotFoundError): + target.open('rb') def test_open_text_FileNotFoundError(self): target = resources.files(self.data) / 'does-not-exist' - self.assertRaises(FileNotFoundError, target.open) + with self.assertRaises(FileNotFoundError): + target.open(encoding='utf-8') class OpenDiskTests(OpenTests, unittest.TestCase): @@ -72,12 +76,6 @@ def setUp(self): self.data = namespacedata01 - # TODO: RUSTPYTHON - import sys - if sys.platform == 'win32': - @unittest.expectedFailure - def test_open_text_default_encoding(self): - super().test_open_text_default_encoding() class OpenZipTests(OpenTests, util.ZipSetup, unittest.TestCase): pass diff --git a/Lib/test/test_importlib/test_path.py b/Lib/test/test_importlib/resources/test_path.py similarity index 84% rename from Lib/test/test_importlib/test_path.py rename to Lib/test/test_importlib/resources/test_path.py index 6fc41f301d..34a6bdd2d5 100644 --- a/Lib/test/test_importlib/test_path.py +++ b/Lib/test/test_importlib/resources/test_path.py @@ -3,7 +3,7 @@ from importlib import resources from . import data01 -from .resources import util +from . import util class CommonTests(util.CommonTests, unittest.TestCase): @@ -14,9 +14,12 @@ def execute(self, package, path): class PathTests: def test_reading(self): - # Path should be readable. - # Test also implicitly verifies the returned object is a pathlib.Path - # instance. + """ + Path should be readable. + + Test also implicitly verifies the returned object is a pathlib.Path + instance. + """ target = resources.files(self.data) / 'utf-8.file' with resources.as_file(target) as path: self.assertTrue(path.name.endswith("utf-8.file"), repr(path)) @@ -51,8 +54,10 @@ def setUp(self): class PathZipTests(PathTests, util.ZipSetup, unittest.TestCase): def test_remove_in_context_manager(self): - # It is not an error if the file that was temporarily stashed on the - # file system is removed inside the `with` stanza. + """ + It is not an error if the file that was temporarily stashed on the + file system is removed inside the `with` stanza. + """ target = resources.files(self.data) / 'utf-8.file' with resources.as_file(target) as path: path.unlink() diff --git a/Lib/test/test_importlib/test_read.py b/Lib/test/test_importlib/resources/test_read.py similarity index 86% rename from Lib/test/test_importlib/test_read.py rename to Lib/test/test_importlib/resources/test_read.py index ebd7226777..088982681e 100644 --- a/Lib/test/test_importlib/test_read.py +++ b/Lib/test/test_importlib/resources/test_read.py @@ -2,7 +2,7 @@ from importlib import import_module, resources from . import data01 -from .resources import util +from . import util class CommonBinaryTests(util.CommonTests, unittest.TestCase): @@ -12,7 +12,7 @@ def execute(self, package, path): class CommonTextTests(util.CommonTests, unittest.TestCase): def execute(self, package, path): - resources.files(package).joinpath(path).read_text() + resources.files(package).joinpath(path).read_text(encoding='utf-8') class ReadTests: @@ -21,7 +21,11 @@ def test_read_bytes(self): self.assertEqual(result, b'\0\1\2\3') def test_read_text_default_encoding(self): - result = resources.files(self.data).joinpath('utf-8.file').read_text() + result = ( + resources.files(self.data) + .joinpath('utf-8.file') + .read_text(encoding='utf-8') + ) self.assertEqual(result, 'Hello, UTF-8 world!\n') def test_read_text_given_encoding(self): @@ -33,7 +37,9 @@ def test_read_text_given_encoding(self): self.assertEqual(result, 'Hello, UTF-16 world!\n') def test_read_text_with_errors(self): - # Raises UnicodeError without the 'errors' argument. + """ + Raises UnicodeError without the 'errors' argument. + """ target = resources.files(self.data) / 'utf-16.file' self.assertRaises(UnicodeError, target.read_text, encoding='utf-8') result = target.read_text(encoding='utf-8', errors='ignore') diff --git a/Lib/test/test_importlib/test_reader.py b/Lib/test/test_importlib/resources/test_reader.py similarity index 85% rename from Lib/test/test_importlib/test_reader.py rename to Lib/test/test_importlib/resources/test_reader.py index 9d20c976b8..8670f72a33 100644 --- a/Lib/test/test_importlib/test_reader.py +++ b/Lib/test/test_importlib/resources/test_reader.py @@ -75,6 +75,22 @@ def test_join_path(self): str(path.joinpath('imaginary'))[len(prefix) + 1 :], os.path.join('namespacedata01', 'imaginary'), ) + self.assertEqual(path.joinpath(), path) + + def test_join_path_compound(self): + path = MultiplexedPath(self.folder) + assert not path.joinpath('imaginary/foo.py').exists() + + def test_join_path_common_subdir(self): + prefix = os.path.abspath(os.path.join(__file__, '..')) + data01 = os.path.join(prefix, 'data01') + data02 = os.path.join(prefix, 'data02') + path = MultiplexedPath(data01, data02) + self.assertIsInstance(path.joinpath('subdirectory'), MultiplexedPath) + self.assertEqual( + str(path.joinpath('subdirectory', 'subsubdir'))[len(prefix) + 1 :], + os.path.join('data02', 'subdirectory', 'subsubdir'), + ) def test_repr(self): self.assertEqual( diff --git a/Lib/test/test_importlib/test_resource.py b/Lib/test/test_importlib/resources/test_resource.py similarity index 74% rename from Lib/test/test_importlib/test_resource.py rename to Lib/test/test_importlib/resources/test_resource.py index 834b8bd8a2..6f75cf57f0 100644 --- a/Lib/test/test_importlib/test_resource.py +++ b/Lib/test/test_importlib/resources/test_resource.py @@ -1,3 +1,4 @@ +import contextlib import sys import unittest import uuid @@ -5,9 +6,9 @@ from . import data01 from . import zipdata01, zipdata02 -from .resources import util +from . import util from importlib import resources, import_module -from test.support import import_helper +from test.support import import_helper, os_helper from test.support.os_helper import unlink @@ -69,10 +70,12 @@ def test_resource_missing(self): class ResourceCornerCaseTests(unittest.TestCase): def test_package_has_no_reader_fallback(self): - # Test odd ball packages which: + """ + Test odd ball packages which: # 1. Do not have a ResourceReader as a loader # 2. Are not on the file system # 3. Are not in a zip file + """ module = util.create_package( file=data01, path=data01.__file__, contents=['A', 'B', 'C'] ) @@ -111,6 +114,14 @@ def test_submodule_contents_by_name(self): {'__init__.py', 'binary.file'}, ) + def test_as_file_directory(self): + with resources.as_file(resources.files('ziptestdata')) as data: + assert data.name == 'ziptestdata' + assert data.is_dir() + assert data.joinpath('subdirectory').is_dir() + assert len(list(data.iterdir())) + assert not data.parent.exists() + class ResourceFromZipsTest02(util.ZipSetupBase, unittest.TestCase): ZIP_MODULE = zipdata02 # type: ignore @@ -130,82 +141,71 @@ def test_unrelated_contents(self): ) +@contextlib.contextmanager +def zip_on_path(dir): + data_path = pathlib.Path(zipdata01.__file__) + source_zip_path = data_path.parent.joinpath('ziptestdata.zip') + zip_path = pathlib.Path(dir) / f'{uuid.uuid4()}.zip' + zip_path.write_bytes(source_zip_path.read_bytes()) + sys.path.append(str(zip_path)) + import_module('ziptestdata') + + try: + yield + finally: + with contextlib.suppress(ValueError): + sys.path.remove(str(zip_path)) + + with contextlib.suppress(KeyError): + del sys.path_importer_cache[str(zip_path)] + del sys.modules['ziptestdata'] + + with contextlib.suppress(OSError): + unlink(zip_path) + + class DeletingZipsTest(unittest.TestCase): """Having accessed resources in a zip file should not keep an open reference to the zip. """ - ZIP_MODULE = zipdata01 - def setUp(self): + self.fixtures = contextlib.ExitStack() + self.addCleanup(self.fixtures.close) + modules = import_helper.modules_setup() self.addCleanup(import_helper.modules_cleanup, *modules) - data_path = pathlib.Path(self.ZIP_MODULE.__file__) - data_dir = data_path.parent - self.source_zip_path = data_dir / 'ziptestdata.zip' - self.zip_path = pathlib.Path(f'{uuid.uuid4()}.zip').absolute() - self.zip_path.write_bytes(self.source_zip_path.read_bytes()) - sys.path.append(str(self.zip_path)) - self.data = import_module('ziptestdata') - - def tearDown(self): - try: - sys.path.remove(str(self.zip_path)) - except ValueError: - pass - - try: - del sys.path_importer_cache[str(self.zip_path)] - del sys.modules[self.data.__name__] - except KeyError: - pass - - try: - unlink(self.zip_path) - except OSError: - # If the test fails, this will probably fail too - pass + temp_dir = self.fixtures.enter_context(os_helper.temp_dir()) + self.fixtures.enter_context(zip_on_path(temp_dir)) def test_iterdir_does_not_keep_open(self): - c = [item.name for item in resources.files('ziptestdata').iterdir()] - self.zip_path.unlink() - del c + [item.name for item in resources.files('ziptestdata').iterdir()] def test_is_file_does_not_keep_open(self): - c = resources.files('ziptestdata').joinpath('binary.file').is_file() - self.zip_path.unlink() - del c + resources.files('ziptestdata').joinpath('binary.file').is_file() def test_is_file_failure_does_not_keep_open(self): - c = resources.files('ziptestdata').joinpath('not-present').is_file() - self.zip_path.unlink() - del c + resources.files('ziptestdata').joinpath('not-present').is_file() @unittest.skip("Desired but not supported.") def test_as_file_does_not_keep_open(self): # pragma: no cover - c = resources.as_file(resources.files('ziptestdata') / 'binary.file') - self.zip_path.unlink() - del c + resources.as_file(resources.files('ziptestdata') / 'binary.file') def test_entered_path_does_not_keep_open(self): - # This is what certifi does on import to make its bundle - # available for the process duration. - c = resources.as_file( - resources.files('ziptestdata') / 'binary.file' - ).__enter__() - self.zip_path.unlink() - del c + """ + Mimic what certifi does on import to make its bundle + available for the process duration. + """ + resources.as_file(resources.files('ziptestdata') / 'binary.file').__enter__() def test_read_binary_does_not_keep_open(self): - c = resources.files('ziptestdata').joinpath('binary.file').read_bytes() - self.zip_path.unlink() - del c + resources.files('ziptestdata').joinpath('binary.file').read_bytes() def test_read_text_does_not_keep_open(self): - c = resources.files('ziptestdata').joinpath('utf-8.file').read_text() - self.zip_path.unlink() - del c + resources.files('ziptestdata').joinpath('utf-8.file').read_text( + encoding='utf-8' + ) class ResourceFromNamespaceTest01(unittest.TestCase): diff --git a/Lib/test/test_importlib/update-zips.py b/Lib/test/test_importlib/resources/update-zips.py similarity index 100% rename from Lib/test/test_importlib/update-zips.py rename to Lib/test/test_importlib/resources/update-zips.py diff --git a/Lib/test/test_importlib/resources/util.py b/Lib/test/test_importlib/resources/util.py index 11c8aa8080..dbe6ee8147 100644 --- a/Lib/test/test_importlib/resources/util.py +++ b/Lib/test/test_importlib/resources/util.py @@ -3,11 +3,11 @@ import io import sys import types -from pathlib import Path, PurePath +import pathlib -from .. import data01 -from .. import zipdata01 -from importlib.abc import ResourceReader +from . import data01 +from . import zipdata01 +from importlib.resources.abc import ResourceReader from test.support import import_helper @@ -80,43 +80,44 @@ def execute(self, package, path): """ def test_package_name(self): - # Passing in the package name should succeed. + """ + Passing in the package name should succeed. + """ self.execute(data01.__name__, 'utf-8.file') def test_package_object(self): - # Passing in the package itself should succeed. + """ + Passing in the package itself should succeed. + """ self.execute(data01, 'utf-8.file') def test_string_path(self): - # Passing in a string for the path should succeed. + """ + Passing in a string for the path should succeed. + """ path = 'utf-8.file' self.execute(data01, path) def test_pathlib_path(self): - # Passing in a pathlib.PurePath object for the path should succeed. - path = PurePath('utf-8.file') + """ + Passing in a pathlib.PurePath object for the path should succeed. + """ + path = pathlib.PurePath('utf-8.file') self.execute(data01, path) def test_importing_module_as_side_effect(self): - # The anchor package can already be imported. + """ + The anchor package can already be imported. + """ del sys.modules[data01.__name__] self.execute(data01.__name__, 'utf-8.file') - def test_non_package_by_name(self): - # The anchor package cannot be a module. - with self.assertRaises(TypeError): - self.execute(__name__, 'utf-8.file') - - def test_non_package_by_package(self): - # The anchor package cannot be a module. - with self.assertRaises(TypeError): - module = sys.modules['test.test_importlib.resources.util'] - self.execute(module, 'utf-8.file') - def test_missing_path(self): - # Attempting to open or read or request the path for a - # non-existent path should succeed if open_resource - # can return a viable data stream. + """ + Attempting to open or read or request the path for a + non-existent path should succeed if open_resource + can return a viable data stream. + """ bytes_data = io.BytesIO(b'Hello, world!') package = create_package(file=bytes_data, path=FileNotFoundError()) self.execute(package, 'utf-8.file') @@ -144,7 +145,7 @@ class ZipSetupBase: @classmethod def setUpClass(cls): - data_path = Path(cls.ZIP_MODULE.__file__) + data_path = pathlib.Path(cls.ZIP_MODULE.__file__) data_dir = data_path.parent cls._zip_path = str(data_dir / 'ziptestdata.zip') sys.path.append(cls._zip_path) diff --git a/Lib/test/test_importlib/zipdata01/__init__.py b/Lib/test/test_importlib/resources/zipdata01/__init__.py similarity index 100% rename from Lib/test/test_importlib/zipdata01/__init__.py rename to Lib/test/test_importlib/resources/zipdata01/__init__.py diff --git a/Lib/test/test_importlib/resources/zipdata01/ziptestdata.zip b/Lib/test/test_importlib/resources/zipdata01/ziptestdata.zip new file mode 100644 index 0000000000000000000000000000000000000000..9a3bb0739f87e97c1084b94d7d153680f6727738 GIT binary patch literal 876 zcmWIWW@Zs#00HOCX@Q%&m27l?Y!DU);;PJolGNgol*E!m{nC;&T|+ayw9K5;|NlG~ zQWMD z9;rDw`8o=rA#S=B3g!7lIVp-}COK17UPc zNtt;*xhM-3R!jMEPhCreO-3*u>5Df}T7+BJ{639e$2uhfsIs`pJ5Qf}C xGXyDE@VNvOv@o!wQJfLgCAgysx3f@9jKpUmiW^zkK<;1z!tFpk^MROw0RS~O%0&PG literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/zipdata02/__init__.py b/Lib/test/test_importlib/resources/zipdata02/__init__.py similarity index 100% rename from Lib/test/test_importlib/zipdata02/__init__.py rename to Lib/test/test_importlib/resources/zipdata02/__init__.py diff --git a/Lib/test/test_importlib/resources/zipdata02/ziptestdata.zip b/Lib/test/test_importlib/resources/zipdata02/ziptestdata.zip new file mode 100644 index 0000000000000000000000000000000000000000..d63ff512d2807ef2fd259455283b81b02e0e45fb GIT binary patch literal 698 zcmWIWW@Zs#00HOCX@Ot{ln@8fRhb1Psl_EJi6x2p@$s2?nI-Y@dIgmMI5kP5Y0A$_ z#jWw|&p#`9ff_(q7K_HB)Z+ZoqU2OVy^@L&ph*fa0WRVlP*R?c+X1opI-R&20MZDv z&j{oIpa8N17@0(vaR(gGH(;=&5k%n(M%;#g0ulz6G@1gL$cA79E2=^00gEsw4~s!C zUxI@ZWaIMqz|BszK;s4KsL2<9jRy!Q2E6`2cTLHjr{wAk1ZCU@!+_ G1_l6Bc%f?m literal 0 HcmV?d00001 diff --git a/Lib/test/test_importlib/source/test_case_sensitivity.py b/Lib/test/test_importlib/source/test_case_sensitivity.py index 9d472707ab..6a06313319 100644 --- a/Lib/test/test_importlib/source/test_case_sensitivity.py +++ b/Lib/test/test_importlib/source/test_case_sensitivity.py @@ -63,19 +63,6 @@ def test_insensitive(self): self.assertIn(self.name, insensitive.get_filename(self.name)) -class CaseSensitivityTestPEP302(CaseSensitivityTest): - def find(self, finder): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - return finder.find_module(self.name) - - -(Frozen_CaseSensitivityTestPEP302, - Source_CaseSensitivityTestPEP302 - ) = util.test_both(CaseSensitivityTestPEP302, importlib=importlib, - machinery=machinery) - - class CaseSensitivityTestPEP451(CaseSensitivityTest): def find(self, finder): found = finder.find_spec(self.name) diff --git a/Lib/test/test_importlib/source/test_file_loader.py b/Lib/test/test_importlib/source/test_file_loader.py index ebf6ec68d7..9c85bd234f 100644 --- a/Lib/test/test_importlib/source/test_file_loader.py +++ b/Lib/test/test_importlib/source/test_file_loader.py @@ -51,7 +51,6 @@ class Tester(self.abc.FileLoader): def get_code(self, _): pass def get_source(self, _): pass def is_package(self, _): pass - def module_repr(self, _): pass path = 'some_path' name = 'some_name' diff --git a/Lib/test/test_importlib/source/test_finder.py b/Lib/test/test_importlib/source/test_finder.py index 3c12ab0123..17d09d4cee 100644 --- a/Lib/test/test_importlib/source/test_finder.py +++ b/Lib/test/test_importlib/source/test_finder.py @@ -120,7 +120,7 @@ def test_package_over_module(self): def test_failure(self): with util.create_modules('blah') as mapping: nothing = self.import_(mapping['.root'], 'sdfsadsadf') - self.assertIsNone(nothing) + self.assertEqual(nothing, self.NOT_FOUND) def test_empty_string_for_dir(self): # The empty string from sys.path means to search in the cwd. @@ -150,7 +150,7 @@ def test_dir_removal_handling(self): found = self._find(finder, 'mod', loader_only=True) self.assertIsNotNone(found) found = self._find(finder, 'mod', loader_only=True) - self.assertIsNone(found) + self.assertEqual(found, self.NOT_FOUND) @unittest.skipUnless(sys.platform != 'win32', 'os.chmod() does not support the needed arguments under Windows') @@ -197,10 +197,12 @@ class FinderTestsPEP420(FinderTests): NOT_FOUND = (None, []) def _find(self, finder, name, loader_only=False): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - loader_portions = finder.find_loader(name) - return loader_portions[0] if loader_only else loader_portions + spec = finder.find_spec(name) + if spec is None: + return self.NOT_FOUND + if loader_only: + return spec.loader + return spec.loader, spec.submodule_search_locations (Frozen_FinderTestsPEP420, @@ -208,20 +210,5 @@ def _find(self, finder, name, loader_only=False): ) = util.test_both(FinderTestsPEP420, machinery=machinery) -class FinderTestsPEP302(FinderTests): - - NOT_FOUND = None - - def _find(self, finder, name, loader_only=False): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - return finder.find_module(name) - - -(Frozen_FinderTestsPEP302, - Source_FinderTestsPEP302 - ) = util.test_both(FinderTestsPEP302, machinery=machinery) - - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_importlib/source/test_path_hook.py b/Lib/test/test_importlib/source/test_path_hook.py index ead62f5e94..f274330e0b 100644 --- a/Lib/test/test_importlib/source/test_path_hook.py +++ b/Lib/test/test_importlib/source/test_path_hook.py @@ -18,19 +18,10 @@ def test_success(self): self.assertTrue(hasattr(self.path_hook()(mapping['.root']), 'find_spec')) - def test_success_legacy(self): - with util.create_modules('dummy') as mapping: - self.assertTrue(hasattr(self.path_hook()(mapping['.root']), - 'find_module')) - def test_empty_string(self): # The empty string represents the cwd. self.assertTrue(hasattr(self.path_hook()(''), 'find_spec')) - def test_empty_string_legacy(self): - # The empty string represents the cwd. - self.assertTrue(hasattr(self.path_hook()(''), 'find_module')) - (Frozen_PathHookTest, Source_PathHooktest diff --git a/Lib/test/test_importlib/test_abc.py b/Lib/test/test_importlib/test_abc.py index d77b8a0a4d..a231ae1d5f 100644 --- a/Lib/test/test_importlib/test_abc.py +++ b/Lib/test/test_importlib/test_abc.py @@ -2,7 +2,6 @@ import marshal import os import sys -from test import support from test.support import import_helper import types import unittest @@ -148,20 +147,13 @@ def ins(self): class MetaPathFinder: - def find_module(self, fullname, path): - return super().find_module(fullname, path) + pass class MetaPathFinderDefaultsTests(ABCTestHarness): SPLIT = make_abc_subclasses(MetaPathFinder) - def test_find_module(self): - # Default should return None. - with self.assertWarns(DeprecationWarning): - found = self.ins.find_module('something', None) - self.assertIsNone(found) - def test_invalidate_caches(self): # Calling the method is a no-op. self.ins.invalidate_caches() @@ -174,22 +166,13 @@ def test_invalidate_caches(self): class PathEntryFinder: - def find_loader(self, fullname): - return super().find_loader(fullname) + pass class PathEntryFinderDefaultsTests(ABCTestHarness): SPLIT = make_abc_subclasses(PathEntryFinder) - def test_find_loader(self): - with self.assertWarns(DeprecationWarning): - found = self.ins.find_loader('something') - self.assertEqual(found, (None, [])) - - def find_module(self): - self.assertEqual(None, self.ins.find_module('something')) - def test_invalidate_caches(self): # Should be a no-op. self.ins.invalidate_caches() @@ -202,8 +185,7 @@ def test_invalidate_caches(self): class Loader: - def load_module(self, fullname): - return super().load_module(fullname) + pass class LoaderDefaultsTests(ABCTestHarness): @@ -222,8 +204,6 @@ def test_module_repr(self): mod = types.ModuleType('blah') with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) - with self.assertRaises(NotImplementedError): - self.ins.module_repr(mod) original_repr = repr(mod) mod.__loader__ = self.ins # Should still return a proper repr. @@ -323,32 +303,6 @@ def contents(self, *args, **kwargs): return super().contents(*args, **kwargs) -class ResourceReaderDefaultsTests(ABCTestHarness): - - SPLIT = make_abc_subclasses(ResourceReader) - - def test_open_resource(self): - with self.assertRaises(FileNotFoundError): - self.ins.open_resource('dummy_file') - - def test_resource_path(self): - with self.assertRaises(FileNotFoundError): - self.ins.resource_path('dummy_file') - - def test_is_resource(self): - with self.assertRaises(FileNotFoundError): - self.ins.is_resource('dummy_file') - - def test_contents(self): - with self.assertRaises(FileNotFoundError): - self.ins.contents() - - -(Frozen_RRDefaultTests, - Source_RRDefaultsTests - ) = test_util.test_both(ResourceReaderDefaultsTests) - - ##### MetaPathFinder concrete methods ########################################## class MetaPathFinderFindModuleTests: @@ -362,14 +316,6 @@ def find_spec(self, fullname, path, target=None): return MetaPathSpecFinder() - def test_find_module(self): - finder = self.finder(None) - path = ['a', 'b', 'c'] - name = 'blah' - with self.assertWarns(DeprecationWarning): - found = finder.find_module(name, path) - self.assertIsNone(found) - def test_find_spec_with_explicit_target(self): loader = object() spec = self.util.spec_from_loader('blah', loader) @@ -399,53 +345,6 @@ def test_spec(self): ) = test_util.test_both(MetaPathFinderFindModuleTests, abc=abc, util=util) -##### PathEntryFinder concrete methods ######################################### -class PathEntryFinderFindLoaderTests: - - @classmethod - def finder(cls, spec): - class PathEntrySpecFinder(cls.abc.PathEntryFinder): - - def find_spec(self, fullname, target=None): - self.called_for = fullname - return spec - - return PathEntrySpecFinder() - - def test_no_spec(self): - finder = self.finder(None) - name = 'blah' - with self.assertWarns(DeprecationWarning): - found = finder.find_loader(name) - self.assertIsNone(found[0]) - self.assertEqual([], found[1]) - self.assertEqual(name, finder.called_for) - - def test_spec_with_loader(self): - loader = object() - spec = self.util.spec_from_loader('blah', loader) - finder = self.finder(spec) - with self.assertWarns(DeprecationWarning): - found = finder.find_loader('blah') - self.assertIs(found[0], spec.loader) - - def test_spec_with_portions(self): - spec = self.machinery.ModuleSpec('blah', None) - paths = ['a', 'b', 'c'] - spec.submodule_search_locations = paths - finder = self.finder(spec) - with self.assertWarns(DeprecationWarning): - found = finder.find_loader('blah') - self.assertIsNone(found[0]) - self.assertEqual(paths, found[1]) - - -(Frozen_PEFFindLoaderTests, - Source_PEFFindLoaderTests - ) = test_util.test_both(PathEntryFinderFindLoaderTests, abc=abc, util=util, - machinery=machinery) - - ##### Loader concrete methods ################################################## class LoaderLoadModuleTests: @@ -716,9 +615,6 @@ def get_data(self, path): def get_filename(self, fullname): return self.path - def module_repr(self, module): - return '' - SPLIT_SOL = make_abc_subclasses(SourceOnlyLoader, 'SourceLoader') @@ -803,13 +699,7 @@ def verify_code(self, code_object): class SourceOnlyLoaderTests(SourceLoaderTestHarness): - - """Test importlib.abc.SourceLoader for source-only loading. - - Reload testing is subsumed by the tests for - importlib.util.module_for_loader. - - """ + """Test importlib.abc.SourceLoader for source-only loading.""" # TODO: RUSTPYTHON @unittest.expectedFailure diff --git a/Lib/test/test_importlib/test_api.py b/Lib/test/test_importlib/test_api.py index 1beb7835d4..ecf2c47c46 100644 --- a/Lib/test/test_importlib/test_api.py +++ b/Lib/test/test_importlib/test_api.py @@ -6,7 +6,6 @@ import os.path import sys -from test import support from test.support import import_helper from test.support import os_helper import types @@ -96,7 +95,8 @@ def load_b(): (Frozen_ImportModuleTests, Source_ImportModuleTests - ) = test_util.test_both(ImportModuleTests, init=init) + ) = test_util.test_both( + ImportModuleTests, init=init, util=util, machinery=machinery) class FindLoaderTests: @@ -104,29 +104,26 @@ class FindLoaderTests: FakeMetaFinder = None def test_sys_modules(self): - # If a module with __loader__ is in sys.modules, then return it. + # If a module with __spec__.loader is in sys.modules, then return it. name = 'some_mod' with test_util.uncache(name): module = types.ModuleType(name) loader = 'a loader!' - module.__loader__ = loader + module.__spec__ = self.machinery.ModuleSpec(name, loader) sys.modules[name] = module - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - found = self.init.find_loader(name) - self.assertEqual(loader, found) + spec = self.util.find_spec(name) + self.assertIsNotNone(spec) + self.assertEqual(spec.loader, loader) def test_sys_modules_loader_is_None(self): - # If sys.modules[name].__loader__ is None, raise ValueError. + # If sys.modules[name].__spec__.loader is None, raise ValueError. name = 'some_mod' with test_util.uncache(name): module = types.ModuleType(name) module.__loader__ = None sys.modules[name] = module with self.assertRaises(ValueError): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.init.find_loader(name) + self.util.find_spec(name) def test_sys_modules_loader_is_not_set(self): # Should raise ValueError @@ -135,24 +132,20 @@ def test_sys_modules_loader_is_not_set(self): with test_util.uncache(name): module = types.ModuleType(name) try: - del module.__loader__ + del module.__spec__.loader except AttributeError: pass sys.modules[name] = module with self.assertRaises(ValueError): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.init.find_loader(name) + self.util.find_spec(name) def test_success(self): # Return the loader found on sys.meta_path. name = 'some_mod' with test_util.uncache(name): with test_util.import_state(meta_path=[self.FakeMetaFinder]): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - warnings.simplefilter('ignore', ImportWarning) - self.assertEqual((name, None), self.init.find_loader(name)) + spec = self.util.find_spec(name) + self.assertEqual((name, (name, None)), (spec.name, spec.loader)) def test_success_path(self): # Searching on a path should work. @@ -160,17 +153,12 @@ def test_success_path(self): path = 'path to some place' with test_util.uncache(name): with test_util.import_state(meta_path=[self.FakeMetaFinder]): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - warnings.simplefilter('ignore', ImportWarning) - self.assertEqual((name, path), - self.init.find_loader(name, path)) + spec = self.util.find_spec(name, path) + self.assertEqual(name, spec.name) def test_nothing(self): # None is returned upon failure to find a loader. - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.assertIsNone(self.init.find_loader('nevergoingtofindthismodule')) + self.assertIsNone(self.util.find_spec('nevergoingtofindthismodule')) class FindLoaderPEP451Tests(FindLoaderTests): @@ -183,20 +171,8 @@ def find_spec(name, path=None, target=None): (Frozen_FindLoaderPEP451Tests, Source_FindLoaderPEP451Tests - ) = test_util.test_both(FindLoaderPEP451Tests, init=init) - - -class FindLoaderPEP302Tests(FindLoaderTests): - - class FakeMetaFinder: - @staticmethod - def find_module(name, path=None): - return name, path - - -(Frozen_FindLoaderPEP302Tests, - Source_FindLoaderPEP302Tests - ) = test_util.test_both(FindLoaderPEP302Tests, init=init) + ) = test_util.test_both( + FindLoaderPEP451Tests, init=init, util=util, machinery=machinery) class ReloadTests: @@ -301,7 +277,8 @@ def test_reload_namespace_changed(self): name = 'spam' with os_helper.temp_cwd(None) as cwd: with test_util.uncache('spam'): - with import_helper.DirsOnSysPath(cwd): + with test_util.import_state(path=[cwd]): + self.init._bootstrap_external._install(self.init._bootstrap) # Start as a namespace package. self.init.invalidate_caches() bad_path = os.path.join(cwd, name, '__init.py') @@ -380,7 +357,8 @@ def test_module_missing_spec(self): (Frozen_ReloadTests, Source_ReloadTests - ) = test_util.test_both(ReloadTests, init=init, util=util) + ) = test_util.test_both( + ReloadTests, init=init, util=util, machinery=machinery) class InvalidateCacheTests: @@ -390,8 +368,6 @@ def test_method_called(self): class InvalidatingNullFinder: def __init__(self, *ignored): self.called = False - def find_module(self, *args): - return None def invalidate_caches(self): self.called = True @@ -416,7 +392,8 @@ def test_method_lacking(self): (Frozen_InvalidateCacheTests, Source_InvalidateCacheTests - ) = test_util.test_both(InvalidateCacheTests, init=init) + ) = test_util.test_both( + InvalidateCacheTests, init=init, util=util, machinery=machinery) class FrozenImportlibTests(unittest.TestCase): diff --git a/Lib/test/test_importlib/test_files.py b/Lib/test/test_importlib/test_files.py deleted file mode 100644 index b9170d83be..0000000000 --- a/Lib/test/test_importlib/test_files.py +++ /dev/null @@ -1,46 +0,0 @@ -import typing -import unittest - -from importlib import resources -from importlib.abc import Traversable -from . import data01 -from .resources import util - - -class FilesTests: - def test_read_bytes(self): - files = resources.files(self.data) - actual = files.joinpath('utf-8.file').read_bytes() - assert actual == b'Hello, UTF-8 world!\n' - - def test_read_text(self): - files = resources.files(self.data) - actual = files.joinpath('utf-8.file').read_text(encoding='utf-8') - assert actual == 'Hello, UTF-8 world!\n' - - @unittest.skipUnless( - hasattr(typing, 'runtime_checkable'), - "Only suitable when typing supports runtime_checkable", - ) - def test_traversable(self): - assert isinstance(resources.files(self.data), Traversable) - - -class OpenDiskTests(FilesTests, unittest.TestCase): - def setUp(self): - self.data = data01 - - -class OpenZipTests(FilesTests, util.ZipSetup, unittest.TestCase): - pass - - -class OpenNamespaceTests(FilesTests, unittest.TestCase): - def setUp(self): - from . import namespacedata01 - - self.data = namespacedata01 - - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/test/test_importlib/test_locks.py b/Lib/test/test_importlib/test_locks.py index 32ed67c308..17cce741cc 100644 --- a/Lib/test/test_importlib/test_locks.py +++ b/Lib/test/test_importlib/test_locks.py @@ -33,6 +33,11 @@ class ModuleLockAsRLockTests: test_repr = None test_locked_repr = None + def tearDown(self): + for splitinit in init.values(): + splitinit._bootstrap._blocking_on.clear() + + LOCK_TYPES = {kind: splitinit._bootstrap._ModuleLock for kind, splitinit in init.items()} diff --git a/Lib/test/test_importlib/test_main.py b/Lib/test/test_importlib/test_main.py index d9d067c4b2..81f683799c 100644 --- a/Lib/test/test_importlib/test_main.py +++ b/Lib/test/test_importlib/test_main.py @@ -1,9 +1,10 @@ import re -import json import pickle import unittest import warnings import importlib.metadata +import contextlib +import itertools try: import pyfakefs.fake_filesystem_unittest as ffs @@ -11,6 +12,7 @@ from .stubs import fake_filesystem_unittest as ffs from . import fixtures +from ._context import suppress from importlib.metadata import ( Distribution, EntryPoint, @@ -24,6 +26,13 @@ ) +@contextlib.contextmanager +def suppress_known_deprecation(): + with warnings.catch_warnings(record=True) as ctx: + warnings.simplefilter('default', category=DeprecationWarning) + yield ctx + + class BasicTests(fixtures.DistInfoPkg, unittest.TestCase): version_pattern = r'\d+\.\d+(\.\d)?' @@ -39,7 +48,7 @@ def test_for_name_does_not_exist(self): def test_package_not_found_mentions_metadata(self): """ When a package is not found, that could indicate that the - packgae is not installed or that it is installed without + package is not installed or that it is installed without metadata. Ensure the exception mentions metadata to help guide users toward the cause. See #124. """ @@ -48,15 +57,19 @@ def test_package_not_found_mentions_metadata(self): assert "metadata" in str(ctx.exception) - def test_new_style_classes(self): - self.assertIsInstance(Distribution, type) + # expected to fail until ABC is enforced + @suppress(AssertionError) + @suppress_known_deprecation() + def test_abc_enforced(self): + with self.assertRaises(TypeError): + type('DistributionSubclass', (Distribution,), {})() @fixtures.parameterize( dict(name=None), dict(name=''), ) def test_invalid_inputs_to_from_name(self, name): - with self.assertRaises(Exception): + with self.assertRaises(ValueError): Distribution.from_name(name) @@ -174,11 +187,21 @@ def test_metadata_loads_egg_info(self): assert meta['Description'] == 'pôrˈtend' -class DiscoveryTests(fixtures.EggInfoPkg, fixtures.DistInfoPkg, unittest.TestCase): +class DiscoveryTests( + fixtures.EggInfoPkg, + fixtures.EggInfoPkgPipInstalledNoToplevel, + fixtures.EggInfoPkgPipInstalledNoModules, + fixtures.EggInfoPkgSourcesFallback, + fixtures.DistInfoPkg, + unittest.TestCase, +): def test_package_discovery(self): dists = list(distributions()) assert all(isinstance(dist, Distribution) for dist in dists) assert any(dist.metadata['Name'] == 'egginfo-pkg' for dist in dists) + assert any(dist.metadata['Name'] == 'egg_with_module-pkg' for dist in dists) + assert any(dist.metadata['Name'] == 'egg_with_no_modules-pkg' for dist in dists) + assert any(dist.metadata['Name'] == 'sources_fallback-pkg' for dist in dists) assert any(dist.metadata['Name'] == 'distinfo-pkg' for dist in dists) def test_invalid_usage(self): @@ -260,14 +283,6 @@ def test_hashable(self): """EntryPoints should be hashable""" hash(self.ep) - def test_json_dump(self): - """ - json should not expect to be able to dump an EntryPoint - """ - with self.assertRaises(Exception): - with warnings.catch_warnings(record=True): - json.dumps(self.ep) - def test_module(self): assert self.ep.module == 'value' @@ -334,3 +349,79 @@ def test_packages_distributions_neither_toplevel_nor_files(self): prefix=self.site_dir, ) packages_distributions() + + def test_packages_distributions_all_module_types(self): + """ + Test top-level modules detected on a package without 'top-level.txt'. + """ + suffixes = importlib.machinery.all_suffixes() + metadata = dict( + METADATA=""" + Name: all_distributions + Version: 1.0.0 + """, + ) + files = { + 'all_distributions-1.0.0.dist-info': metadata, + } + for i, suffix in enumerate(suffixes): + files.update( + { + f'importable-name {i}{suffix}': '', + f'in_namespace_{i}': { + f'mod{suffix}': '', + }, + f'in_package_{i}': { + '__init__.py': '', + f'mod{suffix}': '', + }, + } + ) + metadata.update(RECORD=fixtures.build_record(files)) + fixtures.build_files(files, prefix=self.site_dir) + + distributions = packages_distributions() + + for i in range(len(suffixes)): + assert distributions[f'importable-name {i}'] == ['all_distributions'] + assert distributions[f'in_namespace_{i}'] == ['all_distributions'] + assert distributions[f'in_package_{i}'] == ['all_distributions'] + + assert not any(name.endswith('.dist-info') for name in distributions) + + +class PackagesDistributionsEggTest( + fixtures.EggInfoPkg, + fixtures.EggInfoPkgPipInstalledNoToplevel, + fixtures.EggInfoPkgPipInstalledNoModules, + fixtures.EggInfoPkgSourcesFallback, + unittest.TestCase, +): + def test_packages_distributions_on_eggs(self): + """ + Test old-style egg packages with a variation of 'top_level.txt', + 'SOURCES.txt', and 'installed-files.txt', available. + """ + distributions = packages_distributions() + + def import_names_from_package(package_name): + return { + import_name + for import_name, package_names in distributions.items() + if package_name in package_names + } + + # egginfo-pkg declares one import ('mod') via top_level.txt + assert import_names_from_package('egginfo-pkg') == {'mod'} + + # egg_with_module-pkg has one import ('egg_with_module') inferred from + # installed-files.txt (top_level.txt is missing) + assert import_names_from_package('egg_with_module-pkg') == {'egg_with_module'} + + # egg_with_no_modules-pkg should not be associated with any import names + # (top_level.txt is empty, and installed-files.txt has no .py files) + assert import_names_from_package('egg_with_no_modules-pkg') == set() + + # sources_fallback-pkg has one import ('sources_fallback') inferred from + # SOURCES.txt (top_level.txt and installed-files.txt is missing) + assert import_names_from_package('sources_fallback-pkg') == {'sources_fallback'} diff --git a/Lib/test/test_importlib/test_metadata_api.py b/Lib/test/test_importlib/test_metadata_api.py index abf568fcca..55c9f8007e 100644 --- a/Lib/test/test_importlib/test_metadata_api.py +++ b/Lib/test/test_importlib/test_metadata_api.py @@ -27,12 +27,14 @@ def suppress_known_deprecation(): class APITests( fixtures.EggInfoPkg, + fixtures.EggInfoPkgPipInstalledNoToplevel, + fixtures.EggInfoPkgPipInstalledNoModules, + fixtures.EggInfoPkgSourcesFallback, fixtures.DistInfoPkg, fixtures.DistInfoPkgWithDot, fixtures.EggInfoFile, unittest.TestCase, ): - version_pattern = r'\d+\.\d+(\.\d)?' def test_retrieves_version_of_self(self): @@ -63,15 +65,28 @@ def test_prefix_not_matched(self): distribution(prefix) def test_for_top_level(self): - self.assertEqual( - distribution('egginfo-pkg').read_text('top_level.txt').strip(), 'mod' - ) + tests = [ + ('egginfo-pkg', 'mod'), + ('egg_with_no_modules-pkg', ''), + ] + for pkg_name, expect_content in tests: + with self.subTest(pkg_name): + self.assertEqual( + distribution(pkg_name).read_text('top_level.txt').strip(), + expect_content, + ) def test_read_text(self): - top_level = [ - path for path in files('egginfo-pkg') if path.name == 'top_level.txt' - ][0] - self.assertEqual(top_level.read_text(), 'mod\n') + tests = [ + ('egginfo-pkg', 'mod\n'), + ('egg_with_no_modules-pkg', '\n'), + ] + for pkg_name, expect_content in tests: + with self.subTest(pkg_name): + top_level = [ + path for path in files(pkg_name) if path.name == 'top_level.txt' + ][0] + self.assertEqual(top_level.read_text(), expect_content) def test_entry_points(self): eps = entry_points() @@ -124,62 +139,6 @@ def test_entry_points_missing_name(self): def test_entry_points_missing_group(self): assert entry_points(group='missing') == () - def test_entry_points_dict_construction(self): - """ - Prior versions of entry_points() returned simple lists and - allowed casting those lists into maps by name using ``dict()``. - Capture this now deprecated use-case. - """ - with suppress_known_deprecation() as caught: - eps = dict(entry_points(group='entries')) - - assert 'main' in eps - assert eps['main'] == entry_points(group='entries')['main'] - - # check warning - expected = next(iter(caught)) - assert expected.category is DeprecationWarning - assert "Construction of dict of EntryPoints is deprecated" in str(expected) - - def test_entry_points_by_index(self): - """ - Prior versions of Distribution.entry_points would return a - tuple that allowed access by index. - Capture this now deprecated use-case - See python/importlib_metadata#300 and bpo-44246. - """ - eps = distribution('distinfo-pkg').entry_points - with suppress_known_deprecation() as caught: - eps[0] - - # check warning - expected = next(iter(caught)) - assert expected.category is DeprecationWarning - assert "Accessing entry points by index is deprecated" in str(expected) - - def test_entry_points_groups_getitem(self): - """ - Prior versions of entry_points() returned a dict. Ensure - that callers using '.__getitem__()' are supported but warned to - migrate. - """ - with suppress_known_deprecation(): - entry_points()['entries'] == entry_points(group='entries') - - with self.assertRaises(KeyError): - entry_points()['missing'] - - def test_entry_points_groups_get(self): - """ - Prior versions of entry_points() returned a dict. Ensure - that callers using '.get()' are supported but warned to - migrate. - """ - with suppress_known_deprecation(): - entry_points().get('missing', 'default') == 'default' - entry_points().get('entries', 'default') == entry_points()['entries'] - entry_points().get('missing', ()) == () - # TODO: RUSTPYTHON @unittest.expectedFailure def test_entry_points_allows_no_attributes(self): @@ -195,6 +154,28 @@ def test_metadata_for_this_package(self): classifiers = md.get_all('Classifier') assert 'Topic :: Software Development :: Libraries' in classifiers + def test_missing_key_legacy(self): + """ + Requesting a missing key will still return None, but warn. + """ + md = metadata('distinfo-pkg') + with suppress_known_deprecation(): + assert md['does-not-exist'] is None + + def test_get_key(self): + """ + Getting a key gets the key. + """ + md = metadata('egginfo-pkg') + assert md.get('Name') == 'egginfo-pkg' + + def test_get_missing_key(self): + """ + Requesting a missing key will return None. + """ + md = metadata('distinfo-pkg') + assert md.get('does-not-exist') is None + @staticmethod def _test_files(files): root = files[0].root @@ -217,6 +198,9 @@ def test_files_dist_info(self): def test_files_egg_info(self): self._test_files(files('egginfo-pkg')) + self._test_files(files('egg_with_module-pkg')) + self._test_files(files('egg_with_no_modules-pkg')) + self._test_files(files('sources_fallback-pkg')) def test_version_egg_info_file(self): self.assertEqual(version('egginfo-file'), '0.1') diff --git a/Lib/test/test_importlib/test_namespace_pkgs.py b/Lib/test/test_importlib/test_namespace_pkgs.py index cd08498545..65428c3d3e 100644 --- a/Lib/test/test_importlib/test_namespace_pkgs.py +++ b/Lib/test/test_importlib/test_namespace_pkgs.py @@ -79,12 +79,9 @@ def test_cant_import_other(self): with self.assertRaises(ImportError): import foo.two - def test_module_repr(self): + def test_simple_repr(self): import foo.one - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - self.assertEqual(foo.__spec__.loader.module_repr(foo), - "") + assert repr(foo).startswith("'.format(module.__name__) - self.module.__loader__ = Loader() - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, '') - - def test_module___loader___module_repr_bad(self): - class Loader(TestLoader): - def module_repr(self, module): - raise Exception - self.module.__loader__ = Loader() - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, - ')>'.format('spam')) - - def test_module___spec__(self): - origin = 'in a hole, in the ground' - self.spec.origin = origin - self.module.__spec__ = self.spec - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, ''.format('spam', origin)) - - def test_module___spec___location(self): - location = 'in_a_galaxy_far_far_away.py' - self.spec.origin = location - self.spec._set_fileattr = True - self.module.__spec__ = self.spec - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, - ''.format('spam', location)) - - def test_module___spec___no_origin(self): - self.spec.loader = TestLoader() - self.module.__spec__ = self.spec - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, - ')>'.format('spam')) - - def test_module___spec___no_origin_no_loader(self): - self.spec.loader = None - self.module.__spec__ = self.spec - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, ''.format('spam')) - - def test_module_no_name(self): - del self.module.__name__ - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, ''.format('?')) - - def test_module_with_file(self): - filename = 'e/i/e/i/o/spam.py' - self.module.__file__ = filename - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, - ''.format('spam', filename)) - - def test_module_no_file(self): - self.module.__loader__ = TestLoader() - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, - ')>'.format('spam')) - - def test_module_no_file_no_loader(self): - modrepr = self.bootstrap._module_repr(self.module) - - self.assertEqual(modrepr, ''.format('spam')) - - -(Frozen_ModuleReprTests, - Source_ModuleReprTests - ) = test_util.test_both(ModuleReprTests, init=init, util=util, - machinery=machinery) - - class FactoryTests: def setUp(self): diff --git a/Lib/test/test_importlib/test_threaded_import.py b/Lib/test/test_importlib/test_threaded_import.py index 49c02484b7..148b2e4370 100644 --- a/Lib/test/test_importlib/test_threaded_import.py +++ b/Lib/test/test_importlib/test_threaded_import.py @@ -16,7 +16,7 @@ import unittest from unittest import mock from test.support import verbose -from test.support.import_helper import forget +from test.support.import_helper import forget, mock_register_at_fork from test.support.os_helper import (TESTFN, unlink, rmtree) from test.support import script_helper, threading_helper @@ -42,12 +42,6 @@ def task(N, done, done_tasks, errors): if finished: done.set() -def mock_register_at_fork(func): - # bpo-30599: Mock os.register_at_fork() when importing the random module, - # since this function doesn't allow to unregister callbacks and would leak - # memory. - return mock.patch('os.register_at_fork', create=True)(func) - # Create a circular import structure: A -> C -> B -> D -> A # NOTE: `time` is already loaded and therefore doesn't threaten to deadlock. @@ -251,7 +245,8 @@ def target(): self.addCleanup(forget, TESTFN) self.addCleanup(rmtree, '__pycache__') importlib.invalidate_caches() - __import__(TESTFN) + with threading_helper.wait_threads_exit(): + __import__(TESTFN) del sys.modules[TESTFN] @unittest.skip("TODO: RUSTPYTHON; hang") diff --git a/Lib/test/test_importlib/test_util.py b/Lib/test/test_importlib/test_util.py index 6c791fc012..dc27e4aa99 100644 --- a/Lib/test/test_importlib/test_util.py +++ b/Lib/test/test_importlib/test_util.py @@ -8,14 +8,29 @@ import importlib.util import os import pathlib +import re import string import sys from test import support +import textwrap import types import unittest import unittest.mock import warnings +try: + import _testsinglephase +except ImportError: + _testsinglephase = None +try: + import _testmultiphase +except ImportError: + _testmultiphase = None +try: + import _xxsubinterpreters as _interpreters +except ModuleNotFoundError: + _interpreters = None + class DecodeSourceBytesTests: @@ -127,247 +142,6 @@ def test___cached__(self): util=importlib_util) -class ModuleForLoaderTests: - - """Tests for importlib.util.module_for_loader.""" - - @classmethod - def module_for_loader(cls, func): - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - return cls.util.module_for_loader(func) - - def test_warning(self): - # Should raise a PendingDeprecationWarning when used. - with warnings.catch_warnings(): - warnings.simplefilter('error', DeprecationWarning) - with self.assertRaises(DeprecationWarning): - func = self.util.module_for_loader(lambda x: x) - - def return_module(self, name): - fxn = self.module_for_loader(lambda self, module: module) - return fxn(self, name) - - def raise_exception(self, name): - def to_wrap(self, module): - raise ImportError - fxn = self.module_for_loader(to_wrap) - try: - fxn(self, name) - except ImportError: - pass - - def test_new_module(self): - # Test that when no module exists in sys.modules a new module is - # created. - module_name = 'a.b.c' - with util.uncache(module_name): - module = self.return_module(module_name) - self.assertIn(module_name, sys.modules) - self.assertIsInstance(module, types.ModuleType) - self.assertEqual(module.__name__, module_name) - - def test_reload(self): - # Test that a module is reused if already in sys.modules. - class FakeLoader: - def is_package(self, name): - return True - @self.module_for_loader - def load_module(self, module): - return module - name = 'a.b.c' - module = types.ModuleType('a.b.c') - module.__loader__ = 42 - module.__package__ = 42 - with util.uncache(name): - sys.modules[name] = module - loader = FakeLoader() - returned_module = loader.load_module(name) - self.assertIs(returned_module, sys.modules[name]) - self.assertEqual(module.__loader__, loader) - self.assertEqual(module.__package__, name) - - def test_new_module_failure(self): - # Test that a module is removed from sys.modules if added but an - # exception is raised. - name = 'a.b.c' - with util.uncache(name): - self.raise_exception(name) - self.assertNotIn(name, sys.modules) - - def test_reload_failure(self): - # Test that a failure on reload leaves the module in-place. - name = 'a.b.c' - module = types.ModuleType(name) - with util.uncache(name): - sys.modules[name] = module - self.raise_exception(name) - self.assertIs(module, sys.modules[name]) - - def test_decorator_attrs(self): - def fxn(self, module): pass - wrapped = self.module_for_loader(fxn) - self.assertEqual(wrapped.__name__, fxn.__name__) - self.assertEqual(wrapped.__qualname__, fxn.__qualname__) - - def test_false_module(self): - # If for some odd reason a module is considered false, still return it - # from sys.modules. - class FalseModule(types.ModuleType): - def __bool__(self): return False - - name = 'mod' - module = FalseModule(name) - with util.uncache(name): - self.assertFalse(module) - sys.modules[name] = module - given = self.return_module(name) - self.assertIs(given, module) - - def test_attributes_set(self): - # __name__, __loader__, and __package__ should be set (when - # is_package() is defined; undefined implicitly tested elsewhere). - class FakeLoader: - def __init__(self, is_package): - self._pkg = is_package - def is_package(self, name): - return self._pkg - @self.module_for_loader - def load_module(self, module): - return module - - name = 'pkg.mod' - with util.uncache(name): - loader = FakeLoader(False) - module = loader.load_module(name) - self.assertEqual(module.__name__, name) - self.assertIs(module.__loader__, loader) - self.assertEqual(module.__package__, 'pkg') - - name = 'pkg.sub' - with util.uncache(name): - loader = FakeLoader(True) - module = loader.load_module(name) - self.assertEqual(module.__name__, name) - self.assertIs(module.__loader__, loader) - self.assertEqual(module.__package__, name) - - -(Frozen_ModuleForLoaderTests, - Source_ModuleForLoaderTests - ) = util.test_both(ModuleForLoaderTests, util=importlib_util) - - -class SetPackageTests: - - """Tests for importlib.util.set_package.""" - - def verify(self, module, expect): - """Verify the module has the expected value for __package__ after - passing through set_package.""" - fxn = lambda: module - wrapped = self.util.set_package(fxn) - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - wrapped() - self.assertTrue(hasattr(module, '__package__')) - self.assertEqual(expect, module.__package__) - - def test_top_level(self): - # __package__ should be set to the empty string if a top-level module. - # Implicitly tests when package is set to None. - module = types.ModuleType('module') - module.__package__ = None - self.verify(module, '') - - def test_package(self): - # Test setting __package__ for a package. - module = types.ModuleType('pkg') - module.__path__ = [''] - module.__package__ = None - self.verify(module, 'pkg') - - def test_submodule(self): - # Test __package__ for a module in a package. - module = types.ModuleType('pkg.mod') - module.__package__ = None - self.verify(module, 'pkg') - - def test_setting_if_missing(self): - # __package__ should be set if it is missing. - module = types.ModuleType('mod') - if hasattr(module, '__package__'): - delattr(module, '__package__') - self.verify(module, '') - - def test_leaving_alone(self): - # If __package__ is set and not None then leave it alone. - for value in (True, False): - module = types.ModuleType('mod') - module.__package__ = value - self.verify(module, value) - - def test_decorator_attrs(self): - def fxn(module): pass - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - wrapped = self.util.set_package(fxn) - self.assertEqual(wrapped.__name__, fxn.__name__) - self.assertEqual(wrapped.__qualname__, fxn.__qualname__) - - -(Frozen_SetPackageTests, - Source_SetPackageTests - ) = util.test_both(SetPackageTests, util=importlib_util) - - -class SetLoaderTests: - - """Tests importlib.util.set_loader().""" - - @property - def DummyLoader(self): - # Set DummyLoader on the class lazily. - class DummyLoader: - @self.util.set_loader - def load_module(self, module): - return self.module - self.__class__.DummyLoader = DummyLoader - return DummyLoader - - def test_no_attribute(self): - loader = self.DummyLoader() - loader.module = types.ModuleType('blah') - try: - del loader.module.__loader__ - except AttributeError: - pass - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.assertEqual(loader, loader.load_module('blah').__loader__) - - def test_attribute_is_None(self): - loader = self.DummyLoader() - loader.module = types.ModuleType('blah') - loader.module.__loader__ = None - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.assertEqual(loader, loader.load_module('blah').__loader__) - - def test_not_reset(self): - loader = self.DummyLoader() - loader.module = types.ModuleType('blah') - loader.module.__loader__ = 42 - with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.assertEqual(42, loader.load_module('blah').__loader__) - - -(Frozen_SetLoaderTests, - Source_SetLoaderTests - ) = util.test_both(SetLoaderTests, util=importlib_util) - - class ResolveNameTests: """Tests importlib.util.resolve_name().""" @@ -877,7 +651,7 @@ def test_magic_number(self): # stakeholders such as OS package maintainers must be notified # in advance. Such exceptional releases will then require an # adjustment to this test case. - EXPECTED_MAGIC_NUMBER = 3495 + EXPECTED_MAGIC_NUMBER = 3531 actual = int.from_bytes(importlib.util.MAGIC_NUMBER[:2], 'little') msg = ( @@ -895,5 +669,111 @@ def test_magic_number(self): self.assertEqual(EXPECTED_MAGIC_NUMBER, actual, msg) +@unittest.skipIf(_interpreters is None, 'subinterpreters required') +class IncompatibleExtensionModuleRestrictionsTests(unittest.TestCase): + + ERROR = re.compile("^: module (.*) does not support loading in subinterpreters") + + def run_with_own_gil(self, script): + interpid = _interpreters.create(isolated=True) + try: + _interpreters.run_string(interpid, script) + except _interpreters.RunFailedError as exc: + if m := self.ERROR.match(str(exc)): + modname, = m.groups() + raise ImportError(modname) + + def run_with_shared_gil(self, script): + interpid = _interpreters.create(isolated=False) + try: + _interpreters.run_string(interpid, script) + except _interpreters.RunFailedError as exc: + if m := self.ERROR.match(str(exc)): + modname, = m.groups() + raise ImportError(modname) + + @unittest.skipIf(_testsinglephase is None, "test requires _testsinglephase module") + def test_single_phase_init_module(self): + script = textwrap.dedent(''' + from importlib.util import _incompatible_extension_module_restrictions + with _incompatible_extension_module_restrictions(disable_check=True): + import _testsinglephase + ''') + with self.subTest('check disabled, shared GIL'): + self.run_with_shared_gil(script) + with self.subTest('check disabled, per-interpreter GIL'): + self.run_with_own_gil(script) + + script = textwrap.dedent(f''' + from importlib.util import _incompatible_extension_module_restrictions + with _incompatible_extension_module_restrictions(disable_check=False): + import _testsinglephase + ''') + with self.subTest('check enabled, shared GIL'): + with self.assertRaises(ImportError): + self.run_with_shared_gil(script) + with self.subTest('check enabled, per-interpreter GIL'): + with self.assertRaises(ImportError): + self.run_with_own_gil(script) + + @unittest.skipIf(_testmultiphase is None, "test requires _testmultiphase module") + def test_incomplete_multi_phase_init_module(self): + prescript = textwrap.dedent(f''' + from importlib.util import spec_from_loader, module_from_spec + from importlib.machinery import ExtensionFileLoader + + name = '_test_shared_gil_only' + filename = {_testmultiphase.__file__!r} + loader = ExtensionFileLoader(name, filename) + spec = spec_from_loader(name, loader) + + ''') + + script = prescript + textwrap.dedent(''' + from importlib.util import _incompatible_extension_module_restrictions + with _incompatible_extension_module_restrictions(disable_check=True): + module = module_from_spec(spec) + loader.exec_module(module) + ''') + with self.subTest('check disabled, shared GIL'): + self.run_with_shared_gil(script) + with self.subTest('check disabled, per-interpreter GIL'): + self.run_with_own_gil(script) + + script = prescript + textwrap.dedent(''' + from importlib.util import _incompatible_extension_module_restrictions + with _incompatible_extension_module_restrictions(disable_check=False): + module = module_from_spec(spec) + loader.exec_module(module) + ''') + with self.subTest('check enabled, shared GIL'): + self.run_with_shared_gil(script) + with self.subTest('check enabled, per-interpreter GIL'): + with self.assertRaises(ImportError): + self.run_with_own_gil(script) + + @unittest.skipIf(_testmultiphase is None, "test requires _testmultiphase module") + def test_complete_multi_phase_init_module(self): + script = textwrap.dedent(''' + from importlib.util import _incompatible_extension_module_restrictions + with _incompatible_extension_module_restrictions(disable_check=True): + import _testmultiphase + ''') + with self.subTest('check disabled, shared GIL'): + self.run_with_shared_gil(script) + with self.subTest('check disabled, per-interpreter GIL'): + self.run_with_own_gil(script) + + script = textwrap.dedent(f''' + from importlib.util import _incompatible_extension_module_restrictions + with _incompatible_extension_module_restrictions(disable_check=False): + import _testmultiphase + ''') + with self.subTest('check enabled, shared GIL'): + self.run_with_shared_gil(script) + with self.subTest('check enabled, per-interpreter GIL'): + self.run_with_own_gil(script) + + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_importlib/test_windows.py b/Lib/test/test_importlib/test_windows.py index 051193fae0..f8a9ead9ac 100644 --- a/Lib/test/test_importlib/test_windows.py +++ b/Lib/test/test_importlib/test_windows.py @@ -92,30 +92,16 @@ class WindowsRegistryFinderTests: def test_find_spec_missing(self): spec = self.machinery.WindowsRegistryFinder.find_spec('spam') - self.assertIs(spec, None) - - def test_find_module_missing(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - loader = self.machinery.WindowsRegistryFinder.find_module('spam') - self.assertIs(loader, None) + self.assertIsNone(spec) def test_module_found(self): with setup_module(self.machinery, self.test_module): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - loader = self.machinery.WindowsRegistryFinder.find_module(self.test_module) spec = self.machinery.WindowsRegistryFinder.find_spec(self.test_module) - self.assertIsNot(loader, None) - self.assertIsNot(spec, None) + self.assertIsNotNone(spec) def test_module_not_found(self): with setup_module(self.machinery, self.test_module, path="."): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - loader = self.machinery.WindowsRegistryFinder.find_module(self.test_module) spec = self.machinery.WindowsRegistryFinder.find_spec(self.test_module) - self.assertIsNone(loader) self.assertIsNone(spec) (Frozen_WindowsRegistryFinderTests, diff --git a/Lib/test/test_importlib/util.py b/Lib/test/test_importlib/util.py index 0b6dcc5eaf..c25be096e5 100644 --- a/Lib/test/test_importlib/util.py +++ b/Lib/test/test_importlib/util.py @@ -27,7 +27,7 @@ EXTENSIONS.ext = None EXTENSIONS.filename = None EXTENSIONS.file_path = None -EXTENSIONS.name = '_testcapi' +EXTENSIONS.name = '_testsinglephase' def _extension_details(): global EXTENSIONS @@ -131,9 +131,8 @@ def uncache(*names): """ for name in names: - if name in ('sys', 'marshal', 'imp'): - raise ValueError( - "cannot uncache {0}".format(name)) + if name in ('sys', 'marshal'): + raise ValueError("cannot uncache {}".format(name)) try: del sys.modules[name] except KeyError: @@ -195,8 +194,7 @@ def import_state(**kwargs): new_value = default setattr(sys, attr, new_value) if len(kwargs): - raise ValueError( - 'unrecognized arguments: {0}'.format(kwargs.keys())) + raise ValueError('unrecognized arguments: {}'.format(kwargs)) yield finally: for attr, value in originals.items(): @@ -244,30 +242,6 @@ def __exit__(self, *exc_info): self._uncache.__exit__(None, None, None) -class mock_modules(_ImporterMock): - - """Importer mock using PEP 302 APIs.""" - - def find_module(self, fullname, path=None): - if fullname not in self.modules: - return None - else: - return self - - def load_module(self, fullname): - if fullname not in self.modules: - raise ImportError - else: - sys.modules[fullname] = self.modules[fullname] - if fullname in self.module_code: - try: - self.module_code[fullname]() - except Exception: - del sys.modules[fullname] - raise - return self.modules[fullname] - - class mock_spec(_ImporterMock): """Importer mock using PEP 451 APIs.""" diff --git a/Lib/test/test_module.py b/Lib/test/test_module/__init__.py similarity index 90% rename from Lib/test/test_module.py rename to Lib/test/test_module/__init__.py index b921fc6f4e..d8a0ba0803 100644 --- a/Lib/test/test_module.py +++ b/Lib/test/test_module/__init__.py @@ -8,17 +8,16 @@ import sys ModuleType = type(sys) + class FullLoader: - @classmethod - def module_repr(cls, m): - return "".format(m.__name__) + pass + class BareLoader: pass class ModuleTests(unittest.TestCase): - def test_uninitialized(self): # An uninitialized module has no __dict__ or __name__, # and __doc__ is None @@ -128,11 +127,9 @@ def test_weakref(self): gc_collect() self.assertIs(wr(), None) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_module_getattr(self): - import test.good_getattr as gga - from test.good_getattr import test + import test.test_module.good_getattr as gga + from test.test_module.good_getattr import test self.assertEqual(test, "There is test") self.assertEqual(gga.x, 1) self.assertEqual(gga.y, 2) @@ -140,54 +137,50 @@ def test_module_getattr(self): "Deprecated, use whatever instead"): gga.yolo self.assertEqual(gga.whatever, "There is whatever") - del sys.modules['test.good_getattr'] + del sys.modules['test.test_module.good_getattr'] - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_module_getattr_errors(self): - import test.bad_getattr as bga - from test import bad_getattr2 + import test.test_module.bad_getattr as bga + from test.test_module import bad_getattr2 self.assertEqual(bga.x, 1) self.assertEqual(bad_getattr2.x, 1) with self.assertRaises(TypeError): bga.nope with self.assertRaises(TypeError): bad_getattr2.nope - del sys.modules['test.bad_getattr'] - if 'test.bad_getattr2' in sys.modules: - del sys.modules['test.bad_getattr2'] + del sys.modules['test.test_module.bad_getattr'] + if 'test.test_module.bad_getattr2' in sys.modules: + del sys.modules['test.test_module.bad_getattr2'] # TODO: RUSTPYTHON @unittest.expectedFailure def test_module_dir(self): - import test.good_getattr as gga + import test.test_module.good_getattr as gga self.assertEqual(dir(gga), ['a', 'b', 'c']) - del sys.modules['test.good_getattr'] + del sys.modules['test.test_module.good_getattr'] # TODO: RUSTPYTHON @unittest.expectedFailure def test_module_dir_errors(self): - import test.bad_getattr as bga - from test import bad_getattr2 + import test.test_module.bad_getattr as bga + from test.test_module import bad_getattr2 with self.assertRaises(TypeError): dir(bga) with self.assertRaises(TypeError): dir(bad_getattr2) - del sys.modules['test.bad_getattr'] - if 'test.bad_getattr2' in sys.modules: - del sys.modules['test.bad_getattr2'] + del sys.modules['test.test_module.bad_getattr'] + if 'test.test_module.bad_getattr2' in sys.modules: + del sys.modules['test.test_module.bad_getattr2'] - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_module_getattr_tricky(self): - from test import bad_getattr3 + from test.test_module import bad_getattr3 # these lookups should not crash with self.assertRaises(AttributeError): bad_getattr3.one with self.assertRaises(AttributeError): bad_getattr3.delgetattr - if 'test.bad_getattr3' in sys.modules: - del sys.modules['test.bad_getattr3'] + if 'test.test_module.bad_getattr3' in sys.modules: + del sys.modules['test.test_module.bad_getattr3'] def test_module_repr_minimal(self): # reprs when modules have no __file__, __name__, or __loader__ @@ -249,10 +242,9 @@ def test_module_repr_with_full_loader(self): # Yes, a class not an instance. m.__loader__ = FullLoader self.assertEqual( - repr(m), "") + repr(m), f")>") def test_module_repr_with_bare_loader_and_filename(self): - # Because the loader has no module_repr(), use the file name. m = ModuleType('foo') # Yes, a class not an instance. m.__loader__ = BareLoader @@ -260,12 +252,11 @@ def test_module_repr_with_bare_loader_and_filename(self): self.assertEqual(repr(m), "") def test_module_repr_with_full_loader_and_filename(self): - # Even though the module has an __file__, use __loader__.module_repr() m = ModuleType('foo') # Yes, a class not an instance. m.__loader__ = FullLoader m.__file__ = '/tmp/foo.py' - self.assertEqual(repr(m), "") + self.assertEqual(repr(m), "") def test_module_repr_builtin(self): self.assertEqual(repr(sys), "") diff --git a/Lib/test/test_module/bad_getattr.py b/Lib/test/test_module/bad_getattr.py new file mode 100644 index 0000000000..16f901b13b --- /dev/null +++ b/Lib/test/test_module/bad_getattr.py @@ -0,0 +1,4 @@ +x = 1 + +__getattr__ = "Surprise!" +__dir__ = "Surprise again!" diff --git a/Lib/test/test_module/bad_getattr2.py b/Lib/test/test_module/bad_getattr2.py new file mode 100644 index 0000000000..0a52a53b54 --- /dev/null +++ b/Lib/test/test_module/bad_getattr2.py @@ -0,0 +1,7 @@ +def __getattr__(): + "Bad one" + +x = 1 + +def __dir__(bad_sig): + return [] diff --git a/Lib/test/test_module/bad_getattr3.py b/Lib/test/test_module/bad_getattr3.py new file mode 100644 index 0000000000..0d5f9266c7 --- /dev/null +++ b/Lib/test/test_module/bad_getattr3.py @@ -0,0 +1,5 @@ +def __getattr__(name): + if name != 'delgetattr': + raise AttributeError + del globals()['__getattr__'] + raise AttributeError diff --git a/Lib/test/test_module/good_getattr.py b/Lib/test/test_module/good_getattr.py new file mode 100644 index 0000000000..7d27de6262 --- /dev/null +++ b/Lib/test/test_module/good_getattr.py @@ -0,0 +1,11 @@ +x = 1 + +def __dir__(): + return ['a', 'b', 'c'] + +def __getattr__(name): + if name == "yolo": + raise AttributeError("Deprecated, use whatever instead") + return f"There is {name}" + +y = 2 diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 6ad272697b..673160c20b 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -9,7 +9,6 @@ import sys import tempfile import textwrap -import time import unittest import warnings @@ -31,7 +30,7 @@ def setUpClass(cls): "test.support.warnings_helper", like=".*used in test_support.*" ) cls._test_support_token = support.ignore_deprecations_from( - "test.test_support", like=".*You should NOT be seeing this.*" + __name__, like=".*You should NOT be seeing this.*" ) assert len(warnings.filters) == orig_filter_len + 2 @@ -464,18 +463,12 @@ def test_reap_children(self): # child process: do nothing, just exit os._exit(0) - t0 = time.monotonic() - deadline = time.monotonic() + support.SHORT_TIMEOUT - was_altered = support.environment_altered try: support.environment_altered = False stderr = io.StringIO() - while True: - if time.monotonic() > deadline: - self.fail("timeout") - + for _ in support.sleeping_retry(support.SHORT_TIMEOUT): with support.swap_attr(support.print_warning, 'orig_stderr', stderr): support.reap_children() @@ -484,9 +477,6 @@ def test_reap_children(self): if support.environment_altered: break - # loop until the child process completed - time.sleep(0.100) - msg = "Warning -- reap_children() reaped child process %s" % pid self.assertIn(msg, stderr.getvalue()) self.assertTrue(support.environment_altered) @@ -513,6 +503,7 @@ def check_options(self, args, func, expected=None): self.assertEqual(proc.stdout.rstrip(), repr(expected)) self.assertEqual(proc.returncode, 0) + @support.requires_resource('cpu') def test_args_from_interpreter_flags(self): # Test test.support.args_from_interpreter_flags() for opts in ( @@ -702,6 +693,84 @@ def test_has_strftime_extensions(self): else: self.assertTrue(support.has_strftime_extensions) + @unittest.expectedFailure + def test_get_recursion_depth(self): + # test support.get_recursion_depth() + code = textwrap.dedent(""" + from test import support + import sys + + def check(cond): + if not cond: + raise AssertionError("test failed") + + # depth 1 + check(support.get_recursion_depth() == 1) + + # depth 2 + def test_func(): + check(support.get_recursion_depth() == 2) + test_func() + + def test_recursive(depth, limit): + if depth >= limit: + # cannot call get_recursion_depth() at this depth, + # it can raise RecursionError + return + get_depth = support.get_recursion_depth() + print(f"test_recursive: {depth}/{limit}: " + f"get_recursion_depth() says {get_depth}") + check(get_depth == depth) + test_recursive(depth + 1, limit) + + # depth up to 25 + with support.infinite_recursion(max_depth=25): + limit = sys.getrecursionlimit() + print(f"test with sys.getrecursionlimit()={limit}") + test_recursive(2, limit) + + # depth up to 500 + with support.infinite_recursion(max_depth=500): + limit = sys.getrecursionlimit() + print(f"test with sys.getrecursionlimit()={limit}") + test_recursive(2, limit) + """) + script_helper.assert_python_ok("-c", code) + + def test_recursion(self): + # Test infinite_recursion() and get_recursion_available() functions. + def recursive_function(depth): + if depth: + recursive_function(depth - 1) + + for max_depth in (5, 25, 250): + with support.infinite_recursion(max_depth): + available = support.get_recursion_available() + + # Recursion up to 'available' additional frames should be OK. + recursive_function(available) + + # Recursion up to 'available+1' additional frames must raise + # RecursionError. Avoid self.assertRaises(RecursionError) which + # can consume more than 3 frames and so raises RecursionError. + try: + recursive_function(available + 1) + except RecursionError: + pass + else: + self.fail("RecursionError was not raised") + + # Test the bare minimumum: max_depth=3 + with support.infinite_recursion(3): + try: + recursive_function(3) + except RecursionError: + pass + else: + self.fail("RecursionError was not raised") + + #self.assertEqual(available, 2) + # XXX -follows a list of untested API # make_legacy_pyc # is_resource_enabled diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index 071a2a06c1..17c9f01cd8 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -9,17 +9,22 @@ import codecs import itertools import operator +import pickle import struct import sys import textwrap import unicodedata import unittest import warnings -from test.support import import_helper from test.support import warnings_helper from test import support, string_tests from test.support.script_helper import assert_python_failure +try: + import _testcapi +except ImportError: + _testcapi = None + # Error handling (bad decoder return) def search_function(encoding): def decode1(input, errors="strict"): @@ -89,88 +94,85 @@ def test_literals(self): self.assertNotEqual(r"\u0020", " ") def test_ascii(self): - if not sys.platform.startswith('java'): - # Test basic sanity of repr() - self.assertEqual(ascii('abc'), "'abc'") - self.assertEqual(ascii('ab\\c'), "'ab\\\\c'") - self.assertEqual(ascii('ab\\'), "'ab\\\\'") - self.assertEqual(ascii('\\c'), "'\\\\c'") - self.assertEqual(ascii('\\'), "'\\\\'") - self.assertEqual(ascii('\n'), "'\\n'") - self.assertEqual(ascii('\r'), "'\\r'") - self.assertEqual(ascii('\t'), "'\\t'") - self.assertEqual(ascii('\b'), "'\\x08'") - self.assertEqual(ascii("'\""), """'\\'"'""") - self.assertEqual(ascii("'\""), """'\\'"'""") - self.assertEqual(ascii("'"), '''"'"''') - self.assertEqual(ascii('"'), """'"'""") - latin1repr = ( - "'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r" - "\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a" - "\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHI" - "JKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f" - "\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d" - "\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b" - "\\x9c\\x9d\\x9e\\x9f\\xa0\\xa1\\xa2\\xa3\\xa4\\xa5\\xa6\\xa7\\xa8\\xa9" - "\\xaa\\xab\\xac\\xad\\xae\\xaf\\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7" - "\\xb8\\xb9\\xba\\xbb\\xbc\\xbd\\xbe\\xbf\\xc0\\xc1\\xc2\\xc3\\xc4\\xc5" - "\\xc6\\xc7\\xc8\\xc9\\xca\\xcb\\xcc\\xcd\\xce\\xcf\\xd0\\xd1\\xd2\\xd3" - "\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0\\xe1" - "\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef" - "\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb\\xfc\\xfd" - "\\xfe\\xff'") - testrepr = ascii(''.join(map(chr, range(256)))) - self.assertEqual(testrepr, latin1repr) - # Test ascii works on wide unicode escapes without overflow. - self.assertEqual(ascii("\U00010000" * 39 + "\uffff" * 4096), - ascii("\U00010000" * 39 + "\uffff" * 4096)) - - class WrongRepr: - def __repr__(self): - return b'byte-repr' - self.assertRaises(TypeError, ascii, WrongRepr()) + self.assertEqual(ascii('abc'), "'abc'") + self.assertEqual(ascii('ab\\c'), "'ab\\\\c'") + self.assertEqual(ascii('ab\\'), "'ab\\\\'") + self.assertEqual(ascii('\\c'), "'\\\\c'") + self.assertEqual(ascii('\\'), "'\\\\'") + self.assertEqual(ascii('\n'), "'\\n'") + self.assertEqual(ascii('\r'), "'\\r'") + self.assertEqual(ascii('\t'), "'\\t'") + self.assertEqual(ascii('\b'), "'\\x08'") + self.assertEqual(ascii("'\""), """'\\'"'""") + self.assertEqual(ascii("'\""), """'\\'"'""") + self.assertEqual(ascii("'"), '''"'"''') + self.assertEqual(ascii('"'), """'"'""") + latin1repr = ( + "'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r" + "\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a" + "\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHI" + "JKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f" + "\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d" + "\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b" + "\\x9c\\x9d\\x9e\\x9f\\xa0\\xa1\\xa2\\xa3\\xa4\\xa5\\xa6\\xa7\\xa8\\xa9" + "\\xaa\\xab\\xac\\xad\\xae\\xaf\\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7" + "\\xb8\\xb9\\xba\\xbb\\xbc\\xbd\\xbe\\xbf\\xc0\\xc1\\xc2\\xc3\\xc4\\xc5" + "\\xc6\\xc7\\xc8\\xc9\\xca\\xcb\\xcc\\xcd\\xce\\xcf\\xd0\\xd1\\xd2\\xd3" + "\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0\\xe1" + "\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef" + "\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb\\xfc\\xfd" + "\\xfe\\xff'") + testrepr = ascii(''.join(map(chr, range(256)))) + self.assertEqual(testrepr, latin1repr) + # Test ascii works on wide unicode escapes without overflow. + self.assertEqual(ascii("\U00010000" * 39 + "\uffff" * 4096), + ascii("\U00010000" * 39 + "\uffff" * 4096)) + + class WrongRepr: + def __repr__(self): + return b'byte-repr' + self.assertRaises(TypeError, ascii, WrongRepr()) def test_repr(self): - if not sys.platform.startswith('java'): - # Test basic sanity of repr() - self.assertEqual(repr('abc'), "'abc'") - self.assertEqual(repr('ab\\c'), "'ab\\\\c'") - self.assertEqual(repr('ab\\'), "'ab\\\\'") - self.assertEqual(repr('\\c'), "'\\\\c'") - self.assertEqual(repr('\\'), "'\\\\'") - self.assertEqual(repr('\n'), "'\\n'") - self.assertEqual(repr('\r'), "'\\r'") - self.assertEqual(repr('\t'), "'\\t'") - self.assertEqual(repr('\b'), "'\\x08'") - self.assertEqual(repr("'\""), """'\\'"'""") - self.assertEqual(repr("'\""), """'\\'"'""") - self.assertEqual(repr("'"), '''"'"''') - self.assertEqual(repr('"'), """'"'""") - latin1repr = ( - "'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r" - "\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a" - "\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHI" - "JKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f" - "\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d" - "\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b" - "\\x9c\\x9d\\x9e\\x9f\\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9" - "\xaa\xab\xac\\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7" - "\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5" - "\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3" - "\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1" - "\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef" - "\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd" - "\xfe\xff'") - testrepr = repr(''.join(map(chr, range(256)))) - self.assertEqual(testrepr, latin1repr) - # Test repr works on wide unicode escapes without overflow. - self.assertEqual(repr("\U00010000" * 39 + "\uffff" * 4096), - repr("\U00010000" * 39 + "\uffff" * 4096)) - - class WrongRepr: - def __repr__(self): - return b'byte-repr' - self.assertRaises(TypeError, repr, WrongRepr()) + # Test basic sanity of repr() + self.assertEqual(repr('abc'), "'abc'") + self.assertEqual(repr('ab\\c'), "'ab\\\\c'") + self.assertEqual(repr('ab\\'), "'ab\\\\'") + self.assertEqual(repr('\\c'), "'\\\\c'") + self.assertEqual(repr('\\'), "'\\\\'") + self.assertEqual(repr('\n'), "'\\n'") + self.assertEqual(repr('\r'), "'\\r'") + self.assertEqual(repr('\t'), "'\\t'") + self.assertEqual(repr('\b'), "'\\x08'") + self.assertEqual(repr("'\""), """'\\'"'""") + self.assertEqual(repr("'\""), """'\\'"'""") + self.assertEqual(repr("'"), '''"'"''') + self.assertEqual(repr('"'), """'"'""") + latin1repr = ( + "'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b\\x0c\\r" + "\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a" + "\\x1b\\x1c\\x1d\\x1e\\x1f !\"#$%&\\'()*+,-./0123456789:;<=>?@ABCDEFGHI" + "JKLMNOPQRSTUVWXYZ[\\\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\\x7f" + "\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d" + "\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b" + "\\x9c\\x9d\\x9e\\x9f\\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9" + "\xaa\xab\xac\\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7" + "\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5" + "\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3" + "\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1" + "\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef" + "\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd" + "\xfe\xff'") + testrepr = repr(''.join(map(chr, range(256)))) + self.assertEqual(testrepr, latin1repr) + # Test repr works on wide unicode escapes without overflow. + self.assertEqual(repr("\U00010000" * 39 + "\uffff" * 4096), + repr("\U00010000" * 39 + "\uffff" * 4096)) + + class WrongRepr: + def __repr__(self): + return b'byte-repr' + self.assertRaises(TypeError, repr, WrongRepr()) def test_iterators(self): # Make sure unicode objects have an __iter__ method @@ -180,6 +182,36 @@ def test_iterators(self): self.assertEqual(next(it), "\u3333") self.assertRaises(StopIteration, next, it) + def test_iterators_invocation(self): + cases = [type(iter('abc')), type(iter('🚀'))] + for cls in cases: + with self.subTest(cls=cls): + self.assertRaises(TypeError, cls) + + def test_iteration(self): + cases = ['abc', '🚀🚀🚀', "\u1111\u2222\u3333"] + for case in cases: + with self.subTest(string=case): + self.assertEqual(case, "".join(iter(case))) + + def test_exhausted_iterator(self): + cases = ['abc', '🚀🚀🚀', "\u1111\u2222\u3333"] + for case in cases: + with self.subTest(case=case): + iterator = iter(case) + tuple(iterator) + self.assertRaises(StopIteration, next, iterator) + + def test_pickle_iterator(self): + cases = ['abc', '🚀🚀🚀', "\u1111\u2222\u3333"] + for case in cases: + with self.subTest(case=case): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + it = iter(case) + with self.subTest(proto=proto): + pickled = "".join(pickle.loads(pickle.dumps(it, proto))) + self.assertEqual(case, pickled) + def test_count(self): string_tests.CommonTest.test_count(self) # check mixed argument types @@ -205,6 +237,10 @@ def test_count(self): self.checkequal(0, 'a' * 10, 'count', 'a\u0102') self.checkequal(0, 'a' * 10, 'count', 'a\U00100304') self.checkequal(0, '\u0102' * 10, 'count', '\u0102\U00100304') + # test subclass + class MyStr(str): + pass + self.checkequal(3, MyStr('aaa'), 'count', 'a') def test_find(self): string_tests.CommonTest.test_find(self) @@ -221,6 +257,20 @@ def test_find(self): self.checkequalnofix(9, 'abcdefghiabc', 'find', 'abc', 1) self.checkequalnofix(-1, 'abcdefghiabc', 'find', 'def', 4) + # test utf-8 non-ascii char + self.checkequal(0, 'тест', 'find', 'т') + self.checkequal(3, 'тест', 'find', 'т', 1) + self.checkequal(-1, 'тест', 'find', 'т', 1, 3) + self.checkequal(-1, 'тест', 'find', 'e') # english `e` + # test utf-8 non-ascii slice + self.checkequal(1, 'тест тест', 'find', 'ес') + self.checkequal(1, 'тест тест', 'find', 'ес', 1) + self.checkequal(1, 'тест тест', 'find', 'ес', 1, 3) + self.checkequal(6, 'тест тест', 'find', 'ес', 2) + self.checkequal(-1, 'тест тест', 'find', 'ес', 6, 7) + self.checkequal(-1, 'тест тест', 'find', 'ес', 7) + self.checkequal(-1, 'тест тест', 'find', 'ec') # english `ec` + self.assertRaises(TypeError, 'hello'.find) self.assertRaises(TypeError, 'hello'.find, 42) # test mixed kinds @@ -251,6 +301,19 @@ def test_rfind(self): self.checkequalnofix(9, 'abcdefghiabc', 'rfind', 'abc') self.checkequalnofix(12, 'abcdefghiabc', 'rfind', '') self.checkequalnofix(12, 'abcdefghiabc', 'rfind', '') + # test utf-8 non-ascii char + self.checkequal(1, 'тест', 'rfind', 'е') + self.checkequal(1, 'тест', 'rfind', 'е', 1) + self.checkequal(-1, 'тест', 'rfind', 'е', 2) + self.checkequal(-1, 'тест', 'rfind', 'e') # english `e` + # test utf-8 non-ascii slice + self.checkequal(6, 'тест тест', 'rfind', 'ес') + self.checkequal(6, 'тест тест', 'rfind', 'ес', 1) + self.checkequal(1, 'тест тест', 'rfind', 'ес', 1, 3) + self.checkequal(6, 'тест тест', 'rfind', 'ес', 2) + self.checkequal(-1, 'тест тест', 'rfind', 'ес', 6, 7) + self.checkequal(-1, 'тест тест', 'rfind', 'ес', 7) + self.checkequal(-1, 'тест тест', 'rfind', 'ec') # english `ec` # test mixed kinds self.checkequal(0, 'a' + '\u0102' * 100, 'rfind', 'a') self.checkequal(0, 'a' + '\U00100304' * 100, 'rfind', 'a') @@ -407,10 +470,10 @@ def test_split(self): def test_rsplit(self): string_tests.CommonTest.test_rsplit(self) # test mixed kinds - for left, right in ('ba', '\u0101\u0100', '\U00010301\U00010300'): + for left, right in ('ba', 'юё', '\u0101\u0100', '\U00010301\U00010300'): left *= 9 right *= 9 - for delim in ('c', '\u0102', '\U00010302'): + for delim in ('c', 'ы', '\u0102', '\U00010302'): self.checkequal([left + right], left + right, 'rsplit', delim) self.checkequal([left, right], @@ -420,6 +483,10 @@ def test_rsplit(self): self.checkequal([left, right], left + delim * 2 + right, 'rsplit', delim *2) + # Check `None` as well: + self.checkequal([left + right], + left + right, 'rsplit', None) + def test_partition(self): string_tests.MixinStrUnicodeUserStringTest.test_partition(self) # test mixed kinds @@ -619,8 +686,7 @@ def test_islower(self): def test_isupper(self): super().test_isupper() - if not sys.platform.startswith('java'): - self.checkequalnofix(False, '\u1FFc', 'isupper') + self.checkequalnofix(False, '\u1FFc', 'isupper') self.assertTrue('\u2167'.isupper()) self.assertFalse('\u2177'.isupper()) # non-BMP, uppercase @@ -757,9 +823,9 @@ def test_isidentifier(self): self.assertFalse("0".isidentifier()) @support.cpython_only - @support.requires_legacy_unicode_capi + @support.requires_legacy_unicode_capi() + @unittest.skipIf(_testcapi is None, 'need _testcapi module') def test_isidentifier_legacy(self): - import _testcapi u = '𝖀𝖓𝖎𝖈𝖔𝖉𝖊' self.assertTrue(u.isidentifier()) with warnings_helper.check_warnings(): @@ -1261,6 +1327,20 @@ def __repr__(self): self.assertRaises(ValueError, ("{" + big + "}").format) self.assertRaises(ValueError, ("{[" + big + "]}").format, [0]) + # test number formatter errors: + self.assertRaises(ValueError, '{0:x}'.format, 1j) + self.assertRaises(ValueError, '{0:x}'.format, 1.0) + self.assertRaises(ValueError, '{0:X}'.format, 1j) + self.assertRaises(ValueError, '{0:X}'.format, 1.0) + self.assertRaises(ValueError, '{0:o}'.format, 1j) + self.assertRaises(ValueError, '{0:o}'.format, 1.0) + self.assertRaises(ValueError, '{0:u}'.format, 1j) + self.assertRaises(ValueError, '{0:u}'.format, 1.0) + self.assertRaises(ValueError, '{0:i}'.format, 1j) + self.assertRaises(ValueError, '{0:i}'.format, 1.0) + self.assertRaises(ValueError, '{0:d}'.format, 1j) + self.assertRaises(ValueError, '{0:d}'.format, 1.0) + # issue 6089 self.assertRaises(ValueError, "{0[0]x}".format, [None]) self.assertRaises(ValueError, "{0[0](10)}".format, [None]) @@ -1431,10 +1511,9 @@ def test_formatting(self): self.assertEqual("%s, %s, %i, %f, %5.2f" % ("abc", "abc", -1, -2, 3.5), 'abc, abc, -1, -2.000000, 3.50') self.assertEqual("%s, %s, %i, %f, %5.2f" % ("abc", "abc", -1, -2, 3.57), 'abc, abc, -1, -2.000000, 3.57') self.assertEqual("%s, %s, %i, %f, %5.2f" % ("abc", "abc", -1, -2, 1003.57), 'abc, abc, -1, -2.000000, 1003.57') - if not sys.platform.startswith('java'): - self.assertEqual("%r, %r" % (b"abc", "abc"), "b'abc', 'abc'") - self.assertEqual("%r" % ("\u1234",), "'\u1234'") - self.assertEqual("%a" % ("\u1234",), "'\\u1234'") + self.assertEqual("%r, %r" % (b"abc", "abc"), "b'abc', 'abc'") + self.assertEqual("%r" % ("\u1234",), "'\u1234'") + self.assertEqual("%a" % ("\u1234",), "'\\u1234'") self.assertEqual("%(x)s, %(y)s" % {'x':"abc", 'y':"def"}, 'abc, def') self.assertEqual("%(x)s, %(\xfc)s" % {'x':"abc", '\xfc':"def"}, 'abc, def') @@ -1503,38 +1582,60 @@ def __int__(self): self.assertEqual('%X' % letter_m, '6D') self.assertEqual('%o' % letter_m, '155') self.assertEqual('%c' % letter_m, 'm') - self.assertRaisesRegex(TypeError, '%x format: an integer is required, not float', operator.mod, '%x', 3.14), - self.assertRaisesRegex(TypeError, '%X format: an integer is required, not float', operator.mod, '%X', 2.11), - self.assertRaisesRegex(TypeError, '%o format: an integer is required, not float', operator.mod, '%o', 1.79), - self.assertRaisesRegex(TypeError, '%x format: an integer is required, not PseudoFloat', operator.mod, '%x', pi), - self.assertRaises(TypeError, operator.mod, '%c', pi), + self.assertRaisesRegex(TypeError, '%x format: an integer is required, not float', operator.mod, '%x', 3.14) + self.assertRaisesRegex(TypeError, '%X format: an integer is required, not float', operator.mod, '%X', 2.11) + self.assertRaisesRegex(TypeError, '%o format: an integer is required, not float', operator.mod, '%o', 1.79) + self.assertRaisesRegex(TypeError, '%x format: an integer is required, not PseudoFloat', operator.mod, '%x', pi) + self.assertRaisesRegex(TypeError, '%x format: an integer is required, not complex', operator.mod, '%x', 3j) + self.assertRaisesRegex(TypeError, '%X format: an integer is required, not complex', operator.mod, '%X', 2j) + self.assertRaisesRegex(TypeError, '%o format: an integer is required, not complex', operator.mod, '%o', 1j) + self.assertRaisesRegex(TypeError, '%u format: a real number is required, not complex', operator.mod, '%u', 3j) + self.assertRaisesRegex(TypeError, '%i format: a real number is required, not complex', operator.mod, '%i', 2j) + self.assertRaisesRegex(TypeError, '%d format: a real number is required, not complex', operator.mod, '%d', 1j) + self.assertRaisesRegex(TypeError, '%c requires int or char', operator.mod, '%c', pi) + + class RaisingNumber: + def __int__(self): + raise RuntimeError('int') # should not be `TypeError` + def __index__(self): + raise RuntimeError('index') # should not be `TypeError` + rn = RaisingNumber() + self.assertRaisesRegex(RuntimeError, 'int', operator.mod, '%d', rn) + self.assertRaisesRegex(RuntimeError, 'int', operator.mod, '%i', rn) + self.assertRaisesRegex(RuntimeError, 'int', operator.mod, '%u', rn) + self.assertRaisesRegex(RuntimeError, 'index', operator.mod, '%x', rn) + self.assertRaisesRegex(RuntimeError, 'index', operator.mod, '%X', rn) + self.assertRaisesRegex(RuntimeError, 'index', operator.mod, '%o', rn) - # TODO: RUSTPYTHON, AssertionError: '...15...' != '...Int.IDES...' - @unittest.expectedFailure def test_formatting_with_enum(self): # issue18780 import enum class Float(float, enum.Enum): + # a mixed-in type will use the name for %s etc. PI = 3.1415926 class Int(enum.IntEnum): + # IntEnum uses the value and not the name for %s etc. IDES = 15 - class Str(str, enum.Enum): + class Str(enum.StrEnum): + # StrEnum uses the value and not the name for %s etc. ABC = 'abc' # Testing Unicode formatting strings... self.assertEqual("%s, %s" % (Str.ABC, Str.ABC), - 'Str.ABC, Str.ABC') + 'abc, abc') self.assertEqual("%s, %s, %d, %i, %u, %f, %5.2f" % (Str.ABC, Str.ABC, Int.IDES, Int.IDES, Int.IDES, Float.PI, Float.PI), - 'Str.ABC, Str.ABC, 15, 15, 15, 3.141593, 3.14') + 'abc, abc, 15, 15, 15, 3.141593, 3.14') # formatting jobs delegated from the string implementation: self.assertEqual('...%(foo)s...' % {'foo':Str.ABC}, - '...Str.ABC...') + '...abc...') + self.assertEqual('...%(foo)r...' % {'foo':Int.IDES}, + '......') self.assertEqual('...%(foo)s...' % {'foo':Int.IDES}, - '...Int.IDES...') + '...15...') self.assertEqual('...%(foo)i...' % {'foo':Int.IDES}, '...15...') self.assertEqual('...%(foo)d...' % {'foo':Int.IDES}, @@ -1559,9 +1660,9 @@ def __rmod__(self, other): "Success, self.__rmod__('lhs %% %r') was called") @support.cpython_only + @unittest.skipIf(_testcapi is None, 'need _testcapi module') def test_formatting_huge_precision_c_limits(self): - from _testcapi import INT_MAX - format_string = "%.{}f".format(INT_MAX + 1) + format_string = "%.{}f".format(_testcapi.INT_MAX + 1) with self.assertRaises(ValueError): result = format_string % 2.34 @@ -1627,29 +1728,27 @@ def __str__(self): # unicode(obj, encoding, error) tests (this maps to # PyUnicode_FromEncodedObject() at C level) - if not sys.platform.startswith('java'): - self.assertRaises( - TypeError, - str, - 'decoding unicode is not supported', - 'utf-8', - 'strict' - ) + self.assertRaises( + TypeError, + str, + 'decoding unicode is not supported', + 'utf-8', + 'strict' + ) self.assertEqual( str(b'strings are decoded to unicode', 'utf-8', 'strict'), 'strings are decoded to unicode' ) - if not sys.platform.startswith('java'): - self.assertEqual( - str( - memoryview(b'character buffers are decoded to unicode'), - 'utf-8', - 'strict' - ), - 'character buffers are decoded to unicode' - ) + self.assertEqual( + str( + memoryview(b'character buffers are decoded to unicode'), + 'utf-8', + 'strict' + ), + 'character buffers are decoded to unicode' + ) self.assertRaises(TypeError, str, 42, 42, 42) @@ -2347,12 +2446,7 @@ class s1: def __repr__(self): return '\\n' - class s2: - def __repr__(self): - return '\\n' - self.assertEqual(repr(s1()), '\\n') - self.assertEqual(repr(s2()), '\\n') def test_printable_repr(self): self.assertEqual(repr('\U00010000'), "'%c'" % (0x10000,)) # printable @@ -2374,20 +2468,19 @@ def test_expandtabs_optimization(self): @unittest.skip("TODO: RUSTPYTHON, aborted: memory allocation of 9223372036854775759 bytes failed") def test_raiseMemError(self): - if struct.calcsize('P') == 8: - # 64 bits pointers - ascii_struct_size = 48 - compact_struct_size = 72 - else: - # 32 bits pointers - ascii_struct_size = 24 - compact_struct_size = 36 + asciifields = "nnb" + compactfields = asciifields + "nP" + ascii_struct_size = support.calcobjsize(asciifields) + compact_struct_size = support.calcobjsize(compactfields) for char in ('a', '\xe9', '\u20ac', '\U0010ffff'): code = ord(char) - if code < 0x100: + if code < 0x80: char_size = 1 # sizeof(Py_UCS1) struct_size = ascii_struct_size + elif code < 0x100: + char_size = 1 # sizeof(Py_UCS1) + struct_size = compact_struct_size elif code < 0x10000: char_size = 2 # sizeof(Py_UCS2) struct_size = compact_struct_size @@ -2399,8 +2492,18 @@ def test_raiseMemError(self): # be allocatable, given enough memory. maxlen = ((sys.maxsize - struct_size) // char_size) alloc = lambda: char * maxlen - self.assertRaises(MemoryError, alloc) - self.assertRaises(MemoryError, alloc) + with self.subTest( + char=char, + struct_size=struct_size, + char_size=char_size + ): + # self-check + self.assertEqual( + sys.getsizeof(char * 42), + struct_size + (char_size * (42 + 1)) + ) + self.assertRaises(MemoryError, alloc) + self.assertRaises(MemoryError, alloc) def test_format_subclass(self): class S(str): @@ -2430,22 +2533,22 @@ def test_getnewargs(self): self.assertEqual(len(args), 1) @support.cpython_only - @support.requires_legacy_unicode_capi + @support.requires_legacy_unicode_capi() + @unittest.skipIf(_testcapi is None, 'need _testcapi module') def test_resize(self): - from _testcapi import getargs_u for length in range(1, 100, 7): # generate a fresh string (refcount=1) text = 'a' * length + 'b' # fill wstr internal field with self.assertWarns(DeprecationWarning): - abc = getargs_u(text) + abc = _testcapi.getargs_u(text) self.assertEqual(abc, text) # resize text: wstr field must be cleared and then recomputed text += 'c' with self.assertWarns(DeprecationWarning): - abcdef = getargs_u(text) + abcdef = _testcapi.getargs_u(text) self.assertNotEqual(abc, abcdef) self.assertEqual(abcdef, text) @@ -2592,473 +2695,6 @@ def test_check_encoding_errors(self): self.assertEqual(proc.rc, 10, proc) -class CAPITest(unittest.TestCase): - - # Test PyUnicode_FromFormat() - def test_from_format(self): - import_helper.import_module('ctypes') - from ctypes import ( - c_char_p, - pythonapi, py_object, sizeof, - c_int, c_long, c_longlong, c_ssize_t, - c_uint, c_ulong, c_ulonglong, c_size_t, c_void_p) - name = "PyUnicode_FromFormat" - _PyUnicode_FromFormat = getattr(pythonapi, name) - _PyUnicode_FromFormat.argtypes = (c_char_p,) - _PyUnicode_FromFormat.restype = py_object - - def PyUnicode_FromFormat(format, *args): - cargs = tuple( - py_object(arg) if isinstance(arg, str) else arg - for arg in args) - return _PyUnicode_FromFormat(format, *cargs) - - def check_format(expected, format, *args): - text = PyUnicode_FromFormat(format, *args) - self.assertEqual(expected, text) - - # ascii format, non-ascii argument - check_format('ascii\x7f=unicode\xe9', - b'ascii\x7f=%U', 'unicode\xe9') - - # non-ascii format, ascii argument: ensure that PyUnicode_FromFormatV() - # raises an error - self.assertRaisesRegex(ValueError, - r'^PyUnicode_FromFormatV\(\) expects an ASCII-encoded format ' - 'string, got a non-ASCII byte: 0xe9$', - PyUnicode_FromFormat, b'unicode\xe9=%s', 'ascii') - - # test "%c" - check_format('\uabcd', - b'%c', c_int(0xabcd)) - check_format('\U0010ffff', - b'%c', c_int(0x10ffff)) - with self.assertRaises(OverflowError): - PyUnicode_FromFormat(b'%c', c_int(0x110000)) - # Issue #18183 - check_format('\U00010000\U00100000', - b'%c%c', c_int(0x10000), c_int(0x100000)) - - # test "%" - check_format('%', - b'%') - check_format('%', - b'%%') - check_format('%s', - b'%%s') - check_format('[%]', - b'[%%]') - check_format('%abc', - b'%%%s', b'abc') - - # truncated string - check_format('abc', - b'%.3s', b'abcdef') - check_format('abc[\ufffd', - b'%.5s', 'abc[\u20ac]'.encode('utf8')) - check_format("'\\u20acABC'", - b'%A', '\u20acABC') - check_format("'\\u20", - b'%.5A', '\u20acABCDEF') - check_format("'\u20acABC'", - b'%R', '\u20acABC') - check_format("'\u20acA", - b'%.3R', '\u20acABCDEF') - check_format('\u20acAB', - b'%.3S', '\u20acABCDEF') - check_format('\u20acAB', - b'%.3U', '\u20acABCDEF') - check_format('\u20acAB', - b'%.3V', '\u20acABCDEF', None) - check_format('abc[\ufffd', - b'%.5V', None, 'abc[\u20ac]'.encode('utf8')) - - # following tests comes from #7330 - # test width modifier and precision modifier with %S - check_format("repr= abc", - b'repr=%5S', 'abc') - check_format("repr=ab", - b'repr=%.2S', 'abc') - check_format("repr= ab", - b'repr=%5.2S', 'abc') - - # test width modifier and precision modifier with %R - check_format("repr= 'abc'", - b'repr=%8R', 'abc') - check_format("repr='ab", - b'repr=%.3R', 'abc') - check_format("repr= 'ab", - b'repr=%5.3R', 'abc') - - # test width modifier and precision modifier with %A - check_format("repr= 'abc'", - b'repr=%8A', 'abc') - check_format("repr='ab", - b'repr=%.3A', 'abc') - check_format("repr= 'ab", - b'repr=%5.3A', 'abc') - - # test width modifier and precision modifier with %s - check_format("repr= abc", - b'repr=%5s', b'abc') - check_format("repr=ab", - b'repr=%.2s', b'abc') - check_format("repr= ab", - b'repr=%5.2s', b'abc') - - # test width modifier and precision modifier with %U - check_format("repr= abc", - b'repr=%5U', 'abc') - check_format("repr=ab", - b'repr=%.2U', 'abc') - check_format("repr= ab", - b'repr=%5.2U', 'abc') - - # test width modifier and precision modifier with %V - check_format("repr= abc", - b'repr=%5V', 'abc', b'123') - check_format("repr=ab", - b'repr=%.2V', 'abc', b'123') - check_format("repr= ab", - b'repr=%5.2V', 'abc', b'123') - check_format("repr= 123", - b'repr=%5V', None, b'123') - check_format("repr=12", - b'repr=%.2V', None, b'123') - check_format("repr= 12", - b'repr=%5.2V', None, b'123') - - # test integer formats (%i, %d, %u) - check_format('010', - b'%03i', c_int(10)) - check_format('0010', - b'%0.4i', c_int(10)) - check_format('-123', - b'%i', c_int(-123)) - check_format('-123', - b'%li', c_long(-123)) - check_format('-123', - b'%lli', c_longlong(-123)) - check_format('-123', - b'%zi', c_ssize_t(-123)) - - check_format('-123', - b'%d', c_int(-123)) - check_format('-123', - b'%ld', c_long(-123)) - check_format('-123', - b'%lld', c_longlong(-123)) - check_format('-123', - b'%zd', c_ssize_t(-123)) - - check_format('123', - b'%u', c_uint(123)) - check_format('123', - b'%lu', c_ulong(123)) - check_format('123', - b'%llu', c_ulonglong(123)) - check_format('123', - b'%zu', c_size_t(123)) - - # test long output - min_longlong = -(2 ** (8 * sizeof(c_longlong) - 1)) - max_longlong = -min_longlong - 1 - check_format(str(min_longlong), - b'%lld', c_longlong(min_longlong)) - check_format(str(max_longlong), - b'%lld', c_longlong(max_longlong)) - max_ulonglong = 2 ** (8 * sizeof(c_ulonglong)) - 1 - check_format(str(max_ulonglong), - b'%llu', c_ulonglong(max_ulonglong)) - PyUnicode_FromFormat(b'%p', c_void_p(-1)) - - # test padding (width and/or precision) - check_format('123'.rjust(10, '0'), - b'%010i', c_int(123)) - check_format('123'.rjust(100), - b'%100i', c_int(123)) - check_format('123'.rjust(100, '0'), - b'%.100i', c_int(123)) - check_format('123'.rjust(80, '0').rjust(100), - b'%100.80i', c_int(123)) - - check_format('123'.rjust(10, '0'), - b'%010u', c_uint(123)) - check_format('123'.rjust(100), - b'%100u', c_uint(123)) - check_format('123'.rjust(100, '0'), - b'%.100u', c_uint(123)) - check_format('123'.rjust(80, '0').rjust(100), - b'%100.80u', c_uint(123)) - - check_format('123'.rjust(10, '0'), - b'%010x', c_int(0x123)) - check_format('123'.rjust(100), - b'%100x', c_int(0x123)) - check_format('123'.rjust(100, '0'), - b'%.100x', c_int(0x123)) - check_format('123'.rjust(80, '0').rjust(100), - b'%100.80x', c_int(0x123)) - - # test %A - check_format(r"%A:'abc\xe9\uabcd\U0010ffff'", - b'%%A:%A', 'abc\xe9\uabcd\U0010ffff') - - # test %V - check_format('repr=abc', - b'repr=%V', 'abc', b'xyz') - - # Test string decode from parameter of %s using utf-8. - # b'\xe4\xba\xba\xe6\xb0\x91' is utf-8 encoded byte sequence of - # '\u4eba\u6c11' - check_format('repr=\u4eba\u6c11', - b'repr=%V', None, b'\xe4\xba\xba\xe6\xb0\x91') - - #Test replace error handler. - check_format('repr=abc\ufffd', - b'repr=%V', None, b'abc\xff') - - # not supported: copy the raw format string. these tests are just here - # to check for crashes and should not be considered as specifications - check_format('%s', - b'%1%s', b'abc') - check_format('%1abc', - b'%1abc') - check_format('%+i', - b'%+i', c_int(10)) - check_format('%.%s', - b'%.%s', b'abc') - - # Issue #33817: empty strings - check_format('', - b'') - check_format('', - b'%s', b'') - - # Test PyUnicode_AsWideChar() - @support.cpython_only - def test_aswidechar(self): - from _testcapi import unicode_aswidechar - import_helper.import_module('ctypes') - from ctypes import c_wchar, sizeof - - wchar, size = unicode_aswidechar('abcdef', 2) - self.assertEqual(size, 2) - self.assertEqual(wchar, 'ab') - - wchar, size = unicode_aswidechar('abc', 3) - self.assertEqual(size, 3) - self.assertEqual(wchar, 'abc') - - wchar, size = unicode_aswidechar('abc', 4) - self.assertEqual(size, 3) - self.assertEqual(wchar, 'abc\0') - - wchar, size = unicode_aswidechar('abc', 10) - self.assertEqual(size, 3) - self.assertEqual(wchar, 'abc\0') - - wchar, size = unicode_aswidechar('abc\0def', 20) - self.assertEqual(size, 7) - self.assertEqual(wchar, 'abc\0def\0') - - nonbmp = chr(0x10ffff) - if sizeof(c_wchar) == 2: - buflen = 3 - nchar = 2 - else: # sizeof(c_wchar) == 4 - buflen = 2 - nchar = 1 - wchar, size = unicode_aswidechar(nonbmp, buflen) - self.assertEqual(size, nchar) - self.assertEqual(wchar, nonbmp + '\0') - - # Test PyUnicode_AsWideCharString() - @support.cpython_only - def test_aswidecharstring(self): - from _testcapi import unicode_aswidecharstring - import_helper.import_module('ctypes') - from ctypes import c_wchar, sizeof - - wchar, size = unicode_aswidecharstring('abc') - self.assertEqual(size, 3) - self.assertEqual(wchar, 'abc\0') - - wchar, size = unicode_aswidecharstring('abc\0def') - self.assertEqual(size, 7) - self.assertEqual(wchar, 'abc\0def\0') - - nonbmp = chr(0x10ffff) - if sizeof(c_wchar) == 2: - nchar = 2 - else: # sizeof(c_wchar) == 4 - nchar = 1 - wchar, size = unicode_aswidecharstring(nonbmp) - self.assertEqual(size, nchar) - self.assertEqual(wchar, nonbmp + '\0') - - # Test PyUnicode_AsUCS4() - @support.cpython_only - def test_asucs4(self): - from _testcapi import unicode_asucs4 - for s in ['abc', '\xa1\xa2', '\u4f60\u597d', 'a\U0001f600', - 'a\ud800b\udfffc', '\ud834\udd1e']: - l = len(s) - self.assertEqual(unicode_asucs4(s, l, True), s+'\0') - self.assertEqual(unicode_asucs4(s, l, False), s+'\uffff') - self.assertEqual(unicode_asucs4(s, l+1, True), s+'\0\uffff') - self.assertEqual(unicode_asucs4(s, l+1, False), s+'\0\uffff') - self.assertRaises(SystemError, unicode_asucs4, s, l-1, True) - self.assertRaises(SystemError, unicode_asucs4, s, l-2, False) - s = '\0'.join([s, s]) - self.assertEqual(unicode_asucs4(s, len(s), True), s+'\0') - self.assertEqual(unicode_asucs4(s, len(s), False), s+'\uffff') - - # Test PyUnicode_AsUTF8() - @support.cpython_only - def test_asutf8(self): - from _testcapi import unicode_asutf8 - - bmp = '\u0100' - bmp2 = '\uffff' - nonbmp = chr(0x10ffff) - - self.assertEqual(unicode_asutf8(bmp), b'\xc4\x80') - self.assertEqual(unicode_asutf8(bmp2), b'\xef\xbf\xbf') - self.assertEqual(unicode_asutf8(nonbmp), b'\xf4\x8f\xbf\xbf') - self.assertRaises(UnicodeEncodeError, unicode_asutf8, 'a\ud800b\udfffc') - - # Test PyUnicode_AsUTF8AndSize() - @support.cpython_only - def test_asutf8andsize(self): - from _testcapi import unicode_asutf8andsize - - bmp = '\u0100' - bmp2 = '\uffff' - nonbmp = chr(0x10ffff) - - self.assertEqual(unicode_asutf8andsize(bmp), (b'\xc4\x80', 2)) - self.assertEqual(unicode_asutf8andsize(bmp2), (b'\xef\xbf\xbf', 3)) - self.assertEqual(unicode_asutf8andsize(nonbmp), (b'\xf4\x8f\xbf\xbf', 4)) - self.assertRaises(UnicodeEncodeError, unicode_asutf8andsize, 'a\ud800b\udfffc') - - # Test PyUnicode_FindChar() - @support.cpython_only - def test_findchar(self): - from _testcapi import unicode_findchar - - for str in "\xa1", "\u8000\u8080", "\ud800\udc02", "\U0001f100\U0001f1f1": - for i, ch in enumerate(str): - self.assertEqual(unicode_findchar(str, ord(ch), 0, len(str), 1), i) - self.assertEqual(unicode_findchar(str, ord(ch), 0, len(str), -1), i) - - str = "!>_= end - self.assertEqual(unicode_findchar(str, ord('!'), 0, 0, 1), -1) - self.assertEqual(unicode_findchar(str, ord('!'), len(str), 0, 1), -1) - # negative - self.assertEqual(unicode_findchar(str, ord('!'), -len(str), -1, 1), 0) - self.assertEqual(unicode_findchar(str, ord('!'), -len(str), -1, -1), 0) - - # Test PyUnicode_CopyCharacters() - @support.cpython_only - def test_copycharacters(self): - from _testcapi import unicode_copycharacters - - strings = [ - 'abcde', '\xa1\xa2\xa3\xa4\xa5', - '\u4f60\u597d\u4e16\u754c\uff01', - '\U0001f600\U0001f601\U0001f602\U0001f603\U0001f604' - ] - - for idx, from_ in enumerate(strings): - # wide -> narrow: exceed maxchar limitation - for to in strings[:idx]: - self.assertRaises( - SystemError, - unicode_copycharacters, to, 0, from_, 0, 5 - ) - # same kind - for from_start in range(5): - self.assertEqual( - unicode_copycharacters(from_, 0, from_, from_start, 5), - (from_[from_start:from_start+5].ljust(5, '\0'), - 5-from_start) - ) - for to_start in range(5): - self.assertEqual( - unicode_copycharacters(from_, to_start, from_, to_start, 5), - (from_[to_start:to_start+5].rjust(5, '\0'), - 5-to_start) - ) - # narrow -> wide - # Tests omitted since this creates invalid strings. - - s = strings[0] - self.assertRaises(IndexError, unicode_copycharacters, s, 6, s, 0, 5) - self.assertRaises(IndexError, unicode_copycharacters, s, -1, s, 0, 5) - self.assertRaises(IndexError, unicode_copycharacters, s, 0, s, 6, 5) - self.assertRaises(IndexError, unicode_copycharacters, s, 0, s, -1, 5) - self.assertRaises(SystemError, unicode_copycharacters, s, 1, s, 0, 5) - self.assertRaises(SystemError, unicode_copycharacters, s, 0, s, 0, -1) - self.assertRaises(SystemError, unicode_copycharacters, s, 0, b'', 0, 0) - - @support.cpython_only - @support.requires_legacy_unicode_capi - def test_encode_decimal(self): - from _testcapi import unicode_encodedecimal - with warnings_helper.check_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.assertEqual(unicode_encodedecimal('123'), - b'123') - self.assertEqual(unicode_encodedecimal('\u0663.\u0661\u0664'), - b'3.14') - self.assertEqual(unicode_encodedecimal( - "\N{EM SPACE}3.14\N{EN SPACE}"), b' 3.14 ') - self.assertRaises(UnicodeEncodeError, - unicode_encodedecimal, "123\u20ac", "strict") - self.assertRaisesRegex( - ValueError, - "^'decimal' codec can't encode character", - unicode_encodedecimal, "123\u20ac", "replace") - - @support.cpython_only - @support.requires_legacy_unicode_capi - def test_transform_decimal(self): - from _testcapi import unicode_transformdecimaltoascii as transform_decimal - with warnings_helper.check_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) - self.assertEqual(transform_decimal('123'), - '123') - self.assertEqual(transform_decimal('\u0663.\u0661\u0664'), - '3.14') - self.assertEqual(transform_decimal("\N{EM SPACE}3.14\N{EN SPACE}"), - "\N{EM SPACE}3.14\N{EN SPACE}") - self.assertEqual(transform_decimal('123\u20ac'), - '123\u20ac') - - @support.cpython_only - def test_pep393_utf8_caching_bug(self): - # Issue #25709: Problem with string concatenation and utf-8 cache - from _testcapi import getargs_s_hash - for k in 0x24, 0xa4, 0x20ac, 0x1f40d: - s = '' - for i in range(5): - # Due to CPython specific optimization the 's' string can be - # resized in-place. - s += chr(k) - # Parsing with the "s#" format code calls indirectly - # PyUnicode_AsUTF8AndSize() which creates the UTF-8 - # encoded string cached in the Unicode object. - self.assertEqual(getargs_s_hash(s), chr(k).encode() * (i + 1)) - # Check that the second call returns the same result - self.assertEqual(getargs_s_hash(s), chr(k).encode() * (i + 1)) - class StringModuleTest(unittest.TestCase): def test_formatter_parser(self): def parse(format): @@ -3109,6 +2745,30 @@ def split(name): ]]) self.assertRaises(TypeError, _string.formatter_field_name_split, 1) + def test_str_subclass_attr(self): + + name = StrSubclass("name") + name2 = StrSubclass("name2") + class Bag: + pass + + o = Bag() + with self.assertRaises(AttributeError): + delattr(o, name) + setattr(o, name, 1) + self.assertEqual(o.name, 1) + o.name = 2 + self.assertEqual(list(o.__dict__), [name]) + + with self.assertRaises(AttributeError): + delattr(o, name2) + with self.assertRaises(AttributeError): + del o.name2 + setattr(o, name2, 3) + self.assertEqual(o.name2, 3) + o.name2 = 4 + self.assertEqual(list(o.__dict__), [name, name2]) + if __name__ == "__main__": unittest.main() diff --git a/README.md b/README.md index 2c8266cb12..65d57a7ee5 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # [RustPython](https://rustpython.github.io/) -A Python-3 (CPython >= 3.11.0) Interpreter written in Rust :snake: :scream: +A Python-3 (CPython >= 3.12.0) Interpreter written in Rust :snake: :scream: :metal:. [![Build Status](https://github.com/RustPython/RustPython/workflows/CI/badge.svg)](https://github.com/RustPython/RustPython/actions?query=workflow%3ACI) diff --git a/extra_tests/snippets/builtin_none.py b/extra_tests/snippets/builtin_none.py index b0080a9d25..c75f04ea73 100644 --- a/extra_tests/snippets/builtin_none.py +++ b/extra_tests/snippets/builtin_none.py @@ -22,5 +22,4 @@ def none2(): assert None.__eq__(3) is NotImplemented assert None.__ne__(3) is NotImplemented assert None.__eq__(None) is True -assert None.__ne__(None) is False - +# assert None.__ne__(None) is False # changed in 3.12 diff --git a/extra_tests/snippets/builtin_slice.py b/extra_tests/snippets/builtin_slice.py index 71ab7cbde5..57fb7e21c2 100644 --- a/extra_tests/snippets/builtin_slice.py +++ b/extra_tests/snippets/builtin_slice.py @@ -82,14 +82,15 @@ assert_raises(TypeError, lambda: slice(0) <= 3) assert_raises(TypeError, lambda: slice(0) >= 3) -assert_raises(TypeError, hash, slice(0)) -assert_raises(TypeError, hash, slice(None)) - -def dict_slice(): - d = {} - d[slice(0)] = 3 - -assert_raises(TypeError, dict_slice) +# TODO: slice is hashable in CPython 3.12 +# assert_raises(TypeError, hash, slice(0)) +# assert_raises(TypeError, hash, slice(None)) +# +# def dict_slice(): +# d = {} +# d[slice(0)] = 3 +# +# assert_raises(TypeError, dict_slice) assert slice(None ).indices(10) == (0, 10, 1) assert slice(None, None, 2).indices(10) == (0, 10, 2) diff --git a/extra_tests/snippets/syntax_async.py b/extra_tests/snippets/syntax_async.py index 011182cce7..953669c2c4 100644 --- a/extra_tests/snippets/syntax_async.py +++ b/extra_tests/snippets/syntax_async.py @@ -128,5 +128,5 @@ async def foo(): foo().send(None) -if __name__ == "__main__": - unittest.main() + if __name__ == "__main__": + unittest.main() diff --git a/vm/src/version.rs b/vm/src/version.rs index ec23e896b0..9a75f71142 100644 --- a/vm/src/version.rs +++ b/vm/src/version.rs @@ -4,9 +4,9 @@ use chrono::{prelude::DateTime, Local}; use std::time::{Duration, UNIX_EPOCH}; -// = 3.11.0alpha +// = 3.12.0alpha pub const MAJOR: usize = 3; -pub const MINOR: usize = 11; +pub const MINOR: usize = 12; pub const MICRO: usize = 0; pub const RELEASELEVEL: &str = "alpha"; pub const RELEASELEVEL_N: usize = 0xA; diff --git a/whats_left.py b/whats_left.py index 7f3ad80c63..4f087f89af 100755 --- a/whats_left.py +++ b/whats_left.py @@ -35,8 +35,8 @@ implementation = platform.python_implementation() if implementation != "CPython": sys.exit(f"whats_left.py must be run under CPython, got {implementation} instead") -if sys.version_info[:2] < (3, 11): - sys.exit(f"whats_left.py must be run under CPython 3.11 or newer, got {implementation} {sys.version} instead") +if sys.version_info[:2] < (3, 12): + sys.exit(f"whats_left.py must be run under CPython 3.12 or newer, got {implementation} {sys.version} instead") def parse_args(): parser = argparse.ArgumentParser(description="Process some integers.") From 4dacbc51e280802c222422927d2c725769494a22 Mon Sep 17 00:00:00 2001 From: Olivier Lemasle Date: Tue, 7 Nov 2023 06:12:21 +0100 Subject: [PATCH 140/893] Remove outdated info about yanked crates (#5111) The `rustpython-*` crates have been published since this message was added. --- README.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/README.md b/README.md index 65d57a7ee5..c4d1fc9a7d 100644 --- a/README.md +++ b/README.md @@ -62,10 +62,6 @@ Welcome to the magnificent Rust Python interpreter >>>>> ``` -(The `rustpython-*` crates are currently yanked from crates.io due to being out -of date and not building on newer rust versions; we hope to release a new -version Soon™) - If you'd like to make https requests, you can enable the `ssl` feature, which also lets you install the `pip` package manager. Note that on Windows, you may need to install OpenSSL, or you can enable the `ssl-vendor` feature instead, From b4b71e5a119778a972db7ae45913c3c45534d2c4 Mon Sep 17 00:00:00 2001 From: Tae-Geun Kim Date: Wed, 8 Nov 2023 15:37:06 +0900 Subject: [PATCH 141/893] Bump puruspe to 0.2 (#5112) --- stdlib/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index cdb3902e52..1cd6f09b5e 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -43,7 +43,7 @@ base64 = "0.13.0" csv-core = "0.1.10" dyn-clone = "1.0.10" libz-sys = { version = "1.1.5", optional = true } -puruspe = "0.1.5" +puruspe = "0.2.0" xml-rs = "0.8.14" # random From b8f22e296776cc77a34ed70b7822b7d2231ab193 Mon Sep 17 00:00:00 2001 From: Steve Shi Date: Wed, 8 Nov 2023 08:43:49 +0200 Subject: [PATCH 142/893] bump malachite to 0.4.4 (#5069) --- Cargo.toml | 16 ++++++++-------- common/src/int.rs | 14 +++++++------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a3b78fd70a..680f090cc4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,11 +29,11 @@ rustpython-pylib = { path = "pylib", version = "0.3.0" } rustpython-stdlib = { path = "stdlib", default-features = false, version = "0.3.0" } rustpython-doc = { git = "https://github.com/RustPython/__doc__", tag = "0.3.0", version = "0.3.0" } -rustpython-literal = { git = "https://github.com/RustPython/Parser.git", rev = "13cae0af64d0a23de95f08c0210e97ad74d155e9" } -rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", rev = "13cae0af64d0a23de95f08c0210e97ad74d155e9" } -rustpython-parser = { git = "https://github.com/RustPython/Parser.git", rev = "13cae0af64d0a23de95f08c0210e97ad74d155e9" } -rustpython-ast = { git = "https://github.com/RustPython/Parser.git", rev = "13cae0af64d0a23de95f08c0210e97ad74d155e9" } -rustpython-format = { git = "https://github.com/RustPython/Parser.git", rev = "13cae0af64d0a23de95f08c0210e97ad74d155e9" } +rustpython-literal = { git = "https://github.com/RustPython/Parser.git", rev = "52edf4525ec300f2b69670f3991784bbcf140564" } +rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", rev = "52edf4525ec300f2b69670f3991784bbcf140564" } +rustpython-parser = { git = "https://github.com/RustPython/Parser.git", rev = "52edf4525ec300f2b69670f3991784bbcf140564" } +rustpython-ast = { git = "https://github.com/RustPython/Parser.git", rev = "52edf4525ec300f2b69670f3991784bbcf140564" } +rustpython-format = { git = "https://github.com/RustPython/Parser.git", rev = "52edf4525ec300f2b69670f3991784bbcf140564" } # rustpython-literal = { path = "../RustPython-parser/literal" } # rustpython-parser-core = { path = "../RustPython-parser/core" } # rustpython-parser = { path = "../RustPython-parser/parser" } @@ -59,9 +59,9 @@ is-macro = "0.3.0" libc = "0.2.133" log = "0.4.16" nix = "0.26" -malachite-bigint = { version = "0.1.0" } -malachite-q = "0.3.2" -malachite-base = "0.3.2" +malachite-bigint = "0.1.1" +malachite-q = "0.4.4" +malachite-base = "0.4.4" num-complex = "0.4.0" num-integer = "0.1.44" num-traits = "0.2" diff --git a/common/src/int.rs b/common/src/int.rs index c968d747ef..2de993b59b 100644 --- a/common/src/int.rs +++ b/common/src/int.rs @@ -5,14 +5,14 @@ use malachite_q::Rational; use num_traits::{One, ToPrimitive, Zero}; pub fn true_div(numerator: &BigInt, denominator: &BigInt) -> f64 { - let val: f64 = Rational::from_integers_ref(numerator.into(), denominator.into()) - .rounding_into(RoundingMode::Nearest); - - if val == f64::MAX || val == f64::MIN { - // FIXME: not possible for available ratio? - return f64::INFINITY; + let rational = Rational::from_integers_ref(numerator.into(), denominator.into()); + match rational.rounding_into(RoundingMode::Nearest) { + // returned value is $t::MAX but still less than the original + (val, std::cmp::Ordering::Less) if val == f64::MAX => f64::INFINITY, + // returned value is $t::MIN but still greater than the original + (val, std::cmp::Ordering::Greater) if val == f64::MIN => f64::NEG_INFINITY, + (val, _) => val, } - val } pub fn float_to_ratio(value: f64) -> Option<(BigInt, BigInt)> { From c32369bd27a6e3df1e0dfbc510391bd7dc2d18cc Mon Sep 17 00:00:00 2001 From: NakanoMiku <91249276+NakanoMiku39@users.noreply.github.com> Date: Wed, 8 Nov 2023 15:37:26 +0800 Subject: [PATCH 143/893] Add Lib/test/test___all__.py (#5110) * Add Lib/test/test___all__.py cpython version: 3.12 * Edit Lib/test/test___all__.py Add @unittest.expectedFailure for function test_all(self) cpython version: 3.12 --- Lib/test/test___all__.py | 143 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 Lib/test/test___all__.py diff --git a/Lib/test/test___all__.py b/Lib/test/test___all__.py new file mode 100644 index 0000000000..a620dd5b4c --- /dev/null +++ b/Lib/test/test___all__.py @@ -0,0 +1,143 @@ +import unittest +from test import support +from test.support import warnings_helper +import os +import sys +import types + +try: + import _multiprocessing +except ModuleNotFoundError: + _multiprocessing = None + + +if support.check_sanitizer(address=True, memory=True): + # bpo-46633: test___all__ is skipped because importing some modules + # directly can trigger known problems with ASAN (like tk or crypt). + raise unittest.SkipTest("workaround ASAN build issues on loading tests " + "like tk or crypt") + + +class NoAll(RuntimeError): + pass + +class FailedImport(RuntimeError): + pass + + +class AllTest(unittest.TestCase): + + def setUp(self): + # concurrent.futures uses a __getattr__ hook. Its __all__ triggers + # import of a submodule, which fails when _multiprocessing is not + # available. + if _multiprocessing is None: + sys.modules["_multiprocessing"] = types.ModuleType("_multiprocessing") + + def tearDown(self): + if _multiprocessing is None: + sys.modules.pop("_multiprocessing") + + def check_all(self, modname): + names = {} + with warnings_helper.check_warnings( + (f".*{modname}", DeprecationWarning), + (".* (module|package)", DeprecationWarning), + (".* (module|package)", PendingDeprecationWarning), + ("", ResourceWarning), + quiet=True): + try: + exec("import %s" % modname, names) + except: + # Silent fail here seems the best route since some modules + # may not be available or not initialize properly in all + # environments. + raise FailedImport(modname) + if not hasattr(sys.modules[modname], "__all__"): + raise NoAll(modname) + names = {} + with self.subTest(module=modname): + with warnings_helper.check_warnings( + ("", DeprecationWarning), + ("", ResourceWarning), + quiet=True): + try: + exec("from %s import *" % modname, names) + except Exception as e: + # Include the module name in the exception string + self.fail("__all__ failure in {}: {}: {}".format( + modname, e.__class__.__name__, e)) + if "__builtins__" in names: + del names["__builtins__"] + if '__annotations__' in names: + del names['__annotations__'] + if "__warningregistry__" in names: + del names["__warningregistry__"] + keys = set(names) + all_list = sys.modules[modname].__all__ + all_set = set(all_list) + self.assertCountEqual(all_set, all_list, "in module {}".format(modname)) + self.assertEqual(keys, all_set, "in module {}".format(modname)) + + def walk_modules(self, basedir, modpath): + for fn in sorted(os.listdir(basedir)): + path = os.path.join(basedir, fn) + if os.path.isdir(path): + pkg_init = os.path.join(path, '__init__.py') + if os.path.exists(pkg_init): + yield pkg_init, modpath + fn + for p, m in self.walk_modules(path, modpath + fn + "."): + yield p, m + continue + if not fn.endswith('.py') or fn == '__init__.py': + continue + yield path, modpath + fn[:-3] + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_all(self): + # List of denied modules and packages + denylist = set([ + # Will raise a SyntaxError when compiling the exec statement + '__future__', + ]) + + # In case _socket fails to build, make this test fail more gracefully + # than an AttributeError somewhere deep in CGIHTTPServer. + import _socket + + ignored = [] + failed_imports = [] + lib_dir = os.path.dirname(os.path.dirname(__file__)) + for path, modname in self.walk_modules(lib_dir, ""): + m = modname + denied = False + while m: + if m in denylist: + denied = True + break + m = m.rpartition('.')[0] + if denied: + continue + if support.verbose: + print(modname) + try: + # This heuristic speeds up the process by removing, de facto, + # most test modules (and avoiding the auto-executing ones). + with open(path, "rb") as f: + if b"__all__" not in f.read(): + raise NoAll(modname) + self.check_all(modname) + except NoAll: + ignored.append(modname) + except FailedImport: + failed_imports.append(modname) + + if support.verbose: + print('Following modules have no __all__ and have been ignored:', + ignored) + print('Following modules failed to be imported:', failed_imports) + + +if __name__ == "__main__": + unittest.main() From 2bfbfd7a035b3203e39417cf8afafe350b30d71a Mon Sep 17 00:00:00 2001 From: NakanoMiku39 Date: Wed, 8 Nov 2023 19:25:01 +0800 Subject: [PATCH 144/893] Update test__locale.py from CPython v3.12 --- Lib/test/test_locale.py | 118 ++++++++++++++++------------------------ 1 file changed, 47 insertions(+), 71 deletions(-) diff --git a/Lib/test/test_locale.py b/Lib/test/test_locale.py index bc8a7a35fb..b0d7998559 100644 --- a/Lib/test/test_locale.py +++ b/Lib/test/test_locale.py @@ -141,18 +141,9 @@ class BaseFormattingTest(object): # Utility functions for formatting tests # - def _test_formatfunc(self, format, value, out, func, **format_opts): - self.assertEqual( - func(format, value, **format_opts), out) - - def _test_format(self, format, value, out, **format_opts): - with check_warnings(('', DeprecationWarning)): - self._test_formatfunc(format, value, out, - func=locale.format, **format_opts) - def _test_format_string(self, format, value, out, **format_opts): - self._test_formatfunc(format, value, out, - func=locale.format_string, **format_opts) + self.assertEqual( + locale.format_string(format, value, **format_opts), out) def _test_currency(self, value, out, **format_opts): self.assertEqual(locale.currency(value, **format_opts), out) @@ -166,44 +157,40 @@ def setUp(self): self.sep = locale.localeconv()['thousands_sep'] def test_grouping(self): - self._test_format("%f", 1024, grouping=1, out='1%s024.000000' % self.sep) - self._test_format("%f", 102, grouping=1, out='102.000000') - self._test_format("%f", -42, grouping=1, out='-42.000000') - self._test_format("%+f", -42, grouping=1, out='-42.000000') + self._test_format_string("%f", 1024, grouping=1, out='1%s024.000000' % self.sep) + self._test_format_string("%f", 102, grouping=1, out='102.000000') + self._test_format_string("%f", -42, grouping=1, out='-42.000000') + self._test_format_string("%+f", -42, grouping=1, out='-42.000000') def test_grouping_and_padding(self): - self._test_format("%20.f", -42, grouping=1, out='-42'.rjust(20)) + self._test_format_string("%20.f", -42, grouping=1, out='-42'.rjust(20)) if self.sep: - self._test_format("%+10.f", -4200, grouping=1, + self._test_format_string("%+10.f", -4200, grouping=1, out=('-4%s200' % self.sep).rjust(10)) - self._test_format("%-10.f", -4200, grouping=1, + self._test_format_string("%-10.f", -4200, grouping=1, out=('-4%s200' % self.sep).ljust(10)) def test_integer_grouping(self): - self._test_format("%d", 4200, grouping=True, out='4%s200' % self.sep) - self._test_format("%+d", 4200, grouping=True, out='+4%s200' % self.sep) - self._test_format("%+d", -4200, grouping=True, out='-4%s200' % self.sep) + self._test_format_string("%d", 4200, grouping=True, out='4%s200' % self.sep) + self._test_format_string("%+d", 4200, grouping=True, out='+4%s200' % self.sep) + self._test_format_string("%+d", -4200, grouping=True, out='-4%s200' % self.sep) def test_integer_grouping_and_padding(self): - self._test_format("%10d", 4200, grouping=True, + self._test_format_string("%10d", 4200, grouping=True, out=('4%s200' % self.sep).rjust(10)) - self._test_format("%-10d", -4200, grouping=True, + self._test_format_string("%-10d", -4200, grouping=True, out=('-4%s200' % self.sep).ljust(10)) def test_simple(self): - self._test_format("%f", 1024, grouping=0, out='1024.000000') - self._test_format("%f", 102, grouping=0, out='102.000000') - self._test_format("%f", -42, grouping=0, out='-42.000000') - self._test_format("%+f", -42, grouping=0, out='-42.000000') + self._test_format_string("%f", 1024, grouping=0, out='1024.000000') + self._test_format_string("%f", 102, grouping=0, out='102.000000') + self._test_format_string("%f", -42, grouping=0, out='-42.000000') + self._test_format_string("%+f", -42, grouping=0, out='-42.000000') def test_padding(self): - self._test_format("%20.f", -42, grouping=0, out='-42'.rjust(20)) - self._test_format("%+10.f", -4200, grouping=0, out='-4200'.rjust(10)) - self._test_format("%-10.f", 4200, grouping=0, out='4200'.ljust(10)) - - def test_format_deprecation(self): - with self.assertWarns(DeprecationWarning): - locale.format("%-10.f", 4200, grouping=True) + self._test_format_string("%20.f", -42, grouping=0, out='-42'.rjust(20)) + self._test_format_string("%+10.f", -4200, grouping=0, out='-4200'.rjust(10)) + self._test_format_string("%-10.f", 4200, grouping=0, out='4200'.ljust(10)) def test_complex_formatting(self): # Spaces in formatting string @@ -230,20 +217,9 @@ def test_complex_formatting(self): out='int 1%s000 float 1%s000.00 str str' % (self.sep, self.sep)) - -class TestFormatPatternArg(unittest.TestCase): - # Test handling of pattern argument of format - - def test_onlyOnePattern(self): - with check_warnings(('', DeprecationWarning)): - # Issue 2522: accept exactly one % pattern, and no extra chars. - self.assertRaises(ValueError, locale.format, "%f\n", 'foo') - self.assertRaises(ValueError, locale.format, "%f\r", 'foo') - self.assertRaises(ValueError, locale.format, "%f\r\n", 'foo') - self.assertRaises(ValueError, locale.format, " %f", 'foo') - self.assertRaises(ValueError, locale.format, "%fg", 'foo') - self.assertRaises(ValueError, locale.format, "%^g", 'foo') - self.assertRaises(ValueError, locale.format, "%f%%", 'foo') + self._test_format_string("total=%i%%", 100, out='total=100%') + self._test_format_string("newline: %i\n", 3, out='newline: 3\n') + self._test_format_string("extra: %ii", 3, out='extra: 3i') class TestLocaleFormatString(unittest.TestCase): @@ -292,45 +268,45 @@ class TestCNumberFormatting(CCookedTest, BaseFormattingTest): # Test number formatting with a cooked "C" locale. def test_grouping(self): - self._test_format("%.2f", 12345.67, grouping=True, out='12345.67') + self._test_format_string("%.2f", 12345.67, grouping=True, out='12345.67') def test_grouping_and_padding(self): - self._test_format("%9.2f", 12345.67, grouping=True, out=' 12345.67') + self._test_format_string("%9.2f", 12345.67, grouping=True, out=' 12345.67') class TestFrFRNumberFormatting(FrFRCookedTest, BaseFormattingTest): # Test number formatting with a cooked "fr_FR" locale. def test_decimal_point(self): - self._test_format("%.2f", 12345.67, out='12345,67') + self._test_format_string("%.2f", 12345.67, out='12345,67') def test_grouping(self): - self._test_format("%.2f", 345.67, grouping=True, out='345,67') - self._test_format("%.2f", 12345.67, grouping=True, out='12 345,67') + self._test_format_string("%.2f", 345.67, grouping=True, out='345,67') + self._test_format_string("%.2f", 12345.67, grouping=True, out='12 345,67') def test_grouping_and_padding(self): - self._test_format("%6.2f", 345.67, grouping=True, out='345,67') - self._test_format("%7.2f", 345.67, grouping=True, out=' 345,67') - self._test_format("%8.2f", 12345.67, grouping=True, out='12 345,67') - self._test_format("%9.2f", 12345.67, grouping=True, out='12 345,67') - self._test_format("%10.2f", 12345.67, grouping=True, out=' 12 345,67') - self._test_format("%-6.2f", 345.67, grouping=True, out='345,67') - self._test_format("%-7.2f", 345.67, grouping=True, out='345,67 ') - self._test_format("%-8.2f", 12345.67, grouping=True, out='12 345,67') - self._test_format("%-9.2f", 12345.67, grouping=True, out='12 345,67') - self._test_format("%-10.2f", 12345.67, grouping=True, out='12 345,67 ') + self._test_format_string("%6.2f", 345.67, grouping=True, out='345,67') + self._test_format_string("%7.2f", 345.67, grouping=True, out=' 345,67') + self._test_format_string("%8.2f", 12345.67, grouping=True, out='12 345,67') + self._test_format_string("%9.2f", 12345.67, grouping=True, out='12 345,67') + self._test_format_string("%10.2f", 12345.67, grouping=True, out=' 12 345,67') + self._test_format_string("%-6.2f", 345.67, grouping=True, out='345,67') + self._test_format_string("%-7.2f", 345.67, grouping=True, out='345,67 ') + self._test_format_string("%-8.2f", 12345.67, grouping=True, out='12 345,67') + self._test_format_string("%-9.2f", 12345.67, grouping=True, out='12 345,67') + self._test_format_string("%-10.2f", 12345.67, grouping=True, out='12 345,67 ') def test_integer_grouping(self): - self._test_format("%d", 200, grouping=True, out='200') - self._test_format("%d", 4200, grouping=True, out='4 200') + self._test_format_string("%d", 200, grouping=True, out='200') + self._test_format_string("%d", 4200, grouping=True, out='4 200') def test_integer_grouping_and_padding(self): - self._test_format("%4d", 4200, grouping=True, out='4 200') - self._test_format("%5d", 4200, grouping=True, out='4 200') - self._test_format("%10d", 4200, grouping=True, out='4 200'.rjust(10)) - self._test_format("%-4d", 4200, grouping=True, out='4 200') - self._test_format("%-5d", 4200, grouping=True, out='4 200') - self._test_format("%-10d", 4200, grouping=True, out='4 200'.ljust(10)) + self._test_format_string("%4d", 4200, grouping=True, out='4 200') + self._test_format_string("%5d", 4200, grouping=True, out='4 200') + self._test_format_string("%10d", 4200, grouping=True, out='4 200'.rjust(10)) + self._test_format_string("%-4d", 4200, grouping=True, out='4 200') + self._test_format_string("%-5d", 4200, grouping=True, out='4 200') + self._test_format_string("%-10d", 4200, grouping=True, out='4 200'.ljust(10)) def test_currency(self): euro = '\u20ac' From f682d184fbe805ab40b45c6b4160e58dc493acf1 Mon Sep 17 00:00:00 2001 From: NakanoMiku39 Date: Wed, 8 Nov 2023 20:18:46 +0800 Subject: [PATCH 145/893] Update Cargo.lock --- Cargo.lock | 209 ++++++++++++----------------------------------------- 1 file changed, 48 insertions(+), 161 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8688266ad1..c0c59f1552 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,7 +27,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" dependencies = [ "cfg-if", - "getrandom 0.2.8", + "getrandom", "once_cell", "version_check", ] @@ -142,17 +142,7 @@ version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" dependencies = [ - "digest 0.10.6", -] - -[[package]] -name = "block-buffer" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" -dependencies = [ - "block-padding", - "generic-array", + "digest", ] [[package]] @@ -164,12 +154,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "block-padding" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae" - [[package]] name = "bstr" version = "0.2.17" @@ -655,22 +639,13 @@ dependencies = [ "syn 1.0.107", ] -[[package]] -name = "digest" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" -dependencies = [ - "generic-array", -] - [[package]] name = "digest" version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" dependencies = [ - "block-buffer 0.10.3", + "block-buffer", "crypto-common", "subtle", ] @@ -909,17 +884,6 @@ dependencies = [ "unicode-width", ] -[[package]] -name = "getrandom" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" -dependencies = [ - "cfg-if", - "libc", - "wasi 0.9.0+wasi-snapshot-preview1", -] - [[package]] name = "getrandom" version = "0.2.8" @@ -929,7 +893,7 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "wasm-bindgen", ] @@ -1064,15 +1028,6 @@ dependencies = [ "syn 2.0.32", ] -[[package]] -name = "itertools" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.10.5" @@ -1281,9 +1236,9 @@ dependencies = [ [[package]] name = "malachite" -version = "0.3.2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6cf7f4730c30071ba374fac86ad35b1cb7a0716f774737768667ea3fa1828e3" +checksum = "220cb36c52aa6eff45559df497abe0e2a4c1209f92279a746a399f622d7b95c7" dependencies = [ "malachite-base", "malachite-nz", @@ -1292,22 +1247,19 @@ dependencies = [ [[package]] name = "malachite-base" -version = "0.3.2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b06bfa98a4b4802af5a4263b4ad4660e28e51e8490f6354eb9336c70767e1c5" +checksum = "6538136c5daf04126d6be4899f7fe4879b7f8de896dd1b4210fe6de5b94f2555" dependencies = [ - "itertools 0.9.0", - "rand 0.7.3", - "rand_chacha 0.2.2", + "itertools 0.11.0", "ryu", - "sha3 0.9.1", ] [[package]] name = "malachite-bigint" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a5110aee54537b0cef214efbebdd7df79b7408db8eef4f6a4b6db9d0d8fc01b" +checksum = "76c3eca3b5df299486144c8423c45c24bdf9e82e2452c8a1eeda547c4d8b5d41" dependencies = [ "derive_more", "malachite", @@ -1318,22 +1270,22 @@ dependencies = [ [[package]] name = "malachite-nz" -version = "0.3.2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89e21c64b7af5be3dc8cef16f786243faf59459fe4ba93b44efdeb264e5ade4" +checksum = "5f0b05577b7a3f09433106460b10304f97fc572f0baabf6640e6cb1e23f5fc52" dependencies = [ "embed-doc-image", - "itertools 0.9.0", + "itertools 0.11.0", "malachite-base", ] [[package]] name = "malachite-q" -version = "0.3.2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3755e541d5134b5016594c9043094172c4dda9259b3ce824a7b8101941850360" +checksum = "a1cfdb4016292e6acd832eaee261175f3af8bbee62afeefe4420ebce4c440cb5" dependencies = [ - "itertools 0.9.0", + "itertools 0.11.0", "malachite-base", "malachite-nz", ] @@ -1356,7 +1308,7 @@ version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6365506850d44bff6e2fbcb5176cf63650e48bd45ef2fe2665ae1570e0f4b9ca" dependencies = [ - "digest 0.10.6", + "digest", ] [[package]] @@ -1407,7 +1359,7 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12ca7f22ed370d5991a9caec16a83187e865bc8a532f889670337d5a5689e3a1" dependencies = [ - "rand_core 0.6.4", + "rand_core", ] [[package]] @@ -1526,12 +1478,6 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" -[[package]] -name = "opaque-debug" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" - [[package]] name = "openssl" version = "0.10.55" @@ -1657,7 +1603,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1181c94580fa345f50f19d738aaa39c0ed30a600d95cb2d3e23f94266f14fbf" dependencies = [ "phf_shared", - "rand 0.8.5", + "rand", ] [[package]] @@ -1758,9 +1704,9 @@ dependencies = [ [[package]] name = "puruspe" -version = "0.1.5" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b7e158a385023d209d6d5f2585c4b468f6dcb3dd5aca9b75c4f1678c05bb375" +checksum = "fe7765e19fb2ba6fd4373b8d90399f5321683ea7c11b598c6bbaa3a72e9c83b8" [[package]] name = "python3-sys" @@ -1797,19 +1743,6 @@ dependencies = [ "nibble_vec", ] -[[package]] -name = "rand" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" -dependencies = [ - "getrandom 0.1.16", - "libc", - "rand_chacha 0.2.2", - "rand_core 0.5.1", - "rand_hc", -] - [[package]] name = "rand" version = "0.8.5" @@ -1817,18 +1750,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_chacha" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" -dependencies = [ - "ppv-lite86", - "rand_core 0.5.1", + "rand_chacha", + "rand_core", ] [[package]] @@ -1838,16 +1761,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_core" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" -dependencies = [ - "getrandom 0.1.16", + "rand_core", ] [[package]] @@ -1856,16 +1770,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.8", -] - -[[package]] -name = "rand_hc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" -dependencies = [ - "rand_core 0.5.1", + "getrandom", ] [[package]] @@ -1911,7 +1816,7 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ - "getrandom 0.2.8", + "getrandom", "redox_syscall 0.2.16", "thiserror", ] @@ -2041,7 +1946,7 @@ dependencies = [ [[package]] name = "rustpython-ast" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" +source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" dependencies = [ "is-macro", "malachite-bigint", @@ -2087,7 +1992,7 @@ dependencies = [ "once_cell", "parking_lot", "radium", - "rand 0.8.5", + "rand", "rustpython-format", "siphasher", "volatile", @@ -2153,7 +2058,7 @@ dependencies = [ [[package]] name = "rustpython-format" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" +source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" dependencies = [ "bitflags 2.4.0", "itertools 0.11.0", @@ -2180,7 +2085,7 @@ dependencies = [ [[package]] name = "rustpython-literal" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" +source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" dependencies = [ "hexf-parse", "is-macro", @@ -2192,7 +2097,7 @@ dependencies = [ [[package]] name = "rustpython-parser" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" +source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" dependencies = [ "anyhow", "is-macro", @@ -2215,7 +2120,7 @@ dependencies = [ [[package]] name = "rustpython-parser-core" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" +source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" dependencies = [ "is-macro", "memchr", @@ -2225,7 +2130,7 @@ dependencies = [ [[package]] name = "rustpython-parser-vendored" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" +source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" dependencies = [ "memchr", "once_cell", @@ -2254,7 +2159,7 @@ dependencies = [ "crc32fast", "crossbeam-utils", "csv-core", - "digest 0.10.6", + "digest", "dns-lookup", "dyn-clone", "flate2", @@ -2284,15 +2189,15 @@ dependencies = [ "parking_lot", "paste", "puruspe", - "rand 0.8.5", - "rand_core 0.6.4", + "rand", + "rand_core", "rustpython-common", "rustpython-derive", "rustpython-vm", "schannel", "sha-1", "sha2", - "sha3 0.10.6", + "sha3", "socket2", "system-configuration", "termios", @@ -2327,7 +2232,7 @@ dependencies = [ "exitcode", "flame", "flamer", - "getrandom 0.2.8", + "getrandom", "glob", "half", "hex", @@ -2349,7 +2254,7 @@ dependencies = [ "optional", "parking_lot", "paste", - "rand 0.8.5", + "rand", "result-like", "rustc_version", "rustpython-ast", @@ -2437,9 +2342,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.12" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" [[package]] name = "same-file" @@ -2538,7 +2443,7 @@ checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" dependencies = [ "cfg-if", "cpufeatures", - "digest 0.10.6", + "digest", ] [[package]] @@ -2549,19 +2454,7 @@ checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" dependencies = [ "cfg-if", "cpufeatures", - "digest 0.10.6", -] - -[[package]] -name = "sha3" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f81199417d4e5de3f04b1e871023acea7389672c4135918f05aa9cbf2f2fa809" -dependencies = [ - "block-buffer 0.9.0", - "digest 0.9.0", - "keccak", - "opaque-debug", + "digest", ] [[package]] @@ -2570,7 +2463,7 @@ version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bdf0c33fae925bdc080598b84bc15c55e7b9a4a43b3c704da051f977469691c9" dependencies = [ - "digest 0.10.6", + "digest", "keccak", ] @@ -3066,7 +2959,7 @@ dependencies = [ "getopts", "log", "phf_codegen", - "rand 0.8.5", + "rand", "time", ] @@ -3083,8 +2976,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1674845326ee10d37ca60470760d4288a6f80f304007d92e5c53bab78c9cfd79" dependencies = [ "atomic", - "getrandom 0.2.8", - "rand 0.8.5", + "getrandom", + "rand", "uuid-macro-internal", ] @@ -3134,12 +3027,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "wasi" -version = "0.9.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" From 99d992f70808e2c8e686f4a0d92926d4431ab5a4 Mon Sep 17 00:00:00 2001 From: NakanoMiku39 Date: Wed, 8 Nov 2023 20:41:23 +0800 Subject: [PATCH 146/893] Update test_atexit.py from CPython v3.12 --- Lib/test/test_atexit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/Lib/test/test_atexit.py b/Lib/test/test_atexit.py index 7ac063cfc7..913b7556be 100644 --- a/Lib/test/test_atexit.py +++ b/Lib/test/test_atexit.py @@ -1,6 +1,5 @@ import atexit import os -import sys import textwrap import unittest from test import support From fe617431f4b444d6b1d13fe76f20dc26cd287054 Mon Sep 17 00:00:00 2001 From: NakanoMiku <91249276+NakanoMiku39@users.noreply.github.com> Date: Thu, 9 Nov 2023 14:39:56 +0800 Subject: [PATCH 147/893] Update test__locale.py and test_atexit.py from CPython v3.12 (#5114) * Update test__locale.py from CPython v3.12 * Update Cargo.lock * Update test_atexit.py from CPython v3.12 --- Cargo.lock | 209 +++++++++------------------------------- Lib/test/test_atexit.py | 1 - Lib/test/test_locale.py | 118 +++++++++-------------- 3 files changed, 95 insertions(+), 233 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8688266ad1..c0c59f1552 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,7 +27,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" dependencies = [ "cfg-if", - "getrandom 0.2.8", + "getrandom", "once_cell", "version_check", ] @@ -142,17 +142,7 @@ version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" dependencies = [ - "digest 0.10.6", -] - -[[package]] -name = "block-buffer" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" -dependencies = [ - "block-padding", - "generic-array", + "digest", ] [[package]] @@ -164,12 +154,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "block-padding" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d696c370c750c948ada61c69a0ee2cbbb9c50b1019ddb86d9317157a99c2cae" - [[package]] name = "bstr" version = "0.2.17" @@ -655,22 +639,13 @@ dependencies = [ "syn 1.0.107", ] -[[package]] -name = "digest" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" -dependencies = [ - "generic-array", -] - [[package]] name = "digest" version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" dependencies = [ - "block-buffer 0.10.3", + "block-buffer", "crypto-common", "subtle", ] @@ -909,17 +884,6 @@ dependencies = [ "unicode-width", ] -[[package]] -name = "getrandom" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" -dependencies = [ - "cfg-if", - "libc", - "wasi 0.9.0+wasi-snapshot-preview1", -] - [[package]] name = "getrandom" version = "0.2.8" @@ -929,7 +893,7 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "wasm-bindgen", ] @@ -1064,15 +1028,6 @@ dependencies = [ "syn 2.0.32", ] -[[package]] -name = "itertools" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.10.5" @@ -1281,9 +1236,9 @@ dependencies = [ [[package]] name = "malachite" -version = "0.3.2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6cf7f4730c30071ba374fac86ad35b1cb7a0716f774737768667ea3fa1828e3" +checksum = "220cb36c52aa6eff45559df497abe0e2a4c1209f92279a746a399f622d7b95c7" dependencies = [ "malachite-base", "malachite-nz", @@ -1292,22 +1247,19 @@ dependencies = [ [[package]] name = "malachite-base" -version = "0.3.2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b06bfa98a4b4802af5a4263b4ad4660e28e51e8490f6354eb9336c70767e1c5" +checksum = "6538136c5daf04126d6be4899f7fe4879b7f8de896dd1b4210fe6de5b94f2555" dependencies = [ - "itertools 0.9.0", - "rand 0.7.3", - "rand_chacha 0.2.2", + "itertools 0.11.0", "ryu", - "sha3 0.9.1", ] [[package]] name = "malachite-bigint" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a5110aee54537b0cef214efbebdd7df79b7408db8eef4f6a4b6db9d0d8fc01b" +checksum = "76c3eca3b5df299486144c8423c45c24bdf9e82e2452c8a1eeda547c4d8b5d41" dependencies = [ "derive_more", "malachite", @@ -1318,22 +1270,22 @@ dependencies = [ [[package]] name = "malachite-nz" -version = "0.3.2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89e21c64b7af5be3dc8cef16f786243faf59459fe4ba93b44efdeb264e5ade4" +checksum = "5f0b05577b7a3f09433106460b10304f97fc572f0baabf6640e6cb1e23f5fc52" dependencies = [ "embed-doc-image", - "itertools 0.9.0", + "itertools 0.11.0", "malachite-base", ] [[package]] name = "malachite-q" -version = "0.3.2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3755e541d5134b5016594c9043094172c4dda9259b3ce824a7b8101941850360" +checksum = "a1cfdb4016292e6acd832eaee261175f3af8bbee62afeefe4420ebce4c440cb5" dependencies = [ - "itertools 0.9.0", + "itertools 0.11.0", "malachite-base", "malachite-nz", ] @@ -1356,7 +1308,7 @@ version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6365506850d44bff6e2fbcb5176cf63650e48bd45ef2fe2665ae1570e0f4b9ca" dependencies = [ - "digest 0.10.6", + "digest", ] [[package]] @@ -1407,7 +1359,7 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "12ca7f22ed370d5991a9caec16a83187e865bc8a532f889670337d5a5689e3a1" dependencies = [ - "rand_core 0.6.4", + "rand_core", ] [[package]] @@ -1526,12 +1478,6 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" -[[package]] -name = "opaque-debug" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" - [[package]] name = "openssl" version = "0.10.55" @@ -1657,7 +1603,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1181c94580fa345f50f19d738aaa39c0ed30a600d95cb2d3e23f94266f14fbf" dependencies = [ "phf_shared", - "rand 0.8.5", + "rand", ] [[package]] @@ -1758,9 +1704,9 @@ dependencies = [ [[package]] name = "puruspe" -version = "0.1.5" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b7e158a385023d209d6d5f2585c4b468f6dcb3dd5aca9b75c4f1678c05bb375" +checksum = "fe7765e19fb2ba6fd4373b8d90399f5321683ea7c11b598c6bbaa3a72e9c83b8" [[package]] name = "python3-sys" @@ -1797,19 +1743,6 @@ dependencies = [ "nibble_vec", ] -[[package]] -name = "rand" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" -dependencies = [ - "getrandom 0.1.16", - "libc", - "rand_chacha 0.2.2", - "rand_core 0.5.1", - "rand_hc", -] - [[package]] name = "rand" version = "0.8.5" @@ -1817,18 +1750,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_chacha" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" -dependencies = [ - "ppv-lite86", - "rand_core 0.5.1", + "rand_chacha", + "rand_core", ] [[package]] @@ -1838,16 +1761,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_core" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" -dependencies = [ - "getrandom 0.1.16", + "rand_core", ] [[package]] @@ -1856,16 +1770,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.8", -] - -[[package]] -name = "rand_hc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" -dependencies = [ - "rand_core 0.5.1", + "getrandom", ] [[package]] @@ -1911,7 +1816,7 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ - "getrandom 0.2.8", + "getrandom", "redox_syscall 0.2.16", "thiserror", ] @@ -2041,7 +1946,7 @@ dependencies = [ [[package]] name = "rustpython-ast" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" +source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" dependencies = [ "is-macro", "malachite-bigint", @@ -2087,7 +1992,7 @@ dependencies = [ "once_cell", "parking_lot", "radium", - "rand 0.8.5", + "rand", "rustpython-format", "siphasher", "volatile", @@ -2153,7 +2058,7 @@ dependencies = [ [[package]] name = "rustpython-format" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" +source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" dependencies = [ "bitflags 2.4.0", "itertools 0.11.0", @@ -2180,7 +2085,7 @@ dependencies = [ [[package]] name = "rustpython-literal" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" +source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" dependencies = [ "hexf-parse", "is-macro", @@ -2192,7 +2097,7 @@ dependencies = [ [[package]] name = "rustpython-parser" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" +source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" dependencies = [ "anyhow", "is-macro", @@ -2215,7 +2120,7 @@ dependencies = [ [[package]] name = "rustpython-parser-core" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" +source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" dependencies = [ "is-macro", "memchr", @@ -2225,7 +2130,7 @@ dependencies = [ [[package]] name = "rustpython-parser-vendored" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=13cae0af64d0a23de95f08c0210e97ad74d155e9#13cae0af64d0a23de95f08c0210e97ad74d155e9" +source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" dependencies = [ "memchr", "once_cell", @@ -2254,7 +2159,7 @@ dependencies = [ "crc32fast", "crossbeam-utils", "csv-core", - "digest 0.10.6", + "digest", "dns-lookup", "dyn-clone", "flate2", @@ -2284,15 +2189,15 @@ dependencies = [ "parking_lot", "paste", "puruspe", - "rand 0.8.5", - "rand_core 0.6.4", + "rand", + "rand_core", "rustpython-common", "rustpython-derive", "rustpython-vm", "schannel", "sha-1", "sha2", - "sha3 0.10.6", + "sha3", "socket2", "system-configuration", "termios", @@ -2327,7 +2232,7 @@ dependencies = [ "exitcode", "flame", "flamer", - "getrandom 0.2.8", + "getrandom", "glob", "half", "hex", @@ -2349,7 +2254,7 @@ dependencies = [ "optional", "parking_lot", "paste", - "rand 0.8.5", + "rand", "result-like", "rustc_version", "rustpython-ast", @@ -2437,9 +2342,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.12" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" [[package]] name = "same-file" @@ -2538,7 +2443,7 @@ checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" dependencies = [ "cfg-if", "cpufeatures", - "digest 0.10.6", + "digest", ] [[package]] @@ -2549,19 +2454,7 @@ checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" dependencies = [ "cfg-if", "cpufeatures", - "digest 0.10.6", -] - -[[package]] -name = "sha3" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f81199417d4e5de3f04b1e871023acea7389672c4135918f05aa9cbf2f2fa809" -dependencies = [ - "block-buffer 0.9.0", - "digest 0.9.0", - "keccak", - "opaque-debug", + "digest", ] [[package]] @@ -2570,7 +2463,7 @@ version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bdf0c33fae925bdc080598b84bc15c55e7b9a4a43b3c704da051f977469691c9" dependencies = [ - "digest 0.10.6", + "digest", "keccak", ] @@ -3066,7 +2959,7 @@ dependencies = [ "getopts", "log", "phf_codegen", - "rand 0.8.5", + "rand", "time", ] @@ -3083,8 +2976,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1674845326ee10d37ca60470760d4288a6f80f304007d92e5c53bab78c9cfd79" dependencies = [ "atomic", - "getrandom 0.2.8", - "rand 0.8.5", + "getrandom", + "rand", "uuid-macro-internal", ] @@ -3134,12 +3027,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "wasi" -version = "0.9.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/Lib/test/test_atexit.py b/Lib/test/test_atexit.py index 7ac063cfc7..913b7556be 100644 --- a/Lib/test/test_atexit.py +++ b/Lib/test/test_atexit.py @@ -1,6 +1,5 @@ import atexit import os -import sys import textwrap import unittest from test import support diff --git a/Lib/test/test_locale.py b/Lib/test/test_locale.py index bc8a7a35fb..b0d7998559 100644 --- a/Lib/test/test_locale.py +++ b/Lib/test/test_locale.py @@ -141,18 +141,9 @@ class BaseFormattingTest(object): # Utility functions for formatting tests # - def _test_formatfunc(self, format, value, out, func, **format_opts): - self.assertEqual( - func(format, value, **format_opts), out) - - def _test_format(self, format, value, out, **format_opts): - with check_warnings(('', DeprecationWarning)): - self._test_formatfunc(format, value, out, - func=locale.format, **format_opts) - def _test_format_string(self, format, value, out, **format_opts): - self._test_formatfunc(format, value, out, - func=locale.format_string, **format_opts) + self.assertEqual( + locale.format_string(format, value, **format_opts), out) def _test_currency(self, value, out, **format_opts): self.assertEqual(locale.currency(value, **format_opts), out) @@ -166,44 +157,40 @@ def setUp(self): self.sep = locale.localeconv()['thousands_sep'] def test_grouping(self): - self._test_format("%f", 1024, grouping=1, out='1%s024.000000' % self.sep) - self._test_format("%f", 102, grouping=1, out='102.000000') - self._test_format("%f", -42, grouping=1, out='-42.000000') - self._test_format("%+f", -42, grouping=1, out='-42.000000') + self._test_format_string("%f", 1024, grouping=1, out='1%s024.000000' % self.sep) + self._test_format_string("%f", 102, grouping=1, out='102.000000') + self._test_format_string("%f", -42, grouping=1, out='-42.000000') + self._test_format_string("%+f", -42, grouping=1, out='-42.000000') def test_grouping_and_padding(self): - self._test_format("%20.f", -42, grouping=1, out='-42'.rjust(20)) + self._test_format_string("%20.f", -42, grouping=1, out='-42'.rjust(20)) if self.sep: - self._test_format("%+10.f", -4200, grouping=1, + self._test_format_string("%+10.f", -4200, grouping=1, out=('-4%s200' % self.sep).rjust(10)) - self._test_format("%-10.f", -4200, grouping=1, + self._test_format_string("%-10.f", -4200, grouping=1, out=('-4%s200' % self.sep).ljust(10)) def test_integer_grouping(self): - self._test_format("%d", 4200, grouping=True, out='4%s200' % self.sep) - self._test_format("%+d", 4200, grouping=True, out='+4%s200' % self.sep) - self._test_format("%+d", -4200, grouping=True, out='-4%s200' % self.sep) + self._test_format_string("%d", 4200, grouping=True, out='4%s200' % self.sep) + self._test_format_string("%+d", 4200, grouping=True, out='+4%s200' % self.sep) + self._test_format_string("%+d", -4200, grouping=True, out='-4%s200' % self.sep) def test_integer_grouping_and_padding(self): - self._test_format("%10d", 4200, grouping=True, + self._test_format_string("%10d", 4200, grouping=True, out=('4%s200' % self.sep).rjust(10)) - self._test_format("%-10d", -4200, grouping=True, + self._test_format_string("%-10d", -4200, grouping=True, out=('-4%s200' % self.sep).ljust(10)) def test_simple(self): - self._test_format("%f", 1024, grouping=0, out='1024.000000') - self._test_format("%f", 102, grouping=0, out='102.000000') - self._test_format("%f", -42, grouping=0, out='-42.000000') - self._test_format("%+f", -42, grouping=0, out='-42.000000') + self._test_format_string("%f", 1024, grouping=0, out='1024.000000') + self._test_format_string("%f", 102, grouping=0, out='102.000000') + self._test_format_string("%f", -42, grouping=0, out='-42.000000') + self._test_format_string("%+f", -42, grouping=0, out='-42.000000') def test_padding(self): - self._test_format("%20.f", -42, grouping=0, out='-42'.rjust(20)) - self._test_format("%+10.f", -4200, grouping=0, out='-4200'.rjust(10)) - self._test_format("%-10.f", 4200, grouping=0, out='4200'.ljust(10)) - - def test_format_deprecation(self): - with self.assertWarns(DeprecationWarning): - locale.format("%-10.f", 4200, grouping=True) + self._test_format_string("%20.f", -42, grouping=0, out='-42'.rjust(20)) + self._test_format_string("%+10.f", -4200, grouping=0, out='-4200'.rjust(10)) + self._test_format_string("%-10.f", 4200, grouping=0, out='4200'.ljust(10)) def test_complex_formatting(self): # Spaces in formatting string @@ -230,20 +217,9 @@ def test_complex_formatting(self): out='int 1%s000 float 1%s000.00 str str' % (self.sep, self.sep)) - -class TestFormatPatternArg(unittest.TestCase): - # Test handling of pattern argument of format - - def test_onlyOnePattern(self): - with check_warnings(('', DeprecationWarning)): - # Issue 2522: accept exactly one % pattern, and no extra chars. - self.assertRaises(ValueError, locale.format, "%f\n", 'foo') - self.assertRaises(ValueError, locale.format, "%f\r", 'foo') - self.assertRaises(ValueError, locale.format, "%f\r\n", 'foo') - self.assertRaises(ValueError, locale.format, " %f", 'foo') - self.assertRaises(ValueError, locale.format, "%fg", 'foo') - self.assertRaises(ValueError, locale.format, "%^g", 'foo') - self.assertRaises(ValueError, locale.format, "%f%%", 'foo') + self._test_format_string("total=%i%%", 100, out='total=100%') + self._test_format_string("newline: %i\n", 3, out='newline: 3\n') + self._test_format_string("extra: %ii", 3, out='extra: 3i') class TestLocaleFormatString(unittest.TestCase): @@ -292,45 +268,45 @@ class TestCNumberFormatting(CCookedTest, BaseFormattingTest): # Test number formatting with a cooked "C" locale. def test_grouping(self): - self._test_format("%.2f", 12345.67, grouping=True, out='12345.67') + self._test_format_string("%.2f", 12345.67, grouping=True, out='12345.67') def test_grouping_and_padding(self): - self._test_format("%9.2f", 12345.67, grouping=True, out=' 12345.67') + self._test_format_string("%9.2f", 12345.67, grouping=True, out=' 12345.67') class TestFrFRNumberFormatting(FrFRCookedTest, BaseFormattingTest): # Test number formatting with a cooked "fr_FR" locale. def test_decimal_point(self): - self._test_format("%.2f", 12345.67, out='12345,67') + self._test_format_string("%.2f", 12345.67, out='12345,67') def test_grouping(self): - self._test_format("%.2f", 345.67, grouping=True, out='345,67') - self._test_format("%.2f", 12345.67, grouping=True, out='12 345,67') + self._test_format_string("%.2f", 345.67, grouping=True, out='345,67') + self._test_format_string("%.2f", 12345.67, grouping=True, out='12 345,67') def test_grouping_and_padding(self): - self._test_format("%6.2f", 345.67, grouping=True, out='345,67') - self._test_format("%7.2f", 345.67, grouping=True, out=' 345,67') - self._test_format("%8.2f", 12345.67, grouping=True, out='12 345,67') - self._test_format("%9.2f", 12345.67, grouping=True, out='12 345,67') - self._test_format("%10.2f", 12345.67, grouping=True, out=' 12 345,67') - self._test_format("%-6.2f", 345.67, grouping=True, out='345,67') - self._test_format("%-7.2f", 345.67, grouping=True, out='345,67 ') - self._test_format("%-8.2f", 12345.67, grouping=True, out='12 345,67') - self._test_format("%-9.2f", 12345.67, grouping=True, out='12 345,67') - self._test_format("%-10.2f", 12345.67, grouping=True, out='12 345,67 ') + self._test_format_string("%6.2f", 345.67, grouping=True, out='345,67') + self._test_format_string("%7.2f", 345.67, grouping=True, out=' 345,67') + self._test_format_string("%8.2f", 12345.67, grouping=True, out='12 345,67') + self._test_format_string("%9.2f", 12345.67, grouping=True, out='12 345,67') + self._test_format_string("%10.2f", 12345.67, grouping=True, out=' 12 345,67') + self._test_format_string("%-6.2f", 345.67, grouping=True, out='345,67') + self._test_format_string("%-7.2f", 345.67, grouping=True, out='345,67 ') + self._test_format_string("%-8.2f", 12345.67, grouping=True, out='12 345,67') + self._test_format_string("%-9.2f", 12345.67, grouping=True, out='12 345,67') + self._test_format_string("%-10.2f", 12345.67, grouping=True, out='12 345,67 ') def test_integer_grouping(self): - self._test_format("%d", 200, grouping=True, out='200') - self._test_format("%d", 4200, grouping=True, out='4 200') + self._test_format_string("%d", 200, grouping=True, out='200') + self._test_format_string("%d", 4200, grouping=True, out='4 200') def test_integer_grouping_and_padding(self): - self._test_format("%4d", 4200, grouping=True, out='4 200') - self._test_format("%5d", 4200, grouping=True, out='4 200') - self._test_format("%10d", 4200, grouping=True, out='4 200'.rjust(10)) - self._test_format("%-4d", 4200, grouping=True, out='4 200') - self._test_format("%-5d", 4200, grouping=True, out='4 200') - self._test_format("%-10d", 4200, grouping=True, out='4 200'.ljust(10)) + self._test_format_string("%4d", 4200, grouping=True, out='4 200') + self._test_format_string("%5d", 4200, grouping=True, out='4 200') + self._test_format_string("%10d", 4200, grouping=True, out='4 200'.rjust(10)) + self._test_format_string("%-4d", 4200, grouping=True, out='4 200') + self._test_format_string("%-5d", 4200, grouping=True, out='4 200') + self._test_format_string("%-10d", 4200, grouping=True, out='4 200'.ljust(10)) def test_currency(self): euro = '\u20ac' From b5e21ab136ddedfb94aad34ab9d043afc4ceef1e Mon Sep 17 00:00:00 2001 From: NakanoMiku39 Date: Thu, 9 Nov 2023 14:59:59 +0800 Subject: [PATCH 148/893] Add Lib/test/test_opcodes.py cpython version: 3.12 --- Lib/test/test_opcodes.py | 138 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 Lib/test/test_opcodes.py diff --git a/Lib/test/test_opcodes.py b/Lib/test/test_opcodes.py new file mode 100644 index 0000000000..72488b2bb6 --- /dev/null +++ b/Lib/test/test_opcodes.py @@ -0,0 +1,138 @@ +# Python test set -- part 2, opcodes + +import unittest +from test import support +from test.typinganndata import ann_module + +class OpcodeTest(unittest.TestCase): + + def test_try_inside_for_loop(self): + n = 0 + for i in range(10): + n = n+i + try: 1/0 + except NameError: pass + except ZeroDivisionError: pass + except TypeError: pass + try: pass + except: pass + try: pass + finally: pass + n = n+i + if n != 90: + self.fail('try inside for') + + def test_setup_annotations_line(self): + # check that SETUP_ANNOTATIONS does not create spurious line numbers + try: + with open(ann_module.__file__, encoding="utf-8") as f: + txt = f.read() + co = compile(txt, ann_module.__file__, 'exec') + self.assertEqual(co.co_firstlineno, 1) + except OSError: + pass + + def test_default_annotations_exist(self): + class C: pass + self.assertEqual(C.__annotations__, {}) + + def test_use_existing_annotations(self): + ns = {'__annotations__': {1: 2}} + exec('x: int', ns) + self.assertEqual(ns['__annotations__'], {'x': int, 1: 2}) + + def test_do_not_recreate_annotations(self): + # Don't rely on the existence of the '__annotations__' global. + with support.swap_item(globals(), '__annotations__', {}): + del globals()['__annotations__'] + class C: + del __annotations__ + with self.assertRaises(NameError): + x: int + + def test_raise_class_exceptions(self): + + class AClass(Exception): pass + class BClass(AClass): pass + class CClass(Exception): pass + class DClass(AClass): + def __init__(self, ignore): + pass + + try: raise AClass() + except: pass + + try: raise AClass() + except AClass: pass + + try: raise BClass() + except AClass: pass + + try: raise BClass() + except CClass: self.fail() + except: pass + + a = AClass() + b = BClass() + + try: + raise b + except AClass as v: + self.assertEqual(v, b) + else: + self.fail("no exception") + + # not enough arguments + ##try: raise BClass, a + ##except TypeError: pass + ##else: self.fail("no exception") + + try: raise DClass(a) + except DClass as v: + self.assertIsInstance(v, DClass) + else: + self.fail("no exception") + + def test_compare_function_objects(self): + + f = eval('lambda: None') + g = eval('lambda: None') + self.assertNotEqual(f, g) + + f = eval('lambda a: a') + g = eval('lambda a: a') + self.assertNotEqual(f, g) + + f = eval('lambda a=1: a') + g = eval('lambda a=1: a') + self.assertNotEqual(f, g) + + f = eval('lambda: 0') + g = eval('lambda: 1') + self.assertNotEqual(f, g) + + f = eval('lambda: None') + g = eval('lambda a: None') + self.assertNotEqual(f, g) + + f = eval('lambda a: None') + g = eval('lambda b: None') + self.assertNotEqual(f, g) + + f = eval('lambda a: None') + g = eval('lambda a=None: None') + self.assertNotEqual(f, g) + + f = eval('lambda a=0: None') + g = eval('lambda a=1: None') + self.assertNotEqual(f, g) + + def test_modulo_of_string_subclasses(self): + class MyString(str): + def __mod__(self, value): + return 42 + self.assertEqual(MyString() % 3, 42) + + +if __name__ == '__main__': + unittest.main() From 06b9b4938d381f67fb36cf5d7c1d9ffbda007100 Mon Sep 17 00:00:00 2001 From: NakanoMiku39 Date: Thu, 9 Nov 2023 15:05:04 +0800 Subject: [PATCH 149/893] Add Lib/test/typinganndata folder cpython version: 3.12 --- Lib/test/typinganndata/__init__.py | 0 Lib/test/typinganndata/ann_module.py | 62 +++++++++++++++++++++++++++ Lib/test/typinganndata/ann_module2.py | 36 ++++++++++++++++ Lib/test/typinganndata/ann_module3.py | 18 ++++++++ Lib/test/typinganndata/ann_module4.py | 5 +++ Lib/test/typinganndata/ann_module5.py | 10 +++++ Lib/test/typinganndata/ann_module6.py | 7 +++ Lib/test/typinganndata/ann_module7.py | 11 +++++ Lib/test/typinganndata/ann_module8.py | 10 +++++ Lib/test/typinganndata/ann_module9.py | 14 ++++++ 10 files changed, 173 insertions(+) create mode 100644 Lib/test/typinganndata/__init__.py create mode 100644 Lib/test/typinganndata/ann_module.py create mode 100644 Lib/test/typinganndata/ann_module2.py create mode 100644 Lib/test/typinganndata/ann_module3.py create mode 100644 Lib/test/typinganndata/ann_module4.py create mode 100644 Lib/test/typinganndata/ann_module5.py create mode 100644 Lib/test/typinganndata/ann_module6.py create mode 100644 Lib/test/typinganndata/ann_module7.py create mode 100644 Lib/test/typinganndata/ann_module8.py create mode 100644 Lib/test/typinganndata/ann_module9.py diff --git a/Lib/test/typinganndata/__init__.py b/Lib/test/typinganndata/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/Lib/test/typinganndata/ann_module.py b/Lib/test/typinganndata/ann_module.py new file mode 100644 index 0000000000..5081e6b583 --- /dev/null +++ b/Lib/test/typinganndata/ann_module.py @@ -0,0 +1,62 @@ + + +""" +The module for testing variable annotations. +Empty lines above are for good reason (testing for correct line numbers) +""" + +from typing import Optional +from functools import wraps + +__annotations__[1] = 2 + +class C: + + x = 5; y: Optional['C'] = None + +from typing import Tuple +x: int = 5; y: str = x; f: Tuple[int, int] + +class M(type): + + __annotations__['123'] = 123 + o: type = object + +(pars): bool = True + +class D(C): + j: str = 'hi'; k: str= 'bye' + +from types import new_class +h_class = new_class('H', (C,)) +j_class = new_class('J') + +class F(): + z: int = 5 + def __init__(self, x): + pass + +class Y(F): + def __init__(self): + super(F, self).__init__(123) + +class Meta(type): + def __new__(meta, name, bases, namespace): + return super().__new__(meta, name, bases, namespace) + +class S(metaclass = Meta): + x: str = 'something' + y: str = 'something else' + +def foo(x: int = 10): + def bar(y: List[str]): + x: str = 'yes' + bar() + +def dec(func): + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + return wrapper + +u: int | float diff --git a/Lib/test/typinganndata/ann_module2.py b/Lib/test/typinganndata/ann_module2.py new file mode 100644 index 0000000000..76cf5b3ad9 --- /dev/null +++ b/Lib/test/typinganndata/ann_module2.py @@ -0,0 +1,36 @@ +""" +Some correct syntax for variable annotation here. +More examples are in test_grammar and test_parser. +""" + +from typing import no_type_check, ClassVar + +i: int = 1 +j: int +x: float = i/10 + +def f(): + class C: ... + return C() + +f().new_attr: object = object() + +class C: + def __init__(self, x: int) -> None: + self.x = x + +c = C(5) +c.new_attr: int = 10 + +__annotations__ = {} + + +@no_type_check +class NTC: + def meth(self, param: complex) -> None: + ... + +class CV: + var: ClassVar['CV'] + +CV.var = CV() diff --git a/Lib/test/typinganndata/ann_module3.py b/Lib/test/typinganndata/ann_module3.py new file mode 100644 index 0000000000..eccd7be22d --- /dev/null +++ b/Lib/test/typinganndata/ann_module3.py @@ -0,0 +1,18 @@ +""" +Correct syntax for variable annotation that should fail at runtime +in a certain manner. More examples are in test_grammar and test_parser. +""" + +def f_bad_ann(): + __annotations__[1] = 2 + +class C_OK: + def __init__(self, x: int) -> None: + self.x: no_such_name = x # This one is OK as proposed by Guido + +class D_bad_ann: + def __init__(self, x: int) -> None: + sfel.y: int = 0 + +def g_bad_ann(): + no_such_name.attr: int = 0 diff --git a/Lib/test/typinganndata/ann_module4.py b/Lib/test/typinganndata/ann_module4.py new file mode 100644 index 0000000000..13e9aee54c --- /dev/null +++ b/Lib/test/typinganndata/ann_module4.py @@ -0,0 +1,5 @@ +# This ann_module isn't for test_typing, +# it's for test_module + +a:int=3 +b:str=4 diff --git a/Lib/test/typinganndata/ann_module5.py b/Lib/test/typinganndata/ann_module5.py new file mode 100644 index 0000000000..837041e121 --- /dev/null +++ b/Lib/test/typinganndata/ann_module5.py @@ -0,0 +1,10 @@ +# Used by test_typing to verify that Final wrapped in ForwardRef works. + +from __future__ import annotations + +from typing import Final + +name: Final[str] = "final" + +class MyClass: + value: Final = 3000 diff --git a/Lib/test/typinganndata/ann_module6.py b/Lib/test/typinganndata/ann_module6.py new file mode 100644 index 0000000000..679175669b --- /dev/null +++ b/Lib/test/typinganndata/ann_module6.py @@ -0,0 +1,7 @@ +# Tests that top-level ClassVar is not allowed + +from __future__ import annotations + +from typing import ClassVar + +wrong: ClassVar[int] = 1 diff --git a/Lib/test/typinganndata/ann_module7.py b/Lib/test/typinganndata/ann_module7.py new file mode 100644 index 0000000000..8f890cd280 --- /dev/null +++ b/Lib/test/typinganndata/ann_module7.py @@ -0,0 +1,11 @@ +# Tests class have ``__text_signature__`` + +from __future__ import annotations + +DEFAULT_BUFFER_SIZE = 8192 + +class BufferedReader(object): + """BufferedReader(raw, buffer_size=DEFAULT_BUFFER_SIZE)\n--\n\n + Create a new buffered reader using the given readable raw IO object. + """ + pass diff --git a/Lib/test/typinganndata/ann_module8.py b/Lib/test/typinganndata/ann_module8.py new file mode 100644 index 0000000000..bd03148137 --- /dev/null +++ b/Lib/test/typinganndata/ann_module8.py @@ -0,0 +1,10 @@ +# Test `@no_type_check`, +# see https://bugs.python.org/issue46571 + +class NoTypeCheck_Outer: + class Inner: + x: int + + +def NoTypeCheck_function(arg: int) -> int: + ... diff --git a/Lib/test/typinganndata/ann_module9.py b/Lib/test/typinganndata/ann_module9.py new file mode 100644 index 0000000000..952217393e --- /dev/null +++ b/Lib/test/typinganndata/ann_module9.py @@ -0,0 +1,14 @@ +# Test ``inspect.formatannotation`` +# https://github.com/python/cpython/issues/96073 + +from typing import Union, List + +ann = Union[List[str], int] + +# mock typing._type_repr behaviour +class A: ... + +A.__module__ = 'testModule.typing' +A.__qualname__ = 'A' + +ann1 = Union[List[A], int] From a96926cd770d2785c472ba11e000be511daffc82 Mon Sep 17 00:00:00 2001 From: NakanoMiku39 Date: Thu, 9 Nov 2023 15:06:08 +0800 Subject: [PATCH 150/893] Remove Lib/test/test_opcode.py This file is a duplicate of test_codeop.py with a few changes --- Lib/test/test_opcode.py | 356 ---------------------------------------- 1 file changed, 356 deletions(-) delete mode 100644 Lib/test/test_opcode.py diff --git a/Lib/test/test_opcode.py b/Lib/test/test_opcode.py deleted file mode 100644 index 170eb1cb1d..0000000000 --- a/Lib/test/test_opcode.py +++ /dev/null @@ -1,356 +0,0 @@ -""" - Test cases for codeop.py - Nick Mathewson -""" -import sys -import unittest -import warnings -from test import support -from test.support import warnings_helper - -from codeop import compile_command, PyCF_DONT_IMPLY_DEDENT -import io - -if support.is_jython: - - def unify_callables(d): - for n,v in d.items(): - if hasattr(v, '__call__'): - d[n] = True - return d - -class CodeopTests(unittest.TestCase): - - def assertValid(self, str, symbol='single'): - '''succeed iff str is a valid piece of code''' - if support.is_jython: - code = compile_command(str, "", symbol) - self.assertTrue(code) - if symbol == "single": - d,r = {},{} - saved_stdout = sys.stdout - sys.stdout = io.StringIO() - try: - exec(code, d) - exec(compile(str,"","single"), r) - finally: - sys.stdout = saved_stdout - elif symbol == 'eval': - ctx = {'a': 2} - d = { 'value': eval(code,ctx) } - r = { 'value': eval(str,ctx) } - self.assertEqual(unify_callables(r),unify_callables(d)) - else: - expected = compile(str, "", symbol, PyCF_DONT_IMPLY_DEDENT) - self.assertEqual(compile_command(str, "", symbol), expected) - - def assertIncomplete(self, str, symbol='single'): - '''succeed iff str is the start of a valid piece of code''' - self.assertEqual(compile_command(str, symbol=symbol), None) - - def assertInvalid(self, str, symbol='single', is_syntax=1): - '''succeed iff str is the start of an invalid piece of code''' - try: - compile_command(str,symbol=symbol) - self.fail("No exception raised for invalid code") - except SyntaxError: - self.assertTrue(is_syntax) - except OverflowError: - self.assertTrue(not is_syntax) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_valid(self): - av = self.assertValid - - # special case - if not support.is_jython: - self.assertEqual(compile_command(""), - compile("pass", "", 'single', - PyCF_DONT_IMPLY_DEDENT)) - self.assertEqual(compile_command("\n"), - compile("pass", "", 'single', - PyCF_DONT_IMPLY_DEDENT)) - else: - av("") - av("\n") - - av("a = 1") - av("\na = 1") - av("a = 1\n") - av("a = 1\n\n") - av("\n\na = 1\n\n") - - av("def x():\n pass\n") - av("if 1:\n pass\n") - - av("\n\nif 1: pass\n") - av("\n\nif 1: pass\n\n") - - av("def x():\n\n pass\n") - av("def x():\n pass\n \n") - av("def x():\n pass\n \n") - - av("pass\n") - av("3**3\n") - - av("if 9==3:\n pass\nelse:\n pass\n") - av("if 1:\n pass\n if 1:\n pass\n else:\n pass\n") - - av("#a\n#b\na = 3\n") - av("#a\n\n \na=3\n") - av("a=3\n\n") - av("a = 9+ \\\n3") - - av("3**3","eval") - av("(lambda z: \n z**3)","eval") - - av("9+ \\\n3","eval") - av("9+ \\\n3\n","eval") - - av("\n\na**3","eval") - av("\n \na**3","eval") - av("#a\n#b\na**3","eval") - - av("\n\na = 1\n\n") - av("\n\nif 1: a=1\n\n") - - av("if 1:\n pass\n if 1:\n pass\n else:\n pass\n") - av("#a\n\n \na=3\n\n") - - av("\n\na**3","eval") - av("\n \na**3","eval") - av("#a\n#b\na**3","eval") - - av("def f():\n try: pass\n finally: [x for x in (1,2)]\n") - av("def f():\n pass\n#foo\n") - av("@a.b.c\ndef f():\n pass\n") - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_incomplete(self): - ai = self.assertIncomplete - - ai("(a **") - ai("(a,b,") - ai("(a,b,(") - ai("(a,b,(") - ai("a = (") - ai("a = {") - ai("b + {") - - ai("print([1,\n2,") - ai("print({1:1,\n2:3,") - ai("print((1,\n2,") - - ai("if 9==3:\n pass\nelse:") - ai("if 9==3:\n pass\nelse:\n") - ai("if 9==3:\n pass\nelse:\n pass") - ai("if 1:") - ai("if 1:\n") - ai("if 1:\n pass\n if 1:\n pass\n else:") - ai("if 1:\n pass\n if 1:\n pass\n else:\n") - ai("if 1:\n pass\n if 1:\n pass\n else:\n pass") - - ai("def x():") - ai("def x():\n") - ai("def x():\n\n") - - ai("def x():\n pass") - ai("def x():\n pass\n ") - ai("def x():\n pass\n ") - ai("\n\ndef x():\n pass") - - ai("a = 9+ \\") - ai("a = 'a\\") - ai("a = '''xy") - - ai("","eval") - ai("\n","eval") - ai("(","eval") - ai("(9+","eval") - ai("9+ \\","eval") - ai("lambda z: \\","eval") - - ai("if True:\n if True:\n if True: \n") - - ai("@a(") - ai("@a(b") - ai("@a(b,") - ai("@a(b,c") - ai("@a(b,c,") - - ai("from a import (") - ai("from a import (b") - ai("from a import (b,") - ai("from a import (b,c") - ai("from a import (b,c,") - - ai("[") - ai("[a") - ai("[a,") - ai("[a,b") - ai("[a,b,") - - ai("{") - ai("{a") - ai("{a:") - ai("{a:b") - ai("{a:b,") - ai("{a:b,c") - ai("{a:b,c:") - ai("{a:b,c:d") - ai("{a:b,c:d,") - - ai("a(") - ai("a(b") - ai("a(b,") - ai("a(b,c") - ai("a(b,c,") - - ai("a[") - ai("a[b") - ai("a[b,") - ai("a[b:") - ai("a[b:c") - ai("a[b:c:") - ai("a[b:c:d") - - ai("def a(") - ai("def a(b") - ai("def a(b,") - ai("def a(b,c") - ai("def a(b,c,") - - ai("(") - ai("(a") - ai("(a,") - ai("(a,b") - ai("(a,b,") - - ai("if a:\n pass\nelif b:") - ai("if a:\n pass\nelif b:\n pass\nelse:") - - ai("while a:") - ai("while a:\n pass\nelse:") - - ai("for a in b:") - ai("for a in b:\n pass\nelse:") - - ai("try:") - ai("try:\n pass\nexcept:") - ai("try:\n pass\nfinally:") - ai("try:\n pass\nexcept:\n pass\nfinally:") - - ai("with a:") - ai("with a as b:") - - ai("class a:") - ai("class a(") - ai("class a(b") - ai("class a(b,") - ai("class a():") - - ai("[x for") - ai("[x for x in") - ai("[x for x in (") - - ai("(x for") - ai("(x for x in") - ai("(x for x in (") - - def test_invalid(self): - ai = self.assertInvalid - ai("a b") - - ai("a @") - ai("a b @") - ai("a ** @") - - ai("a = ") - ai("a = 9 +") - - ai("def x():\n\npass\n") - - ai("\n\n if 1: pass\n\npass") - - ai("a = 9+ \\\n") - ai("a = 'a\\ ") - ai("a = 'a\\\n") - - ai("a = 1","eval") - ai("]","eval") - ai("())","eval") - ai("[}","eval") - ai("9+","eval") - ai("lambda z:","eval") - ai("a b","eval") - - ai("return 2.3") - ai("if (a == 1 and b = 2): pass") - - ai("del 1") - ai("del (1,)") - ai("del [1]") - ai("del '1'") - - ai("[i for i in range(10)] = (1, 2, 3)") - - def test_invalid_exec(self): - ai = self.assertInvalid - ai("raise = 4", symbol="exec") - ai('def a-b', symbol='exec') - ai('await?', symbol='exec') - ai('=!=', symbol='exec') - ai('a await raise b', symbol='exec') - ai('a await raise b?+1', symbol='exec') - - def test_filename(self): - self.assertEqual(compile_command("a = 1\n", "abc").co_filename, - compile("a = 1\n", "abc", 'single').co_filename) - self.assertNotEqual(compile_command("a = 1\n", "abc").co_filename, - compile("a = 1\n", "def", 'single').co_filename) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_warning(self): - # Test that the warning is only returned once. - with warnings_helper.check_warnings( - (".*literal", SyntaxWarning), - (".*invalid", DeprecationWarning), - ) as w: - compile_command(r"'\e' is 0") - self.assertEqual(len(w.warnings), 2) - - # bpo-41520: check SyntaxWarning treated as an SyntaxError - with warnings.catch_warnings(), self.assertRaises(SyntaxError): - warnings.simplefilter('error', SyntaxWarning) - compile_command('1 is 1', symbol='exec') - - # Check DeprecationWarning treated as an SyntaxError - with warnings.catch_warnings(), self.assertRaises(SyntaxError): - warnings.simplefilter('error', DeprecationWarning) - compile_command(r"'\e'", symbol='exec') - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_incomplete_warning(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - self.assertIncomplete("'\\e' + (") - self.assertEqual(w, []) - - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_invalid_warning(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - self.assertInvalid("'\\e' 1") - self.assertEqual(len(w), 1) - self.assertEqual(w[0].category, DeprecationWarning) - self.assertRegex(str(w[0].message), 'invalid escape sequence') - self.assertEqual(w[0].filename, '') - - -if __name__ == "__main__": - unittest.main() From 8ff0872d41c3692435e3bce37db8cf2fd4ffb345 Mon Sep 17 00:00:00 2001 From: NakanoMiku39 Date: Thu, 9 Nov 2023 15:11:35 +0800 Subject: [PATCH 151/893] Update test_codeop.py from CPython v3.12 --- Lib/test/test_codeop.py | 149 ++++++++++++++++++++-------------------- 1 file changed, 74 insertions(+), 75 deletions(-) diff --git a/Lib/test/test_codeop.py b/Lib/test/test_codeop.py index 671148ce2b..f39f51e824 100644 --- a/Lib/test/test_codeop.py +++ b/Lib/test/test_codeop.py @@ -2,48 +2,19 @@ Test cases for codeop.py Nick Mathewson """ -import sys import unittest import warnings -from test import support from test.support import warnings_helper +from textwrap import dedent from codeop import compile_command, PyCF_DONT_IMPLY_DEDENT -import io - -if support.is_jython: - - def unify_callables(d): - for n, v in d.items(): - if hasattr(v, '__call__'): - d[n] = True - return d - class CodeopTests(unittest.TestCase): def assertValid(self, str, symbol='single'): '''succeed iff str is a valid piece of code''' - if support.is_jython: - code = compile_command(str, "", symbol) - self.assertTrue(code) - if symbol == "single": - d, r = {}, {} - saved_stdout = sys.stdout - sys.stdout = io.StringIO() - try: - exec(code, d) - exec(compile(str, "", "single"), r) - finally: - sys.stdout = saved_stdout - elif symbol == 'eval': - ctx = {'a': 2} - d = {'value': eval(code, ctx)} - r = {'value': eval(str, ctx)} - self.assertEqual(unify_callables(r), unify_callables(d)) - else: - expected = compile(str, "", symbol, PyCF_DONT_IMPLY_DEDENT) - self.assertEqual(compile_command(str, "", symbol), expected) + expected = compile(str, "", symbol, PyCF_DONT_IMPLY_DEDENT) + self.assertEqual(compile_command(str, "", symbol), expected) def assertIncomplete(self, str, symbol='single'): '''succeed iff str is the start of a valid piece of code''' @@ -52,30 +23,23 @@ def assertIncomplete(self, str, symbol='single'): def assertInvalid(self, str, symbol='single', is_syntax=1): '''succeed iff str is the start of an invalid piece of code''' try: - compile_command(str, symbol=symbol) + compile_command(str,symbol=symbol) self.fail("No exception raised for invalid code") except SyntaxError: self.assertTrue(is_syntax) except OverflowError: self.assertTrue(not is_syntax) - # TODO: RUSTPYTHON - - @unittest.expectedFailure def test_valid(self): av = self.assertValid # special case - if not support.is_jython: - self.assertEqual(compile_command(""), - compile("pass", "", 'single', - PyCF_DONT_IMPLY_DEDENT)) - self.assertEqual(compile_command("\n"), - compile("pass", "", 'single', - PyCF_DONT_IMPLY_DEDENT)) - else: - av("") - av("\n") + self.assertEqual(compile_command(""), + compile("pass", "", 'single', + PyCF_DONT_IMPLY_DEDENT)) + self.assertEqual(compile_command("\n"), + compile("pass", "", 'single', + PyCF_DONT_IMPLY_DEDENT)) av("a = 1") av("\na = 1") @@ -104,15 +68,15 @@ def test_valid(self): av("a=3\n\n") av("a = 9+ \\\n3") - av("3**3", "eval") - av("(lambda z: \n z**3)", "eval") + av("3**3","eval") + av("(lambda z: \n z**3)","eval") - av("9+ \\\n3", "eval") - av("9+ \\\n3\n", "eval") + av("9+ \\\n3","eval") + av("9+ \\\n3\n","eval") - av("\n\na**3", "eval") - av("\n \na**3", "eval") - av("#a\n#b\na**3", "eval") + av("\n\na**3","eval") + av("\n \na**3","eval") + av("#a\n#b\na**3","eval") av("\n\na = 1\n\n") av("\n\nif 1: a=1\n\n") @@ -120,9 +84,9 @@ def test_valid(self): av("if 1:\n pass\n if 1:\n pass\n else:\n pass\n") av("#a\n\n \na=3\n\n") - av("\n\na**3", "eval") - av("\n \na**3", "eval") - av("#a\n#b\na**3", "eval") + av("\n\na**3","eval") + av("\n \na**3","eval") + av("#a\n#b\na**3","eval") av("def f():\n try: pass\n finally: [x for x in (1,2)]\n") av("def f():\n pass\n#foo\n") @@ -141,6 +105,10 @@ def test_incomplete(self): ai("a = {") ai("b + {") + ai("print([1,\n2,") + ai("print({1:1,\n2:3,") + ai("print((1,\n2,") + ai("if 9==3:\n pass\nelse:") ai("if 9==3:\n pass\nelse:\n") ai("if 9==3:\n pass\nelse:\n pass") @@ -163,13 +131,12 @@ def test_incomplete(self): ai("a = 'a\\") ai("a = '''xy") - ai("", "eval") - ai("\n", "eval") - ai("(", "eval") - ai("(\n\n\n", "eval") - ai("(9+", "eval") - ai("9+ \\", "eval") - ai("lambda z: \\", "eval") + ai("","eval") + ai("\n","eval") + ai("(","eval") + ai("(9+","eval") + ai("9+ \\","eval") + ai("lambda z: \\","eval") ai("if True:\n if True:\n if True: \n") @@ -277,14 +244,13 @@ def test_invalid(self): ai("a = 'a\\ ") ai("a = 'a\\\n") - ai("a = 1", "eval") - ai("a = (", "eval") - ai("]", "eval") - ai("())", "eval") - ai("[}", "eval") - ai("9+", "eval") - ai("lambda z:", "eval") - ai("a b", "eval") + ai("a = 1","eval") + ai("]","eval") + ai("())","eval") + ai("[}","eval") + ai("9+","eval") + ai("lambda z:","eval") + ai("a b","eval") ai("return 2.3") ai("if (a == 1 and b = 2): pass") @@ -314,11 +280,11 @@ def test_filename(self): # TODO: RUSTPYTHON @unittest.expectedFailure def test_warning(self): - # Teswarnings_helper.check_warningsonly returned once. + # Test that the warning is only returned once. with warnings_helper.check_warnings( - (".*literal", SyntaxWarning), - (".*invalid", DeprecationWarning), - ) as w: + ('"is" with \'str\' literal', SyntaxWarning), + ("invalid escape sequence", SyntaxWarning), + ) as w: compile_command(r"'\e' is 0") self.assertEqual(len(w.warnings), 2) @@ -327,6 +293,39 @@ def test_warning(self): warnings.simplefilter('error', SyntaxWarning) compile_command('1 is 1', symbol='exec') + # Check SyntaxWarning treated as an SyntaxError + with warnings.catch_warnings(), self.assertRaises(SyntaxError): + warnings.simplefilter('error', SyntaxWarning) + compile_command(r"'\e'", symbol='exec') + + def test_incomplete_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + self.assertIncomplete("'\\e' + (") + self.assertEqual(w, []) + + def test_invalid_warning(self): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + self.assertInvalid("'\\e' 1") + self.assertEqual(len(w), 1) + self.assertEqual(w[0].category, SyntaxWarning) + self.assertRegex(str(w[0].message), 'invalid escape sequence') + self.assertEqual(w[0].filename, '') + + def assertSyntaxErrorMatches(self, code, message): + with self.subTest(code): + with self.assertRaisesRegex(SyntaxError, message): + compile_command(code, symbol='exec') + + def test_syntax_errors(self): + self.assertSyntaxErrorMatches( + dedent("""\ + def foo(x,x): + pass + """), "duplicate argument 'x' in function definition") + + if __name__ == "__main__": unittest.main() From 488bac741373b2a7639c9931d5e81c52e7f5b9e4 Mon Sep 17 00:00:00 2001 From: NakanoMiku39 Date: Thu, 9 Nov 2023 15:20:52 +0800 Subject: [PATCH 152/893] Update codeop.py from CPython v3.12 --- Lib/codeop.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/Lib/codeop.py b/Lib/codeop.py index 2213b69f23..4dd096574b 100644 --- a/Lib/codeop.py +++ b/Lib/codeop.py @@ -70,8 +70,7 @@ def _maybe_compile(compiler, source, filename, symbol): return None # fallthrough - return compiler(source, filename, symbol) - + return compiler(source, filename, symbol, incomplete_input=False) def _is_syntax_error(err1, err2): rep1 = repr(err1) @@ -82,8 +81,13 @@ def _is_syntax_error(err1, err2): return True return False -def _compile(source, filename, symbol): - return compile(source, filename, symbol, PyCF_DONT_IMPLY_DEDENT | PyCF_ALLOW_INCOMPLETE_INPUT) +def _compile(source, filename, symbol, incomplete_input=True): + flags = 0 + if incomplete_input: + flags |= PyCF_ALLOW_INCOMPLETE_INPUT + flags |= PyCF_DONT_IMPLY_DEDENT + return compile(source, filename, symbol, flags) + def compile_command(source, filename="", symbol="single"): r"""Compile a command and determine whether it is incomplete. @@ -114,8 +118,12 @@ class Compile: def __init__(self): self.flags = PyCF_DONT_IMPLY_DEDENT | PyCF_ALLOW_INCOMPLETE_INPUT - def __call__(self, source, filename, symbol): - codeob = compile(source, filename, symbol, self.flags, True) + def __call__(self, source, filename, symbol, **kwargs): + flags = self.flags + if kwargs.get('incomplete_input', True) is False: + flags &= ~PyCF_DONT_IMPLY_DEDENT + flags &= ~PyCF_ALLOW_INCOMPLETE_INPUT + codeob = compile(source, filename, symbol, flags, True) for feature in _features: if codeob.co_flags & feature.compiler_flag: self.flags |= feature.compiler_flag From ee9f6ae079ea13b51679b97eaf773b6c04828785 Mon Sep 17 00:00:00 2001 From: NakanoMiku39 Date: Thu, 9 Nov 2023 15:29:35 +0800 Subject: [PATCH 153/893] Edit Lib/test/test_codeop.py cpython version: 3.12 --- Lib/test/test_codeop.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/Lib/test/test_codeop.py b/Lib/test/test_codeop.py index f39f51e824..19117fa409 100644 --- a/Lib/test/test_codeop.py +++ b/Lib/test/test_codeop.py @@ -30,6 +30,8 @@ def assertInvalid(self, str, symbol='single', is_syntax=1): except OverflowError: self.assertTrue(not is_syntax) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_valid(self): av = self.assertValid @@ -298,12 +300,15 @@ def test_warning(self): warnings.simplefilter('error', SyntaxWarning) compile_command(r"'\e'", symbol='exec') - def test_incomplete_warning(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - self.assertIncomplete("'\\e' + (") - self.assertEqual(w, []) + # TODO: RUSTPYTHON + #def test_incomplete_warning(self): + # with warnings.catch_warnings(record=True) as w: + # warnings.simplefilter('always') + # self.assertIncomplete("'\\e' + (") + # self.assertEqual(w, []) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_invalid_warning(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') From 87cf891e5062fe518944ceca70e3c0b168cfb16e Mon Sep 17 00:00:00 2001 From: Noa Date: Tue, 14 Nov 2023 21:52:47 -0600 Subject: [PATCH 154/893] Make PyMethodDef construction const (#5117) * Make PyMethodDef construction const * Remove iter_chain![] Obsolete since arrays now impl IntoIterator --- derive-impl/src/pyclass.rs | 72 +++++++++++++++++-------------------- derive-impl/src/pymodule.rs | 63 ++++++++++++++++---------------- derive-impl/src/util.rs | 7 ---- vm/src/class.rs | 7 ++-- vm/src/function/builtin.rs | 32 +++++++++++++++++ vm/src/function/method.rs | 51 ++++++++++++++++++++++++++ vm/src/function/mod.rs | 2 +- 7 files changed, 149 insertions(+), 85 deletions(-) diff --git a/derive-impl/src/pyclass.rs b/derive-impl/src/pyclass.rs index 55705c742d..8f688b366a 100644 --- a/derive-impl/src/pyclass.rs +++ b/derive-impl/src/pyclass.rs @@ -4,7 +4,7 @@ use crate::util::{ ClassItemMeta, ContentItem, ContentItemInner, ErrorVec, ExceptionItemMeta, ItemMeta, ItemMetaInner, ItemNursery, SimpleItemMeta, ALL_ALLOWED_NAMES, }; -use proc_macro2::{Span, TokenStream}; +use proc_macro2::{Delimiter, Group, Span, TokenStream, TokenTree}; use quote::{quote, quote_spanned, ToTokens}; use std::collections::{HashMap, HashSet}; use std::str::FromStr; @@ -172,14 +172,9 @@ pub(crate) fn impl_pyclass_impl(attr: AttributeArgs, item: Item) -> Result, - ) { - #method_def - } + const __OWN_METHOD_DEFS: &'static [::rustpython_vm::function::PyMethodDef] = &#method_def; }, parse_quote! { fn __extend_py_class( @@ -201,6 +196,15 @@ pub(crate) fn impl_pyclass_impl(attr: AttributeArgs, item: Item) -> Result(&[#impl_ty::__OWN_METHOD_DEFS, #(#with_method_defs,)*]) + ) + }; quote! { #imp impl ::rustpython_vm::class::PyClassImpl for #payload_ty { @@ -214,13 +218,7 @@ pub(crate) fn impl_pyclass_impl(attr: AttributeArgs, item: Item) -> Result, - ) { - #impl_ty::__extend_method_def(method_defs); - #with_method_defs - } + const METHOD_DEFS: &'static [::rustpython_vm::function::PyMethodDef] = &#method_defs; fn extend_slots(slots: &mut ::rustpython_vm::types::PyTypeSlots) { #impl_ty::__extend_slots(slots); @@ -268,14 +266,9 @@ pub(crate) fn impl_pyclass_impl(attr: AttributeArgs, item: Item) -> Result, - ) { - #method_def - } + const __OWN_METHOD_DEFS: &'static [::rustpython_vm::function::PyMethodDef] = &#method_def; }, parse_quote! { fn __extend_py_class( @@ -983,6 +976,7 @@ impl MethodNursery { impl ToTokens for MethodNursery { fn to_tokens(&self, tokens: &mut TokenStream) { + let mut inner_tokens = TokenStream::new(); for item in &self.items { let py_name = &item.py_name; let ident = &item.ident; @@ -1011,16 +1005,18 @@ impl ToTokens for MethodNursery { // } else { // quote_spanned! { ident.span() => #py_name } // }; - tokens.extend(quote! { + inner_tokens.extend(quote! [ #(#cfgs)* - method_defs.push(rustpython_vm::function::PyMethodDef { - name: #py_name, - func: rustpython_vm::function::IntoPyNativeFn::into_func(Self::#ident), - flags: #flags, - doc: #doc, - }); - }); + rustpython_vm::function::PyMethodDef::new_const( + #py_name, + Self::#ident, + #flags, + #doc, + ), + ]); } + let array: TokenTree = Group::new(Delimiter::Bracket, inner_tokens).into(); + tokens.extend([array]); } } @@ -1436,7 +1432,7 @@ struct ExtractedImplAttrs { payload: Option, flags: TokenStream, with_impl: TokenStream, - with_method_defs: TokenStream, + with_method_defs: Vec, with_slots: TokenStream, } @@ -1465,18 +1461,18 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result::__extend_py_class), - quote!(#path::::__extend_method_def), + quote!(#path::::__OWN_METHOD_DEFS), quote!(#path::::__extend_slots), ) } else { ( quote!(::__extend_py_class), - quote!(::__extend_method_def), + quote!(::__OWN_METHOD_DEFS), quote!(::__extend_slots), ) }; @@ -1484,9 +1480,7 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result #extend_class(ctx, class); }); - with_method_defs.push(quote_spanned! { path.span() => - #extend_method_def(method_defs); - }); + with_method_defs.push(method_defs); with_slots.push(quote_spanned! { item_span => #extend_slots(slots); }); @@ -1530,9 +1524,7 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result Result Result &'static ::rustpython_vm::builtins::PyModuleDef { DEF.get_or_init(|| { - #[allow(clippy::ptr_arg)] - let method_defs = { - let mut method_defs = Vec::new(); - extend_method_def(ctx, &mut method_defs); - method_defs - }; let mut def = ::rustpython_vm::builtins::PyModuleDef { name: ctx.intern_str(MODULE_NAME), doc: DOC.map(|doc| ctx.intern_str(doc)), - methods: Box::leak(method_defs.into_boxed_slice()), + methods: METHOD_DEFS, slots: Default::default(), }; def.slots.exec = Some(extend_module); @@ -161,23 +155,24 @@ pub fn impl_pymodule(attr: AttributeArgs, module_item: Item) -> Result(&[#(super::#withs::METHOD_DEFS,)* OWN_METHODS]) + }) + }; + items.extend([ parse_quote! { ::rustpython_vm::common::static_cell! { pub(crate) static DEF: ::rustpython_vm::builtins::PyModuleDef; } }, parse_quote! { - #[allow(clippy::ptr_arg)] - pub(crate) fn extend_method_def( - ctx: &::rustpython_vm::Context, - method_defs: &mut Vec<::rustpython_vm::function::PyMethodDef>, - ) { - #( - super::#withs::extend_method_def(ctx, method_defs); - )* - #function_items - } + pub(crate) const METHOD_DEFS: &'static [::rustpython_vm::function::PyMethodDef] = &#method_defs; }, parse_quote! { pub(crate) fn __init_attributes( @@ -361,26 +356,30 @@ struct ValidatedFunctionNursery(FunctionNursery); impl ToTokens for ValidatedFunctionNursery { fn to_tokens(&self, tokens: &mut TokenStream) { + let mut inner_tokens = TokenStream::new(); + let flags = quote! { rustpython_vm::function::PyMethodFlags::empty() }; for item in &self.0.items { let ident = &item.ident; let cfgs = &item.cfgs; + let cfgs = quote!(#(#cfgs)*); let py_names = &item.py_names; let doc = &item.doc; - let flags = quote! { rustpython_vm::function::PyMethodFlags::empty() }; - - tokens.extend(quote! { - #(#cfgs)* - { - let doc = Some(#doc); - #(method_defs.push(rustpython_vm::function::PyMethodDef::new( - (#py_names), + let doc = quote!(Some(#doc)); + + inner_tokens.extend(quote![ + #( + #cfgs + rustpython_vm::function::PyMethodDef::new_const( + #py_names, #ident, #flags, - doc, - ));)* - } - }); + #doc, + ), + )* + ]); } + let array: TokenTree = Group::new(Delimiter::Bracket, inner_tokens).into(); + tokens.extend([array]); } } diff --git a/derive-impl/src/util.rs b/derive-impl/src/util.rs index 9f827f9e9a..916b19db06 100644 --- a/derive-impl/src/util.rs +++ b/derive-impl/src/util.rs @@ -639,13 +639,6 @@ impl ErrorVec for Vec { } } -macro_rules! iter_chain { - ($($it:expr),*$(,)?) => { - ::std::iter::empty() - $(.chain(::std::iter::once($it)))* - }; -} - pub(crate) fn iter_use_idents<'a, F, R: 'a>(item_use: &'a syn::ItemUse, mut f: F) -> Result> where F: FnMut(&'a syn::Ident, bool) -> Result, diff --git a/vm/src/class.rs b/vm/src/class.rs index 8b10c5e626..a2eea21213 100644 --- a/vm/src/class.rs +++ b/vm/src/class.rs @@ -135,19 +135,16 @@ pub trait PyClassImpl: PyClassDef { } fn impl_extend_class(ctx: &Context, class: &'static Py); - fn impl_extend_method_def(method_defs: &mut Vec); + const METHOD_DEFS: &'static [PyMethodDef]; fn extend_slots(slots: &mut PyTypeSlots); fn make_slots() -> PyTypeSlots { - let mut method_defs = Vec::new(); - Self::impl_extend_method_def(&mut method_defs); - let mut slots = PyTypeSlots { flags: Self::TP_FLAGS, name: Self::TP_NAME, basicsize: Self::BASICSIZE, doc: Self::DOC, - methods: Box::leak(method_defs.into_boxed_slice()), + methods: Self::METHOD_DEFS, ..Default::default() }; diff --git a/vm/src/function/builtin.rs b/vm/src/function/builtin.rs index e37270bcc9..3faaf594fe 100644 --- a/vm/src/function/builtin.rs +++ b/vm/src/function/builtin.rs @@ -30,6 +30,7 @@ pub type PyNativeFn = py_dyn_fn!(dyn Fn(&VirtualMachine, FuncArgs) -> PyResult); /// `fn foo(f: F) where F: IntoPyNativeFn` pub trait IntoPyNativeFn: Sized + PyThreadingConstraint + 'static { fn call(&self, vm: &VirtualMachine, args: FuncArgs) -> PyResult; + /// `IntoPyNativeFn::into_func()` generates a PyNativeFn that performs the /// appropriate type and arity checking, any requested conversions, and then if /// successful calls the function with the extracted parameters. @@ -37,6 +38,36 @@ pub trait IntoPyNativeFn: Sized + PyThreadingConstraint + 'static { let boxed = Box::new(move |vm: &VirtualMachine, args| self.call(vm, args)); Box::leak(boxed) } + + /// Equivalent to `into_func()`, but accessible as a constant. This is only + /// valid if this function is zero-sized, i.e. that + /// `std::mem::size_of::() == 0`. If it isn't, use of this constant will + /// raise a compile error. + const STATIC_FUNC: &'static PyNativeFn = { + if std::mem::size_of::() == 0 { + &|vm, args| { + // SAFETY: we just confirmed that Self is zero-sized, so there + // aren't any bytes in it that could be uninit. + #[allow(clippy::uninit_assumed_init)] + let f = unsafe { std::mem::MaybeUninit::::uninit().assume_init() }; + f.call(vm, args) + } + } else { + panic!("function must be zero-sized to access STATIC_FUNC") + } + }; +} + +/// Get the [`STATIC_FUNC`](IntoPyNativeFn::STATIC_FUNC) of the passed function. The same +/// requirements of zero-sizedness apply, see that documentation for details. +#[inline(always)] +pub const fn static_func>(f: F) -> &'static PyNativeFn { + // if f is zero-sized, there's no issue forgetting it - even if a capture of f does have a Drop + // impl, it would never get called anyway. If you passed it to into_func, it would just get + // Box::leak'd, and as a 'static reference it'll never be dropped. and if f isn't zero-sized, + // we'll never reach this point anyway because we'll fail to compile. + std::mem::forget(f); + F::STATIC_FUNC } // TODO: once higher-rank trait bounds are stabilized, remove the `Kind` type @@ -186,5 +217,6 @@ mod tests { check_zst(py_func.into_func()); let empty_closure = || "foo".to_owned(); check_zst(empty_closure.into_func()); + check_zst(static_func(empty_closure)); } } diff --git a/vm/src/function/method.rs b/vm/src/function/method.rs index 5922d1b19a..fa47c16f4a 100644 --- a/vm/src/function/method.rs +++ b/vm/src/function/method.rs @@ -86,6 +86,22 @@ impl PyMethodDef { doc, } } + + #[inline] + pub const fn new_const( + name: &'static str, + func: impl IntoPyNativeFn, + flags: PyMethodFlags, + doc: Option<&'static str>, + ) -> Self { + Self { + name, + func: super::static_func(func), + flags, + doc, + } + } + pub fn to_proper_method( &'static self, class: &'static Py, @@ -192,6 +208,41 @@ impl PyMethodDef { let func = self.to_function(); PyNativeMethod { func, class }.into_ref(ctx) } + + #[doc(hidden)] + pub const fn __const_concat_arrays( + method_groups: &[&[Self]], + ) -> [Self; SUM_LEN] { + const NULL_METHOD: PyMethodDef = PyMethodDef { + name: "", + func: &|_, _| unreachable!(), + flags: PyMethodFlags::empty(), + doc: None, + }; + let mut all_methods = [NULL_METHOD; SUM_LEN]; + let mut all_idx = 0; + let mut group_idx = 0; + while group_idx < method_groups.len() { + let group = method_groups[group_idx]; + let mut method_idx = 0; + while method_idx < group.len() { + all_methods[all_idx] = group[method_idx].const_copy(); + method_idx += 1; + all_idx += 1; + } + group_idx += 1; + } + all_methods + } + + const fn const_copy(&self) -> Self { + Self { + name: self.name, + func: self.func, + flags: self.flags, + doc: self.doc, + } + } } impl std::fmt::Debug for PyMethodDef { diff --git a/vm/src/function/mod.rs b/vm/src/function/mod.rs index 1c3babd808..a0a20c9960 100644 --- a/vm/src/function/mod.rs +++ b/vm/src/function/mod.rs @@ -15,7 +15,7 @@ pub use argument::{ }; pub use arithmetic::{PyArithmeticValue, PyComparisonValue}; pub use buffer::{ArgAsciiBuffer, ArgBytesLike, ArgMemoryBuffer, ArgStrOrBytesLike}; -pub use builtin::{IntoPyNativeFn, PyNativeFn}; +pub use builtin::{static_func, IntoPyNativeFn, PyNativeFn}; pub use either::Either; pub use fspath::FsPath; pub use getset::PySetterValue; From 068249196f6877e2c778bb75056dae8e1471aad1 Mon Sep 17 00:00:00 2001 From: Noa Date: Thu, 16 Nov 2023 22:01:05 -0600 Subject: [PATCH 155/893] Use new 1.74 features (#5118) --- common/src/str.rs | 28 ++++++++++++---------------- vm/src/stdlib/os.rs | 9 +-------- 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/common/src/str.rs b/common/src/str.rs index 96c7e675dc..cdee03f14f 100644 --- a/common/src/str.rs +++ b/common/src/str.rs @@ -221,19 +221,6 @@ pub fn to_ascii(value: &str) -> AsciiString { unsafe { AsciiString::from_ascii_unchecked(ascii) } } -#[doc(hidden)] -pub const fn bytes_is_ascii(x: &str) -> bool { - let x = x.as_bytes(); - let mut i = 0; - while i < x.len() { - if !x[i].is_ascii() { - return false; - } - i += 1; - } - true -} - pub mod levenshtein { use std::{cell::RefCell, thread_local}; @@ -335,15 +322,24 @@ pub mod levenshtein { } } +/// Creates an [`AsciiStr`][ascii::AsciiStr] from a string literal, throwing a compile error if the +/// literal isn't actually ascii. +/// +/// ```compile_fail +/// # use rustpython_common::str::ascii; +/// ascii!("I ❤️ Rust & Python"); +/// ``` #[macro_export] macro_rules! ascii { ($x:literal) => {{ - const _: () = { - ["not ascii"][!$crate::str::bytes_is_ascii($x) as usize]; + const STR: &str = $x; + const _: () = if !STR.is_ascii() { + panic!("ascii!() argument is not an ascii string"); }; - unsafe { $crate::vendored::ascii::AsciiStr::from_ascii_unchecked($x.as_bytes()) } + unsafe { $crate::vendored::ascii::AsciiStr::from_ascii_unchecked(STR.as_bytes()) } }}; } +pub use ascii; #[cfg(test)] mod tests { diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index d10acf1b26..8164e5ea48 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -75,15 +75,8 @@ impl OsPath { Path::new(&self.path) } - #[cfg(any(unix, target_os = "wasi"))] pub fn into_bytes(self) -> Vec { - use rustpython_common::os::ffi::OsStrExt; - self.path.as_bytes().to_vec() - } - - #[cfg(windows)] - pub fn into_bytes(self) -> Vec { - self.path.to_string_lossy().to_string().into_bytes() + self.path.into_encoded_bytes() } pub fn into_cstring(self, vm: &VirtualMachine) -> PyResult { From f54b80f88cc049c60734d2010c084640398c0290 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 23 Nov 2023 15:40:57 +0800 Subject: [PATCH 156/893] Add Lib/test/test_bigaddrspace.py from CPython v3.12.0 --- Lib/test/test_bigaddrspace.py | 98 +++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 Lib/test/test_bigaddrspace.py diff --git a/Lib/test/test_bigaddrspace.py b/Lib/test/test_bigaddrspace.py new file mode 100644 index 0000000000..50272e9960 --- /dev/null +++ b/Lib/test/test_bigaddrspace.py @@ -0,0 +1,98 @@ +""" +These tests are meant to exercise that requests to create objects bigger +than what the address space allows are properly met with an OverflowError +(rather than crash weirdly). + +Primarily, this means 32-bit builds with at least 2 GiB of available memory. +You need to pass the -M option to regrtest (e.g. "-M 2.1G") for tests to +be enabled. +""" + +from test import support +from test.support import bigaddrspacetest, MAX_Py_ssize_t + +import unittest +import operator +import sys + + +class BytesTest(unittest.TestCase): + + @bigaddrspacetest + def test_concat(self): + # Allocate a bytestring that's near the maximum size allowed by + # the address space, and then try to build a new, larger one through + # concatenation. + try: + x = b"x" * (MAX_Py_ssize_t - 128) + self.assertRaises(OverflowError, operator.add, x, b"x" * 128) + finally: + x = None + + @bigaddrspacetest + def test_optimized_concat(self): + try: + x = b"x" * (MAX_Py_ssize_t - 128) + + with self.assertRaises(OverflowError) as cm: + # this statement used a fast path in ceval.c + x = x + b"x" * 128 + + with self.assertRaises(OverflowError) as cm: + # this statement used a fast path in ceval.c + x += b"x" * 128 + finally: + x = None + + @bigaddrspacetest + def test_repeat(self): + try: + x = b"x" * (MAX_Py_ssize_t - 128) + self.assertRaises(OverflowError, operator.mul, x, 128) + finally: + x = None + + +class StrTest(unittest.TestCase): + + unicodesize = 4 + + @bigaddrspacetest + def test_concat(self): + try: + # Create a string that would fill almost the address space + x = "x" * int(MAX_Py_ssize_t // (1.1 * self.unicodesize)) + # Unicode objects trigger MemoryError in case an operation that's + # going to cause a size overflow is executed + self.assertRaises(MemoryError, operator.add, x, x) + finally: + x = None + + @bigaddrspacetest + def test_optimized_concat(self): + try: + x = "x" * int(MAX_Py_ssize_t // (1.1 * self.unicodesize)) + + with self.assertRaises(MemoryError) as cm: + # this statement uses a fast path in ceval.c + x = x + x + + with self.assertRaises(MemoryError) as cm: + # this statement uses a fast path in ceval.c + x += x + finally: + x = None + + @bigaddrspacetest + def test_repeat(self): + try: + x = "x" * int(MAX_Py_ssize_t // (1.1 * self.unicodesize)) + self.assertRaises(MemoryError, operator.mul, x, 2) + finally: + x = None + + +if __name__ == '__main__': + if len(sys.argv) > 1: + support.set_memlimit(sys.argv[1]) + unittest.main() From 5303b33c8b7a6566773a68c1304f6d346b03b82b Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 23 Nov 2023 15:42:40 +0800 Subject: [PATCH 157/893] Update Lib/test/test_bigmem.py from CPython v3.12.0 --- Lib/test/test_bigmem.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/Lib/test/test_bigmem.py b/Lib/test/test_bigmem.py index 0a4c141903..e360ec15a8 100644 --- a/Lib/test/test_bigmem.py +++ b/Lib/test/test_bigmem.py @@ -1275,6 +1275,15 @@ def test_sort(self, size): self.assertEqual(l[-10:], [5] * 10) +class DictTest(unittest.TestCase): + + @bigmemtest(size=357913941, memuse=160) + def test_dict(self, size): + # https://github.com/python/cpython/issues/102701 + d = dict.fromkeys(range(size)) + d[size] = 1 + + if __name__ == '__main__': if len(sys.argv) > 1: support.set_memlimit(sys.argv[1]) From 81208637f52cd7686e1dd213a4bb718e2c8eed12 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 23 Nov 2023 15:51:02 +0800 Subject: [PATCH 158/893] Update Lib/test/test_bool.py from CPython v3.12.0 --- Lib/test/test_bool.py | 44 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_bool.py b/Lib/test/test_bool.py index 6da00aacd6..241db7738d 100644 --- a/Lib/test/test_bool.py +++ b/Lib/test/test_bool.py @@ -42,6 +42,12 @@ def test_float(self): self.assertEqual(float(True), 1.0) self.assertIsNot(float(True), True) + def test_complex(self): + self.assertEqual(complex(False), 0j) + self.assertEqual(complex(False), False) + self.assertEqual(complex(True), 1+0j) + self.assertEqual(complex(True), True) + def test_math(self): self.assertEqual(+False, 0) self.assertIsNot(+False, False) @@ -54,8 +60,22 @@ def test_math(self): self.assertEqual(-True, -1) self.assertEqual(abs(True), 1) self.assertIsNot(abs(True), True) - self.assertEqual(~False, -1) - self.assertEqual(~True, -2) + with self.assertWarns(DeprecationWarning): + # We need to put the bool in a variable, because the constant + # ~False is evaluated at compile time due to constant folding; + # consequently the DeprecationWarning would be issued during + # module loading and not during test execution. + false = False + self.assertEqual(~false, -1) + with self.assertWarns(DeprecationWarning): + # also check that the warning is issued in case of constant + # folding at compile time + self.assertEqual(eval("~False"), -1) + with self.assertWarns(DeprecationWarning): + true = True + self.assertEqual(~true, -2) + with self.assertWarns(DeprecationWarning): + self.assertEqual(eval("~True"), -2) self.assertEqual(False+2, 2) self.assertEqual(True+2, 3) @@ -315,6 +335,26 @@ def __len__(self): return -1 self.assertRaises(ValueError, bool, Eggs()) + def test_interpreter_convert_to_bool_raises(self): + class SymbolicBool: + def __bool__(self): + raise TypeError + + class Symbol: + def __gt__(self, other): + return SymbolicBool() + + x = Symbol() + + with self.assertRaises(TypeError): + if x > 0: + msg = "x > 0 was true" + else: + msg = "x > 0 was false" + + # This used to create negative refcounts, see gh-102250 + del x + def test_from_bytes(self): self.assertIs(bool.from_bytes(b'\x00'*8, 'big'), False) self.assertIs(bool.from_bytes(b'abcd', 'little'), True) From 7833fd9b12fa5642f60ad053b8c3aa308febf96e Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 23 Nov 2023 15:53:48 +0800 Subject: [PATCH 159/893] Edit Lib/test/test_bool.py --- Lib/test/test_bool.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_bool.py b/Lib/test/test_bool.py index 241db7738d..3e83e4aceb 100644 --- a/Lib/test/test_bool.py +++ b/Lib/test/test_bool.py @@ -48,6 +48,8 @@ def test_complex(self): self.assertEqual(complex(True), 1+0j) self.assertEqual(complex(True), True) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_math(self): self.assertEqual(+False, 0) self.assertIsNot(+False, False) From b0ebe58636fbe64cc3fdd8f5af02e36bbba5d738 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 23 Nov 2023 15:57:39 +0800 Subject: [PATCH 160/893] Update Lib/test/test_bufio.py from CPython v3.12.0 --- Lib/test/test_bufio.py | 1 - 1 file changed, 1 deletion(-) diff --git a/Lib/test/test_bufio.py b/Lib/test/test_bufio.py index 3471351c45..989d8cd349 100644 --- a/Lib/test/test_bufio.py +++ b/Lib/test/test_bufio.py @@ -1,5 +1,4 @@ import unittest -from test import support from test.support import os_helper import io # C implementation. From ac21576ab98d39ddaa31ddb94694f65ace21b631 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 23 Nov 2023 16:19:42 +0800 Subject: [PATCH 161/893] Add Lib/test/test_c_locale_coercion.py from CPython v3.12.0 --- Lib/test/test_c_locale_coercion.py | 435 +++++++++++++++++++++++++++++ 1 file changed, 435 insertions(+) create mode 100644 Lib/test/test_c_locale_coercion.py diff --git a/Lib/test/test_c_locale_coercion.py b/Lib/test/test_c_locale_coercion.py new file mode 100644 index 0000000000..71f934756e --- /dev/null +++ b/Lib/test/test_c_locale_coercion.py @@ -0,0 +1,435 @@ +# Tests the attempted automatic coercion of the C locale to a UTF-8 locale + +import locale +import os +import subprocess +import sys +import sysconfig +import unittest +from collections import namedtuple + +from test import support +from test.support.script_helper import run_python_until_end + + +# Set the list of ways we expect to be able to ask for the "C" locale +EXPECTED_C_LOCALE_EQUIVALENTS = ["C", "invalid.ascii"] + +# Set our expectation for the default encoding used in the C locale +# for the filesystem encoding and the standard streams +EXPECTED_C_LOCALE_STREAM_ENCODING = "ascii" +EXPECTED_C_LOCALE_FS_ENCODING = "ascii" + +# Set our expectation for the default locale used when none is specified +EXPECT_COERCION_IN_DEFAULT_LOCALE = True + +TARGET_LOCALES = ["C.UTF-8", "C.utf8", "UTF-8"] + +# Apply some platform dependent overrides +if sys.platform.startswith("linux"): + if support.is_android: + # Android defaults to using UTF-8 for all system interfaces + EXPECTED_C_LOCALE_STREAM_ENCODING = "utf-8" + EXPECTED_C_LOCALE_FS_ENCODING = "utf-8" + else: + # Linux distros typically alias the POSIX locale directly to the C + # locale. + # TODO: Once https://bugs.python.org/issue30672 is addressed, we'll be + # able to check this case unconditionally + EXPECTED_C_LOCALE_EQUIVALENTS.append("POSIX") +elif sys.platform.startswith("aix"): + # AIX uses iso8859-1 in the C locale, other *nix platforms use ASCII + EXPECTED_C_LOCALE_STREAM_ENCODING = "iso8859-1" + EXPECTED_C_LOCALE_FS_ENCODING = "iso8859-1" +elif sys.platform == "darwin": + # FS encoding is UTF-8 on macOS + EXPECTED_C_LOCALE_FS_ENCODING = "utf-8" +elif sys.platform == "cygwin": + # Cygwin defaults to using C.UTF-8 + # TODO: Work out a robust dynamic test for this that doesn't rely on + # CPython's own locale handling machinery + EXPECT_COERCION_IN_DEFAULT_LOCALE = False +elif sys.platform == "vxworks": + # VxWorks defaults to using UTF-8 for all system interfaces + EXPECTED_C_LOCALE_STREAM_ENCODING = "utf-8" + EXPECTED_C_LOCALE_FS_ENCODING = "utf-8" + +# Note that the above expectations are still wrong in some cases, such as: +# * Windows when PYTHONLEGACYWINDOWSFSENCODING is set +# * Any platform other than AIX that uses latin-1 in the C locale +# * Any Linux distro where POSIX isn't a simple alias for the C locale +# * Any Linux distro where the default locale is something other than "C" +# +# Options for dealing with this: +# * Don't set the PY_COERCE_C_LOCALE preprocessor definition on +# such platforms (e.g. it isn't set on Windows) +# * Fix the test expectations to match the actual platform behaviour + +# In order to get the warning messages to match up as expected, the candidate +# order here must much the target locale order in Python/pylifecycle.c +_C_UTF8_LOCALES = ("C.UTF-8", "C.utf8", "UTF-8") + +# There's no reliable cross-platform way of checking locale alias +# lists, so the only way of knowing which of these locales will work +# is to try them with locale.setlocale(). We do that in a subprocess +# in setUpModule() below to avoid altering the locale of the test runner. +# +# If the relevant locale module attributes exist, and we're not on a platform +# where we expect it to always succeed, we also check that +# `locale.nl_langinfo(locale.CODESET)` works, as if it fails, the interpreter +# will skip locale coercion for that particular target locale +_check_nl_langinfo_CODESET = bool( + sys.platform not in ("darwin", "linux") and + hasattr(locale, "nl_langinfo") and + hasattr(locale, "CODESET") +) + +def _set_locale_in_subprocess(locale_name): + cmd_fmt = "import locale; print(locale.setlocale(locale.LC_CTYPE, '{}'))" + if _check_nl_langinfo_CODESET: + # If there's no valid CODESET, we expect coercion to be skipped + cmd_fmt += "; import sys; sys.exit(not locale.nl_langinfo(locale.CODESET))" + cmd = cmd_fmt.format(locale_name) + result, py_cmd = run_python_until_end("-c", cmd, PYTHONCOERCECLOCALE='') + return result.rc == 0 + + + +_fields = "fsencoding stdin_info stdout_info stderr_info lang lc_ctype lc_all" +_EncodingDetails = namedtuple("EncodingDetails", _fields) + +class EncodingDetails(_EncodingDetails): + # XXX (ncoghlan): Using JSON for child state reporting may be less fragile + CHILD_PROCESS_SCRIPT = ";".join([ + "import sys, os", + "print(sys.getfilesystemencoding())", + "print(sys.stdin.encoding + ':' + sys.stdin.errors)", + "print(sys.stdout.encoding + ':' + sys.stdout.errors)", + "print(sys.stderr.encoding + ':' + sys.stderr.errors)", + "print(os.environ.get('LANG', 'not set'))", + "print(os.environ.get('LC_CTYPE', 'not set'))", + "print(os.environ.get('LC_ALL', 'not set'))", + ]) + + @classmethod + def get_expected_details(cls, coercion_expected, fs_encoding, stream_encoding, env_vars): + """Returns expected child process details for a given encoding""" + _stream = stream_encoding + ":{}" + # stdin and stdout should use surrogateescape either because the + # coercion triggered, or because the C locale was detected + stream_info = 2*[_stream.format("surrogateescape")] + # stderr should always use backslashreplace + stream_info.append(_stream.format("backslashreplace")) + expected_lang = env_vars.get("LANG", "not set") + if coercion_expected: + expected_lc_ctype = CLI_COERCION_TARGET + else: + expected_lc_ctype = env_vars.get("LC_CTYPE", "not set") + expected_lc_all = env_vars.get("LC_ALL", "not set") + env_info = expected_lang, expected_lc_ctype, expected_lc_all + return dict(cls(fs_encoding, *stream_info, *env_info)._asdict()) + + @classmethod + def get_child_details(cls, env_vars): + """Retrieves fsencoding and standard stream details from a child process + + Returns (encoding_details, stderr_lines): + + - encoding_details: EncodingDetails for eager decoding + - stderr_lines: result of calling splitlines() on the stderr output + + The child is run in isolated mode if the current interpreter supports + that. + """ + result, py_cmd = run_python_until_end( + "-X", "utf8=0", "-c", cls.CHILD_PROCESS_SCRIPT, + **env_vars + ) + if not result.rc == 0: + result.fail(py_cmd) + # All subprocess outputs in this test case should be pure ASCII + stdout_lines = result.out.decode("ascii").splitlines() + child_encoding_details = dict(cls(*stdout_lines)._asdict()) + stderr_lines = result.err.decode("ascii").rstrip().splitlines() + return child_encoding_details, stderr_lines + + +# Details of the shared library warning emitted at runtime +LEGACY_LOCALE_WARNING = ( + "Python runtime initialized with LC_CTYPE=C (a locale with default ASCII " + "encoding), which may cause Unicode compatibility problems. Using C.UTF-8, " + "C.utf8, or UTF-8 (if available) as alternative Unicode-compatible " + "locales is recommended." +) + +# Details of the CLI locale coercion warning emitted at runtime +CLI_COERCION_WARNING_FMT = ( + "Python detected LC_CTYPE=C: LC_CTYPE coerced to {} (set another locale " + "or PYTHONCOERCECLOCALE=0 to disable this locale coercion behavior)." +) + + +AVAILABLE_TARGETS = None +CLI_COERCION_TARGET = None +CLI_COERCION_WARNING = None + +def setUpModule(): + global AVAILABLE_TARGETS + global CLI_COERCION_TARGET + global CLI_COERCION_WARNING + + if AVAILABLE_TARGETS is not None: + # initialization already done + return + AVAILABLE_TARGETS = [] + + # Find the target locales available in the current system + for target_locale in _C_UTF8_LOCALES: + if _set_locale_in_subprocess(target_locale): + AVAILABLE_TARGETS.append(target_locale) + + if AVAILABLE_TARGETS: + # Coercion is expected to use the first available target locale + CLI_COERCION_TARGET = AVAILABLE_TARGETS[0] + CLI_COERCION_WARNING = CLI_COERCION_WARNING_FMT.format(CLI_COERCION_TARGET) + + if support.verbose: + print(f"AVAILABLE_TARGETS = {AVAILABLE_TARGETS!r}") + print(f"EXPECTED_C_LOCALE_EQUIVALENTS = {EXPECTED_C_LOCALE_EQUIVALENTS!r}") + print(f"EXPECTED_C_LOCALE_STREAM_ENCODING = {EXPECTED_C_LOCALE_STREAM_ENCODING!r}") + print(f"EXPECTED_C_LOCALE_FS_ENCODING = {EXPECTED_C_LOCALE_FS_ENCODING!r}") + print(f"EXPECT_COERCION_IN_DEFAULT_LOCALE = {EXPECT_COERCION_IN_DEFAULT_LOCALE!r}") + print(f"_C_UTF8_LOCALES = {_C_UTF8_LOCALES!r}") + print(f"_check_nl_langinfo_CODESET = {_check_nl_langinfo_CODESET!r}") + + +class _LocaleHandlingTestCase(unittest.TestCase): + # Base class to check expected locale handling behaviour + + def _check_child_encoding_details(self, + env_vars, + expected_fs_encoding, + expected_stream_encoding, + expected_warnings, + coercion_expected): + """Check the C locale handling for the given process environment + + Parameters: + expected_fs_encoding: expected sys.getfilesystemencoding() result + expected_stream_encoding: expected encoding for standard streams + expected_warning: stderr output to expect (if any) + """ + result = EncodingDetails.get_child_details(env_vars) + encoding_details, stderr_lines = result + expected_details = EncodingDetails.get_expected_details( + coercion_expected, + expected_fs_encoding, + expected_stream_encoding, + env_vars + ) + self.assertEqual(encoding_details, expected_details) + if expected_warnings is None: + expected_warnings = [] + self.assertEqual(stderr_lines, expected_warnings) + + +class LocaleConfigurationTests(_LocaleHandlingTestCase): + # Test explicit external configuration via the process environment + + @classmethod + def setUpClass(cls): + # This relies on setUpModule() having been run, so it can't be + # handled via the @unittest.skipUnless decorator + if not AVAILABLE_TARGETS: + raise unittest.SkipTest("No C-with-UTF-8 locale available") + + def test_external_target_locale_configuration(self): + + # Explicitly setting a target locale should give the same behaviour as + # is seen when implicitly coercing to that target locale + self.maxDiff = None + + expected_fs_encoding = "utf-8" + expected_stream_encoding = "utf-8" + + base_var_dict = { + "LANG": "", + "LC_CTYPE": "", + "LC_ALL": "", + "PYTHONCOERCECLOCALE": "", + } + for env_var in ("LANG", "LC_CTYPE"): + for locale_to_set in AVAILABLE_TARGETS: + # XXX (ncoghlan): LANG=UTF-8 doesn't appear to work as + # expected, so skip that combination for now + # See https://bugs.python.org/issue30672 for discussion + if env_var == "LANG" and locale_to_set == "UTF-8": + continue + + with self.subTest(env_var=env_var, + configured_locale=locale_to_set): + var_dict = base_var_dict.copy() + var_dict[env_var] = locale_to_set + self._check_child_encoding_details(var_dict, + expected_fs_encoding, + expected_stream_encoding, + expected_warnings=None, + coercion_expected=False) + + + +@support.cpython_only +@unittest.skipUnless(sysconfig.get_config_var("PY_COERCE_C_LOCALE"), + "C locale coercion disabled at build time") +class LocaleCoercionTests(_LocaleHandlingTestCase): + # Test implicit reconfiguration of the environment during CLI startup + + def _check_c_locale_coercion(self, + fs_encoding, stream_encoding, + coerce_c_locale, + expected_warnings=None, + coercion_expected=True, + **extra_vars): + """Check the C locale handling for various configurations + + Parameters: + fs_encoding: expected sys.getfilesystemencoding() result + stream_encoding: expected encoding for standard streams + coerce_c_locale: setting to use for PYTHONCOERCECLOCALE + None: don't set the variable at all + str: the value set in the child's environment + expected_warnings: expected warning lines on stderr + extra_vars: additional environment variables to set in subprocess + """ + self.maxDiff = None + + if not AVAILABLE_TARGETS: + # Locale coercion is disabled when there aren't any target locales + fs_encoding = EXPECTED_C_LOCALE_FS_ENCODING + stream_encoding = EXPECTED_C_LOCALE_STREAM_ENCODING + coercion_expected = False + if expected_warnings: + expected_warnings = [LEGACY_LOCALE_WARNING] + + base_var_dict = { + "LANG": "", + "LC_CTYPE": "", + "LC_ALL": "", + "PYTHONCOERCECLOCALE": "", + } + base_var_dict.update(extra_vars) + if coerce_c_locale is not None: + base_var_dict["PYTHONCOERCECLOCALE"] = coerce_c_locale + + # Check behaviour for the default locale + with self.subTest(default_locale=True, + PYTHONCOERCECLOCALE=coerce_c_locale): + if EXPECT_COERCION_IN_DEFAULT_LOCALE: + _expected_warnings = expected_warnings + _coercion_expected = coercion_expected + else: + _expected_warnings = None + _coercion_expected = False + # On Android CLI_COERCION_WARNING is not printed when all the + # locale environment variables are undefined or empty. When + # this code path is run with environ['LC_ALL'] == 'C', then + # LEGACY_LOCALE_WARNING is printed. + if (support.is_android and + _expected_warnings == [CLI_COERCION_WARNING]): + _expected_warnings = None + self._check_child_encoding_details(base_var_dict, + fs_encoding, + stream_encoding, + _expected_warnings, + _coercion_expected) + + # Check behaviour for explicitly configured locales + for locale_to_set in EXPECTED_C_LOCALE_EQUIVALENTS: + for env_var in ("LANG", "LC_CTYPE"): + with self.subTest(env_var=env_var, + nominal_locale=locale_to_set, + PYTHONCOERCECLOCALE=coerce_c_locale): + var_dict = base_var_dict.copy() + var_dict[env_var] = locale_to_set + # Check behaviour on successful coercion + self._check_child_encoding_details(var_dict, + fs_encoding, + stream_encoding, + expected_warnings, + coercion_expected) + + def test_PYTHONCOERCECLOCALE_not_set(self): + # This should coerce to the first available target locale by default + self._check_c_locale_coercion("utf-8", "utf-8", coerce_c_locale=None) + + def test_PYTHONCOERCECLOCALE_not_zero(self): + # *Any* string other than "0" is considered "set" for our purposes + # and hence should result in the locale coercion being enabled + for setting in ("", "1", "true", "false"): + self._check_c_locale_coercion("utf-8", "utf-8", coerce_c_locale=setting) + + def test_PYTHONCOERCECLOCALE_set_to_warn(self): + # PYTHONCOERCECLOCALE=warn enables runtime warnings for legacy locales + self._check_c_locale_coercion("utf-8", "utf-8", + coerce_c_locale="warn", + expected_warnings=[CLI_COERCION_WARNING]) + + + def test_PYTHONCOERCECLOCALE_set_to_zero(self): + # The setting "0" should result in the locale coercion being disabled + self._check_c_locale_coercion(EXPECTED_C_LOCALE_FS_ENCODING, + EXPECTED_C_LOCALE_STREAM_ENCODING, + coerce_c_locale="0", + coercion_expected=False) + # Setting LC_ALL=C shouldn't make any difference to the behaviour + self._check_c_locale_coercion(EXPECTED_C_LOCALE_FS_ENCODING, + EXPECTED_C_LOCALE_STREAM_ENCODING, + coerce_c_locale="0", + LC_ALL="C", + coercion_expected=False) + + def test_LC_ALL_set_to_C(self): + # Setting LC_ALL should render the locale coercion ineffective + self._check_c_locale_coercion(EXPECTED_C_LOCALE_FS_ENCODING, + EXPECTED_C_LOCALE_STREAM_ENCODING, + coerce_c_locale=None, + LC_ALL="C", + coercion_expected=False) + # And result in a warning about a lack of locale compatibility + self._check_c_locale_coercion(EXPECTED_C_LOCALE_FS_ENCODING, + EXPECTED_C_LOCALE_STREAM_ENCODING, + coerce_c_locale="warn", + LC_ALL="C", + expected_warnings=[LEGACY_LOCALE_WARNING], + coercion_expected=False) + + def test_PYTHONCOERCECLOCALE_set_to_one(self): + # skip the test if the LC_CTYPE locale is C or coerced + old_loc = locale.setlocale(locale.LC_CTYPE, None) + self.addCleanup(locale.setlocale, locale.LC_CTYPE, old_loc) + try: + loc = locale.setlocale(locale.LC_CTYPE, "") + except locale.Error as e: + self.skipTest(str(e)) + if loc == "C": + self.skipTest("test requires LC_CTYPE locale different than C") + if loc in TARGET_LOCALES : + self.skipTest("coerced LC_CTYPE locale: %s" % loc) + + # bpo-35336: PYTHONCOERCECLOCALE=1 must not coerce the LC_CTYPE locale + # if it's not equal to "C" + code = 'import locale; print(locale.setlocale(locale.LC_CTYPE, None))' + env = dict(os.environ, PYTHONCOERCECLOCALE='1') + cmd = subprocess.run([sys.executable, '-c', code], + stdout=subprocess.PIPE, + env=env, + text=True) + self.assertEqual(cmd.stdout.rstrip(), loc) + + +def tearDownModule(): + support.reap_children() + + +if __name__ == "__main__": + unittest.main() From 68dc5ca0520085ee8cd3be5d2f255b705f7f35ed Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 23 Nov 2023 16:21:42 +0800 Subject: [PATCH 162/893] Edit Lib/test/test_c_locale_coercion.py ExpectedFailure added at line 247 --- Lib/test/test_c_locale_coercion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_c_locale_coercion.py b/Lib/test/test_c_locale_coercion.py index 71f934756e..818dc16b83 100644 --- a/Lib/test/test_c_locale_coercion.py +++ b/Lib/test/test_c_locale_coercion.py @@ -243,6 +243,8 @@ def setUpClass(cls): if not AVAILABLE_TARGETS: raise unittest.SkipTest("No C-with-UTF-8 locale available") + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_external_target_locale_configuration(self): # Explicitly setting a target locale should give the same behaviour as From e06f5ccfe4177afb8f4de966bfe37b23059ba399 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 23 Nov 2023 16:31:27 +0800 Subject: [PATCH 163/893] Update Lib/test/test_class.py from CPython v3.12.0 --- Lib/test/test_class.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/Lib/test/test_class.py b/Lib/test/test_class.py index 48897d732f..a6c99fbddf 100644 --- a/Lib/test/test_class.py +++ b/Lib/test/test_class.py @@ -445,6 +445,20 @@ def __delattr__(self, *args): del testme.cardinal self.assertCallStack([('__delattr__', (testme, "cardinal"))]) + def testHasAttrString(self): + import sys + from test.support import import_helper + _testcapi = import_helper.import_module('_testcapi') + + class A: + def __init__(self): + self.attr = 1 + + a = A() + self.assertEqual(_testcapi.object_hasattrstring(a, b"attr"), 1) + self.assertEqual(_testcapi.object_hasattrstring(a, b"noattr"), 0) + self.assertIsNone(sys.exception()) + def testDel(self): x = [] From c8256c5450edf30ee4c6be98d403c0c152bcc0c8 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 23 Nov 2023 16:36:59 +0800 Subject: [PATCH 164/893] Update Lib/test/test_cmath.py from CPython v3.12.0 --- Lib/test/test_cmath.py | 51 +++++++++++++++--------------------------- 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/Lib/test/test_cmath.py b/Lib/test/test_cmath.py index 8abbda2d87..1fc7211834 100644 --- a/Lib/test/test_cmath.py +++ b/Lib/test/test_cmath.py @@ -166,6 +166,11 @@ def test_infinity_and_nan_constants(self): self.assertEqual(cmath.nan.imag, 0.0) self.assertEqual(cmath.nanj.real, 0.0) self.assertTrue(math.isnan(cmath.nanj.imag)) + # Also check that the sign of all of these is positive: + self.assertEqual(math.copysign(1., cmath.nan.real), 1.) + self.assertEqual(math.copysign(1., cmath.nan.imag), 1.) + self.assertEqual(math.copysign(1., cmath.nanj.real), 1.) + self.assertEqual(math.copysign(1., cmath.nanj.imag), 1.) # Check consistency with reprs. self.assertEqual(repr(cmath.inf), "inf") @@ -192,14 +197,7 @@ def test_user_object(self): # end up being passed to the cmath functions # usual case: new-style class implementing __complex__ - class MyComplex(object): - def __init__(self, value): - self.value = value - def __complex__(self): - return self.value - - # old-style class implementing __complex__ - class MyComplexOS: + class MyComplex: def __init__(self, value): self.value = value def __complex__(self): @@ -208,18 +206,13 @@ def __complex__(self): # classes for which __complex__ raises an exception class SomeException(Exception): pass - class MyComplexException(object): - def __complex__(self): - raise SomeException - class MyComplexExceptionOS: + class MyComplexException: def __complex__(self): raise SomeException # some classes not providing __float__ or __complex__ class NeitherComplexNorFloat(object): pass - class NeitherComplexNorFloatOS: - pass class Index: def __int__(self): return 2 def __index__(self): return 2 @@ -228,48 +221,32 @@ def __int__(self): return 2 # other possible combinations of __float__ and __complex__ # that should work - class FloatAndComplex(object): + class FloatAndComplex: def __float__(self): return flt_arg def __complex__(self): return cx_arg - class FloatAndComplexOS: - def __float__(self): - return flt_arg - def __complex__(self): - return cx_arg - class JustFloat(object): - def __float__(self): - return flt_arg - class JustFloatOS: + class JustFloat: def __float__(self): return flt_arg for f in self.test_functions: # usual usage self.assertEqual(f(MyComplex(cx_arg)), f(cx_arg)) - self.assertEqual(f(MyComplexOS(cx_arg)), f(cx_arg)) # other combinations of __float__ and __complex__ self.assertEqual(f(FloatAndComplex()), f(cx_arg)) - self.assertEqual(f(FloatAndComplexOS()), f(cx_arg)) self.assertEqual(f(JustFloat()), f(flt_arg)) - self.assertEqual(f(JustFloatOS()), f(flt_arg)) self.assertEqual(f(Index()), f(int(Index()))) # TypeError should be raised for classes not providing # either __complex__ or __float__, even if they provide - # __int__ or __index__. An old-style class - # currently raises AttributeError instead of a TypeError; - # this could be considered a bug. + # __int__ or __index__: self.assertRaises(TypeError, f, NeitherComplexNorFloat()) self.assertRaises(TypeError, f, MyInt()) - self.assertRaises(Exception, f, NeitherComplexNorFloatOS()) # non-complex return value from __complex__ -> TypeError for bad_complex in non_complexes: self.assertRaises(TypeError, f, MyComplex(bad_complex)) - self.assertRaises(TypeError, f, MyComplexOS(bad_complex)) # exceptions in __complex__ should be propagated correctly self.assertRaises(SomeException, f, MyComplexException()) - self.assertRaises(SomeException, f, MyComplexExceptionOS()) def test_input_type(self): # ints should be acceptable inputs to all cmath @@ -647,6 +624,14 @@ def test_complex_near_zero(self): self.assertIsClose(0.001-0.001j, 0.001+0.001j, abs_tol=2e-03) self.assertIsNotClose(0.001-0.001j, 0.001+0.001j, abs_tol=1e-03) + def test_complex_special(self): + self.assertIsNotClose(INF, INF*1j) + self.assertIsNotClose(INF*1j, INF) + self.assertIsNotClose(INF, -INF) + self.assertIsNotClose(-INF, INF) + self.assertIsNotClose(0, INF) + self.assertIsNotClose(0, INF*1j) + if __name__ == "__main__": unittest.main() From 460e9833f0289b1424a4e5f517abbb3c0ce2d2ff Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 23 Nov 2023 16:38:32 +0800 Subject: [PATCH 165/893] Edit Lib/test/test_cmath.py ExpectedFailure added at line 628 --- Lib/test/test_cmath.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_cmath.py b/Lib/test/test_cmath.py index 1fc7211834..51dd2ecf5f 100644 --- a/Lib/test/test_cmath.py +++ b/Lib/test/test_cmath.py @@ -624,6 +624,8 @@ def test_complex_near_zero(self): self.assertIsClose(0.001-0.001j, 0.001+0.001j, abs_tol=2e-03) self.assertIsNotClose(0.001-0.001j, 0.001+0.001j, abs_tol=1e-03) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_complex_special(self): self.assertIsNotClose(INF, INF*1j) self.assertIsNotClose(INF*1j, INF) From 4d6a180638f3ee5731856c894f8e58e69588f7ee Mon Sep 17 00:00:00 2001 From: ChenyG <56422760+cygao90@users.noreply.github.com> Date: Thu, 23 Nov 2023 22:06:12 +0800 Subject: [PATCH 166/893] Update ast, test_ast from CPython 3.12.0 (#5121) --- Lib/ast.py | 348 ++++++---- Lib/test/support/ast_helper.py | 43 ++ Lib/test/test_ast.py | 1100 ++++++++++++++++++++++++++------ 3 files changed, 1189 insertions(+), 302 deletions(-) create mode 100644 Lib/test/support/ast_helper.py diff --git a/Lib/ast.py b/Lib/ast.py index 4f5f982714..07044706dc 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -25,9 +25,10 @@ :license: Python License. """ import sys +import re from _ast import * from contextlib import contextmanager, nullcontext -from enum import IntEnum, auto +from enum import IntEnum, auto, _simple_enum def parse(source, filename='', mode='exec', *, @@ -40,12 +41,13 @@ def parse(source, filename='', mode='exec', *, flags = PyCF_ONLY_AST if type_comments: flags |= PyCF_TYPE_COMMENTS - if isinstance(feature_version, tuple): + if feature_version is None: + feature_version = -1 + elif isinstance(feature_version, tuple): major, minor = feature_version # Should be a 2-tuple. - assert major == 3 + if major != 3: + raise ValueError(f"Unsupported major version: {major}") feature_version = minor - elif feature_version is None: - feature_version = -1 # Else it should be an int giving the minor version for 3.x. return compile(source, filename, mode, flags, _feature_version=feature_version) @@ -292,9 +294,7 @@ def get_docstring(node, clean=True): if not(node.body and isinstance(node.body[0], Expr)): return None node = node.body[0].value - if isinstance(node, Str): - text = node.s - elif isinstance(node, Constant) and isinstance(node.value, str): + if isinstance(node, Constant) and isinstance(node.value, str): text = node.value else: return None @@ -304,28 +304,17 @@ def get_docstring(node, clean=True): return text -def _splitlines_no_ff(source): +_line_pattern = re.compile(r"(.*?(?:\r\n|\n|\r|$))") +def _splitlines_no_ff(source, maxlines=None): """Split a string into lines ignoring form feed and other chars. This mimics how the Python parser splits source code. """ - idx = 0 lines = [] - next_line = '' - while idx < len(source): - c = source[idx] - next_line += c - idx += 1 - # Keep \r\n together - if c == '\r' and idx < len(source) and source[idx] == '\n': - next_line += '\n' - idx += 1 - if c in '\r\n': - lines.append(next_line) - next_line = '' - - if next_line: - lines.append(next_line) + for lineno, match in enumerate(_line_pattern.finditer(source), 1): + if maxlines is not None and lineno > maxlines: + break + lines.append(match[0]) return lines @@ -359,7 +348,7 @@ def get_source_segment(source, node, *, padded=False): except AttributeError: return None - lines = _splitlines_no_ff(source) + lines = _splitlines_no_ff(source, maxlines=end_lineno+1) if end_lineno == lineno: return lines[lineno].encode()[col_offset:end_col_offset].decode() @@ -508,20 +497,52 @@ def generic_visit(self, node): return node +_DEPRECATED_VALUE_ALIAS_MESSAGE = ( + "{name} is deprecated and will be removed in Python {remove}; use value instead" +) +_DEPRECATED_CLASS_MESSAGE = ( + "{name} is deprecated and will be removed in Python {remove}; " + "use ast.Constant instead" +) + + # If the ast module is loaded more than once, only add deprecated methods once if not hasattr(Constant, 'n'): # The following code is for backward compatibility. # It will be removed in future. - def _getter(self): + def _n_getter(self): + """Deprecated. Use value instead.""" + import warnings + warnings._deprecated( + "Attribute n", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14) + ) + return self.value + + def _n_setter(self, value): + import warnings + warnings._deprecated( + "Attribute n", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14) + ) + self.value = value + + def _s_getter(self): """Deprecated. Use value instead.""" + import warnings + warnings._deprecated( + "Attribute s", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14) + ) return self.value - def _setter(self, value): + def _s_setter(self, value): + import warnings + warnings._deprecated( + "Attribute s", message=_DEPRECATED_VALUE_ALIAS_MESSAGE, remove=(3, 14) + ) self.value = value - Constant.n = property(_getter, _setter) - Constant.s = property(_getter, _setter) + Constant.n = property(_n_getter, _n_setter) + Constant.s = property(_s_getter, _s_setter) class _ABC(type): @@ -529,6 +550,13 @@ def __init__(cls, *args): cls.__doc__ = """Deprecated AST node class. Use ast.Constant instead""" def __instancecheck__(cls, inst): + if cls in _const_types: + import warnings + warnings._deprecated( + f"ast.{cls.__qualname__}", + message=_DEPRECATED_CLASS_MESSAGE, + remove=(3, 14) + ) if not isinstance(inst, Constant): return False if cls in _const_types: @@ -552,6 +580,10 @@ def _new(cls, *args, **kwargs): if pos < len(args): raise TypeError(f"{cls.__name__} got multiple values for argument {key!r}") if cls in _const_types: + import warnings + warnings._deprecated( + f"ast.{cls.__qualname__}", message=_DEPRECATED_CLASS_MESSAGE, remove=(3, 14) + ) return Constant(*args, **kwargs) return Constant.__new__(cls, *args, **kwargs) @@ -574,10 +606,19 @@ class Ellipsis(Constant, metaclass=_ABC): _fields = () def __new__(cls, *args, **kwargs): - if cls is Ellipsis: + if cls is _ast_Ellipsis: + import warnings + warnings._deprecated( + "ast.Ellipsis", message=_DEPRECATED_CLASS_MESSAGE, remove=(3, 14) + ) return Constant(..., *args, **kwargs) return Constant.__new__(cls, *args, **kwargs) +# Keep another reference to Ellipsis in the global namespace +# so it can be referenced in Ellipsis.__new__ +# (The original "Ellipsis" name is removed from the global namespace later on) +_ast_Ellipsis = Ellipsis + _const_types = { Num: (int, float, complex), Str: (str,), @@ -644,10 +685,12 @@ class Param(expr_context): # We unparse those infinities to INFSTR. _INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) -class _Precedence(IntEnum): +@_simple_enum(IntEnum) +class _Precedence: """Precedence table that originated from python grammar.""" - TUPLE = auto() + NAMED_EXPR = auto() # := + TUPLE = auto() # , YIELD = auto() # 'yield', 'yield from' TEST = auto() # 'if'-'else', 'lambda' OR = auto() # 'or' @@ -685,11 +728,11 @@ class _Unparser(NodeVisitor): def __init__(self, *, _avoid_backslashes=False): self._source = [] - self._buffer = [] self._precedences = {} self._type_ignores = {} self._indent = 0 self._avoid_backslashes = _avoid_backslashes + self._in_try_star = False def interleave(self, inter, f, seq): """Call f on each item in seq, calling inter() in between.""" @@ -724,18 +767,19 @@ def fill(self, text=""): self.maybe_newline() self.write(" " * self._indent + text) - def write(self, text): - """Append a piece of text""" - self._source.append(text) + def write(self, *text): + """Add new source parts""" + self._source.extend(text) - def buffer_writer(self, text): - self._buffer.append(text) + @contextmanager + def buffered(self, buffer = None): + if buffer is None: + buffer = [] - @property - def buffer(self): - value = "".join(self._buffer) - self._buffer.clear() - return value + original_source = self._source + self._source = buffer + yield buffer + self._source = original_source @contextmanager def block(self, *, extra = None): @@ -845,7 +889,7 @@ def visit_Expr(self, node): self.traverse(node.value) def visit_NamedExpr(self, node): - with self.require_parens(_Precedence.TUPLE, node): + with self.require_parens(_Precedence.NAMED_EXPR, node): self.set_precedence(_Precedence.ATOM, node.target, node.value) self.traverse(node.target) self.write(" := ") @@ -866,6 +910,7 @@ def visit_ImportFrom(self, node): def visit_Assign(self, node): self.fill() for target in node.targets: + self.set_precedence(_Precedence.TUPLE, target) self.traverse(target) self.write(" = ") self.traverse(node.value) @@ -958,7 +1003,7 @@ def visit_Raise(self, node): self.write(" from ") self.traverse(node.cause) - def visit_Try(self, node): + def do_visit_try(self, node): self.fill("try") with self.block(): self.traverse(node.body) @@ -973,8 +1018,24 @@ def visit_Try(self, node): with self.block(): self.traverse(node.finalbody) + def visit_Try(self, node): + prev_in_try_star = self._in_try_star + try: + self._in_try_star = False + self.do_visit_try(node) + finally: + self._in_try_star = prev_in_try_star + + def visit_TryStar(self, node): + prev_in_try_star = self._in_try_star + try: + self._in_try_star = True + self.do_visit_try(node) + finally: + self._in_try_star = prev_in_try_star + def visit_ExceptHandler(self, node): - self.fill("except") + self.fill("except*" if self._in_try_star else "except") if node.type: self.write(" ") self.traverse(node.type) @@ -990,6 +1051,8 @@ def visit_ClassDef(self, node): self.fill("@") self.traverse(deco) self.fill("class " + node.name) + if hasattr(node, "type_params"): + self._type_params_helper(node.type_params) with self.delimit_if("(", ")", condition = node.bases or node.keywords): comma = False for e in node.bases: @@ -1021,6 +1084,8 @@ def _function_helper(self, node, fill_suffix): self.traverse(deco) def_str = fill_suffix + " " + node.name self.fill(def_str) + if hasattr(node, "type_params"): + self._type_params_helper(node.type_params) with self.delimit("(", ")"): self.traverse(node.args) if node.returns: @@ -1029,6 +1094,30 @@ def _function_helper(self, node, fill_suffix): with self.block(extra=self.get_type_comment(node)): self._write_docstring_and_traverse_body(node) + def _type_params_helper(self, type_params): + if type_params is not None and len(type_params) > 0: + with self.delimit("[", "]"): + self.interleave(lambda: self.write(", "), self.traverse, type_params) + + def visit_TypeVar(self, node): + self.write(node.name) + if node.bound: + self.write(": ") + self.traverse(node.bound) + + def visit_TypeVarTuple(self, node): + self.write("*" + node.name) + + def visit_ParamSpec(self, node): + self.write("**" + node.name) + + def visit_TypeAlias(self, node): + self.fill("type ") + self.traverse(node.name) + self._type_params_helper(node.type_params) + self.write(" = ") + self.traverse(node.value) + def visit_For(self, node): self._for_helper("for ", node) @@ -1037,6 +1126,7 @@ def visit_AsyncFor(self, node): def _for_helper(self, fill, node): self.fill(fill) + self.set_precedence(_Precedence.TUPLE, node.target) self.traverse(node.target) self.write(" in ") self.traverse(node.iter) @@ -1133,71 +1223,81 @@ def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES): def visit_JoinedStr(self, node): self.write("f") - if self._avoid_backslashes: - self._fstring_JoinedStr(node, self.buffer_writer) - self._write_str_avoiding_backslashes(self.buffer) - return - # If we don't need to avoid backslashes globally (i.e., we only need - # to avoid them inside FormattedValues), it's cosmetically preferred - # to use escaped whitespace. That is, it's preferred to use backslashes - # for cases like: f"{x}\n". To accomplish this, we keep track of what - # in our buffer corresponds to FormattedValues and what corresponds to - # Constant parts of the f-string, and allow escapes accordingly. - buffer = [] + fstring_parts = [] for value in node.values: - meth = getattr(self, "_fstring_" + type(value).__name__) - meth(value, self.buffer_writer) - buffer.append((self.buffer, isinstance(value, Constant))) - new_buffer = [] - quote_types = _ALL_QUOTES - for value, is_constant in buffer: - # Repeatedly narrow down the list of possible quote_types - value, quote_types = self._str_literal_helper( - value, quote_types=quote_types, - escape_special_whitespace=is_constant + with self.buffered() as buffer: + self._write_fstring_inner(value) + fstring_parts.append( + ("".join(buffer), isinstance(value, Constant)) ) - new_buffer.append(value) - value = "".join(new_buffer) + + new_fstring_parts = [] + quote_types = list(_ALL_QUOTES) + fallback_to_repr = False + for value, is_constant in fstring_parts: + if is_constant: + value, new_quote_types = self._str_literal_helper( + value, + quote_types=quote_types, + escape_special_whitespace=True, + ) + if set(new_quote_types).isdisjoint(quote_types): + fallback_to_repr = True + break + quote_types = new_quote_types + elif "\n" in value: + quote_types = [q for q in quote_types if q in _MULTI_QUOTES] + assert quote_types + new_fstring_parts.append(value) + + if fallback_to_repr: + # If we weren't able to find a quote type that works for all parts + # of the JoinedStr, fallback to using repr and triple single quotes. + quote_types = ["'''"] + new_fstring_parts.clear() + for value, is_constant in fstring_parts: + if is_constant: + value = repr('"' + value) # force repr to use single quotes + expected_prefix = "'\"" + assert value.startswith(expected_prefix), repr(value) + value = value[len(expected_prefix):-1] + new_fstring_parts.append(value) + + value = "".join(new_fstring_parts) quote_type = quote_types[0] self.write(f"{quote_type}{value}{quote_type}") + def _write_fstring_inner(self, node): + if isinstance(node, JoinedStr): + # for both the f-string itself, and format_spec + for value in node.values: + self._write_fstring_inner(value) + elif isinstance(node, Constant) and isinstance(node.value, str): + value = node.value.replace("{", "{{").replace("}", "}}") + self.write(value) + elif isinstance(node, FormattedValue): + self.visit_FormattedValue(node) + else: + raise ValueError(f"Unexpected node inside JoinedStr, {node!r}") + def visit_FormattedValue(self, node): - self.write("f") - self._fstring_FormattedValue(node, self.buffer_writer) - self._write_str_avoiding_backslashes(self.buffer) + def unparse_inner(inner): + unparser = type(self)() + unparser.set_precedence(_Precedence.TEST.next(), inner) + return unparser.visit(inner) - def _fstring_JoinedStr(self, node, write): - for value in node.values: - meth = getattr(self, "_fstring_" + type(value).__name__) - meth(value, write) - - def _fstring_Constant(self, node, write): - if not isinstance(node.value, str): - raise ValueError("Constants inside JoinedStr should be a string.") - value = node.value.replace("{", "{{").replace("}", "}}") - write(value) - - def _fstring_FormattedValue(self, node, write): - write("{") - unparser = type(self)(_avoid_backslashes=True) - unparser.set_precedence(_Precedence.TEST.next(), node.value) - expr = unparser.visit(node.value) - if expr.startswith("{"): - write(" ") # Separate pair of opening brackets as "{ {" - if "\\" in expr: - raise ValueError("Unable to avoid backslash in f-string expression part") - write(expr) - if node.conversion != -1: - conversion = chr(node.conversion) - if conversion not in "sra": - raise ValueError("Unknown f-string conversion.") - write(f"!{conversion}") - if node.format_spec: - write(":") - meth = getattr(self, "_fstring_" + type(node.format_spec).__name__) - meth(node.format_spec, write) - write("}") + with self.delimit("{", "}"): + expr = unparse_inner(node.value) + if expr.startswith("{"): + # Separate pair of opening brackets as "{ {" + self.write(" ") + self.write(expr) + if node.conversion != -1: + self.write(f"!{chr(node.conversion)}") + if node.format_spec: + self.write(":") + self._write_fstring_inner(node.format_spec) def visit_Name(self, node): self.write(node.id) @@ -1320,7 +1420,11 @@ def write_item(item): ) def visit_Tuple(self, node): - with self.delimit("(", ")"): + with self.delimit_if( + "(", + ")", + len(node.elts) == 0 or self.get_precedence(node) > _Precedence.TUPLE + ): self.items_view(self.traverse, node.elts) unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} @@ -1336,7 +1440,7 @@ def visit_UnaryOp(self, node): operator_precedence = self.unop_precedence[operator] with self.require_parens(operator_precedence, node): self.write(operator) - # factor prefixes (+, -, ~) shouldn't be seperated + # factor prefixes (+, -, ~) shouldn't be separated # from the value they belong, (e.g: +1 instead of + 1) if operator_precedence is not _Precedence.FACTOR: self.write(" ") @@ -1461,20 +1565,17 @@ def visit_Call(self, node): self.traverse(e) def visit_Subscript(self, node): - def is_simple_tuple(slice_value): - # when unparsing a non-empty tuple, the parentheses can be safely - # omitted if there aren't any elements that explicitly requires - # parentheses (such as starred expressions). + def is_non_empty_tuple(slice_value): return ( isinstance(slice_value, Tuple) and slice_value.elts - and not any(isinstance(elt, Starred) for elt in slice_value.elts) ) self.set_precedence(_Precedence.ATOM, node.value) self.traverse(node.value) with self.delimit("[", "]"): - if is_simple_tuple(node.slice): + if is_non_empty_tuple(node.slice): + # parentheses can be omitted if the tuple isn't empty self.items_view(self.traverse, node.slice.elts) else: self.traverse(node.slice) @@ -1571,8 +1672,11 @@ def visit_keyword(self, node): def visit_Lambda(self, node): with self.require_parens(_Precedence.TEST, node): - self.write("lambda ") - self.traverse(node.args) + self.write("lambda") + with self.buffered() as buffer: + self.traverse(node.args) + if buffer: + self.write(" ", *buffer) self.write(": ") self.set_precedence(_Precedence.TEST, node.body) self.traverse(node.body) @@ -1681,6 +1785,22 @@ def unparse(ast_obj): return unparser.visit(ast_obj) +_deprecated_globals = { + name: globals().pop(name) + for name in ('Num', 'Str', 'Bytes', 'NameConstant', 'Ellipsis') +} + +def __getattr__(name): + if name in _deprecated_globals: + globals()[name] = value = _deprecated_globals[name] + import warnings + warnings._deprecated( + f"ast.{name}", message=_DEPRECATED_CLASS_MESSAGE, remove=(3, 14) + ) + return value + raise AttributeError(f"module 'ast' has no attribute '{name}'") + + def main(): import argparse diff --git a/Lib/test/support/ast_helper.py b/Lib/test/support/ast_helper.py new file mode 100644 index 0000000000..8a0415b6aa --- /dev/null +++ b/Lib/test/support/ast_helper.py @@ -0,0 +1,43 @@ +import ast + +class ASTTestMixin: + """Test mixing to have basic assertions for AST nodes.""" + + def assertASTEqual(self, ast1, ast2): + # Ensure the comparisons start at an AST node + self.assertIsInstance(ast1, ast.AST) + self.assertIsInstance(ast2, ast.AST) + + # An AST comparison routine modeled after ast.dump(), but + # instead of string building, it traverses the two trees + # in lock-step. + def traverse_compare(a, b, missing=object()): + if type(a) is not type(b): + self.fail(f"{type(a)!r} is not {type(b)!r}") + if isinstance(a, ast.AST): + for field in a._fields: + value1 = getattr(a, field, missing) + value2 = getattr(b, field, missing) + # Singletons are equal by definition, so further + # testing can be skipped. + if value1 is not value2: + traverse_compare(value1, value2) + elif isinstance(a, list): + try: + for node1, node2 in zip(a, b, strict=True): + traverse_compare(node1, node2) + except ValueError: + # Attempt a "pretty" error ala assertSequenceEqual() + len1 = len(a) + len2 = len(b) + if len1 > len2: + what = "First" + diff = len1 - len2 + else: + what = "Second" + diff = len2 - len1 + msg = f"{what} list contains {diff} additional elements." + raise self.failureException(msg) from None + elif a != b: + self.fail(f"{a!r} != {b!r}") + traverse_compare(ast1, ast2) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 1062f01c2f..af3e2bb5eb 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -1,18 +1,25 @@ import ast import builtins import dis +import enum import os +import re import sys +import textwrap import types import unittest import warnings import weakref +from functools import partial from textwrap import dedent from test import support +from test.support.import_helper import import_fresh_module +from test.support import os_helper, script_helper +from test.support.ast_helper import ASTTestMixin def to_tuple(t): - if t is None or isinstance(t, (str, int, complex)): + if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis: return t elif isinstance(t, list): return [to_tuple(e) for e in t] @@ -45,10 +52,20 @@ def to_tuple(t): "def f(a=0): pass", # FunctionDef with varargs "def f(*args): pass", + # FunctionDef with varargs as TypeVarTuple + "def f(*args: *Ts): pass", + # FunctionDef with varargs as unpacked Tuple + "def f(*args: *tuple[int, ...]): pass", + # FunctionDef with varargs as unpacked Tuple *and* TypeVarTuple + "def f(*args: *tuple[int, *Ts]): pass", # FunctionDef with kwargs "def f(**kwargs): pass", # FunctionDef with all kind of args and docstring "def f(a, b=1, c=None, d=[], e={}, *args, f=42, **kwargs): 'doc for f()'", + # FunctionDef with type annotation on return involving unpacking + "def f() -> tuple[*Ts]: pass", + "def f() -> tuple[int, *Ts]: pass", + "def f() -> tuple[int, *tuple[int, ...]]: pass", # ClassDef "class C:pass", # ClassDef with docstring @@ -64,6 +81,10 @@ def to_tuple(t): "a,b = c", "(a,b) = c", "[a,b] = c", + # AnnAssign with unpacked types + "x: tuple[*Ts]", + "x: tuple[int, *Ts]", + "x: tuple[int, *tuple[str, ...]]", # AugAssign "v += 1", # For @@ -85,6 +106,8 @@ def to_tuple(t): "try:\n pass\nexcept Exception:\n pass", # TryFinally "try:\n pass\nfinally:\n pass", + # TryStarExcept + "try:\n pass\nexcept* Exception:\n pass", # Assert "assert v", # Import @@ -160,7 +183,22 @@ def to_tuple(t): "def f(a=1, /, b=2, *, c): pass", "def f(a=1, /, b=2, *, c=4, **kwargs): pass", "def f(a=1, /, b=2, *, c, **kwargs): pass", - + # Type aliases + "type X = int", + "type X[T] = int", + "type X[T, *Ts, **P] = (T, Ts, P)", + "type X[T: int, *Ts, **P] = (T, Ts, P)", + "type X[T: (int, str), *Ts, **P] = (T, Ts, P)", + # Generic classes + "class X[T]: pass", + "class X[T, *Ts, **P]: pass", + "class X[T: int, *Ts, **P]: pass", + "class X[T: (int, str), *Ts, **P]: pass", + # Generic functions + "def f[T](): pass", + "def f[T, *Ts, **P](): pass", + "def f[T: int, *Ts, **P](): pass", + "def f[T: (int, str), *Ts, **P](): pass", ] # These are compiled through "single" @@ -241,13 +279,13 @@ def to_tuple(t): "()", # Combination "a.b.c.d(a.b[1:2])", - ] # TODO: expr_context, slice, boolop, operator, unaryop, cmpop, comprehension # excepthandler, arguments, keywords, alias class AST_Tests(unittest.TestCase): + maxDiff = None def _is_ast_node(self, name, node): if not isinstance(node, type): @@ -323,6 +361,42 @@ def test_ast_validation(self): tree = ast.parse(snippet) compile(tree, '', 'exec') + @unittest.skip("TODO: RUSTPYTHON, OverflowError: Python int too large to convert to Rust u32") + def test_invalid_position_information(self): + invalid_linenos = [ + (10, 1), (-10, -11), (10, -11), (-5, -2), (-5, 1) + ] + + for lineno, end_lineno in invalid_linenos: + with self.subTest(f"Check invalid linenos {lineno}:{end_lineno}"): + snippet = "a = 1" + tree = ast.parse(snippet) + tree.body[0].lineno = lineno + tree.body[0].end_lineno = end_lineno + with self.assertRaises(ValueError): + compile(tree, '', 'exec') + + invalid_col_offsets = [ + (10, 1), (-10, -11), (10, -11), (-5, -2), (-5, 1) + ] + for col_offset, end_col_offset in invalid_col_offsets: + with self.subTest(f"Check invalid col_offset {col_offset}:{end_col_offset}"): + snippet = "a = 1" + tree = ast.parse(snippet) + tree.body[0].col_offset = col_offset + tree.body[0].end_col_offset = end_col_offset + with self.assertRaises(ValueError): + compile(tree, '', 'exec') + + def test_compilation_of_ast_nodes_with_default_end_position_values(self): + tree = ast.Module(body=[ + ast.Import(names=[ast.alias(name='builtins', lineno=1, col_offset=0)], lineno=1, col_offset=0), + ast.Import(names=[ast.alias(name='traceback', lineno=0, col_offset=0)], lineno=0, col_offset=1) + ], type_ignores=[]) + + # Check that compilation doesn't crash. Note: this may crash explicitly only on debug mode. + compile(tree, "", "exec") + def test_slice(self): slc = ast.parse("x[::]").body[0].value.slice self.assertIsNone(slc.upper) @@ -359,6 +433,24 @@ def test_alias(self): self.assertEqual(alias.col_offset, 16) self.assertEqual(alias.end_col_offset, 17) + im = ast.parse("from bar import y as z").body[0] + alias = im.names[0] + self.assertEqual(alias.name, "y") + self.assertEqual(alias.asname, "z") + self.assertEqual(alias.lineno, 1) + self.assertEqual(alias.end_lineno, 1) + self.assertEqual(alias.col_offset, 16) + self.assertEqual(alias.end_col_offset, 22) + + im = ast.parse("import bar as foo").body[0] + alias = im.names[0] + self.assertEqual(alias.name, "bar") + self.assertEqual(alias.asname, "foo") + self.assertEqual(alias.lineno, 1) + self.assertEqual(alias.end_lineno, 1) + self.assertEqual(alias.col_offset, 7) + self.assertEqual(alias.end_col_offset, 17) + def test_base_classes(self): self.assertTrue(issubclass(ast.For, ast.stmt)) self.assertTrue(issubclass(ast.Name, ast.expr)) @@ -367,16 +459,42 @@ def test_base_classes(self): self.assertTrue(issubclass(ast.comprehension, ast.AST)) self.assertTrue(issubclass(ast.Gt, ast.AST)) + def test_import_deprecated(self): + ast = import_fresh_module('ast') + depr_regex = ( + r'ast\.{} is deprecated and will be removed in Python 3.14; ' + r'use ast\.Constant instead' + ) + for name in 'Num', 'Str', 'Bytes', 'NameConstant', 'Ellipsis': + with self.assertWarnsRegex(DeprecationWarning, depr_regex.format(name)): + getattr(ast, name) + + def test_field_attr_existence_deprecated(self): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', '', DeprecationWarning) + from ast import Num, Str, Bytes, NameConstant, Ellipsis + + for name in ('Num', 'Str', 'Bytes', 'NameConstant', 'Ellipsis'): + item = getattr(ast, name) + if self._is_ast_node(name, item): + with self.subTest(item): + with self.assertWarns(DeprecationWarning): + x = item() + if isinstance(x, ast.AST): + self.assertIs(type(x._fields), tuple) + def test_field_attr_existence(self): for name, item in ast.__dict__.items(): + # These emit DeprecationWarnings + if name in {'Num', 'Str', 'Bytes', 'NameConstant', 'Ellipsis'}: + continue + # constructor has a different signature + if name == 'Index': + continue if self._is_ast_node(name, item): - if name == 'Index': - # Index(value) just returns value now. - # The argument is required. - continue x = item() if isinstance(x, ast.AST): - self.assertEqual(type(x._fields), tuple) + self.assertIs(type(x._fields), tuple) # TODO: RUSTPYTHON @unittest.expectedFailure @@ -393,25 +511,108 @@ def test_arguments(self): self.assertEqual(x.args, 2) self.assertEqual(x.vararg, 3) + def test_field_attr_writable_deprecated(self): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', '', DeprecationWarning) + x = ast.Num() + # We can assign to _fields + x._fields = 666 + self.assertEqual(x._fields, 666) + def test_field_attr_writable(self): - x = ast.Num() + x = ast.Constant() # We can assign to _fields x._fields = 666 self.assertEqual(x._fields, 666) + def test_classattrs_deprecated(self): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', '', DeprecationWarning) + from ast import Num, Str, Bytes, NameConstant, Ellipsis + + with warnings.catch_warnings(record=True) as wlog: + warnings.filterwarnings('always', '', DeprecationWarning) + x = ast.Num() + self.assertEqual(x._fields, ('value', 'kind')) + + with self.assertRaises(AttributeError): + x.value + + with self.assertRaises(AttributeError): + x.n + + x = ast.Num(42) + self.assertEqual(x.value, 42) + self.assertEqual(x.n, 42) + + with self.assertRaises(AttributeError): + x.lineno + + with self.assertRaises(AttributeError): + x.foobar + + x = ast.Num(lineno=2) + self.assertEqual(x.lineno, 2) + + x = ast.Num(42, lineno=0) + self.assertEqual(x.lineno, 0) + self.assertEqual(x._fields, ('value', 'kind')) + self.assertEqual(x.value, 42) + self.assertEqual(x.n, 42) + + self.assertRaises(TypeError, ast.Num, 1, None, 2) + self.assertRaises(TypeError, ast.Num, 1, None, 2, lineno=0) + + # Arbitrary keyword arguments are supported + self.assertEqual(ast.Num(1, foo='bar').foo, 'bar') + + with self.assertRaisesRegex(TypeError, "Num got multiple values for argument 'n'"): + ast.Num(1, n=2) + + self.assertEqual(ast.Num(42).n, 42) + self.assertEqual(ast.Num(4.25).n, 4.25) + self.assertEqual(ast.Num(4.25j).n, 4.25j) + self.assertEqual(ast.Str('42').s, '42') + self.assertEqual(ast.Bytes(b'42').s, b'42') + self.assertIs(ast.NameConstant(True).value, True) + self.assertIs(ast.NameConstant(False).value, False) + self.assertIs(ast.NameConstant(None).value, None) + + self.assertEqual([str(w.message) for w in wlog], [ + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', + 'ast.Str is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'Attribute s is deprecated and will be removed in Python 3.14; use value instead', + 'ast.Bytes is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'Attribute s is deprecated and will be removed in Python 3.14; use value instead', + 'ast.NameConstant is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.NameConstant is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.NameConstant is deprecated and will be removed in Python 3.14; use ast.Constant instead', + ]) + def test_classattrs(self): - x = ast.Num() + x = ast.Constant() self.assertEqual(x._fields, ('value', 'kind')) with self.assertRaises(AttributeError): x.value - with self.assertRaises(AttributeError): - x.n - - x = ast.Num(42) + x = ast.Constant(42) self.assertEqual(x.value, 42) - self.assertEqual(x.n, 42) with self.assertRaises(AttributeError): x.lineno @@ -419,36 +620,23 @@ def test_classattrs(self): with self.assertRaises(AttributeError): x.foobar - x = ast.Num(lineno=2) + x = ast.Constant(lineno=2) self.assertEqual(x.lineno, 2) - x = ast.Num(42, lineno=0) + x = ast.Constant(42, lineno=0) self.assertEqual(x.lineno, 0) self.assertEqual(x._fields, ('value', 'kind')) self.assertEqual(x.value, 42) - self.assertEqual(x.n, 42) - self.assertRaises(TypeError, ast.Num, 1, None, 2) - self.assertRaises(TypeError, ast.Num, 1, None, 2, lineno=0) + self.assertRaises(TypeError, ast.Constant, 1, None, 2) + self.assertRaises(TypeError, ast.Constant, 1, None, 2, lineno=0) # Arbitrary keyword arguments are supported self.assertEqual(ast.Constant(1, foo='bar').foo, 'bar') - self.assertEqual(ast.Num(1, foo='bar').foo, 'bar') - with self.assertRaisesRegex(TypeError, "Num got multiple values for argument 'n'"): - ast.Num(1, n=2) with self.assertRaisesRegex(TypeError, "Constant got multiple values for argument 'value'"): ast.Constant(1, value=2) - self.assertEqual(ast.Num(42).n, 42) - self.assertEqual(ast.Num(4.25).n, 4.25) - self.assertEqual(ast.Num(4.25j).n, 4.25j) - self.assertEqual(ast.Str('42').s, '42') - self.assertEqual(ast.Bytes(b'42').s, b'42') - self.assertIs(ast.NameConstant(True).value, True) - self.assertIs(ast.NameConstant(False).value, False) - self.assertIs(ast.NameConstant(None).value, None) - self.assertEqual(ast.Constant(42).value, 42) self.assertEqual(ast.Constant(4.25).value, 4.25) self.assertEqual(ast.Constant(4.25j).value, 4.25j) @@ -460,85 +648,211 @@ def test_classattrs(self): self.assertIs(ast.Constant(...).value, ...) def test_realtype(self): - self.assertEqual(type(ast.Num(42)), ast.Constant) - self.assertEqual(type(ast.Num(4.25)), ast.Constant) - self.assertEqual(type(ast.Num(4.25j)), ast.Constant) - self.assertEqual(type(ast.Str('42')), ast.Constant) - self.assertEqual(type(ast.Bytes(b'42')), ast.Constant) - self.assertEqual(type(ast.NameConstant(True)), ast.Constant) - self.assertEqual(type(ast.NameConstant(False)), ast.Constant) - self.assertEqual(type(ast.NameConstant(None)), ast.Constant) - self.assertEqual(type(ast.Ellipsis()), ast.Constant) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', '', DeprecationWarning) + from ast import Num, Str, Bytes, NameConstant, Ellipsis + + with warnings.catch_warnings(record=True) as wlog: + warnings.filterwarnings('always', '', DeprecationWarning) + self.assertIs(type(ast.Num(42)), ast.Constant) + self.assertIs(type(ast.Num(4.25)), ast.Constant) + self.assertIs(type(ast.Num(4.25j)), ast.Constant) + self.assertIs(type(ast.Str('42')), ast.Constant) + self.assertIs(type(ast.Bytes(b'42')), ast.Constant) + self.assertIs(type(ast.NameConstant(True)), ast.Constant) + self.assertIs(type(ast.NameConstant(False)), ast.Constant) + self.assertIs(type(ast.NameConstant(None)), ast.Constant) + self.assertIs(type(ast.Ellipsis()), ast.Constant) + + self.assertEqual([str(w.message) for w in wlog], [ + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Str is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Bytes is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.NameConstant is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.NameConstant is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.NameConstant is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Ellipsis is deprecated and will be removed in Python 3.14; use ast.Constant instead', + ]) def test_isinstance(self): - self.assertTrue(isinstance(ast.Num(42), ast.Num)) - self.assertTrue(isinstance(ast.Num(4.2), ast.Num)) - self.assertTrue(isinstance(ast.Num(4.2j), ast.Num)) - self.assertTrue(isinstance(ast.Str('42'), ast.Str)) - self.assertTrue(isinstance(ast.Bytes(b'42'), ast.Bytes)) - self.assertTrue(isinstance(ast.NameConstant(True), ast.NameConstant)) - self.assertTrue(isinstance(ast.NameConstant(False), ast.NameConstant)) - self.assertTrue(isinstance(ast.NameConstant(None), ast.NameConstant)) - self.assertTrue(isinstance(ast.Ellipsis(), ast.Ellipsis)) - - self.assertTrue(isinstance(ast.Constant(42), ast.Num)) - self.assertTrue(isinstance(ast.Constant(4.2), ast.Num)) - self.assertTrue(isinstance(ast.Constant(4.2j), ast.Num)) - self.assertTrue(isinstance(ast.Constant('42'), ast.Str)) - self.assertTrue(isinstance(ast.Constant(b'42'), ast.Bytes)) - self.assertTrue(isinstance(ast.Constant(True), ast.NameConstant)) - self.assertTrue(isinstance(ast.Constant(False), ast.NameConstant)) - self.assertTrue(isinstance(ast.Constant(None), ast.NameConstant)) - self.assertTrue(isinstance(ast.Constant(...), ast.Ellipsis)) - - self.assertFalse(isinstance(ast.Str('42'), ast.Num)) - self.assertFalse(isinstance(ast.Num(42), ast.Str)) - self.assertFalse(isinstance(ast.Str('42'), ast.Bytes)) - self.assertFalse(isinstance(ast.Num(42), ast.NameConstant)) - self.assertFalse(isinstance(ast.Num(42), ast.Ellipsis)) - self.assertFalse(isinstance(ast.NameConstant(True), ast.Num)) - self.assertFalse(isinstance(ast.NameConstant(False), ast.Num)) - - self.assertFalse(isinstance(ast.Constant('42'), ast.Num)) - self.assertFalse(isinstance(ast.Constant(42), ast.Str)) - self.assertFalse(isinstance(ast.Constant('42'), ast.Bytes)) - self.assertFalse(isinstance(ast.Constant(42), ast.NameConstant)) - self.assertFalse(isinstance(ast.Constant(42), ast.Ellipsis)) - self.assertFalse(isinstance(ast.Constant(True), ast.Num)) - self.assertFalse(isinstance(ast.Constant(False), ast.Num)) - - self.assertFalse(isinstance(ast.Constant(), ast.Num)) - self.assertFalse(isinstance(ast.Constant(), ast.Str)) - self.assertFalse(isinstance(ast.Constant(), ast.Bytes)) - self.assertFalse(isinstance(ast.Constant(), ast.NameConstant)) - self.assertFalse(isinstance(ast.Constant(), ast.Ellipsis)) + from ast import Constant + + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', '', DeprecationWarning) + from ast import Num, Str, Bytes, NameConstant, Ellipsis + + cls_depr_msg = ( + 'ast.{} is deprecated and will be removed in Python 3.14; ' + 'use ast.Constant instead' + ) + + assertNumDeprecated = partial( + self.assertWarnsRegex, DeprecationWarning, cls_depr_msg.format("Num") + ) + assertStrDeprecated = partial( + self.assertWarnsRegex, DeprecationWarning, cls_depr_msg.format("Str") + ) + assertBytesDeprecated = partial( + self.assertWarnsRegex, DeprecationWarning, cls_depr_msg.format("Bytes") + ) + assertNameConstantDeprecated = partial( + self.assertWarnsRegex, + DeprecationWarning, + cls_depr_msg.format("NameConstant") + ) + assertEllipsisDeprecated = partial( + self.assertWarnsRegex, DeprecationWarning, cls_depr_msg.format("Ellipsis") + ) + + for arg in 42, 4.2, 4.2j: + with self.subTest(arg=arg): + with assertNumDeprecated(): + n = Num(arg) + with assertNumDeprecated(): + self.assertIsInstance(n, Num) + + with assertStrDeprecated(): + s = Str('42') + with assertStrDeprecated(): + self.assertIsInstance(s, Str) + + with assertBytesDeprecated(): + b = Bytes(b'42') + with assertBytesDeprecated(): + self.assertIsInstance(b, Bytes) + + for arg in True, False, None: + with self.subTest(arg=arg): + with assertNameConstantDeprecated(): + n = NameConstant(arg) + with assertNameConstantDeprecated(): + self.assertIsInstance(n, NameConstant) + + with assertEllipsisDeprecated(): + e = Ellipsis() + with assertEllipsisDeprecated(): + self.assertIsInstance(e, Ellipsis) + + for arg in 42, 4.2, 4.2j: + with self.subTest(arg=arg): + with assertNumDeprecated(): + self.assertIsInstance(Constant(arg), Num) + + with assertStrDeprecated(): + self.assertIsInstance(Constant('42'), Str) + + with assertBytesDeprecated(): + self.assertIsInstance(Constant(b'42'), Bytes) + + for arg in True, False, None: + with self.subTest(arg=arg): + with assertNameConstantDeprecated(): + self.assertIsInstance(Constant(arg), NameConstant) + + with assertEllipsisDeprecated(): + self.assertIsInstance(Constant(...), Ellipsis) + + with assertStrDeprecated(): + s = Str('42') + assertNumDeprecated(self.assertNotIsInstance, s, Num) + assertBytesDeprecated(self.assertNotIsInstance, s, Bytes) + + with assertNumDeprecated(): + n = Num(42) + assertStrDeprecated(self.assertNotIsInstance, n, Str) + assertNameConstantDeprecated(self.assertNotIsInstance, n, NameConstant) + assertEllipsisDeprecated(self.assertNotIsInstance, n, Ellipsis) + + with assertNameConstantDeprecated(): + n = NameConstant(True) + with assertNumDeprecated(): + self.assertNotIsInstance(n, Num) + + with assertNameConstantDeprecated(): + n = NameConstant(False) + with assertNumDeprecated(): + self.assertNotIsInstance(n, Num) + + for arg in '42', True, False: + with self.subTest(arg=arg): + with assertNumDeprecated(): + self.assertNotIsInstance(Constant(arg), Num) + + assertStrDeprecated(self.assertNotIsInstance, Constant(42), Str) + assertBytesDeprecated(self.assertNotIsInstance, Constant('42'), Bytes) + assertNameConstantDeprecated(self.assertNotIsInstance, Constant(42), NameConstant) + assertEllipsisDeprecated(self.assertNotIsInstance, Constant(42), Ellipsis) + assertNumDeprecated(self.assertNotIsInstance, Constant(), Num) + assertStrDeprecated(self.assertNotIsInstance, Constant(), Str) + assertBytesDeprecated(self.assertNotIsInstance, Constant(), Bytes) + assertNameConstantDeprecated(self.assertNotIsInstance, Constant(), NameConstant) + assertEllipsisDeprecated(self.assertNotIsInstance, Constant(), Ellipsis) class S(str): pass - self.assertTrue(isinstance(ast.Constant(S('42')), ast.Str)) - self.assertFalse(isinstance(ast.Constant(S('42')), ast.Num)) + with assertStrDeprecated(): + self.assertIsInstance(Constant(S('42')), Str) + with assertNumDeprecated(): + self.assertNotIsInstance(Constant(S('42')), Num) + + def test_constant_subclasses_deprecated(self): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', '', DeprecationWarning) + from ast import Num + + with warnings.catch_warnings(record=True) as wlog: + warnings.filterwarnings('always', '', DeprecationWarning) + class N(ast.Num): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.z = 'spam' + class N2(ast.Num): + pass + + n = N(42) + self.assertEqual(n.n, 42) + self.assertEqual(n.z, 'spam') + self.assertIs(type(n), N) + self.assertIsInstance(n, N) + self.assertIsInstance(n, ast.Num) + self.assertNotIsInstance(n, N2) + self.assertNotIsInstance(ast.Num(42), N) + n = N(n=42) + self.assertEqual(n.n, 42) + self.assertIs(type(n), N) + + self.assertEqual([str(w.message) for w in wlog], [ + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', + ]) - def test_subclasses(self): - class N(ast.Num): + def test_constant_subclasses(self): + class N(ast.Constant): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.z = 'spam' - class N2(ast.Num): + class N2(ast.Constant): pass n = N(42) - self.assertEqual(n.n, 42) + self.assertEqual(n.value, 42) self.assertEqual(n.z, 'spam') self.assertEqual(type(n), N) self.assertTrue(isinstance(n, N)) - self.assertTrue(isinstance(n, ast.Num)) + self.assertTrue(isinstance(n, ast.Constant)) self.assertFalse(isinstance(n, N2)) - self.assertFalse(isinstance(ast.Num(42), N)) - n = N(n=42) - self.assertEqual(n.n, 42) + self.assertFalse(isinstance(ast.Constant(42), N)) + n = N(value=42) + self.assertEqual(n.value, 42) self.assertEqual(type(n), N) def test_module(self): - body = [ast.Num(42)] + body = [ast.Constant(42)] x = ast.Module(body, []) self.assertEqual(x.body, body) @@ -551,8 +865,8 @@ def test_nodeclasses(self): x.foobarbaz = 5 self.assertEqual(x.foobarbaz, 5) - n1 = ast.Num(1) - n3 = ast.Num(3) + n1 = ast.Constant(1) + n3 = ast.Constant(3) addop = ast.Add() x = ast.BinOp(n1, addop, n3) self.assertEqual(x.left, n1) @@ -595,18 +909,11 @@ def test_no_fields(self): @unittest.expectedFailure def test_pickling(self): import pickle - mods = [pickle] - try: - import cPickle - mods.append(cPickle) - except ImportError: - pass - protocols = [0, 1, 2] - for mod in mods: - for protocol in protocols: - for ast in (compile(i, "?", "exec", 0x400) for i in exec_tests): - ast2 = mod.loads(mod.dumps(ast, protocol)) - self.assertEqual(to_tuple(ast2), to_tuple(ast)) + + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + for ast in (compile(i, "?", "exec", 0x400) for i in exec_tests): + ast2 = pickle.loads(pickle.dumps(ast, protocol)) + self.assertEqual(to_tuple(ast2), to_tuple(ast)) # TODO: RUSTPYTHON @unittest.expectedFailure @@ -702,6 +1009,23 @@ def test_ast_asdl_signature(self): expressions[0] = f"expr = {ast.expr.__subclasses__()[0].__doc__}" self.assertCountEqual(ast.expr.__doc__.split("\n"), expressions) + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_positional_only_feature_version(self): + ast.parse('def foo(x, /): ...', feature_version=(3, 8)) + ast.parse('def bar(x=1, /): ...', feature_version=(3, 8)) + with self.assertRaises(SyntaxError): + ast.parse('def foo(x, /): ...', feature_version=(3, 7)) + with self.assertRaises(SyntaxError): + ast.parse('def bar(x=1, /): ...', feature_version=(3, 7)) + + ast.parse('lambda x, /: ...', feature_version=(3, 8)) + ast.parse('lambda x=1, /: ...', feature_version=(3, 8)) + with self.assertRaises(SyntaxError): + ast.parse('lambda x, /: ...', feature_version=(3, 7)) + with self.assertRaises(SyntaxError): + ast.parse('lambda x=1, /: ...', feature_version=(3, 7)) + # TODO: RUSTPYTHON @unittest.expectedFailure def test_parenthesized_with_feature_version(self): @@ -714,17 +1038,41 @@ def test_parenthesized_with_feature_version(self): # TODO: RUSTPYTHON @unittest.expectedFailure - def test_issue40614_feature_version(self): - ast.parse('f"{x=}"', feature_version=(3, 8)) + def test_assignment_expression_feature_version(self): + ast.parse('(x := 0)', feature_version=(3, 8)) with self.assertRaises(SyntaxError): - ast.parse('f"{x=}"', feature_version=(3, 7)) + ast.parse('(x := 0)', feature_version=(3, 7)) # TODO: RUSTPYTHON @unittest.expectedFailure - def test_assignment_expression_feature_version(self): - ast.parse('(x := 0)', feature_version=(3, 8)) + def test_exception_groups_feature_version(self): + code = dedent(''' + try: ... + except* Exception: ... + ''') + ast.parse(code) with self.assertRaises(SyntaxError): - ast.parse('(x := 0)', feature_version=(3, 7)) + ast.parse(code, feature_version=(3, 10)) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_type_params_feature_version(self): + samples = [ + "type X = int", + "class X[T]: pass", + "def f[T](): pass", + ] + for sample in samples: + with self.subTest(sample): + ast.parse(sample) + with self.assertRaises(SyntaxError): + ast.parse(sample, feature_version=(3, 11)) + + def test_invalid_major_feature_version(self): + with self.assertRaises(ValueError): + ast.parse('pass', feature_version=(2, 7)) + with self.assertRaises(ValueError): + ast.parse('pass', feature_version=(4, 0)) # TODO: RUSTPYTHON @unittest.expectedFailure @@ -735,6 +1083,91 @@ def test_constant_as_name(self): with self.assertRaisesRegex(ValueError, f"identifier field can't represent '{constant}' constant"): compile(expr, "", "eval") + @unittest.skip("TODO: RUSTPYTHON, TypeError: enum mismatch") + def test_precedence_enum(self): + class _Precedence(enum.IntEnum): + """Precedence table that originated from python grammar.""" + NAMED_EXPR = enum.auto() # := + TUPLE = enum.auto() # , + YIELD = enum.auto() # 'yield', 'yield from' + TEST = enum.auto() # 'if'-'else', 'lambda' + OR = enum.auto() # 'or' + AND = enum.auto() # 'and' + NOT = enum.auto() # 'not' + CMP = enum.auto() # '<', '>', '==', '>=', '<=', '!=', + # 'in', 'not in', 'is', 'is not' + EXPR = enum.auto() + BOR = EXPR # '|' + BXOR = enum.auto() # '^' + BAND = enum.auto() # '&' + SHIFT = enum.auto() # '<<', '>>' + ARITH = enum.auto() # '+', '-' + TERM = enum.auto() # '*', '@', '/', '%', '//' + FACTOR = enum.auto() # unary '+', '-', '~' + POWER = enum.auto() # '**' + AWAIT = enum.auto() # 'await' + ATOM = enum.auto() + def next(self): + try: + return self.__class__(self + 1) + except ValueError: + return self + enum._test_simple_enum(_Precedence, ast._Precedence) + + @unittest.skipIf(support.is_wasi, "exhausts limited stack on WASI") + @support.cpython_only + def test_ast_recursion_limit(self): + fail_depth = support.EXCEEDS_RECURSION_LIMIT + crash_depth = 100_000 + success_depth = 1200 + + def check_limit(prefix, repeated): + expect_ok = prefix + repeated * success_depth + ast.parse(expect_ok) + for depth in (fail_depth, crash_depth): + broken = prefix + repeated * depth + details = "Compiling ({!r} + {!r} * {})".format( + prefix, repeated, depth) + with self.assertRaises(RecursionError, msg=details): + with support.infinite_recursion(): + ast.parse(broken) + + check_limit("a", "()") + check_limit("a", ".b") + check_limit("a", "[0]") + check_limit("a", "*a") + + def test_null_bytes(self): + with self.assertRaises(SyntaxError, + msg="source code string cannot contain null bytes"): + ast.parse("a\0b") + + def assert_none_check(self, node: type[ast.AST], attr: str, source: str) -> None: + with self.subTest(f"{node.__name__}.{attr}"): + tree = ast.parse(source) + found = 0 + for child in ast.walk(tree): + if isinstance(child, node): + setattr(child, attr, None) + found += 1 + self.assertEqual(found, 1) + e = re.escape(f"field '{attr}' is required for {node.__name__}") + with self.assertRaisesRegex(ValueError, f"^{e}$"): + compile(tree, "", "exec") + + @unittest.skip("TODO: RUSTPYTHON, TypeError: Expected type 'str' but 'NoneType' found") + def test_none_checks(self) -> None: + tests = [ + (ast.alias, "name", "import spam as SPAM"), + (ast.arg, "arg", "def spam(SPAM): spam"), + (ast.comprehension, "target", "[spam for SPAM in spam]"), + (ast.comprehension, "iter", "[spam for spam in SPAM]"), + (ast.keyword, "value", "spam(**SPAM)"), + (ast.match_case, "pattern", "match spam:\n case SPAM: spam"), + (ast.withitem, "context_expr", "with SPAM: spam"), + ] + for node, attr, source in tests: + self.assert_none_check(node, attr, source) class ASTHelpers_Test(unittest.TestCase): maxDiff = None @@ -871,7 +1304,7 @@ def test_dump_incomplete(self): @unittest.expectedFailure def test_copy_location(self): src = ast.parse('1 + 1', mode='eval') - src.body.right = ast.copy_location(ast.Num(2), src.body.right) + src.body.right = ast.copy_location(ast.Constant(2), src.body.right) self.assertEqual(ast.dump(src, include_attributes=True), 'Expression(body=BinOp(left=Constant(value=1, lineno=1, col_offset=0, ' 'end_lineno=1, end_col_offset=1), op=Add(), right=Constant(value=2, ' @@ -890,7 +1323,7 @@ def test_copy_location(self): def test_fix_missing_locations(self): src = ast.parse('write("spam")') src.body.append(ast.Expr(ast.Call(ast.Name('spam', ast.Load()), - [ast.Str('eggs')], []))) + [ast.Constant('eggs')], []))) self.assertEqual(src, ast.fix_missing_locations(src)) self.maxDiff = None self.assertEqual(ast.dump(src, include_attributes=True), @@ -933,6 +1366,19 @@ def test_increment_lineno(self): self.assertEqual(ast.increment_lineno(src).lineno, 2) self.assertIsNone(ast.increment_lineno(src).end_lineno) + @unittest.skip("TODO: RUSTPYTHON, NameError: name 'PyCF_TYPE_COMMENTS' is not defined") + def test_increment_lineno_on_module(self): + src = ast.parse(dedent("""\ + a = 1 + b = 2 # type: ignore + c = 3 + d = 4 # type: ignore@tag + """), type_comments=True) + ast.increment_lineno(src, n=5) + self.assertEqual(src.type_ignores[0].lineno, 7) + self.assertEqual(src.type_ignores[1].lineno, 9) + self.assertEqual(src.type_ignores[1].tag, '@tag') + def test_iter_fields(self): node = ast.parse('foo()', mode='eval') d = dict(ast.iter_fields(node.body)) @@ -1048,6 +1494,16 @@ def test_literal_eval(self): self.assertRaises(ValueError, ast.literal_eval, '+True') self.assertRaises(ValueError, ast.literal_eval, '2+3') + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_literal_eval_str_int_limit(self): + with support.adjust_int_max_str_digits(4000): + ast.literal_eval('3'*4000) # no error + with self.assertRaises(SyntaxError) as err_ctx: + ast.literal_eval('3'*4001) + self.assertIn('Exceeds the limit ', str(err_ctx.exception)) + self.assertIn(' Consider hexadecimal ', str(err_ctx.exception)) + def test_literal_eval_complex(self): # Issue #4907 self.assertEqual(ast.literal_eval('6j'), 6j) @@ -1193,9 +1649,9 @@ def arguments(args=None, posonlyargs=None, vararg=None, check(arguments(args=args), "must have Load context") check(arguments(posonlyargs=args), "must have Load context") check(arguments(kwonlyargs=args), "must have Load context") - check(arguments(defaults=[ast.Num(3)]), + check(arguments(defaults=[ast.Constant(3)]), "more positional defaults than args") - check(arguments(kw_defaults=[ast.Num(4)]), + check(arguments(kw_defaults=[ast.Constant(4)]), "length of kwonlyargs is not the same as kw_defaults") args = [ast.arg("x", ast.Name("x", ast.Load()))] check(arguments(args=args, defaults=[ast.Name("x", ast.Store())]), @@ -1210,22 +1666,46 @@ def arguments(args=None, posonlyargs=None, vararg=None, @unittest.expectedFailure def test_funcdef(self): a = ast.arguments([], [], None, [], [], None, []) - f = ast.FunctionDef("x", a, [], [], None) + f = ast.FunctionDef("x", a, [], [], None, None, []) self.stmt(f, "empty body on FunctionDef") - f = ast.FunctionDef("x", a, [ast.Pass()], [ast.Name("x", ast.Store())], - None) + f = ast.FunctionDef("x", a, [ast.Pass()], [ast.Name("x", ast.Store())], None, None, []) self.stmt(f, "must have Load context") f = ast.FunctionDef("x", a, [ast.Pass()], [], - ast.Name("x", ast.Store())) + ast.Name("x", ast.Store()), None, []) self.stmt(f, "must have Load context") + f = ast.FunctionDef("x", ast.arguments(), [ast.Pass()]) + self.stmt(f) def fac(args): - return ast.FunctionDef("x", args, [ast.Pass()], [], None) + return ast.FunctionDef("x", args, [ast.Pass()], [], None, None, []) self._check_arguments(fac, self.stmt) + # TODO: RUSTPYTHON, match expression is not implemented yet + # def test_funcdef_pattern_matching(self): + # # gh-104799: New fields on FunctionDef should be added at the end + # def matcher(node): + # match node: + # case ast.FunctionDef("foo", ast.arguments(args=[ast.arg("bar")]), + # [ast.Pass()], + # [ast.Name("capybara", ast.Load())], + # ast.Name("pacarana", ast.Load())): + # return True + # case _: + # return False + + # code = """ + # @capybara + # def foo(bar) -> pacarana: + # pass + # """ + # source = ast.parse(textwrap.dedent(code)) + # funcdef = source.body[0] + # self.assertIsInstance(funcdef, ast.FunctionDef) + # self.assertTrue(matcher(funcdef)) + # TODO: RUSTPYTHON @unittest.expectedFailure def test_classdef(self): - def cls(bases=None, keywords=None, body=None, decorator_list=None): + def cls(bases=None, keywords=None, body=None, decorator_list=None, type_params=None): if bases is None: bases = [] if keywords is None: @@ -1234,8 +1714,10 @@ def cls(bases=None, keywords=None, body=None, decorator_list=None): body = [ast.Pass()] if decorator_list is None: decorator_list = [] + if type_params is None: + type_params = [] return ast.ClassDef("myclass", bases, keywords, - body, decorator_list) + body, decorator_list, type_params) self.stmt(cls(bases=[ast.Name("x", ast.Store())]), "must have Load context") self.stmt(cls(keywords=[ast.keyword("x", ast.Name("x", ast.Store()))]), @@ -1256,9 +1738,9 @@ def test_delete(self): # TODO: RUSTPYTHON @unittest.expectedFailure def test_assign(self): - self.stmt(ast.Assign([], ast.Num(3)), "empty targets on Assign") - self.stmt(ast.Assign([None], ast.Num(3)), "None disallowed") - self.stmt(ast.Assign([ast.Name("x", ast.Load())], ast.Num(3)), + self.stmt(ast.Assign([], ast.Constant(3)), "empty targets on Assign") + self.stmt(ast.Assign([None], ast.Constant(3)), "None disallowed") + self.stmt(ast.Assign([ast.Name("x", ast.Load())], ast.Constant(3)), "must have Store context") self.stmt(ast.Assign([ast.Name("x", ast.Store())], ast.Name("y", ast.Store())), @@ -1292,22 +1774,22 @@ def test_for(self): # TODO: RUSTPYTHON @unittest.expectedFailure def test_while(self): - self.stmt(ast.While(ast.Num(3), [], []), "empty body on While") + self.stmt(ast.While(ast.Constant(3), [], []), "empty body on While") self.stmt(ast.While(ast.Name("x", ast.Store()), [ast.Pass()], []), "must have Load context") - self.stmt(ast.While(ast.Num(3), [ast.Pass()], + self.stmt(ast.While(ast.Constant(3), [ast.Pass()], [ast.Expr(ast.Name("x", ast.Store()))]), "must have Load context") # TODO: RUSTPYTHON @unittest.expectedFailure def test_if(self): - self.stmt(ast.If(ast.Num(3), [], []), "empty body on If") + self.stmt(ast.If(ast.Constant(3), [], []), "empty body on If") i = ast.If(ast.Name("x", ast.Store()), [ast.Pass()], []) self.stmt(i, "must have Load context") - i = ast.If(ast.Num(3), [ast.Expr(ast.Name("x", ast.Store()))], []) + i = ast.If(ast.Constant(3), [ast.Expr(ast.Name("x", ast.Store()))], []) self.stmt(i, "must have Load context") - i = ast.If(ast.Num(3), [ast.Pass()], + i = ast.If(ast.Constant(3), [ast.Pass()], [ast.Expr(ast.Name("x", ast.Store()))]) self.stmt(i, "must have Load context") @@ -1316,21 +1798,21 @@ def test_if(self): def test_with(self): p = ast.Pass() self.stmt(ast.With([], [p]), "empty items on With") - i = ast.withitem(ast.Num(3), None) + i = ast.withitem(ast.Constant(3), None) self.stmt(ast.With([i], []), "empty body on With") i = ast.withitem(ast.Name("x", ast.Store()), None) self.stmt(ast.With([i], [p]), "must have Load context") - i = ast.withitem(ast.Num(3), ast.Name("x", ast.Load())) + i = ast.withitem(ast.Constant(3), ast.Name("x", ast.Load())) self.stmt(ast.With([i], [p]), "must have Store context") # TODO: RUSTPYTHON @unittest.expectedFailure def test_raise(self): - r = ast.Raise(None, ast.Num(3)) + r = ast.Raise(None, ast.Constant(3)) self.stmt(r, "Raise with cause but no exception") r = ast.Raise(ast.Name("x", ast.Store()), None) self.stmt(r, "must have Load context") - r = ast.Raise(ast.Num(4), ast.Name("x", ast.Store())) + r = ast.Raise(ast.Constant(4), ast.Name("x", ast.Store())) self.stmt(r, "must have Load context") # TODO: RUSTPYTHON @@ -1355,6 +1837,28 @@ def test_try(self): t = ast.Try([p], e, [p], [ast.Expr(ast.Name("x", ast.Store()))]) self.stmt(t, "must have Load context") + # TODO: RUSTPYTHON + @unittest.skip("TODO: RUSTPYTHON, SyntaxError: RustPython does not implement this feature yet") + def test_try_star(self): + p = ast.Pass() + t = ast.TryStar([], [], [], [p]) + self.stmt(t, "empty body on TryStar") + t = ast.TryStar([ast.Expr(ast.Name("x", ast.Store()))], [], [], [p]) + self.stmt(t, "must have Load context") + t = ast.TryStar([p], [], [], []) + self.stmt(t, "TryStar has neither except handlers nor finalbody") + t = ast.TryStar([p], [], [p], [p]) + self.stmt(t, "TryStar has orelse but no except handlers") + t = ast.TryStar([p], [ast.ExceptHandler(None, "x", [])], [], []) + self.stmt(t, "empty body on ExceptHandler") + e = [ast.ExceptHandler(ast.Name("x", ast.Store()), "y", [p])] + self.stmt(ast.TryStar([p], e, [], []), "must have Load context") + e = [ast.ExceptHandler(None, "x", [p])] + t = ast.TryStar([p], e, [ast.Expr(ast.Name("x", ast.Store()))], [p]) + self.stmt(t, "must have Load context") + t = ast.TryStar([p], e, [p], [ast.Expr(ast.Name("x", ast.Store()))]) + self.stmt(t, "must have Load context") + # TODO: RUSTPYTHON @unittest.expectedFailure def test_assert(self): @@ -1396,11 +1900,11 @@ def test_expr(self): def test_boolop(self): b = ast.BoolOp(ast.And(), []) self.expr(b, "less than 2 values") - b = ast.BoolOp(ast.And(), [ast.Num(3)]) + b = ast.BoolOp(ast.And(), [ast.Constant(3)]) self.expr(b, "less than 2 values") - b = ast.BoolOp(ast.And(), [ast.Num(4), None]) + b = ast.BoolOp(ast.And(), [ast.Constant(4), None]) self.expr(b, "None disallowed") - b = ast.BoolOp(ast.And(), [ast.Num(4), ast.Name("x", ast.Store())]) + b = ast.BoolOp(ast.And(), [ast.Constant(4), ast.Name("x", ast.Store())]) self.expr(b, "must have Load context") # TODO: RUSTPYTHON @@ -1509,11 +2013,11 @@ def test_compare(self): left = ast.Name("x", ast.Load()) comp = ast.Compare(left, [ast.In()], []) self.expr(comp, "no comparators") - comp = ast.Compare(left, [ast.In()], [ast.Num(4), ast.Num(5)]) + comp = ast.Compare(left, [ast.In()], [ast.Constant(4), ast.Constant(5)]) self.expr(comp, "different number of comparators and operands") - comp = ast.Compare(ast.Num("blah"), [ast.In()], [left]) + comp = ast.Compare(ast.Constant("blah"), [ast.In()], [left]) self.expr(comp) - comp = ast.Compare(left, [ast.In()], [ast.Num("blah")]) + comp = ast.Compare(left, [ast.In()], [ast.Constant("blah")]) self.expr(comp) # TODO: RUSTPYTHON @@ -1533,16 +2037,30 @@ def test_call(self): # TODO: RUSTPYTHON @unittest.expectedFailure def test_num(self): - class subint(int): - pass - class subfloat(float): - pass - class subcomplex(complex): - pass - for obj in "0", "hello": - self.expr(ast.Num(obj)) - for obj in subint(), subfloat(), subcomplex(): - self.expr(ast.Num(obj), "invalid type", exc=TypeError) + with warnings.catch_warnings(record=True) as wlog: + warnings.filterwarnings('ignore', '', DeprecationWarning) + from ast import Num + + with warnings.catch_warnings(record=True) as wlog: + warnings.filterwarnings('always', '', DeprecationWarning) + class subint(int): + pass + class subfloat(float): + pass + class subcomplex(complex): + pass + for obj in "0", "hello": + self.expr(ast.Num(obj)) + for obj in subint(), subfloat(), subcomplex(): + self.expr(ast.Num(obj), "invalid type", exc=TypeError) + + self.assertEqual([str(w.message) for w in wlog], [ + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + 'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead', + ]) # TODO: RUSTPYTHON @unittest.expectedFailure @@ -1553,7 +2071,7 @@ def test_attribute(self): # TODO: RUSTPYTHON @unittest.expectedFailure def test_subscript(self): - sub = ast.Subscript(ast.Name("x", ast.Store()), ast.Num(3), + sub = ast.Subscript(ast.Name("x", ast.Store()), ast.Constant(3), ast.Load()) self.expr(sub, "must have Load context") x = ast.Name("x", ast.Load()) @@ -1575,7 +2093,7 @@ def test_subscript(self): def test_starred(self): left = ast.List([ast.Starred(ast.Name("x", ast.Load()), ast.Store())], ast.Store()) - assign = ast.Assign([left], ast.Num(4)) + assign = ast.Assign([left], ast.Constant(4)) self.stmt(assign, "must have Store context") def _sequence(self, fac): @@ -1594,10 +2112,21 @@ def test_tuple(self): self._sequence(ast.Tuple) def test_nameconstant(self): - self.expr(ast.NameConstant(4)) + with warnings.catch_warnings(record=True) as wlog: + warnings.filterwarnings('ignore', '', DeprecationWarning) + from ast import NameConstant + + with warnings.catch_warnings(record=True) as wlog: + warnings.filterwarnings('always', '', DeprecationWarning) + self.expr(ast.NameConstant(4)) + + self.assertEqual([str(w.message) for w in wlog], [ + 'ast.NameConstant is deprecated and will be removed in Python 3.14; use ast.Constant instead', + ]) # TODO: RUSTPYTHON @unittest.expectedFailure + @support.requires_resource('cpu') def test_stdlib_validates(self): stdlib = os.path.dirname(ast.__file__) tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")] @@ -1714,6 +2243,12 @@ def test_stdlib_validates(self): kwd_attrs=[], kwd_patterns=[ast.MatchStar()] ), + ast.MatchClass( + constant_true, # invalid name + patterns=[], + kwd_attrs=['True'], + kwd_patterns=[pattern_1] + ), ast.MatchSequence( [ ast.MatchStar("True") @@ -1834,7 +2369,7 @@ def get_load_const(self, tree): co = compile(tree, '', 'exec') consts = [] for instr in dis.get_instructions(co): - if instr.opname == 'LOAD_CONST': + if instr.opname == 'LOAD_CONST' or instr.opname == 'RETURN_CONST': consts.append(instr.argval) return consts @@ -2235,6 +2770,17 @@ class C: cdef = ast.parse(s).body[0] self.assertEqual(ast.get_source_segment(s, cdef.body[0], padded=True), s_method) + def test_source_segment_newlines(self): + s = 'def f():\n pass\ndef g():\r pass\r\ndef h():\r\n pass\r\n' + f, g, h = ast.parse(s).body + self._check_content(s, f, 'def f():\n pass') + self._check_content(s, g, 'def g():\r pass') + self._check_content(s, h, 'def h():\r\n pass') + + s = 'def f():\n a = 1\r b = 2\r\n c = 3\n' + f = ast.parse(s).body[0] + self._check_content(s, f, s.rstrip()) + def test_source_segment_missing_info(self): s = 'v = 1\r\nw = 1\nx = 1\n\ry = 1\r\n' v, w, x, y = ast.parse(s).body @@ -2247,9 +2793,10 @@ def test_source_segment_missing_info(self): self.assertIsNone(ast.get_source_segment(s, x)) self.assertIsNone(ast.get_source_segment(s, y)) -class NodeVisitorTests(unittest.TestCase): +class BaseNodeVisitorCases: + # Both `NodeVisitor` and `NodeTranformer` must raise these warnings: def test_old_constant_nodes(self): - class Visitor(ast.NodeVisitor): + class Visitor(self.visitor_class): def visit_Num(self, node): log.append((node.lineno, 'Num', node.n)) def visit_Str(self, node): @@ -2287,16 +2834,149 @@ def visit_Ellipsis(self, node): ]) self.assertEqual([str(w.message) for w in wlog], [ 'visit_Num is deprecated; add visit_Constant', + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', 'visit_Num is deprecated; add visit_Constant', + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', 'visit_Num is deprecated; add visit_Constant', + 'Attribute n is deprecated and will be removed in Python 3.14; use value instead', 'visit_Str is deprecated; add visit_Constant', + 'Attribute s is deprecated and will be removed in Python 3.14; use value instead', 'visit_Bytes is deprecated; add visit_Constant', + 'Attribute s is deprecated and will be removed in Python 3.14; use value instead', 'visit_NameConstant is deprecated; add visit_Constant', 'visit_NameConstant is deprecated; add visit_Constant', 'visit_Ellipsis is deprecated; add visit_Constant', ]) +class NodeVisitorTests(BaseNodeVisitorCases, unittest.TestCase): + visitor_class = ast.NodeVisitor + + +class NodeTransformerTests(ASTTestMixin, BaseNodeVisitorCases, unittest.TestCase): + visitor_class = ast.NodeTransformer + + def assertASTTransformation(self, tranformer_class, + initial_code, expected_code): + initial_ast = ast.parse(dedent(initial_code)) + expected_ast = ast.parse(dedent(expected_code)) + + tranformer = tranformer_class() + result_ast = ast.fix_missing_locations(tranformer.visit(initial_ast)) + + self.assertASTEqual(result_ast, expected_ast) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_node_remove_single(self): + code = 'def func(arg) -> SomeType: ...' + expected = 'def func(arg): ...' + + # Since `FunctionDef.returns` is defined as a single value, we test + # the `if isinstance(old_value, AST):` branch here. + class SomeTypeRemover(ast.NodeTransformer): + def visit_Name(self, node: ast.Name): + self.generic_visit(node) + if node.id == 'SomeType': + return None + return node + + self.assertASTTransformation(SomeTypeRemover, code, expected) + + def test_node_remove_from_list(self): + code = """ + def func(arg): + print(arg) + yield arg + """ + expected = """ + def func(arg): + print(arg) + """ + + # Since `FunctionDef.body` is defined as a list, we test + # the `if isinstance(old_value, list):` branch here. + class YieldRemover(ast.NodeTransformer): + def visit_Expr(self, node: ast.Expr): + self.generic_visit(node) + if isinstance(node.value, ast.Yield): + return None # Remove `yield` from a function + return node + + self.assertASTTransformation(YieldRemover, code, expected) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_node_return_list(self): + code = """ + class DSL(Base, kw1=True): ... + """ + expected = """ + class DSL(Base, kw1=True, kw2=True, kw3=False): ... + """ + + class ExtendKeywords(ast.NodeTransformer): + def visit_keyword(self, node: ast.keyword): + self.generic_visit(node) + if node.arg == 'kw1': + return [ + node, + ast.keyword('kw2', ast.Constant(True)), + ast.keyword('kw3', ast.Constant(False)), + ] + return node + + self.assertASTTransformation(ExtendKeywords, code, expected) + + def test_node_mutate(self): + code = """ + def func(arg): + print(arg) + """ + expected = """ + def func(arg): + log(arg) + """ + + class PrintToLog(ast.NodeTransformer): + def visit_Call(self, node: ast.Call): + self.generic_visit(node) + if isinstance(node.func, ast.Name) and node.func.id == 'print': + node.func.id = 'log' + return node + + self.assertASTTransformation(PrintToLog, code, expected) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_node_replace(self): + code = """ + def func(arg): + print(arg) + """ + expected = """ + def func(arg): + logger.log(arg, debug=True) + """ + + class PrintToLog(ast.NodeTransformer): + def visit_Call(self, node: ast.Call): + self.generic_visit(node) + if isinstance(node.func, ast.Name) and node.func.id == 'print': + return ast.Call( + func=ast.Attribute( + ast.Name('logger', ctx=ast.Load()), + attr='log', + ctx=ast.Load(), + ), + args=node.args, + keywords=[ast.keyword('debug', ast.Constant(True))], + ) + return node + + self.assertASTTransformation(PrintToLog, code, expected) + + @support.cpython_only class ModuleStateTests(unittest.TestCase): # bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state. @@ -2379,6 +3059,27 @@ def test_subinterpreter(self): self.assertEqual(res, 0) +class ASTMainTests(unittest.TestCase): + # Tests `ast.main()` function. + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_cli_file_input(self): + code = "print(1, 2, 3)" + expected = ast.dump(ast.parse(code), indent=3) + + with os_helper.temp_dir() as tmp_dir: + filename = os.path.join(tmp_dir, "test_module.py") + with open(filename, 'w', encoding='utf-8') as f: + f.write(code) + res, _ = script_helper.run_python_until_end("-m", "ast", filename) + + self.assertEqual(res.err, b"") + self.assertEqual(expected.splitlines(), + res.out.decode("utf8").splitlines()) + self.assertEqual(res.rc, 0) + + def main(): if __name__ != '__main__': return @@ -2398,22 +3099,31 @@ def main(): exec_results = [ ('Module', [('Expr', (1, 0, 1, 4), ('Constant', (1, 0, 1, 4), None, None))], []), ('Module', [('Expr', (1, 0, 1, 18), ('Constant', (1, 0, 1, 18), 'module docstring', None))], []), -('Module', [('FunctionDef', (1, 0, 1, 13), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (1, 9, 1, 13))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 29), 'f', ('arguments', [], [], None, [], [], None, []), [('Expr', (1, 9, 1, 29), ('Constant', (1, 9, 1, 29), 'function docstring', None))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 14), 'f', ('arguments', [], [('arg', (1, 6, 1, 7), 'a', None, None)], None, [], [], None, []), [('Pass', (1, 10, 1, 14))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 16), 'f', ('arguments', [], [('arg', (1, 6, 1, 7), 'a', None, None)], None, [], [], None, [('Constant', (1, 8, 1, 9), 0, None)]), [('Pass', (1, 12, 1, 16))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 18), 'f', ('arguments', [], [], ('arg', (1, 7, 1, 11), 'args', None, None), [], [], None, []), [('Pass', (1, 14, 1, 18))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 21), 'f', ('arguments', [], [], None, [], [], ('arg', (1, 8, 1, 14), 'kwargs', None, None), []), [('Pass', (1, 17, 1, 21))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 71), 'f', ('arguments', [], [('arg', (1, 6, 1, 7), 'a', None, None), ('arg', (1, 9, 1, 10), 'b', None, None), ('arg', (1, 14, 1, 15), 'c', None, None), ('arg', (1, 22, 1, 23), 'd', None, None), ('arg', (1, 28, 1, 29), 'e', None, None)], ('arg', (1, 35, 1, 39), 'args', None, None), [('arg', (1, 41, 1, 42), 'f', None, None)], [('Constant', (1, 43, 1, 45), 42, None)], ('arg', (1, 49, 1, 55), 'kwargs', None, None), [('Constant', (1, 11, 1, 12), 1, None), ('Constant', (1, 16, 1, 20), None, None), ('List', (1, 24, 1, 26), [], ('Load',)), ('Dict', (1, 30, 1, 32), [], [])]), [('Expr', (1, 58, 1, 71), ('Constant', (1, 58, 1, 71), 'doc for f()', None))], [], None, None)], []), -('Module', [('ClassDef', (1, 0, 1, 12), 'C', [], [], [('Pass', (1, 8, 1, 12))], [])], []), -('Module', [('ClassDef', (1, 0, 1, 32), 'C', [], [], [('Expr', (1, 9, 1, 32), ('Constant', (1, 9, 1, 32), 'docstring for class C', None))], [])], []), -('Module', [('ClassDef', (1, 0, 1, 21), 'C', [('Name', (1, 8, 1, 14), 'object', ('Load',))], [], [('Pass', (1, 17, 1, 21))], [])], []), -('Module', [('FunctionDef', (1, 0, 1, 16), 'f', ('arguments', [], [], None, [], [], None, []), [('Return', (1, 8, 1, 16), ('Constant', (1, 15, 1, 16), 1, None))], [], None, None)], []), +('Module', [('FunctionDef', (1, 0, 1, 13), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (1, 9, 1, 13))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 29), 'f', ('arguments', [], [], None, [], [], None, []), [('Expr', (1, 9, 1, 29), ('Constant', (1, 9, 1, 29), 'function docstring', None))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 14), 'f', ('arguments', [], [('arg', (1, 6, 1, 7), 'a', None, None)], None, [], [], None, []), [('Pass', (1, 10, 1, 14))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 16), 'f', ('arguments', [], [('arg', (1, 6, 1, 7), 'a', None, None)], None, [], [], None, [('Constant', (1, 8, 1, 9), 0, None)]), [('Pass', (1, 12, 1, 16))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 18), 'f', ('arguments', [], [], ('arg', (1, 7, 1, 11), 'args', None, None), [], [], None, []), [('Pass', (1, 14, 1, 18))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 23), 'f', ('arguments', [], [], ('arg', (1, 7, 1, 16), 'args', ('Starred', (1, 13, 1, 16), ('Name', (1, 14, 1, 16), 'Ts', ('Load',)), ('Load',)), None), [], [], None, []), [('Pass', (1, 19, 1, 23))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 36), 'f', ('arguments', [], [], ('arg', (1, 7, 1, 29), 'args', ('Starred', (1, 13, 1, 29), ('Subscript', (1, 14, 1, 29), ('Name', (1, 14, 1, 19), 'tuple', ('Load',)), ('Tuple', (1, 20, 1, 28), [('Name', (1, 20, 1, 23), 'int', ('Load',)), ('Constant', (1, 25, 1, 28), Ellipsis, None)], ('Load',)), ('Load',)), ('Load',)), None), [], [], None, []), [('Pass', (1, 32, 1, 36))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 36), 'f', ('arguments', [], [], ('arg', (1, 7, 1, 29), 'args', ('Starred', (1, 13, 1, 29), ('Subscript', (1, 14, 1, 29), ('Name', (1, 14, 1, 19), 'tuple', ('Load',)), ('Tuple', (1, 20, 1, 28), [('Name', (1, 20, 1, 23), 'int', ('Load',)), ('Starred', (1, 25, 1, 28), ('Name', (1, 26, 1, 28), 'Ts', ('Load',)), ('Load',))], ('Load',)), ('Load',)), ('Load',)), None), [], [], None, []), [('Pass', (1, 32, 1, 36))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 21), 'f', ('arguments', [], [], None, [], [], ('arg', (1, 8, 1, 14), 'kwargs', None, None), []), [('Pass', (1, 17, 1, 21))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 71), 'f', ('arguments', [], [('arg', (1, 6, 1, 7), 'a', None, None), ('arg', (1, 9, 1, 10), 'b', None, None), ('arg', (1, 14, 1, 15), 'c', None, None), ('arg', (1, 22, 1, 23), 'd', None, None), ('arg', (1, 28, 1, 29), 'e', None, None)], ('arg', (1, 35, 1, 39), 'args', None, None), [('arg', (1, 41, 1, 42), 'f', None, None)], [('Constant', (1, 43, 1, 45), 42, None)], ('arg', (1, 49, 1, 55), 'kwargs', None, None), [('Constant', (1, 11, 1, 12), 1, None), ('Constant', (1, 16, 1, 20), None, None), ('List', (1, 24, 1, 26), [], ('Load',)), ('Dict', (1, 30, 1, 32), [], [])]), [('Expr', (1, 58, 1, 71), ('Constant', (1, 58, 1, 71), 'doc for f()', None))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 27), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (1, 23, 1, 27))], [], ('Subscript', (1, 11, 1, 21), ('Name', (1, 11, 1, 16), 'tuple', ('Load',)), ('Tuple', (1, 17, 1, 20), [('Starred', (1, 17, 1, 20), ('Name', (1, 18, 1, 20), 'Ts', ('Load',)), ('Load',))], ('Load',)), ('Load',)), None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 32), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (1, 28, 1, 32))], [], ('Subscript', (1, 11, 1, 26), ('Name', (1, 11, 1, 16), 'tuple', ('Load',)), ('Tuple', (1, 17, 1, 25), [('Name', (1, 17, 1, 20), 'int', ('Load',)), ('Starred', (1, 22, 1, 25), ('Name', (1, 23, 1, 25), 'Ts', ('Load',)), ('Load',))], ('Load',)), ('Load',)), None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 45), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (1, 41, 1, 45))], [], ('Subscript', (1, 11, 1, 39), ('Name', (1, 11, 1, 16), 'tuple', ('Load',)), ('Tuple', (1, 17, 1, 38), [('Name', (1, 17, 1, 20), 'int', ('Load',)), ('Starred', (1, 22, 1, 38), ('Subscript', (1, 23, 1, 38), ('Name', (1, 23, 1, 28), 'tuple', ('Load',)), ('Tuple', (1, 29, 1, 37), [('Name', (1, 29, 1, 32), 'int', ('Load',)), ('Constant', (1, 34, 1, 37), Ellipsis, None)], ('Load',)), ('Load',)), ('Load',))], ('Load',)), ('Load',)), None, [])], []), +('Module', [('ClassDef', (1, 0, 1, 12), 'C', [], [], [('Pass', (1, 8, 1, 12))], [], [])], []), +('Module', [('ClassDef', (1, 0, 1, 32), 'C', [], [], [('Expr', (1, 9, 1, 32), ('Constant', (1, 9, 1, 32), 'docstring for class C', None))], [], [])], []), +('Module', [('ClassDef', (1, 0, 1, 21), 'C', [('Name', (1, 8, 1, 14), 'object', ('Load',))], [], [('Pass', (1, 17, 1, 21))], [], [])], []), +('Module', [('FunctionDef', (1, 0, 1, 16), 'f', ('arguments', [], [], None, [], [], None, []), [('Return', (1, 8, 1, 16), ('Constant', (1, 15, 1, 16), 1, None))], [], None, None, [])], []), ('Module', [('Delete', (1, 0, 1, 5), [('Name', (1, 4, 1, 5), 'v', ('Del',))])], []), ('Module', [('Assign', (1, 0, 1, 5), [('Name', (1, 0, 1, 1), 'v', ('Store',))], ('Constant', (1, 4, 1, 5), 1, None), None)], []), ('Module', [('Assign', (1, 0, 1, 7), [('Tuple', (1, 0, 1, 3), [('Name', (1, 0, 1, 1), 'a', ('Store',)), ('Name', (1, 2, 1, 3), 'b', ('Store',))], ('Store',))], ('Name', (1, 6, 1, 7), 'c', ('Load',)), None)], []), ('Module', [('Assign', (1, 0, 1, 9), [('Tuple', (1, 0, 1, 5), [('Name', (1, 1, 1, 2), 'a', ('Store',)), ('Name', (1, 3, 1, 4), 'b', ('Store',))], ('Store',))], ('Name', (1, 8, 1, 9), 'c', ('Load',)), None)], []), ('Module', [('Assign', (1, 0, 1, 9), [('List', (1, 0, 1, 5), [('Name', (1, 1, 1, 2), 'a', ('Store',)), ('Name', (1, 3, 1, 4), 'b', ('Store',))], ('Store',))], ('Name', (1, 8, 1, 9), 'c', ('Load',)), None)], []), +('Module', [('AnnAssign', (1, 0, 1, 13), ('Name', (1, 0, 1, 1), 'x', ('Store',)), ('Subscript', (1, 3, 1, 13), ('Name', (1, 3, 1, 8), 'tuple', ('Load',)), ('Tuple', (1, 9, 1, 12), [('Starred', (1, 9, 1, 12), ('Name', (1, 10, 1, 12), 'Ts', ('Load',)), ('Load',))], ('Load',)), ('Load',)), None, 1)], []), +('Module', [('AnnAssign', (1, 0, 1, 18), ('Name', (1, 0, 1, 1), 'x', ('Store',)), ('Subscript', (1, 3, 1, 18), ('Name', (1, 3, 1, 8), 'tuple', ('Load',)), ('Tuple', (1, 9, 1, 17), [('Name', (1, 9, 1, 12), 'int', ('Load',)), ('Starred', (1, 14, 1, 17), ('Name', (1, 15, 1, 17), 'Ts', ('Load',)), ('Load',))], ('Load',)), ('Load',)), None, 1)], []), +('Module', [('AnnAssign', (1, 0, 1, 31), ('Name', (1, 0, 1, 1), 'x', ('Store',)), ('Subscript', (1, 3, 1, 31), ('Name', (1, 3, 1, 8), 'tuple', ('Load',)), ('Tuple', (1, 9, 1, 30), [('Name', (1, 9, 1, 12), 'int', ('Load',)), ('Starred', (1, 14, 1, 30), ('Subscript', (1, 15, 1, 30), ('Name', (1, 15, 1, 20), 'tuple', ('Load',)), ('Tuple', (1, 21, 1, 29), [('Name', (1, 21, 1, 24), 'str', ('Load',)), ('Constant', (1, 26, 1, 29), Ellipsis, None)], ('Load',)), ('Load',)), ('Load',))], ('Load',)), ('Load',)), None, 1)], []), ('Module', [('AugAssign', (1, 0, 1, 6), ('Name', (1, 0, 1, 1), 'v', ('Store',)), ('Add',), ('Constant', (1, 5, 1, 6), 1, None))], []), ('Module', [('For', (1, 0, 1, 15), ('Name', (1, 4, 1, 5), 'v', ('Store',)), ('Name', (1, 9, 1, 10), 'v', ('Load',)), [('Pass', (1, 11, 1, 15))], [], None)], []), ('Module', [('While', (1, 0, 1, 12), ('Name', (1, 6, 1, 7), 'v', ('Load',)), [('Pass', (1, 8, 1, 12))], [])], []), @@ -2425,6 +3135,7 @@ def main(): ('Module', [('Raise', (1, 0, 1, 25), ('Call', (1, 6, 1, 25), ('Name', (1, 6, 1, 15), 'Exception', ('Load',)), [('Constant', (1, 16, 1, 24), 'string', None)], []), None)], []), ('Module', [('Try', (1, 0, 4, 6), [('Pass', (2, 2, 2, 6))], [('ExceptHandler', (3, 0, 4, 6), ('Name', (3, 7, 3, 16), 'Exception', ('Load',)), None, [('Pass', (4, 2, 4, 6))])], [], [])], []), ('Module', [('Try', (1, 0, 4, 6), [('Pass', (2, 2, 2, 6))], [], [], [('Pass', (4, 2, 4, 6))])], []), +('Module', [('TryStar', (1, 0, 4, 6), [('Pass', (2, 2, 2, 6))], [('ExceptHandler', (3, 0, 4, 6), ('Name', (3, 8, 3, 17), 'Exception', ('Load',)), None, [('Pass', (4, 2, 4, 6))])], [], [])], []), ('Module', [('Assert', (1, 0, 1, 8), ('Name', (1, 7, 1, 8), 'v', ('Load',)), None)], []), ('Module', [('Import', (1, 0, 1, 10), [('alias', (1, 7, 1, 10), 'sys', None)])], []), ('Module', [('ImportFrom', (1, 0, 1, 17), 'sys', [('alias', (1, 16, 1, 17), 'v', None)], 0)], []), @@ -2441,28 +3152,41 @@ def main(): ('Module', [('Expr', (1, 0, 1, 20), ('DictComp', (1, 0, 1, 20), ('Name', (1, 1, 1, 2), 'a', ('Load',)), ('Name', (1, 5, 1, 6), 'b', ('Load',)), [('comprehension', ('Tuple', (1, 11, 1, 14), [('Name', (1, 11, 1, 12), 'v', ('Store',)), ('Name', (1, 13, 1, 14), 'w', ('Store',))], ('Store',)), ('Name', (1, 18, 1, 19), 'x', ('Load',)), [], 0)]))], []), ('Module', [('Expr', (1, 0, 1, 19), ('SetComp', (1, 0, 1, 19), ('Name', (1, 1, 1, 2), 'r', ('Load',)), [('comprehension', ('Name', (1, 7, 1, 8), 'l', ('Store',)), ('Name', (1, 12, 1, 13), 'x', ('Load',)), [('Name', (1, 17, 1, 18), 'g', ('Load',))], 0)]))], []), ('Module', [('Expr', (1, 0, 1, 16), ('SetComp', (1, 0, 1, 16), ('Name', (1, 1, 1, 2), 'r', ('Load',)), [('comprehension', ('Tuple', (1, 7, 1, 10), [('Name', (1, 7, 1, 8), 'l', ('Store',)), ('Name', (1, 9, 1, 10), 'm', ('Store',))], ('Store',)), ('Name', (1, 14, 1, 15), 'x', ('Load',)), [], 0)]))], []), -('Module', [('AsyncFunctionDef', (1, 0, 3, 18), 'f', ('arguments', [], [], None, [], [], None, []), [('Expr', (2, 1, 2, 17), ('Constant', (2, 1, 2, 17), 'async function', None)), ('Expr', (3, 1, 3, 18), ('Await', (3, 1, 3, 18), ('Call', (3, 7, 3, 18), ('Name', (3, 7, 3, 16), 'something', ('Load',)), [], [])))], [], None, None)], []), -('Module', [('AsyncFunctionDef', (1, 0, 3, 8), 'f', ('arguments', [], [], None, [], [], None, []), [('AsyncFor', (2, 1, 3, 8), ('Name', (2, 11, 2, 12), 'e', ('Store',)), ('Name', (2, 16, 2, 17), 'i', ('Load',)), [('Expr', (2, 19, 2, 20), ('Constant', (2, 19, 2, 20), 1, None))], [('Expr', (3, 7, 3, 8), ('Constant', (3, 7, 3, 8), 2, None))], None)], [], None, None)], []), -('Module', [('AsyncFunctionDef', (1, 0, 2, 21), 'f', ('arguments', [], [], None, [], [], None, []), [('AsyncWith', (2, 1, 2, 21), [('withitem', ('Name', (2, 12, 2, 13), 'a', ('Load',)), ('Name', (2, 17, 2, 18), 'b', ('Store',)))], [('Expr', (2, 20, 2, 21), ('Constant', (2, 20, 2, 21), 1, None))], None)], [], None, None)], []), +('Module', [('AsyncFunctionDef', (1, 0, 3, 18), 'f', ('arguments', [], [], None, [], [], None, []), [('Expr', (2, 1, 2, 17), ('Constant', (2, 1, 2, 17), 'async function', None)), ('Expr', (3, 1, 3, 18), ('Await', (3, 1, 3, 18), ('Call', (3, 7, 3, 18), ('Name', (3, 7, 3, 16), 'something', ('Load',)), [], [])))], [], None, None, [])], []), +('Module', [('AsyncFunctionDef', (1, 0, 3, 8), 'f', ('arguments', [], [], None, [], [], None, []), [('AsyncFor', (2, 1, 3, 8), ('Name', (2, 11, 2, 12), 'e', ('Store',)), ('Name', (2, 16, 2, 17), 'i', ('Load',)), [('Expr', (2, 19, 2, 20), ('Constant', (2, 19, 2, 20), 1, None))], [('Expr', (3, 7, 3, 8), ('Constant', (3, 7, 3, 8), 2, None))], None)], [], None, None, [])], []), +('Module', [('AsyncFunctionDef', (1, 0, 2, 21), 'f', ('arguments', [], [], None, [], [], None, []), [('AsyncWith', (2, 1, 2, 21), [('withitem', ('Name', (2, 12, 2, 13), 'a', ('Load',)), ('Name', (2, 17, 2, 18), 'b', ('Store',)))], [('Expr', (2, 20, 2, 21), ('Constant', (2, 20, 2, 21), 1, None))], None)], [], None, None, [])], []), ('Module', [('Expr', (1, 0, 1, 14), ('Dict', (1, 0, 1, 14), [None, ('Constant', (1, 10, 1, 11), 2, None)], [('Dict', (1, 3, 1, 8), [('Constant', (1, 4, 1, 5), 1, None)], [('Constant', (1, 6, 1, 7), 2, None)]), ('Constant', (1, 12, 1, 13), 3, None)]))], []), ('Module', [('Expr', (1, 0, 1, 12), ('Set', (1, 0, 1, 12), [('Starred', (1, 1, 1, 8), ('Set', (1, 2, 1, 8), [('Constant', (1, 3, 1, 4), 1, None), ('Constant', (1, 6, 1, 7), 2, None)]), ('Load',)), ('Constant', (1, 10, 1, 11), 3, None)]))], []), -('Module', [('AsyncFunctionDef', (1, 0, 2, 21), 'f', ('arguments', [], [], None, [], [], None, []), [('Expr', (2, 1, 2, 21), ('ListComp', (2, 1, 2, 21), ('Name', (2, 2, 2, 3), 'i', ('Load',)), [('comprehension', ('Name', (2, 14, 2, 15), 'b', ('Store',)), ('Name', (2, 19, 2, 20), 'c', ('Load',)), [], 1)]))], [], None, None)], []), -('Module', [('FunctionDef', (4, 0, 4, 13), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (4, 9, 4, 13))], [('Name', (1, 1, 1, 6), 'deco1', ('Load',)), ('Call', (2, 1, 2, 8), ('Name', (2, 1, 2, 6), 'deco2', ('Load',)), [], []), ('Call', (3, 1, 3, 9), ('Name', (3, 1, 3, 6), 'deco3', ('Load',)), [('Constant', (3, 7, 3, 8), 1, None)], [])], None, None)], []), -('Module', [('AsyncFunctionDef', (4, 0, 4, 19), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (4, 15, 4, 19))], [('Name', (1, 1, 1, 6), 'deco1', ('Load',)), ('Call', (2, 1, 2, 8), ('Name', (2, 1, 2, 6), 'deco2', ('Load',)), [], []), ('Call', (3, 1, 3, 9), ('Name', (3, 1, 3, 6), 'deco3', ('Load',)), [('Constant', (3, 7, 3, 8), 1, None)], [])], None, None)], []), -('Module', [('ClassDef', (4, 0, 4, 13), 'C', [], [], [('Pass', (4, 9, 4, 13))], [('Name', (1, 1, 1, 6), 'deco1', ('Load',)), ('Call', (2, 1, 2, 8), ('Name', (2, 1, 2, 6), 'deco2', ('Load',)), [], []), ('Call', (3, 1, 3, 9), ('Name', (3, 1, 3, 6), 'deco3', ('Load',)), [('Constant', (3, 7, 3, 8), 1, None)], [])])], []), -('Module', [('FunctionDef', (2, 0, 2, 13), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (2, 9, 2, 13))], [('Call', (1, 1, 1, 19), ('Name', (1, 1, 1, 5), 'deco', ('Load',)), [('GeneratorExp', (1, 5, 1, 19), ('Name', (1, 6, 1, 7), 'a', ('Load',)), [('comprehension', ('Name', (1, 12, 1, 13), 'a', ('Store',)), ('Name', (1, 17, 1, 18), 'b', ('Load',)), [], 0)])], [])], None, None)], []), -('Module', [('FunctionDef', (2, 0, 2, 13), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (2, 9, 2, 13))], [('Attribute', (1, 1, 1, 6), ('Attribute', (1, 1, 1, 4), ('Name', (1, 1, 1, 2), 'a', ('Load',)), 'b', ('Load',)), 'c', ('Load',))], None, None)], []), +('Module', [('AsyncFunctionDef', (1, 0, 2, 21), 'f', ('arguments', [], [], None, [], [], None, []), [('Expr', (2, 1, 2, 21), ('ListComp', (2, 1, 2, 21), ('Name', (2, 2, 2, 3), 'i', ('Load',)), [('comprehension', ('Name', (2, 14, 2, 15), 'b', ('Store',)), ('Name', (2, 19, 2, 20), 'c', ('Load',)), [], 1)]))], [], None, None, [])], []), +('Module', [('FunctionDef', (4, 0, 4, 13), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (4, 9, 4, 13))], [('Name', (1, 1, 1, 6), 'deco1', ('Load',)), ('Call', (2, 1, 2, 8), ('Name', (2, 1, 2, 6), 'deco2', ('Load',)), [], []), ('Call', (3, 1, 3, 9), ('Name', (3, 1, 3, 6), 'deco3', ('Load',)), [('Constant', (3, 7, 3, 8), 1, None)], [])], None, None, [])], []), +('Module', [('AsyncFunctionDef', (4, 0, 4, 19), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (4, 15, 4, 19))], [('Name', (1, 1, 1, 6), 'deco1', ('Load',)), ('Call', (2, 1, 2, 8), ('Name', (2, 1, 2, 6), 'deco2', ('Load',)), [], []), ('Call', (3, 1, 3, 9), ('Name', (3, 1, 3, 6), 'deco3', ('Load',)), [('Constant', (3, 7, 3, 8), 1, None)], [])], None, None, [])], []), +('Module', [('ClassDef', (4, 0, 4, 13), 'C', [], [], [('Pass', (4, 9, 4, 13))], [('Name', (1, 1, 1, 6), 'deco1', ('Load',)), ('Call', (2, 1, 2, 8), ('Name', (2, 1, 2, 6), 'deco2', ('Load',)), [], []), ('Call', (3, 1, 3, 9), ('Name', (3, 1, 3, 6), 'deco3', ('Load',)), [('Constant', (3, 7, 3, 8), 1, None)], [])], [])], []), +('Module', [('FunctionDef', (2, 0, 2, 13), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (2, 9, 2, 13))], [('Call', (1, 1, 1, 19), ('Name', (1, 1, 1, 5), 'deco', ('Load',)), [('GeneratorExp', (1, 5, 1, 19), ('Name', (1, 6, 1, 7), 'a', ('Load',)), [('comprehension', ('Name', (1, 12, 1, 13), 'a', ('Store',)), ('Name', (1, 17, 1, 18), 'b', ('Load',)), [], 0)])], [])], None, None, [])], []), +('Module', [('FunctionDef', (2, 0, 2, 13), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (2, 9, 2, 13))], [('Attribute', (1, 1, 1, 6), ('Attribute', (1, 1, 1, 4), ('Name', (1, 1, 1, 2), 'a', ('Load',)), 'b', ('Load',)), 'c', ('Load',))], None, None, [])], []), ('Module', [('Expr', (1, 0, 1, 8), ('NamedExpr', (1, 1, 1, 7), ('Name', (1, 1, 1, 2), 'a', ('Store',)), ('Constant', (1, 6, 1, 7), 1, None)))], []), -('Module', [('FunctionDef', (1, 0, 1, 18), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [], None, [], [], None, []), [('Pass', (1, 14, 1, 18))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 26), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 12, 1, 13), 'c', None, None), ('arg', (1, 15, 1, 16), 'd', None, None), ('arg', (1, 18, 1, 19), 'e', None, None)], None, [], [], None, []), [('Pass', (1, 22, 1, 26))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 29), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 12, 1, 13), 'c', None, None)], None, [('arg', (1, 18, 1, 19), 'd', None, None), ('arg', (1, 21, 1, 22), 'e', None, None)], [None, None], None, []), [('Pass', (1, 25, 1, 29))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 39), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 12, 1, 13), 'c', None, None)], None, [('arg', (1, 18, 1, 19), 'd', None, None), ('arg', (1, 21, 1, 22), 'e', None, None)], [None, None], ('arg', (1, 26, 1, 32), 'kwargs', None, None), []), [('Pass', (1, 35, 1, 39))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 20), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [], None, [], [], None, [('Constant', (1, 8, 1, 9), 1, None)]), [('Pass', (1, 16, 1, 20))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 29), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 14, 1, 15), 'b', None, None), ('arg', (1, 19, 1, 20), 'c', None, None)], None, [], [], None, [('Constant', (1, 8, 1, 9), 1, None), ('Constant', (1, 16, 1, 17), 2, None), ('Constant', (1, 21, 1, 22), 4, None)]), [('Pass', (1, 25, 1, 29))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 32), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 14, 1, 15), 'b', None, None)], None, [('arg', (1, 22, 1, 23), 'c', None, None)], [('Constant', (1, 24, 1, 25), 4, None)], None, [('Constant', (1, 8, 1, 9), 1, None), ('Constant', (1, 16, 1, 17), 2, None)]), [('Pass', (1, 28, 1, 32))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 30), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 14, 1, 15), 'b', None, None)], None, [('arg', (1, 22, 1, 23), 'c', None, None)], [None], None, [('Constant', (1, 8, 1, 9), 1, None), ('Constant', (1, 16, 1, 17), 2, None)]), [('Pass', (1, 26, 1, 30))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 42), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 14, 1, 15), 'b', None, None)], None, [('arg', (1, 22, 1, 23), 'c', None, None)], [('Constant', (1, 24, 1, 25), 4, None)], ('arg', (1, 29, 1, 35), 'kwargs', None, None), [('Constant', (1, 8, 1, 9), 1, None), ('Constant', (1, 16, 1, 17), 2, None)]), [('Pass', (1, 38, 1, 42))], [], None, None)], []), -('Module', [('FunctionDef', (1, 0, 1, 40), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 14, 1, 15), 'b', None, None)], None, [('arg', (1, 22, 1, 23), 'c', None, None)], [None], ('arg', (1, 27, 1, 33), 'kwargs', None, None), [('Constant', (1, 8, 1, 9), 1, None), ('Constant', (1, 16, 1, 17), 2, None)]), [('Pass', (1, 36, 1, 40))], [], None, None)], []), +('Module', [('FunctionDef', (1, 0, 1, 18), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [], None, [], [], None, []), [('Pass', (1, 14, 1, 18))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 26), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 12, 1, 13), 'c', None, None), ('arg', (1, 15, 1, 16), 'd', None, None), ('arg', (1, 18, 1, 19), 'e', None, None)], None, [], [], None, []), [('Pass', (1, 22, 1, 26))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 29), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 12, 1, 13), 'c', None, None)], None, [('arg', (1, 18, 1, 19), 'd', None, None), ('arg', (1, 21, 1, 22), 'e', None, None)], [None, None], None, []), [('Pass', (1, 25, 1, 29))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 39), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 12, 1, 13), 'c', None, None)], None, [('arg', (1, 18, 1, 19), 'd', None, None), ('arg', (1, 21, 1, 22), 'e', None, None)], [None, None], ('arg', (1, 26, 1, 32), 'kwargs', None, None), []), [('Pass', (1, 35, 1, 39))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 20), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [], None, [], [], None, [('Constant', (1, 8, 1, 9), 1, None)]), [('Pass', (1, 16, 1, 20))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 29), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 14, 1, 15), 'b', None, None), ('arg', (1, 19, 1, 20), 'c', None, None)], None, [], [], None, [('Constant', (1, 8, 1, 9), 1, None), ('Constant', (1, 16, 1, 17), 2, None), ('Constant', (1, 21, 1, 22), 4, None)]), [('Pass', (1, 25, 1, 29))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 32), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 14, 1, 15), 'b', None, None)], None, [('arg', (1, 22, 1, 23), 'c', None, None)], [('Constant', (1, 24, 1, 25), 4, None)], None, [('Constant', (1, 8, 1, 9), 1, None), ('Constant', (1, 16, 1, 17), 2, None)]), [('Pass', (1, 28, 1, 32))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 30), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 14, 1, 15), 'b', None, None)], None, [('arg', (1, 22, 1, 23), 'c', None, None)], [None], None, [('Constant', (1, 8, 1, 9), 1, None), ('Constant', (1, 16, 1, 17), 2, None)]), [('Pass', (1, 26, 1, 30))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 42), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 14, 1, 15), 'b', None, None)], None, [('arg', (1, 22, 1, 23), 'c', None, None)], [('Constant', (1, 24, 1, 25), 4, None)], ('arg', (1, 29, 1, 35), 'kwargs', None, None), [('Constant', (1, 8, 1, 9), 1, None), ('Constant', (1, 16, 1, 17), 2, None)]), [('Pass', (1, 38, 1, 42))], [], None, None, [])], []), +('Module', [('FunctionDef', (1, 0, 1, 40), 'f', ('arguments', [('arg', (1, 6, 1, 7), 'a', None, None)], [('arg', (1, 14, 1, 15), 'b', None, None)], None, [('arg', (1, 22, 1, 23), 'c', None, None)], [None], ('arg', (1, 27, 1, 33), 'kwargs', None, None), [('Constant', (1, 8, 1, 9), 1, None), ('Constant', (1, 16, 1, 17), 2, None)]), [('Pass', (1, 36, 1, 40))], [], None, None, [])], []), +('Module', [('TypeAlias', (1, 0, 1, 12), ('Name', (1, 5, 1, 6), 'X', ('Store',)), [], ('Name', (1, 9, 1, 12), 'int', ('Load',)))], []), +('Module', [('TypeAlias', (1, 0, 1, 15), ('Name', (1, 5, 1, 6), 'X', ('Store',)), [('TypeVar', (1, 7, 1, 8), 'T', None)], ('Name', (1, 12, 1, 15), 'int', ('Load',)))], []), +('Module', [('TypeAlias', (1, 0, 1, 32), ('Name', (1, 5, 1, 6), 'X', ('Store',)), [('TypeVar', (1, 7, 1, 8), 'T', None), ('TypeVarTuple', (1, 10, 1, 13), 'Ts'), ('ParamSpec', (1, 15, 1, 18), 'P')], ('Tuple', (1, 22, 1, 32), [('Name', (1, 23, 1, 24), 'T', ('Load',)), ('Name', (1, 26, 1, 28), 'Ts', ('Load',)), ('Name', (1, 30, 1, 31), 'P', ('Load',))], ('Load',)))], []), +('Module', [('TypeAlias', (1, 0, 1, 37), ('Name', (1, 5, 1, 6), 'X', ('Store',)), [('TypeVar', (1, 7, 1, 13), 'T', ('Name', (1, 10, 1, 13), 'int', ('Load',))), ('TypeVarTuple', (1, 15, 1, 18), 'Ts'), ('ParamSpec', (1, 20, 1, 23), 'P')], ('Tuple', (1, 27, 1, 37), [('Name', (1, 28, 1, 29), 'T', ('Load',)), ('Name', (1, 31, 1, 33), 'Ts', ('Load',)), ('Name', (1, 35, 1, 36), 'P', ('Load',))], ('Load',)))], []), +('Module', [('TypeAlias', (1, 0, 1, 44), ('Name', (1, 5, 1, 6), 'X', ('Store',)), [('TypeVar', (1, 7, 1, 20), 'T', ('Tuple', (1, 10, 1, 20), [('Name', (1, 11, 1, 14), 'int', ('Load',)), ('Name', (1, 16, 1, 19), 'str', ('Load',))], ('Load',))), ('TypeVarTuple', (1, 22, 1, 25), 'Ts'), ('ParamSpec', (1, 27, 1, 30), 'P')], ('Tuple', (1, 34, 1, 44), [('Name', (1, 35, 1, 36), 'T', ('Load',)), ('Name', (1, 38, 1, 40), 'Ts', ('Load',)), ('Name', (1, 42, 1, 43), 'P', ('Load',))], ('Load',)))], []), +('Module', [('ClassDef', (1, 0, 1, 16), 'X', [], [], [('Pass', (1, 12, 1, 16))], [], [('TypeVar', (1, 8, 1, 9), 'T', None)])], []), +('Module', [('ClassDef', (1, 0, 1, 26), 'X', [], [], [('Pass', (1, 22, 1, 26))], [], [('TypeVar', (1, 8, 1, 9), 'T', None), ('TypeVarTuple', (1, 11, 1, 14), 'Ts'), ('ParamSpec', (1, 16, 1, 19), 'P')])], []), +('Module', [('ClassDef', (1, 0, 1, 31), 'X', [], [], [('Pass', (1, 27, 1, 31))], [], [('TypeVar', (1, 8, 1, 14), 'T', ('Name', (1, 11, 1, 14), 'int', ('Load',))), ('TypeVarTuple', (1, 16, 1, 19), 'Ts'), ('ParamSpec', (1, 21, 1, 24), 'P')])], []), +('Module', [('ClassDef', (1, 0, 1, 38), 'X', [], [], [('Pass', (1, 34, 1, 38))], [], [('TypeVar', (1, 8, 1, 21), 'T', ('Tuple', (1, 11, 1, 21), [('Name', (1, 12, 1, 15), 'int', ('Load',)), ('Name', (1, 17, 1, 20), 'str', ('Load',))], ('Load',))), ('TypeVarTuple', (1, 23, 1, 26), 'Ts'), ('ParamSpec', (1, 28, 1, 31), 'P')])], []), +('Module', [('FunctionDef', (1, 0, 1, 16), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (1, 12, 1, 16))], [], None, None, [('TypeVar', (1, 6, 1, 7), 'T', None)])], []), +('Module', [('FunctionDef', (1, 0, 1, 26), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (1, 22, 1, 26))], [], None, None, [('TypeVar', (1, 6, 1, 7), 'T', None), ('TypeVarTuple', (1, 9, 1, 12), 'Ts'), ('ParamSpec', (1, 14, 1, 17), 'P')])], []), +('Module', [('FunctionDef', (1, 0, 1, 31), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (1, 27, 1, 31))], [], None, None, [('TypeVar', (1, 6, 1, 12), 'T', ('Name', (1, 9, 1, 12), 'int', ('Load',))), ('TypeVarTuple', (1, 14, 1, 17), 'Ts'), ('ParamSpec', (1, 19, 1, 22), 'P')])], []), +('Module', [('FunctionDef', (1, 0, 1, 38), 'f', ('arguments', [], [], None, [], [], None, []), [('Pass', (1, 34, 1, 38))], [], None, None, [('TypeVar', (1, 6, 1, 19), 'T', ('Tuple', (1, 9, 1, 19), [('Name', (1, 10, 1, 13), 'int', ('Load',)), ('Name', (1, 15, 1, 18), 'str', ('Load',))], ('Load',))), ('TypeVarTuple', (1, 21, 1, 24), 'Ts'), ('ParamSpec', (1, 26, 1, 29), 'P')])], []), ] single_results = [ ('Interactive', [('Expr', (1, 0, 1, 3), ('BinOp', (1, 0, 1, 3), ('Constant', (1, 0, 1, 1), 1, None), ('Add',), ('Constant', (1, 2, 1, 3), 2, None)))]), From dc4f6994fb00ed02090bee888ba640434e44deec Mon Sep 17 00:00:00 2001 From: ChenyG Date: Sat, 25 Nov 2023 12:11:17 +0800 Subject: [PATCH 167/893] Support slice hash (#5123) * make slice object hashable * Update test_slice.py from CPython v3.12 * remove TODO * remove outdated tests --- Lib/test/test_slice.py | 52 ++++++++++++++++++++++++--- extra_tests/snippets/builtin_slice.py | 10 ------ vm/src/builtins/slice.rs | 46 ++++++++++++++++++++++-- 3 files changed, 92 insertions(+), 16 deletions(-) diff --git a/Lib/test/test_slice.py b/Lib/test/test_slice.py index 9e79775aca..53d4c77616 100644 --- a/Lib/test/test_slice.py +++ b/Lib/test/test_slice.py @@ -5,6 +5,7 @@ import sys import unittest import weakref +import copy from pickle import loads, dumps from test import support @@ -79,10 +80,16 @@ def test_repr(self): self.assertEqual(repr(slice(1, 2, 3)), "slice(1, 2, 3)") def test_hash(self): - # Verify clearing of SF bug #800796 - self.assertRaises(TypeError, hash, slice(5)) + self.assertEqual(hash(slice(5)), slice(5).__hash__()) + self.assertEqual(hash(slice(1, 2)), slice(1, 2).__hash__()) + self.assertEqual(hash(slice(1, 2, 3)), slice(1, 2, 3).__hash__()) + self.assertNotEqual(slice(5), slice(6)) + + with self.assertRaises(TypeError): + hash(slice(1, 2, [])) + with self.assertRaises(TypeError): - slice(5).__hash__() + hash(slice(4, {})) def test_cmp(self): s1 = slice(1, 2, 3) @@ -235,13 +242,50 @@ def __setitem__(self, i, k): self.assertEqual(tmp, [(slice(1, 2), 42)]) def test_pickle(self): + import pickle + s = slice(10, 20, 3) - for protocol in (0,1,2): + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): t = loads(dumps(s, protocol)) self.assertEqual(s, t) self.assertEqual(s.indices(15), t.indices(15)) self.assertNotEqual(id(s), id(t)) + def test_copy(self): + s = slice(1, 10) + c = copy.copy(s) + self.assertIs(s, c) + + s = slice(1, 10, 2) + c = copy.copy(s) + self.assertIs(s, c) + + # Corner case for mutable indices: + s = slice([1, 2], [3, 4], [5, 6]) + c = copy.copy(s) + self.assertIs(s, c) + self.assertIs(s.start, c.start) + self.assertIs(s.stop, c.stop) + self.assertIs(s.step, c.step) + + def test_deepcopy(self): + s = slice(1, 10) + c = copy.deepcopy(s) + self.assertEqual(s, c) + + s = slice(1, 10, 2) + c = copy.deepcopy(s) + self.assertEqual(s, c) + + # Corner case for mutable indices: + s = slice([1, 2], [3, 4], [5, 6]) + c = copy.deepcopy(s) + self.assertIsNot(s, c) + self.assertEqual(s, c) + self.assertIsNot(s.start, c.start) + self.assertIsNot(s.stop, c.stop) + self.assertIsNot(s.step, c.step) + # TODO: RUSTPYTHON @unittest.expectedFailure def test_cycle(self): diff --git a/extra_tests/snippets/builtin_slice.py b/extra_tests/snippets/builtin_slice.py index 57fb7e21c2..b5c3a8ceb4 100644 --- a/extra_tests/snippets/builtin_slice.py +++ b/extra_tests/snippets/builtin_slice.py @@ -82,16 +82,6 @@ assert_raises(TypeError, lambda: slice(0) <= 3) assert_raises(TypeError, lambda: slice(0) >= 3) -# TODO: slice is hashable in CPython 3.12 -# assert_raises(TypeError, hash, slice(0)) -# assert_raises(TypeError, hash, slice(None)) -# -# def dict_slice(): -# d = {} -# d[slice(0)] = 3 -# -# assert_raises(TypeError, dict_slice) - assert slice(None ).indices(10) == (0, 10, 1) assert slice(None, None, 2).indices(10) == (0, 10, 2) assert slice(1, None, 2).indices(10) == (1, 10, 2) diff --git a/vm/src/builtins/slice.rs b/vm/src/builtins/slice.rs index e06aee682e..5da3649115 100644 --- a/vm/src/builtins/slice.rs +++ b/vm/src/builtins/slice.rs @@ -3,10 +3,11 @@ use super::{PyStrRef, PyTupleRef, PyType, PyTypeRef}; use crate::{ class::PyClassImpl, + common::hash::{PyHash, PyUHash}, convert::ToPyObject, function::{ArgIndex, FuncArgs, OptionalArg, PyComparisonValue}, sliceable::SaturatedSlice, - types::{Comparable, Constructor, PyComparisonOp, Representable}, + types::{Comparable, Constructor, Hashable, PyComparisonOp, Representable}, AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use malachite_bigint::{BigInt, ToBigInt}; @@ -26,7 +27,7 @@ impl PyPayload for PySlice { } } -#[pyclass(with(Comparable, Representable))] +#[pyclass(with(Comparable, Representable, Hashable))] impl PySlice { #[pygetset] fn start(&self, vm: &VirtualMachine) -> PyObjectRef { @@ -197,6 +198,47 @@ impl PySlice { } } +impl Hashable for PySlice { + #[inline] + fn hash(zelf: &Py, vm: &VirtualMachine) -> PyResult { + const XXPRIME_1: PyUHash = if cfg!(target_pointer_width = "64") { + 11400714785074694791 + } else { + 2654435761 + }; + const XXPRIME_2: PyUHash = if cfg!(target_pointer_width = "64") { + 14029467366897019727 + } else { + 2246822519 + }; + const XXPRIME_5: PyUHash = if cfg!(target_pointer_width = "64") { + 2870177450012600261 + } else { + 374761393 + }; + const ROTATE: u32 = if cfg!(target_pointer_width = "64") { + 31 + } else { + 13 + }; + + let mut acc = XXPRIME_5; + for part in [zelf.start_ref(vm), &zelf.stop, zelf.step_ref(vm)].iter() { + let lane = part.hash(vm)? as PyUHash; + if lane == u64::MAX as PyUHash { + return Ok(-1 as PyHash); + } + acc = acc.wrapping_add(lane.wrapping_mul(XXPRIME_2)); + acc = acc.rotate_left(ROTATE); + acc = acc.wrapping_mul(XXPRIME_1); + } + if acc == u64::MAX as PyUHash { + return Ok(1546275796 as PyHash); + } + Ok(acc as PyHash) + } +} + impl Comparable for PySlice { fn cmp( zelf: &Py, From 06bb68a6c6f3c06bca7889f4a1ba771fce9011c4 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 19:45:47 +0800 Subject: [PATCH 168/893] Update test_cmd_line.py from CPython v3.12.0 --- Lib/test/test_cmd_line.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py index 7e15bc8f88..bd1b6620c0 100644 --- a/Lib/test/test_cmd_line.py +++ b/Lib/test/test_cmd_line.py @@ -18,9 +18,6 @@ if not support.has_subprocess_support: raise unittest.SkipTest("test module requires subprocess") -# Debug build? -Py_DEBUG = hasattr(sys, "gettotalrefcount") - # XXX (ncoghlan): Move to script_helper and make consistent with run_python def _kill_python_and_exit_code(p): @@ -144,7 +141,7 @@ def run_python(*args): # "-X showrefcount" shows the refcount, but only in debug builds rc, out, err = run_python('-I', '-X', 'showrefcount', '-c', code) self.assertEqual(out.rstrip(), b"{'showrefcount': True}") - if Py_DEBUG: + if support.Py_DEBUG: # bpo-46417: Tolerate negative reference count which can occur # because of bugs in C extensions. This test is only about checking # the showrefcount feature. @@ -753,7 +750,7 @@ def test_xdev(self): code = ("import warnings; " "print(' '.join('%s::%s' % (f[0], f[2].__name__) " "for f in warnings.filters))") - if Py_DEBUG: + if support.Py_DEBUG: expected_filters = "default::Warning" else: expected_filters = ("default::Warning " @@ -827,7 +824,7 @@ def test_warnings_filter_precedence(self): expected_filters = ("error::BytesWarning " "once::UserWarning " "always::UserWarning") - if not Py_DEBUG: + if not support.Py_DEBUG: expected_filters += (" " "default::DeprecationWarning " "ignore::DeprecationWarning " @@ -867,10 +864,10 @@ def test_pythonmalloc(self): # Test the PYTHONMALLOC environment variable pymalloc = support.with_pymalloc() if pymalloc: - default_name = 'pymalloc_debug' if Py_DEBUG else 'pymalloc' + default_name = 'pymalloc_debug' if support.Py_DEBUG else 'pymalloc' default_name_debug = 'pymalloc_debug' else: - default_name = 'malloc_debug' if Py_DEBUG else 'malloc' + default_name = 'malloc_debug' if support.Py_DEBUG else 'malloc' default_name_debug = 'malloc_debug' tests = [ @@ -950,7 +947,8 @@ def res2int(res): return tuple(int(i) for i in out.split()) res = assert_python_ok('-c', code) - self.assertEqual(res2int(res), (-1, sys.get_int_max_str_digits())) + current_max = sys.get_int_max_str_digits() + self.assertEqual(res2int(res), (current_max, current_max)) res = assert_python_ok('-X', 'int_max_str_digits=0', '-c', code) self.assertEqual(res2int(res), (0, 0)) res = assert_python_ok('-X', 'int_max_str_digits=4000', '-c', code) From 999dcbdd1bd1dec0b16b04309bf24742d889436b Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 19:48:22 +0800 Subject: [PATCH 169/893] Edit test_cmd_line.py --- Lib/test/test_cmd_line.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py index bd1b6620c0..02f060ba2c 100644 --- a/Lib/test/test_cmd_line.py +++ b/Lib/test/test_cmd_line.py @@ -930,6 +930,8 @@ def test_parsing_error(self): self.assertTrue(proc.stderr.startswith(err_msg), proc.stderr) self.assertNotEqual(proc.returncode, 0) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_int_max_str_digits(self): code = "import sys; print(sys.flags.int_max_str_digits, sys.get_int_max_str_digits())" From 700f2b9c12862a11d5c6b0b53c095565437b4ccc Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 19:49:27 +0800 Subject: [PATCH 170/893] Update test_cmd_line_script.py from CPython v3.12.0 --- Lib/test/test_cmd_line_script.py | 35 ++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py index b258edca19..9f368d8069 100644 --- a/Lib/test/test_cmd_line_script.py +++ b/Lib/test/test_cmd_line_script.py @@ -662,9 +662,9 @@ def test_syntaxerror_multi_line_fstring(self): self.assertEqual( stderr.splitlines()[-3:], [ - b' foo"""', - b' ^', - b'SyntaxError: f-string: empty expression not allowed', + b' foo = f"""{}', + b' ^', + b'SyntaxError: f-string: valid expression required before \'}\'', ], ) @@ -685,8 +685,31 @@ def test_syntaxerror_invalid_escape_sequence_multi_line(self): ], ) - # TODO: RUSTPYTHON - @unittest.expectedFailure + def test_syntaxerror_null_bytes(self): + script = "x = '\0' nothing to see here\n';import os;os.system('echo pwnd')\n" + with os_helper.temp_dir() as script_dir: + script_name = _make_test_script(script_dir, 'script', script) + exitcode, stdout, stderr = assert_python_failure(script_name) + self.assertEqual( + stderr.splitlines()[-2:], + [ b" x = '", + b'SyntaxError: source code cannot contain null bytes' + ], + ) + + def test_syntaxerror_null_bytes_in_multiline_string(self): + scripts = ["\n'''\nmultilinestring\0\n'''", "\nf'''\nmultilinestring\0\n'''"] # Both normal and f-strings + with os_helper.temp_dir() as script_dir: + for script in scripts: + script_name = _make_test_script(script_dir, 'script', script) + _, _, stderr = assert_python_failure(script_name) + self.assertEqual( + stderr.splitlines()[-2:], + [ b" multilinestring", + b'SyntaxError: source code cannot contain null bytes' + ] + ) + def test_consistent_sys_path_for_direct_execution(self): # This test case ensures that the following all give the same # sys.path configuration: @@ -785,7 +808,7 @@ def test_script_as_dev_fd(self): with os_helper.temp_dir() as work_dir: script_name = _make_test_script(work_dir, 'script.py', script) with open(script_name, "r") as fp: - p = spawn_python(f"/dev/fd/{fp.fileno()}", close_fds=False, pass_fds=(0,1,2,fp.fileno())) + p = spawn_python(f"/dev/fd/{fp.fileno()}", close_fds=True, pass_fds=(0,1,2,fp.fileno())) out, err = p.communicate() self.assertEqual(out, b"12345678912345678912345\n") From 9a61716f74ea990a60c3ebad6addeb421795d163 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 19:52:25 +0800 Subject: [PATCH 171/893] Edit test_cmd_line_script.py --- Lib/test/test_cmd_line_script.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py index 9f368d8069..e40069d780 100644 --- a/Lib/test/test_cmd_line_script.py +++ b/Lib/test/test_cmd_line_script.py @@ -685,6 +685,8 @@ def test_syntaxerror_invalid_escape_sequence_multi_line(self): ], ) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_syntaxerror_null_bytes(self): script = "x = '\0' nothing to see here\n';import os;os.system('echo pwnd')\n" with os_helper.temp_dir() as script_dir: @@ -697,6 +699,8 @@ def test_syntaxerror_null_bytes(self): ], ) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_syntaxerror_null_bytes_in_multiline_string(self): scripts = ["\n'''\nmultilinestring\0\n'''", "\nf'''\nmultilinestring\0\n'''"] # Both normal and f-strings with os_helper.temp_dir() as script_dir: @@ -710,6 +714,8 @@ def test_syntaxerror_null_bytes_in_multiline_string(self): ] ) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_consistent_sys_path_for_direct_execution(self): # This test case ensures that the following all give the same # sys.path configuration: From c28cca97d1483f0100db8349ab2d17faacffe73f Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 19:53:37 +0800 Subject: [PATCH 172/893] Update test_code_module.py from CPython v3.12.0 --- Lib/test/test_code_module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Lib/test/test_code_module.py b/Lib/test/test_code_module.py index 2b379fc378..5ac17ef16e 100644 --- a/Lib/test/test_code_module.py +++ b/Lib/test/test_code_module.py @@ -6,6 +6,7 @@ from unittest import mock from test.support import import_helper + code = import_helper.import_module('code') From c5364ca1573b5a02d1041ace5bb0695bd79bb137 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:01:48 +0800 Subject: [PATCH 173/893] Update test_compare.py from CPython v3.12.0 --- Lib/test/test_compare.py | 426 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 409 insertions(+), 17 deletions(-) diff --git a/Lib/test/test_compare.py b/Lib/test/test_compare.py index 2b3faed796..8166b0eea3 100644 --- a/Lib/test/test_compare.py +++ b/Lib/test/test_compare.py @@ -1,21 +1,27 @@ +"""Test equality and order comparisons.""" import unittest from test.support import ALWAYS_EQ +from fractions import Fraction +from decimal import Decimal -class Empty: - def __repr__(self): - return '' -class Cmp: - def __init__(self,arg): - self.arg = arg +class ComparisonSimpleTest(unittest.TestCase): + """Test equality and order comparisons for some simple cases.""" - def __repr__(self): - return '' % self.arg + class Empty: + def __repr__(self): + return '' - def __eq__(self, other): - return self.arg == other + class Cmp: + def __init__(self, arg): + self.arg = arg + + def __repr__(self): + return '' % self.arg + + def __eq__(self, other): + return self.arg == other -class ComparisonTest(unittest.TestCase): set1 = [2, 2.0, 2, 2+0j, Cmp(2.0)] set2 = [[1], (3,), None, Empty()] candidates = set1 + set2 @@ -32,16 +38,15 @@ def test_id_comparisons(self): # Ensure default comparison compares id() of args L = [] for i in range(10): - L.insert(len(L)//2, Empty()) + L.insert(len(L)//2, self.Empty()) for a in L: for b in L: - self.assertEqual(a == b, id(a) == id(b), - 'a=%r, b=%r' % (a, b)) + self.assertEqual(a == b, a is b, 'a=%r, b=%r' % (a, b)) def test_ne_defaults_to_not_eq(self): - a = Cmp(1) - b = Cmp(1) - c = Cmp(2) + a = self.Cmp(1) + b = self.Cmp(1) + c = self.Cmp(2) self.assertIs(a == b, True) self.assertIs(a != b, False) self.assertIs(a != c, True) @@ -114,5 +119,392 @@ def test_issue_1393(self): self.assertEqual(ALWAYS_EQ, y) +class ComparisonFullTest(unittest.TestCase): + """Test equality and ordering comparisons for built-in types and + user-defined classes that implement relevant combinations of rich + comparison methods. + """ + + class CompBase: + """Base class for classes with rich comparison methods. + + The "x" attribute should be set to an underlying value to compare. + + Derived classes have a "meth" tuple attribute listing names of + comparison methods implemented. See assert_total_order(). + """ + + # Class without any rich comparison methods. + class CompNone(CompBase): + meth = () + + # Classes with all combinations of value-based equality comparison methods. + class CompEq(CompBase): + meth = ("eq",) + def __eq__(self, other): + return self.x == other.x + + class CompNe(CompBase): + meth = ("ne",) + def __ne__(self, other): + return self.x != other.x + + class CompEqNe(CompBase): + meth = ("eq", "ne") + def __eq__(self, other): + return self.x == other.x + def __ne__(self, other): + return self.x != other.x + + # Classes with all combinations of value-based less/greater-than order + # comparison methods. + class CompLt(CompBase): + meth = ("lt",) + def __lt__(self, other): + return self.x < other.x + + class CompGt(CompBase): + meth = ("gt",) + def __gt__(self, other): + return self.x > other.x + + class CompLtGt(CompBase): + meth = ("lt", "gt") + def __lt__(self, other): + return self.x < other.x + def __gt__(self, other): + return self.x > other.x + + # Classes with all combinations of value-based less/greater-or-equal-than + # order comparison methods + class CompLe(CompBase): + meth = ("le",) + def __le__(self, other): + return self.x <= other.x + + class CompGe(CompBase): + meth = ("ge",) + def __ge__(self, other): + return self.x >= other.x + + class CompLeGe(CompBase): + meth = ("le", "ge") + def __le__(self, other): + return self.x <= other.x + def __ge__(self, other): + return self.x >= other.x + + # It should be sufficient to combine the comparison methods only within + # each group. + all_comp_classes = ( + CompNone, + CompEq, CompNe, CompEqNe, # equal group + CompLt, CompGt, CompLtGt, # less/greater-than group + CompLe, CompGe, CompLeGe) # less/greater-or-equal group + + def create_sorted_instances(self, class_, values): + """Create objects of type `class_` and return them in a list. + + `values` is a list of values that determines the value of data + attribute `x` of each object. + + Objects in the returned list are sorted by their identity. They + assigned values in `values` list order. By assign decreasing + values to objects with increasing identities, testcases can assert + that order comparison is performed by value and not by identity. + """ + + instances = [class_() for __ in range(len(values))] + instances.sort(key=id) + # Assign the provided values to the instances. + for inst, value in zip(instances, values): + inst.x = value + return instances + + def assert_equality_only(self, a, b, equal): + """Assert equality result and that ordering is not implemented. + + a, b: Instances to be tested (of same or different type). + equal: Boolean indicating the expected equality comparison results. + """ + self.assertEqual(a == b, equal) + self.assertEqual(b == a, equal) + self.assertEqual(a != b, not equal) + self.assertEqual(b != a, not equal) + with self.assertRaisesRegex(TypeError, "not supported"): + a < b + with self.assertRaisesRegex(TypeError, "not supported"): + a <= b + with self.assertRaisesRegex(TypeError, "not supported"): + a > b + with self.assertRaisesRegex(TypeError, "not supported"): + a >= b + with self.assertRaisesRegex(TypeError, "not supported"): + b < a + with self.assertRaisesRegex(TypeError, "not supported"): + b <= a + with self.assertRaisesRegex(TypeError, "not supported"): + b > a + with self.assertRaisesRegex(TypeError, "not supported"): + b >= a + + def assert_total_order(self, a, b, comp, a_meth=None, b_meth=None): + """Test total ordering comparison of two instances. + + a, b: Instances to be tested (of same or different type). + + comp: -1, 0, or 1 indicates that the expected order comparison + result for operations that are supported by the classes is + a <, ==, or > b. + + a_meth, b_meth: Either None, indicating that all rich comparison + methods are available, aa for builtins, or the tuple (subset) + of "eq", "ne", "lt", "le", "gt", and "ge" that are available + for the corresponding instance (of a user-defined class). + """ + self.assert_eq_subtest(a, b, comp, a_meth, b_meth) + self.assert_ne_subtest(a, b, comp, a_meth, b_meth) + self.assert_lt_subtest(a, b, comp, a_meth, b_meth) + self.assert_le_subtest(a, b, comp, a_meth, b_meth) + self.assert_gt_subtest(a, b, comp, a_meth, b_meth) + self.assert_ge_subtest(a, b, comp, a_meth, b_meth) + + # The body of each subtest has form: + # + # if value-based comparison methods: + # expect what the testcase defined for a op b and b rop a; + # else: no value-based comparison + # expect default behavior of object for a op b and b rop a. + + def assert_eq_subtest(self, a, b, comp, a_meth, b_meth): + if a_meth is None or "eq" in a_meth or "eq" in b_meth: + self.assertEqual(a == b, comp == 0) + self.assertEqual(b == a, comp == 0) + else: + self.assertEqual(a == b, a is b) + self.assertEqual(b == a, a is b) + + def assert_ne_subtest(self, a, b, comp, a_meth, b_meth): + if a_meth is None or not {"ne", "eq"}.isdisjoint(a_meth + b_meth): + self.assertEqual(a != b, comp != 0) + self.assertEqual(b != a, comp != 0) + else: + self.assertEqual(a != b, a is not b) + self.assertEqual(b != a, a is not b) + + def assert_lt_subtest(self, a, b, comp, a_meth, b_meth): + if a_meth is None or "lt" in a_meth or "gt" in b_meth: + self.assertEqual(a < b, comp < 0) + self.assertEqual(b > a, comp < 0) + else: + with self.assertRaisesRegex(TypeError, "not supported"): + a < b + with self.assertRaisesRegex(TypeError, "not supported"): + b > a + + def assert_le_subtest(self, a, b, comp, a_meth, b_meth): + if a_meth is None or "le" in a_meth or "ge" in b_meth: + self.assertEqual(a <= b, comp <= 0) + self.assertEqual(b >= a, comp <= 0) + else: + with self.assertRaisesRegex(TypeError, "not supported"): + a <= b + with self.assertRaisesRegex(TypeError, "not supported"): + b >= a + + def assert_gt_subtest(self, a, b, comp, a_meth, b_meth): + if a_meth is None or "gt" in a_meth or "lt" in b_meth: + self.assertEqual(a > b, comp > 0) + self.assertEqual(b < a, comp > 0) + else: + with self.assertRaisesRegex(TypeError, "not supported"): + a > b + with self.assertRaisesRegex(TypeError, "not supported"): + b < a + + def assert_ge_subtest(self, a, b, comp, a_meth, b_meth): + if a_meth is None or "ge" in a_meth or "le" in b_meth: + self.assertEqual(a >= b, comp >= 0) + self.assertEqual(b <= a, comp >= 0) + else: + with self.assertRaisesRegex(TypeError, "not supported"): + a >= b + with self.assertRaisesRegex(TypeError, "not supported"): + b <= a + + def test_objects(self): + """Compare instances of type 'object'.""" + a = object() + b = object() + self.assert_equality_only(a, a, True) + self.assert_equality_only(a, b, False) + + def test_comp_classes_same(self): + """Compare same-class instances with comparison methods.""" + + for cls in self.all_comp_classes: + with self.subTest(cls): + instances = self.create_sorted_instances(cls, (1, 2, 1)) + + # Same object. + self.assert_total_order(instances[0], instances[0], 0, + cls.meth, cls.meth) + + # Different objects, same value. + self.assert_total_order(instances[0], instances[2], 0, + cls.meth, cls.meth) + + # Different objects, value ascending for ascending identities. + self.assert_total_order(instances[0], instances[1], -1, + cls.meth, cls.meth) + + # different objects, value descending for ascending identities. + # This is the interesting case to assert that order comparison + # is performed based on the value and not based on the identity. + self.assert_total_order(instances[1], instances[2], +1, + cls.meth, cls.meth) + + def test_comp_classes_different(self): + """Compare different-class instances with comparison methods.""" + + for cls_a in self.all_comp_classes: + for cls_b in self.all_comp_classes: + with self.subTest(a=cls_a, b=cls_b): + a1 = cls_a() + a1.x = 1 + b1 = cls_b() + b1.x = 1 + b2 = cls_b() + b2.x = 2 + + self.assert_total_order( + a1, b1, 0, cls_a.meth, cls_b.meth) + self.assert_total_order( + a1, b2, -1, cls_a.meth, cls_b.meth) + + def test_str_subclass(self): + """Compare instances of str and a subclass.""" + class StrSubclass(str): + pass + + s1 = str("a") + s2 = str("b") + c1 = StrSubclass("a") + c2 = StrSubclass("b") + c3 = StrSubclass("b") + + self.assert_total_order(s1, s1, 0) + self.assert_total_order(s1, s2, -1) + self.assert_total_order(c1, c1, 0) + self.assert_total_order(c1, c2, -1) + self.assert_total_order(c2, c3, 0) + + self.assert_total_order(s1, c2, -1) + self.assert_total_order(s2, c3, 0) + self.assert_total_order(c1, s2, -1) + self.assert_total_order(c2, s2, 0) + + def test_numbers(self): + """Compare number types.""" + + # Same types. + i1 = 1001 + i2 = 1002 + self.assert_total_order(i1, i1, 0) + self.assert_total_order(i1, i2, -1) + + f1 = 1001.0 + f2 = 1001.1 + self.assert_total_order(f1, f1, 0) + self.assert_total_order(f1, f2, -1) + + q1 = Fraction(2002, 2) + q2 = Fraction(2003, 2) + self.assert_total_order(q1, q1, 0) + self.assert_total_order(q1, q2, -1) + + d1 = Decimal('1001.0') + d2 = Decimal('1001.1') + self.assert_total_order(d1, d1, 0) + self.assert_total_order(d1, d2, -1) + + c1 = 1001+0j + c2 = 1001+1j + self.assert_equality_only(c1, c1, True) + self.assert_equality_only(c1, c2, False) + + + # Mixing types. + for n1, n2 in ((i1,f1), (i1,q1), (i1,d1), (f1,q1), (f1,d1), (q1,d1)): + self.assert_total_order(n1, n2, 0) + for n1 in (i1, f1, q1, d1): + self.assert_equality_only(n1, c1, True) + + def test_sequences(self): + """Compare list, tuple, and range.""" + l1 = [1, 2] + l2 = [2, 3] + self.assert_total_order(l1, l1, 0) + self.assert_total_order(l1, l2, -1) + + t1 = (1, 2) + t2 = (2, 3) + self.assert_total_order(t1, t1, 0) + self.assert_total_order(t1, t2, -1) + + r1 = range(1, 2) + r2 = range(2, 2) + self.assert_equality_only(r1, r1, True) + self.assert_equality_only(r1, r2, False) + + self.assert_equality_only(t1, l1, False) + self.assert_equality_only(l1, r1, False) + self.assert_equality_only(r1, t1, False) + + def test_bytes(self): + """Compare bytes and bytearray.""" + bs1 = b'a1' + bs2 = b'b2' + self.assert_total_order(bs1, bs1, 0) + self.assert_total_order(bs1, bs2, -1) + + ba1 = bytearray(b'a1') + ba2 = bytearray(b'b2') + self.assert_total_order(ba1, ba1, 0) + self.assert_total_order(ba1, ba2, -1) + + self.assert_total_order(bs1, ba1, 0) + self.assert_total_order(bs1, ba2, -1) + self.assert_total_order(ba1, bs1, 0) + self.assert_total_order(ba1, bs2, -1) + + def test_sets(self): + """Compare set and frozenset.""" + s1 = {1, 2} + s2 = {1, 2, 3} + self.assert_total_order(s1, s1, 0) + self.assert_total_order(s1, s2, -1) + + f1 = frozenset(s1) + f2 = frozenset(s2) + self.assert_total_order(f1, f1, 0) + self.assert_total_order(f1, f2, -1) + + self.assert_total_order(s1, f1, 0) + self.assert_total_order(s1, f2, -1) + self.assert_total_order(f1, s1, 0) + self.assert_total_order(f1, s2, -1) + + def test_mappings(self): + """ Compare dict. + """ + d1 = {1: "a", 2: "b"} + d2 = {2: "b", 3: "c"} + d3 = {3: "c", 2: "b"} + self.assert_equality_only(d1, d1, True) + self.assert_equality_only(d1, d2, False) + self.assert_equality_only(d2, d3, True) + + if __name__ == '__main__': unittest.main() From fa3eae677d84518139bfe2fb2f517f347c9640e3 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:06:24 +0800 Subject: [PATCH 174/893] Update test_complex.py from CPython v3.12.0 --- Lib/test/test_complex.py | 56 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_complex.py b/Lib/test/test_complex.py index fd6e6de5fc..23d0e3bc81 100644 --- a/Lib/test/test_complex.py +++ b/Lib/test/test_complex.py @@ -109,6 +109,8 @@ def test_truediv(self): complex(random(), random())) self.assertAlmostEqual(complex.__truediv__(2+0j, 1+1j), 1-1j) + self.assertRaises(TypeError, operator.truediv, 1j, None) + self.assertRaises(TypeError, operator.truediv, None, 1j) for denom_real, denom_imag in [(0, NAN), (NAN, 0), (NAN, NAN)]: z = complex(0, 0) / complex(denom_real, denom_imag) @@ -140,6 +142,7 @@ def test_floordiv_zero_division(self): def test_richcompare(self): self.assertIs(complex.__eq__(1+1j, 1<<10000), False) self.assertIs(complex.__lt__(1+1j, None), NotImplemented) + self.assertIs(complex.__eq__(1+1j, None), NotImplemented) self.assertIs(complex.__eq__(1+1j, 1+1j), True) self.assertIs(complex.__eq__(1+1j, 2+2j), False) self.assertIs(complex.__ne__(1+1j, 1+1j), False) @@ -162,6 +165,7 @@ def test_richcompare(self): self.assertIs(operator.eq(1+1j, 2+2j), False) self.assertIs(operator.ne(1+1j, 1+1j), False) self.assertIs(operator.ne(1+1j, 2+2j), True) + self.assertIs(operator.eq(1+1j, 2.0), False) # TODO: RUSTPYTHON @unittest.expectedFailure @@ -182,6 +186,27 @@ def check(n, deltas, is_equal, imag = 0.0): check(2 ** pow, range(1, 101), lambda delta: False, float(i)) check(2 ** 53, range(-100, 0), lambda delta: True) + def test_add(self): + self.assertEqual(1j + int(+1), complex(+1, 1)) + self.assertEqual(1j + int(-1), complex(-1, 1)) + self.assertRaises(OverflowError, operator.add, 1j, 10**1000) + self.assertRaises(TypeError, operator.add, 1j, None) + self.assertRaises(TypeError, operator.add, None, 1j) + + def test_sub(self): + self.assertEqual(1j - int(+1), complex(-1, 1)) + self.assertEqual(1j - int(-1), complex(1, 1)) + self.assertRaises(OverflowError, operator.sub, 1j, 10**1000) + self.assertRaises(TypeError, operator.sub, 1j, None) + self.assertRaises(TypeError, operator.sub, None, 1j) + + def test_mul(self): + self.assertEqual(1j * int(20), complex(0, 20)) + self.assertEqual(1j * int(-1), complex(0, -1)) + self.assertRaises(OverflowError, operator.mul, 1j, 10**1000) + self.assertRaises(TypeError, operator.mul, 1j, None) + self.assertRaises(TypeError, operator.mul, None, 1j) + def test_mod(self): # % is no longer supported on complex numbers with self.assertRaises(TypeError): @@ -214,11 +239,18 @@ def test_divmod_zero_division(self): def test_pow(self): self.assertAlmostEqual(pow(1+1j, 0+0j), 1.0) self.assertAlmostEqual(pow(0+0j, 2+0j), 0.0) + self.assertEqual(pow(0+0j, 2000+0j), 0.0) + self.assertEqual(pow(0, 0+0j), 1.0) + self.assertEqual(pow(-1, 0+0j), 1.0) self.assertRaises(ZeroDivisionError, pow, 0+0j, 1j) + self.assertRaises(ZeroDivisionError, pow, 0+0j, -1000) self.assertAlmostEqual(pow(1j, -1), 1/1j) self.assertAlmostEqual(pow(1j, 200), 1) self.assertRaises(ValueError, pow, 1+1j, 1+1j, 1+1j) self.assertRaises(OverflowError, pow, 1e200+1j, 1e200+1j) + self.assertRaises(TypeError, pow, 1j, None) + self.assertRaises(TypeError, pow, None, 1j) + self.assertAlmostEqual(pow(1j, 0.5), 0.7071067811865476+0.7071067811865475j) a = 3.33+4.43j self.assertEqual(a ** 0j, 1) @@ -303,6 +335,7 @@ def test_boolcontext(self): for i in range(100): self.assertTrue(complex(random() + 1e-6, random() + 1e-6)) self.assertTrue(not complex(0.0, 0.0)) + self.assertTrue(1j) def test_conjugate(self): self.assertClose(complex(5.3, 9.8).conjugate(), 5.3-9.8j) @@ -318,6 +351,8 @@ def __complex__(self): return self.value self.assertRaises(TypeError, complex, {}) self.assertRaises(TypeError, complex, NS(1.5)) self.assertRaises(TypeError, complex, NS(1)) + self.assertRaises(TypeError, complex, object()) + self.assertRaises(TypeError, complex, NS(4.25+0.5j), object()) self.assertAlmostEqual(complex("1+10j"), 1+10j) self.assertAlmostEqual(complex(10), 10+0j) @@ -363,6 +398,8 @@ def __complex__(self): return self.value self.assertAlmostEqual(complex('1e-500'), 0.0 + 0.0j) self.assertAlmostEqual(complex('-1e-500j'), 0.0 - 0.0j) self.assertAlmostEqual(complex('-1e-500+1e-500j'), -0.0 + 0.0j) + self.assertEqual(complex('1-1j'), 1.0 - 1j) + self.assertEqual(complex('1J'), 1j) class complex2(complex): pass self.assertAlmostEqual(complex(complex2(1+1j)), 1+1j) @@ -533,8 +570,12 @@ class complex2(complex): self.assertFloatsAreIdentical(z.real, x) self.assertFloatsAreIdentical(z.imag, y) - # TODO: RUSTPYTHON - @unittest.expectedFailure + def test_constructor_negative_nans_from_string(self): + self.assertEqual(copysign(1., complex("-nan").real), -1.) + self.assertEqual(copysign(1., complex("-nanj").imag), -1.) + self.assertEqual(copysign(1., complex("-nan-nanj").real), -1.) + self.assertEqual(copysign(1., complex("-nan-nanj").imag), -1.) + def test_underscores(self): # check underscores for lit in VALID_UNDERSCORE_LITERALS: @@ -553,6 +594,8 @@ def test_hash(self): x /= 3.0 # now check against floating point self.assertEqual(hash(x), hash(complex(x, 0.))) + self.assertNotEqual(hash(2000005 - 1j), -1) + def test_abs(self): nums = [complex(x/3., y/7.) for x in range(-9,9) for y in range(-9,9)] for num in nums: @@ -575,6 +618,7 @@ def test(v, expected, test_fn=self.assertEqual): test(complex(NAN, 1), "(nan+1j)") test(complex(1, NAN), "(1+nanj)") test(complex(NAN, NAN), "(nan+nanj)") + test(complex(-NAN, -NAN), "(nan+nanj)") test(complex(0, INF), "infj") test(complex(0, -INF), "-infj") @@ -601,6 +645,14 @@ def test(v, expected, test_fn=self.assertEqual): test(complex(-0., 0.), "(-0+0j)") test(complex(-0., -0.), "(-0-0j)") + def test_pos(self): + class ComplexSubclass(complex): + pass + + self.assertEqual(+(1+6j), 1+6j) + self.assertEqual(+ComplexSubclass(1, 6), 1+6j) + self.assertIs(type(+ComplexSubclass(1, 6)), complex) + def test_neg(self): self.assertEqual(-(1+6j), -1-6j) From 07f013dae2b1d0cc7d45dc5c1d49f9d6f594a13d Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:07:44 +0800 Subject: [PATCH 175/893] Edit test_complex.py ExpectedFailures added --- Lib/test/test_complex.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Lib/test/test_complex.py b/Lib/test/test_complex.py index 23d0e3bc81..72999c2f7d 100644 --- a/Lib/test/test_complex.py +++ b/Lib/test/test_complex.py @@ -236,6 +236,8 @@ def test_divmod_zero_division(self): for a, b in ZERO_DIVISION: self.assertRaises(TypeError, divmod, a, b) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_pow(self): self.assertAlmostEqual(pow(1+1j, 0+0j), 1.0) self.assertAlmostEqual(pow(0+0j, 2+0j), 0.0) @@ -570,12 +572,16 @@ class complex2(complex): self.assertFloatsAreIdentical(z.real, x) self.assertFloatsAreIdentical(z.imag, y) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_constructor_negative_nans_from_string(self): self.assertEqual(copysign(1., complex("-nan").real), -1.) self.assertEqual(copysign(1., complex("-nanj").imag), -1.) self.assertEqual(copysign(1., complex("-nan-nanj").real), -1.) self.assertEqual(copysign(1., complex("-nan-nanj").imag), -1.) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_underscores(self): # check underscores for lit in VALID_UNDERSCORE_LITERALS: From b93199f007f7a2dbaac2fa58b53da0b62bcd6420 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:11:24 +0800 Subject: [PATCH 176/893] Update test_context.py from CPython v3.12.0 --- Lib/test/test_context.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_context.py b/Lib/test/test_context.py index 5d62fa0405..01b3399327 100644 --- a/Lib/test/test_context.py +++ b/Lib/test/test_context.py @@ -6,6 +6,7 @@ import time import unittest import weakref +from test import support from test.support import threading_helper try: @@ -608,6 +609,7 @@ def test_hamt_collision_3(self): self.assertEqual({k.name for k in h.keys()}, {'C', 'D', 'E'}) + @support.requires_resource('cpu') def test_hamt_stress(self): COLLECTION_SIZE = 7000 TEST_ITERS_EVERY = 647 From 5c6b4cbd3c13530f17dd42a7791dec45aab4f8a2 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:27:09 +0800 Subject: [PATCH 177/893] Update test_decorators.py from CPython v3.12.0 --- Lib/test/test_decorators.py | 1 - 1 file changed, 1 deletion(-) diff --git a/Lib/test/test_decorators.py b/Lib/test/test_decorators.py index 39b8dc0fac..739c9b3909 100644 --- a/Lib/test/test_decorators.py +++ b/Lib/test/test_decorators.py @@ -1,4 +1,3 @@ -from test import support import unittest from types import MethodType From 926c4cef27a51f2becbb31bb5dba7ad18ab35c38 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:27:46 +0800 Subject: [PATCH 178/893] Update test_defaultdict.py from CPython v3.12.0 --- Lib/test/test_defaultdict.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py index 68fc449780..bdbe9b81e8 100644 --- a/Lib/test/test_defaultdict.py +++ b/Lib/test/test_defaultdict.py @@ -1,9 +1,7 @@ """Unit tests for collections.defaultdict.""" -import os import copy import pickle -import tempfile import unittest from collections import defaultdict From b05d5920ab7b9f847272de5a76fca626efa1c418 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:40:19 +0800 Subject: [PATCH 179/893] Update test_descr.py from CPython v3.12.0 --- Lib/test/test_descr.py | 58 +++++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index 9e74541aac..bff14e8601 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -21,6 +21,11 @@ except ImportError: _testcapi = None +try: + import xxsubtype +except ImportError: + xxsubtype = None + class OperatorsTest(unittest.TestCase): @@ -299,6 +304,7 @@ def test_explicit_reverse_methods(self): self.assertEqual(float.__rsub__(3.0, 1), -2.0) @support.impl_detail("the module 'xxsubtype' is internal") + @unittest.skipIf(xxsubtype is None, "requires xxsubtype module") def test_spam_lists(self): # Testing spamlist operations... import copy, xxsubtype as spam @@ -343,6 +349,7 @@ def foo(self): return 1 self.assertEqual(a.getstate(), 42) @support.impl_detail("the module 'xxsubtype' is internal") + @unittest.skipIf(xxsubtype is None, "requires xxsubtype module") def test_spam_dicts(self): # Testing spamdict operations... import copy, xxsubtype as spam @@ -426,7 +433,7 @@ def __init__(self_local, *a, **kw): def __getitem__(self, key): return self.get(key, 0) def __setitem__(self_local, key, value): - self.assertIsInstance(key, type(0)) + self.assertIsInstance(key, int) dict.__setitem__(self_local, key, value) def setstate(self, state): self.state = state @@ -842,7 +849,7 @@ def __delattr__(self, name): ("getattr", "foo"), ("delattr", "foo")]) - # http://python.org/sf/1174712 + # https://bugs.python.org/issue1174712 try: class Module(types.ModuleType, str): pass @@ -875,7 +882,7 @@ def setstate(self, state): self.assertEqual(a.getstate(), 10) class D(dict, C): def __init__(self): - type({}).__init__(self) + dict.__init__(self) C.__init__(self) d = D() self.assertEqual(list(d.keys()), []) @@ -1627,6 +1634,7 @@ def test_refleaks_in_classmethod___init__(self): self.assertAlmostEqual(gettotalrefcount() - refs_before, 0, delta=10) @support.impl_detail("the module 'xxsubtype' is internal") + @unittest.skipIf(xxsubtype is None, "requires xxsubtype module") def test_classmethods_in_c(self): # Testing C-based class methods... import xxsubtype as spam @@ -1712,6 +1720,7 @@ def test_refleaks_in_staticmethod___init__(self): self.assertAlmostEqual(gettotalrefcount() - refs_before, 0, delta=10) @support.impl_detail("the module 'xxsubtype' is internal") + @unittest.skipIf(xxsubtype is None, "requires xxsubtype module") def test_staticmethods_in_c(self): # Testing C-based static methods... import xxsubtype as spam @@ -1848,9 +1857,7 @@ def __init__(self, foo): object.__init__(A(3)) self.assertRaises(TypeError, object.__init__, A(3), 5) - # TODO: RUSTPYTHON, CPython 3.5 and above expect this test case to fail, but in RustPython this currently passes. - # See https://github.com/python/cpython/issues/49572 for more details. - # @unittest.expectedFailure + @unittest.expectedFailure def test_restored_object_new(self): class A(object): def __new__(cls, *args, **kwargs): @@ -2006,7 +2013,7 @@ def __getattr__(self, attr): ns = {} exec(code, ns) number_attrs = ns["number_attrs"] - # Warm up the the function for quickening (PEP 659) + # Warm up the function for quickening (PEP 659) for _ in range(30): self.assertEqual(number_attrs(Numbers()), list(range(280))) @@ -3298,12 +3305,8 @@ def __get__(self, object, otype): if otype: otype = otype.__name__ return 'object=%s; type=%s' % (object, otype) - class OldClass: - __doc__ = DocDescr() - class NewClass(object): + class NewClass: __doc__ = DocDescr() - self.assertEqual(OldClass.__doc__, 'object=None; type=OldClass') - self.assertEqual(OldClass().__doc__, 'object=OldClass instance; type=OldClass') self.assertEqual(NewClass.__doc__, 'object=None; type=NewClass') self.assertEqual(NewClass().__doc__, 'object=NewClass instance; type=NewClass') @@ -3345,7 +3348,7 @@ class Int(int): __slots__ = [] cant(True, int) cant(2, bool) o = object() - cant(o, type(1)) + cant(o, int) cant(o, type(None)) del o class G(object): @@ -3622,7 +3625,6 @@ class MyInt(int): def test_str_of_str_subclass(self): # Testing __str__ defined in subclass of str ... import binascii - import io class octetstring(str): def __str__(self): @@ -4528,7 +4530,7 @@ class Oops(object): o.whatever = Provoker(o) del o - @unittest.skip("TODO: RUSTPYTHON, rustpython segmentation fault") + @support.requires_resource('cpu') def test_wrapper_segfault(self): # SF 927248: deeply nested wrappers could cause stack overflow f = lambda:None @@ -5095,6 +5097,32 @@ class Child(Parent): gc.collect() self.assertEqual(Parent.__subclasses__(), []) + def test_attr_raise_through_property(self): + # test case for gh-103272 + class A: + def __getattr__(self, name): + raise ValueError("FOO") + + @property + def foo(self): + return self.__getattr__("asdf") + + with self.assertRaisesRegex(ValueError, "FOO"): + A().foo + + # test case for gh-103551 + class B: + @property + def __getattr__(self, name): + raise ValueError("FOO") + + @property + def foo(self): + raise NotImplementedError("BAR") + + with self.assertRaisesRegex(NotImplementedError, "BAR"): + B().foo + class DictProxyTests(unittest.TestCase): def setUp(self): From b11d554c11370ad6e836a916e08c9a1942ca59b4 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:45:05 +0800 Subject: [PATCH 180/893] Edit test_descr.py Skip test at line 1861 ExpectedFailure at line 5104 --- Lib/test/test_descr.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index bff14e8601..e634458e00 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -1858,6 +1858,7 @@ def __init__(self, foo): self.assertRaises(TypeError, object.__init__, A(3), 5) @unittest.expectedFailure + @unittest.skip("TODO: RUSTPYTHON") def test_restored_object_new(self): class A(object): def __new__(cls, *args, **kwargs): @@ -4529,7 +4530,8 @@ class Oops(object): o = Oops() o.whatever = Provoker(o) del o - + + @unittest.skip("TODO: RUSTPYTHON, rustpython segmentation fault") @support.requires_resource('cpu') def test_wrapper_segfault(self): # SF 927248: deeply nested wrappers could cause stack overflow @@ -5097,6 +5099,8 @@ class Child(Parent): gc.collect() self.assertEqual(Parent.__subclasses__(), []) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_attr_raise_through_property(self): # test case for gh-103272 class A: From c4e2d6e3790f7cbd866e331582d56b6ad40157ec Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:46:29 +0800 Subject: [PATCH 181/893] Add test_descrtut.py from CPython v3.12.0 --- Lib/test/test_descrtut.py | 482 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 482 insertions(+) create mode 100644 Lib/test/test_descrtut.py diff --git a/Lib/test/test_descrtut.py b/Lib/test/test_descrtut.py new file mode 100644 index 0000000000..7796031ed0 --- /dev/null +++ b/Lib/test/test_descrtut.py @@ -0,0 +1,482 @@ +# This contains most of the executable examples from Guido's descr +# tutorial, once at +# +# https://www.python.org/download/releases/2.2.3/descrintro/ +# +# A few examples left implicit in the writeup were fleshed out, a few were +# skipped due to lack of interest (e.g., faking super() by hand isn't +# of much interest anymore), and a few were fiddled to make the output +# deterministic. + +from test.support import sortdict +import doctest +import unittest + + +class defaultdict(dict): + def __init__(self, default=None): + dict.__init__(self) + self.default = default + + def __getitem__(self, key): + try: + return dict.__getitem__(self, key) + except KeyError: + return self.default + + def get(self, key, *args): + if not args: + args = (self.default,) + return dict.get(self, key, *args) + + def merge(self, other): + for key in other: + if key not in self: + self[key] = other[key] + +test_1 = """ + +Here's the new type at work: + + >>> print(defaultdict) # show our type + + >>> print(type(defaultdict)) # its metatype + + >>> a = defaultdict(default=0.0) # create an instance + >>> print(a) # show the instance + {} + >>> print(type(a)) # show its type + + >>> print(a.__class__) # show its class + + >>> print(type(a) is a.__class__) # its type is its class + True + >>> a[1] = 3.25 # modify the instance + >>> print(a) # show the new value + {1: 3.25} + >>> print(a[1]) # show the new item + 3.25 + >>> print(a[0]) # a non-existent item + 0.0 + >>> a.merge({1:100, 2:200}) # use a dict method + >>> print(sortdict(a)) # show the result + {1: 3.25, 2: 200} + >>> + +We can also use the new type in contexts where classic only allows "real" +dictionaries, such as the locals/globals dictionaries for the exec +statement or the built-in function eval(): + + >>> print(sorted(a.keys())) + [1, 2] + >>> a['print'] = print # need the print function here + >>> exec("x = 3; print(x)", a) + 3 + >>> print(sorted(a.keys(), key=lambda x: (str(type(x)), x))) + [1, 2, '__builtins__', 'print', 'x'] + >>> print(a['x']) + 3 + >>> + +Now I'll show that defaultdict instances have dynamic instance variables, +just like classic classes: + + >>> a.default = -1 + >>> print(a["noway"]) + -1 + >>> a.default = -1000 + >>> print(a["noway"]) + -1000 + >>> 'default' in dir(a) + True + >>> a.x1 = 100 + >>> a.x2 = 200 + >>> print(a.x1) + 100 + >>> d = dir(a) + >>> 'default' in d and 'x1' in d and 'x2' in d + True + >>> print(sortdict(a.__dict__)) + {'default': -1000, 'x1': 100, 'x2': 200} + >>> +""" + +class defaultdict2(dict): + __slots__ = ['default'] + + def __init__(self, default=None): + dict.__init__(self) + self.default = default + + def __getitem__(self, key): + try: + return dict.__getitem__(self, key) + except KeyError: + return self.default + + def get(self, key, *args): + if not args: + args = (self.default,) + return dict.get(self, key, *args) + + def merge(self, other): + for key in other: + if key not in self: + self[key] = other[key] + +test_2 = """ + +The __slots__ declaration takes a list of instance variables, and reserves +space for exactly these in the instance. When __slots__ is used, other +instance variables cannot be assigned to: + + >>> a = defaultdict2(default=0.0) + >>> a[1] + 0.0 + >>> a.default = -1 + >>> a[1] + -1 + >>> a.x1 = 1 + Traceback (most recent call last): + File "", line 1, in ? + AttributeError: 'defaultdict2' object has no attribute 'x1' + >>> + +""" + +test_3 = """ + +Introspecting instances of built-in types + +For instance of built-in types, x.__class__ is now the same as type(x): + + >>> type([]) + + >>> [].__class__ + + >>> list + + >>> isinstance([], list) + True + >>> isinstance([], dict) + False + >>> isinstance([], object) + True + >>> + +You can get the information from the list type: + + >>> import pprint + >>> pprint.pprint(dir(list)) # like list.__dict__.keys(), but sorted + ['__add__', + '__class__', + '__class_getitem__', + '__contains__', + '__delattr__', + '__delitem__', + '__dir__', + '__doc__', + '__eq__', + '__format__', + '__ge__', + '__getattribute__', + '__getitem__', + '__getstate__', + '__gt__', + '__hash__', + '__iadd__', + '__imul__', + '__init__', + '__init_subclass__', + '__iter__', + '__le__', + '__len__', + '__lt__', + '__mul__', + '__ne__', + '__new__', + '__reduce__', + '__reduce_ex__', + '__repr__', + '__reversed__', + '__rmul__', + '__setattr__', + '__setitem__', + '__sizeof__', + '__str__', + '__subclasshook__', + 'append', + 'clear', + 'copy', + 'count', + 'extend', + 'index', + 'insert', + 'pop', + 'remove', + 'reverse', + 'sort'] + +The new introspection API gives more information than the old one: in +addition to the regular methods, it also shows the methods that are +normally invoked through special notations, e.g. __iadd__ (+=), __len__ +(len), __ne__ (!=). You can invoke any method from this list directly: + + >>> a = ['tic', 'tac'] + >>> list.__len__(a) # same as len(a) + 2 + >>> a.__len__() # ditto + 2 + >>> list.append(a, 'toe') # same as a.append('toe') + >>> a + ['tic', 'tac', 'toe'] + >>> + +This is just like it is for user-defined classes. +""" + +test_4 = """ + +Static methods and class methods + +The new introspection API makes it possible to add static methods and class +methods. Static methods are easy to describe: they behave pretty much like +static methods in C++ or Java. Here's an example: + + >>> class C: + ... + ... @staticmethod + ... def foo(x, y): + ... print("staticmethod", x, y) + + >>> C.foo(1, 2) + staticmethod 1 2 + >>> c = C() + >>> c.foo(1, 2) + staticmethod 1 2 + +Class methods use a similar pattern to declare methods that receive an +implicit first argument that is the *class* for which they are invoked. + + >>> class C: + ... @classmethod + ... def foo(cls, y): + ... print("classmethod", cls, y) + + >>> C.foo(1) + classmethod 1 + >>> c = C() + >>> c.foo(1) + classmethod 1 + + >>> class D(C): + ... pass + + >>> D.foo(1) + classmethod 1 + >>> d = D() + >>> d.foo(1) + classmethod 1 + +This prints "classmethod __main__.D 1" both times; in other words, the +class passed as the first argument of foo() is the class involved in the +call, not the class involved in the definition of foo(). + +But notice this: + + >>> class E(C): + ... @classmethod + ... def foo(cls, y): # override C.foo + ... print("E.foo() called") + ... C.foo(y) + + >>> E.foo(1) + E.foo() called + classmethod 1 + >>> e = E() + >>> e.foo(1) + E.foo() called + classmethod 1 + +In this example, the call to C.foo() from E.foo() will see class C as its +first argument, not class E. This is to be expected, since the call +specifies the class C. But it stresses the difference between these class +methods and methods defined in metaclasses (where an upcall to a metamethod +would pass the target class as an explicit first argument). +""" + +test_5 = """ + +Attributes defined by get/set methods + + + >>> class property(object): + ... + ... def __init__(self, get, set=None): + ... self.__get = get + ... self.__set = set + ... + ... def __get__(self, inst, type=None): + ... return self.__get(inst) + ... + ... def __set__(self, inst, value): + ... if self.__set is None: + ... raise AttributeError("this attribute is read-only") + ... return self.__set(inst, value) + +Now let's define a class with an attribute x defined by a pair of methods, +getx() and setx(): + + >>> class C(object): + ... + ... def __init__(self): + ... self.__x = 0 + ... + ... def getx(self): + ... return self.__x + ... + ... def setx(self, x): + ... if x < 0: x = 0 + ... self.__x = x + ... + ... x = property(getx, setx) + +Here's a small demonstration: + + >>> a = C() + >>> a.x = 10 + >>> print(a.x) + 10 + >>> a.x = -10 + >>> print(a.x) + 0 + >>> + +Hmm -- property is builtin now, so let's try it that way too. + + >>> del property # unmask the builtin + >>> property + + + >>> class C(object): + ... def __init__(self): + ... self.__x = 0 + ... def getx(self): + ... return self.__x + ... def setx(self, x): + ... if x < 0: x = 0 + ... self.__x = x + ... x = property(getx, setx) + + + >>> a = C() + >>> a.x = 10 + >>> print(a.x) + 10 + >>> a.x = -10 + >>> print(a.x) + 0 + >>> +""" + +test_6 = """ + +Method resolution order + +This example is implicit in the writeup. + +>>> class A: # implicit new-style class +... def save(self): +... print("called A.save()") +>>> class B(A): +... pass +>>> class C(A): +... def save(self): +... print("called C.save()") +>>> class D(B, C): +... pass + +>>> D().save() +called C.save() + +>>> class A(object): # explicit new-style class +... def save(self): +... print("called A.save()") +>>> class B(A): +... pass +>>> class C(A): +... def save(self): +... print("called C.save()") +>>> class D(B, C): +... pass + +>>> D().save() +called C.save() +""" + +class A(object): + def m(self): + return "A" + +class B(A): + def m(self): + return "B" + super(B, self).m() + +class C(A): + def m(self): + return "C" + super(C, self).m() + +class D(C, B): + def m(self): + return "D" + super(D, self).m() + + +test_7 = """ + +Cooperative methods and "super" + +>>> print(D().m()) # "DCBA" +DCBA +""" + +test_8 = """ + +Backwards incompatibilities + +>>> class A: +... def foo(self): +... print("called A.foo()") + +>>> class B(A): +... pass + +>>> class C(A): +... def foo(self): +... B.foo(self) + +>>> C().foo() +called A.foo() + +>>> class C(A): +... def foo(self): +... A.foo(self) +>>> C().foo() +called A.foo() +""" + +__test__ = {"tut1": test_1, + "tut2": test_2, + "tut3": test_3, + "tut4": test_4, + "tut5": test_5, + "tut6": test_6, + "tut7": test_7, + "tut8": test_8} + +def load_tests(loader, tests, pattern): + tests.addTest(doctest.DocTestSuite()) + return tests + + +if __name__ == "__main__": + unittest.main() From f1991c25bc5e802525492575042a8a1fcb1ec8da Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:48:58 +0800 Subject: [PATCH 182/893] Edit test_descrtut.py ExpectedFailures at line 469-472 I'm not able to put a @unittest.expectedFailure for each tests so I commented them out --- Lib/test/test_descrtut.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_descrtut.py b/Lib/test/test_descrtut.py index 7796031ed0..4c128f770e 100644 --- a/Lib/test/test_descrtut.py +++ b/Lib/test/test_descrtut.py @@ -34,6 +34,7 @@ def merge(self, other): if key not in self: self[key] = other[key] + test_1 = """ Here's the new type at work: @@ -464,10 +465,11 @@ def m(self): called A.foo() """ -__test__ = {"tut1": test_1, - "tut2": test_2, - "tut3": test_3, - "tut4": test_4, +# TODO: RUSTPYTHON +__test__ = {# "tut1": test_1, + # "tut2": test_2, + # "tut3": test_3, + # "tut4": test_4, "tut5": test_5, "tut6": test_6, "tut7": test_7, From d5e4af7a4bf02accbb2d03c17dc03c47ff6a0265 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:51:40 +0800 Subject: [PATCH 183/893] Update test_dict.py from CPython v3.12.0 --- Lib/test/test_dict.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index 3ec0ac541b..4aa6f1089a 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -8,7 +8,7 @@ import unittest import weakref from test import support -from test.support import import_helper +from test.support import import_helper, C_RECURSION_LIMIT class DictTest(unittest.TestCase): @@ -599,7 +599,7 @@ def __repr__(self): @unittest.skipIf(sys.platform == 'win32', 'TODO: RUSTPYTHON Windows') def test_repr_deep(self): d = {} - for i in range(sys.getrecursionlimit() + 100): + for i in range(C_RECURSION_LIMIT + 1): d = {1: d} self.assertRaises(RecursionError, repr, d) @@ -1099,6 +1099,21 @@ def __init__(self, order): d.update(o.__dict__) self.assertEqual(list(d), ["c", "b", "a"]) + @support.cpython_only + def test_splittable_to_generic_combinedtable(self): + """split table must be correctly resized and converted to generic combined table""" + class C: + pass + + a = C() + a.x = 1 + d = a.__dict__ + before_resize = sys.getsizeof(d) + d[2] = 2 # split table is resized to a generic combined table + + self.assertGreater(sys.getsizeof(d), before_resize) + self.assertEqual(list(d), ['x', 2]) + def test_iterator_pickling(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): data = {1:"a", 2:"b", 3:"c"} From 43ede611e350875cc08ad90c33b846287475f711 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:53:15 +0800 Subject: [PATCH 184/893] Update test_dictviews.py from CPython v3.12.0 --- Lib/test/test_dictviews.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_dictviews.py b/Lib/test/test_dictviews.py index e8db2a61a0..b16353f2b6 100644 --- a/Lib/test/test_dictviews.py +++ b/Lib/test/test_dictviews.py @@ -3,6 +3,7 @@ import pickle import sys import unittest +from test.support import C_RECURSION_LIMIT class DictSetTest(unittest.TestCase): @@ -170,6 +171,10 @@ def test_items_set_operations(self): {('a', 1), ('b', 2)}) self.assertEqual(d1.items() & set(d2.items()), {('b', 2)}) self.assertEqual(d1.items() & set(d3.items()), set()) + self.assertEqual(d1.items() & (("a", 1), ("b", 2)), + {('a', 1), ('b', 2)}) + self.assertEqual(d1.items() & (("a", 2), ("b", 2)), {('b', 2)}) + self.assertEqual(d1.items() & (("d", 4), ("e", 5)), set()) self.assertEqual(d1.items() | d1.items(), {('a', 1), ('b', 2)}) @@ -183,12 +188,23 @@ def test_items_set_operations(self): {('a', 1), ('a', 2), ('b', 2)}) self.assertEqual(d1.items() | set(d3.items()), {('a', 1), ('b', 2), ('d', 4), ('e', 5)}) + self.assertEqual(d1.items() | (('a', 1), ('b', 2)), + {('a', 1), ('b', 2)}) + self.assertEqual(d1.items() | (('a', 2), ('b', 2)), + {('a', 1), ('a', 2), ('b', 2)}) + self.assertEqual(d1.items() | (('d', 4), ('e', 5)), + {('a', 1), ('b', 2), ('d', 4), ('e', 5)}) self.assertEqual(d1.items() ^ d1.items(), set()) self.assertEqual(d1.items() ^ d2.items(), {('a', 1), ('a', 2)}) self.assertEqual(d1.items() ^ d3.items(), {('a', 1), ('b', 2), ('d', 4), ('e', 5)}) + self.assertEqual(d1.items() ^ (('a', 1), ('b', 2)), set()) + self.assertEqual(d1.items() ^ (("a", 2), ("b", 2)), + {('a', 1), ('a', 2)}) + self.assertEqual(d1.items() ^ (("d", 4), ("e", 5)), + {('a', 1), ('b', 2), ('d', 4), ('e', 5)}) self.assertEqual(d1.items() - d1.items(), set()) self.assertEqual(d1.items() - d2.items(), {('a', 1)}) @@ -196,6 +212,9 @@ def test_items_set_operations(self): self.assertEqual(d1.items() - set(d1.items()), set()) self.assertEqual(d1.items() - set(d2.items()), {('a', 1)}) self.assertEqual(d1.items() - set(d3.items()), {('a', 1), ('b', 2)}) + self.assertEqual(d1.items() - (('a', 1), ('b', 2)), set()) + self.assertEqual(d1.items() - (("a", 2), ("b", 2)), {('a', 1)}) + self.assertEqual(d1.items() - (("d", 4), ("e", 5)), {('a', 1), ('b', 2)}) self.assertFalse(d1.items().isdisjoint(d1.items())) self.assertFalse(d1.items().isdisjoint(d2.items())) @@ -261,7 +280,7 @@ def test_recursive_repr(self): def test_deeply_nested_repr(self): d = {} - for i in range(sys.getrecursionlimit() + 100): + for i in range(C_RECURSION_LIMIT//2 + 100): d = {42: d.values()} self.assertRaises(RecursionError, repr, d) From 186eac5095ea8f110cb3ea401cbc344d6e3cfa8a Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:55:02 +0800 Subject: [PATCH 185/893] Update test_dtrace.py from CPython v3.12.0 --- Lib/test/test_dtrace.py | 82 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/Lib/test/test_dtrace.py b/Lib/test/test_dtrace.py index 4b971deacc..e1adf8e974 100644 --- a/Lib/test/test_dtrace.py +++ b/Lib/test/test_dtrace.py @@ -3,6 +3,7 @@ import re import subprocess import sys +import sysconfig import types import unittest @@ -173,6 +174,87 @@ class SystemTapOptimizedTests(TraceTests, unittest.TestCase): backend = SystemTapBackend() optimize_python = 2 +class CheckDtraceProbes(unittest.TestCase): + @classmethod + def setUpClass(cls): + if sysconfig.get_config_var('WITH_DTRACE'): + readelf_major_version, readelf_minor_version = cls.get_readelf_version() + if support.verbose: + print(f"readelf version: {readelf_major_version}.{readelf_minor_version}") + else: + raise unittest.SkipTest("CPython must be configured with the --with-dtrace option.") + + + @staticmethod + def get_readelf_version(): + try: + cmd = ["readelf", "--version"] + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + ) + with proc: + version, stderr = proc.communicate() + + if proc.returncode: + raise Exception( + f"Command {' '.join(cmd)!r} failed " + f"with exit code {proc.returncode}: " + f"stdout={version!r} stderr={stderr!r}" + ) + except OSError: + raise unittest.SkipTest("Couldn't find readelf on the path") + + # Regex to parse: + # 'GNU readelf (GNU Binutils) 2.40.0\n' -> 2.40 + match = re.search(r"^(?:GNU) readelf.*?\b(\d+)\.(\d+)", version) + if match is None: + raise unittest.SkipTest(f"Unable to parse readelf version: {version}") + + return int(match.group(1)), int(match.group(2)) + + def get_readelf_output(self): + command = ["readelf", "-n", sys.executable] + stdout, _ = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + ).communicate() + return stdout + + def test_check_probes(self): + readelf_output = self.get_readelf_output() + + available_probe_names = [ + "Name: import__find__load__done", + "Name: import__find__load__start", + "Name: audit", + "Name: gc__start", + "Name: gc__done", + ] + + for probe_name in available_probe_names: + with self.subTest(probe_name=probe_name): + self.assertIn(probe_name, readelf_output) + + @unittest.expectedFailure + def test_missing_probes(self): + readelf_output = self.get_readelf_output() + + # Missing probes will be added in the future. + missing_probe_names = [ + "Name: function__entry", + "Name: function__return", + "Name: line", + ] + + for probe_name in missing_probe_names: + with self.subTest(probe_name=probe_name): + self.assertIn(probe_name, readelf_output) + if __name__ == '__main__': unittest.main() From f7c7398ba8077605ef0b3e68d6a572bcd7e046cd Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 20:55:54 +0800 Subject: [PATCH 186/893] Update test_dynamic.py from CPython v3.12.0 --- Lib/test/test_dynamic.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/Lib/test/test_dynamic.py b/Lib/test/test_dynamic.py index 2155b40289..4217e8e01c 100644 --- a/Lib/test/test_dynamic.py +++ b/Lib/test/test_dynamic.py @@ -142,12 +142,14 @@ class MyGlobals(dict): def __missing__(self, key): return int(key.removeprefix("_number_")) - code = "lambda: " + "+".join(f"_number_{i}" for i in range(1000)) - sum_1000 = eval(code, MyGlobals()) - expected = sum(range(1000)) - # Warm up the the function for quickening (PEP 659) + # Need more than 256 variables to use EXTENDED_ARGS + variables = 400 + code = "lambda: " + "+".join(f"_number_{i}" for i in range(variables)) + sum_func = eval(code, MyGlobals()) + expected = sum(range(variables)) + # Warm up the function for quickening (PEP 659) for _ in range(30): - self.assertEqual(sum_1000(), expected) + self.assertEqual(sum_func(), expected) class TestTracing(unittest.TestCase): From 1d76b762c7f82cffcfee78fa1dfeb1dddc19466b Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 21:01:17 +0800 Subject: [PATCH 187/893] Update test_eintr.py from CPython v3.12.0 --- Lib/test/test_eintr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Lib/test/test_eintr.py b/Lib/test/test_eintr.py index 528147802b..49b15f1a2d 100644 --- a/Lib/test/test_eintr.py +++ b/Lib/test/test_eintr.py @@ -9,6 +9,7 @@ class EINTRTests(unittest.TestCase): @unittest.skipUnless(hasattr(signal, "setitimer"), "requires setitimer()") + @support.requires_resource('walltime') def test_all(self): # Run the tester in a sub-process, to make sure there is only one # thread (for reliable signal delivery). From ad3f0aecafeae690e6de2bad727db786cd82d3af Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 21:05:18 +0800 Subject: [PATCH 188/893] Update test_eof.py from CPython v3.12.0 --- Lib/test/test_eof.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_eof.py b/Lib/test/test_eof.py index 21f83db75f..cf12dd3d5d 100644 --- a/Lib/test/test_eof.py +++ b/Lib/test/test_eof.py @@ -4,6 +4,7 @@ from test import support from test.support import os_helper from test.support import script_helper +from test.support import warnings_helper import unittest # TODO: RUSTPYTHON @@ -38,10 +39,11 @@ def test_EOFS_with_file(self): rc, out, err = script_helper.assert_python_failure(file_name) self.assertIn(b'unterminated triple-quoted string literal (detected at line 3)', err) + @warnings_helper.ignore_warnings(category=SyntaxWarning) def test_eof_with_line_continuation(self): expect = "unexpected EOF while parsing (, line 1)" try: - compile('"\\xhh" \\', '', 'exec', dont_inherit=True) + compile('"\\Xhh" \\', '', 'exec') except SyntaxError as msg: self.assertEqual(str(msg), expect) else: From 3945e9a4ed2ce2149b9a73ccd069377f18eb17db Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 21:06:09 +0800 Subject: [PATCH 189/893] Update test_epoll.py from CPython v3.12.0 --- Lib/test/test_epoll.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_epoll.py b/Lib/test/test_epoll.py index b623852f9e..c94946a6ae 100644 --- a/Lib/test/test_epoll.py +++ b/Lib/test/test_epoll.py @@ -27,6 +27,7 @@ import socket import time import unittest +from test import support if not hasattr(select, "epoll"): raise unittest.SkipTest("test works only on Linux 2.6") @@ -186,10 +187,16 @@ def test_control_and_wait(self): client.sendall(b"Hello!") server.sendall(b"world!!!") - now = time.monotonic() - events = ep.poll(1.0, 4) - then = time.monotonic() - self.assertFalse(then - now > 0.01) + # we might receive events one at a time, necessitating multiple calls to + # poll + events = [] + for _ in support.busy_retry(support.SHORT_TIMEOUT): + now = time.monotonic() + events += ep.poll(1.0, 4) + then = time.monotonic() + self.assertFalse(then - now > 0.01) + if len(events) >= 2: + break expected = [(client.fileno(), select.EPOLLIN | select.EPOLLOUT), (server.fileno(), select.EPOLLIN | select.EPOLLOUT)] From 784b5f1b1c6412844aac9ed0675bbcee6de19cb4 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Thu, 30 Nov 2023 21:58:26 +0800 Subject: [PATCH 190/893] Edit test_dictviews ExpectedFailure added at line 282 --- Lib/test/test_dictviews.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_dictviews.py b/Lib/test/test_dictviews.py index b16353f2b6..172b98aa68 100644 --- a/Lib/test/test_dictviews.py +++ b/Lib/test/test_dictviews.py @@ -278,6 +278,8 @@ def test_recursive_repr(self): # Again. self.assertIsInstance(r, str) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_deeply_nested_repr(self): d = {} for i in range(C_RECURSION_LIMIT//2 + 100): From 99531514c8ae33a2f664b161e366977db6d2886c Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Fri, 22 Dec 2023 03:15:47 +0800 Subject: [PATCH 191/893] Update abc.py from CPython v3.12.0 --- Lib/abc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/abc.py b/Lib/abc.py index bfccab2dfc..1ecff5e214 100644 --- a/Lib/abc.py +++ b/Lib/abc.py @@ -18,7 +18,7 @@ class that has a metaclass derived from ABCMeta cannot be class C(metaclass=ABCMeta): @abstractmethod - def my_abstract_method(self, ...): + def my_abstract_method(self, arg1, arg2, argN): ... """ funcobj.__isabstractmethod__ = True From c0f6a2f8c320d77adf67ff8572655ba578dbd576 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Fri, 22 Dec 2023 03:18:10 +0800 Subject: [PATCH 192/893] Update test_abc.py from CPython v3.12.0 --- Lib/test/test_abc.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/Lib/test/test_abc.py b/Lib/test/test_abc.py index f65b1f6293..5ce57cc209 100644 --- a/Lib/test/test_abc.py +++ b/Lib/test/test_abc.py @@ -154,7 +154,7 @@ class C(metaclass=abc_ABCMeta): @abc.abstractmethod def method_one(self): pass - msg = r"class C with abstract method method_one" + msg = r"class C without an implementation for abstract method 'method_one'" self.assertRaisesRegex(TypeError, msg, C) def test_object_new_with_many_abstractmethods(self): @@ -165,7 +165,7 @@ def method_one(self): @abc.abstractmethod def method_two(self): pass - msg = r"class C with abstract methods method_one, method_two" + msg = r"class C without an implementation for abstract methods 'method_one', 'method_two'" self.assertRaisesRegex(TypeError, msg, C) def test_abstractmethod_integration(self): @@ -448,15 +448,16 @@ class S(metaclass=abc_ABCMeta): # Also check that issubclass() propagates exceptions raised by # __subclasses__. + class CustomError(Exception): ... exc_msg = "exception from __subclasses__" def raise_exc(): - raise Exception(exc_msg) + raise CustomError(exc_msg) class S(metaclass=abc_ABCMeta): __subclasses__ = raise_exc - with self.assertRaisesRegex(Exception, exc_msg): + with self.assertRaisesRegex(CustomError, exc_msg): issubclass(int, S) def test_subclasshook(self): @@ -521,6 +522,7 @@ def foo(self): self.assertEqual(A.__abstractmethods__, set()) A() + def test_update_new_abstractmethods(self): class A(metaclass=abc_ABCMeta): @abc.abstractmethod @@ -534,7 +536,7 @@ def updated_foo(self): A.foo = updated_foo abc.update_abstractmethods(A) self.assertEqual(A.__abstractmethods__, {'foo', 'bar'}) - msg = "class A with abstract methods bar, foo" + msg = "class A without an implementation for abstract methods 'bar', 'foo'" self.assertRaisesRegex(TypeError, msg, A) def test_update_implementation(self): @@ -546,7 +548,7 @@ def foo(self): class B(A): pass - msg = "class B with abstract method foo" + msg = "class B without an implementation for abstract method 'foo'" self.assertRaisesRegex(TypeError, msg, B) self.assertEqual(B.__abstractmethods__, {'foo'}) @@ -604,7 +606,7 @@ def foo(self): abc.update_abstractmethods(B) - msg = "class B with abstract method foo" + msg = "class B without an implementation for abstract method 'foo'" self.assertRaisesRegex(TypeError, msg, B) def test_update_layered_implementation(self): @@ -626,7 +628,7 @@ def foo(self): abc.update_abstractmethods(C) - msg = "class C with abstract method foo" + msg = "class C without an implementation for abstract method 'foo'" self.assertRaisesRegex(TypeError, msg, C) def test_update_multi_inheritance(self): From 79231ee39990f5aca30492c60131706349f6360b Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Fri, 22 Dec 2023 03:22:28 +0800 Subject: [PATCH 193/893] Edit test_abc.py --- Lib/test/test_abc.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_abc.py b/Lib/test/test_abc.py index 5ce57cc209..d912954a41 100644 --- a/Lib/test/test_abc.py +++ b/Lib/test/test_abc.py @@ -149,6 +149,8 @@ def foo(): return 4 self.assertEqual(D.foo(), 4) self.assertEqual(D().foo(), 4) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_object_new_with_one_abstractmethod(self): class C(metaclass=abc_ABCMeta): @abc.abstractmethod @@ -157,6 +159,8 @@ def method_one(self): msg = r"class C without an implementation for abstract method 'method_one'" self.assertRaisesRegex(TypeError, msg, C) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_object_new_with_many_abstractmethods(self): class C(metaclass=abc_ABCMeta): @abc.abstractmethod @@ -522,7 +526,8 @@ def foo(self): self.assertEqual(A.__abstractmethods__, set()) A() - + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_update_new_abstractmethods(self): class A(metaclass=abc_ABCMeta): @abc.abstractmethod @@ -539,6 +544,8 @@ def updated_foo(self): msg = "class A without an implementation for abstract methods 'bar', 'foo'" self.assertRaisesRegex(TypeError, msg, A) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_update_implementation(self): class A(metaclass=abc_ABCMeta): @abc.abstractmethod @@ -590,6 +597,8 @@ def updated_foo(self): A() self.assertFalse(hasattr(A, '__abstractmethods__')) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_update_del_implementation(self): class A(metaclass=abc_ABCMeta): @abc.abstractmethod @@ -609,6 +618,8 @@ def foo(self): msg = "class B without an implementation for abstract method 'foo'" self.assertRaisesRegex(TypeError, msg, B) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_update_layered_implementation(self): class A(metaclass=abc_ABCMeta): @abc.abstractmethod From b5176fdbc0537801507ee4c033ae3a44fbd57db0 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Sun, 24 Dec 2023 03:54:18 +0800 Subject: [PATCH 194/893] Update bisect.py from CPython v3.12.0 --- Lib/bisect.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Lib/bisect.py b/Lib/bisect.py index d37da74f7b..ca6ca72408 100644 --- a/Lib/bisect.py +++ b/Lib/bisect.py @@ -8,6 +8,8 @@ def insort_right(a, x, lo=0, hi=None, *, key=None): Optional args lo (default 0) and hi (default len(a)) bound the slice of a to be searched. + + A custom key function can be supplied to customize the sort order. """ if key is None: lo = bisect_right(a, x, lo, hi) @@ -25,6 +27,8 @@ def bisect_right(a, x, lo=0, hi=None, *, key=None): Optional args lo (default 0) and hi (default len(a)) bound the slice of a to be searched. + + A custom key function can be supplied to customize the sort order. """ if lo < 0: @@ -57,6 +61,8 @@ def insort_left(a, x, lo=0, hi=None, *, key=None): Optional args lo (default 0) and hi (default len(a)) bound the slice of a to be searched. + + A custom key function can be supplied to customize the sort order. """ if key is None: @@ -74,6 +80,8 @@ def bisect_left(a, x, lo=0, hi=None, *, key=None): Optional args lo (default 0) and hi (default len(a)) bound the slice of a to be searched. + + A custom key function can be supplied to customize the sort order. """ if lo < 0: From b6c2179893265ec3a4c31c96c3274ea98b505372 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Sun, 24 Dec 2023 03:56:28 +0800 Subject: [PATCH 195/893] Update test_bisect.py from CPython v3.12.0 --- Lib/test/test_bisect.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/Lib/test/test_bisect.py b/Lib/test/test_bisect.py index ba108221eb..97204d4cad 100644 --- a/Lib/test/test_bisect.py +++ b/Lib/test/test_bisect.py @@ -263,6 +263,34 @@ def test_insort_keynotNone(self): for f in (self.module.insort_left, self.module.insort_right): self.assertRaises(TypeError, f, x, y, key = "b") + def test_lt_returns_non_bool(self): + class A: + def __init__(self, val): + self.val = val + def __lt__(self, other): + return "nonempty" if self.val < other.val else "" + + data = [A(i) for i in range(100)] + i1 = self.module.bisect_left(data, A(33)) + i2 = self.module.bisect_right(data, A(33)) + self.assertEqual(i1, 33) + self.assertEqual(i2, 34) + + def test_lt_returns_notimplemented(self): + class A: + def __init__(self, val): + self.val = val + def __lt__(self, other): + return NotImplemented + def __gt__(self, other): + return self.val > other.val + + data = [A(i) for i in range(100)] + i1 = self.module.bisect_left(data, A(40)) + i2 = self.module.bisect_right(data, A(40)) + self.assertEqual(i1, 40) + self.assertEqual(i2, 41) + class TestBisectPython(TestBisect, unittest.TestCase): module = py_bisect From df98264dd00b19313a1f1e0a28375503720b8e19 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Sun, 24 Dec 2023 03:58:10 +0800 Subject: [PATCH 196/893] Update base64.py from CPython v3.12.0 --- Lib/base64.py | 33 ++++++--------------------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/Lib/base64.py b/Lib/base64.py index 7e9c2a2ca4..e233647ee7 100755 --- a/Lib/base64.py +++ b/Lib/base64.py @@ -508,14 +508,8 @@ def b85decode(b): def encode(input, output): """Encode a file; input and output are binary files.""" - while True: - s = input.read(MAXBINSIZE) - if not s: - break - while len(s) < MAXBINSIZE: - ns = input.read(MAXBINSIZE-len(s)) - if not ns: - break + while s := input.read(MAXBINSIZE): + while len(s) < MAXBINSIZE and (ns := input.read(MAXBINSIZE-len(s))): s += ns line = binascii.b2a_base64(s) output.write(line) @@ -523,10 +517,7 @@ def encode(input, output): def decode(input, output): """Decode a file; input and output are binary files.""" - while True: - line = input.readline() - if not line: - break + while line := input.readline(): s = binascii.a2b_base64(line) output.write(s) @@ -567,13 +558,12 @@ def decodebytes(s): def main(): """Small main program""" import sys, getopt - usage = """usage: %s [-h|-d|-e|-u|-t] [file|-] + usage = f"""usage: {sys.argv[0]} [-h|-d|-e|-u] [file|-] -h: print this help message and exit -d, -u: decode - -e: encode (default) - -t: encode and decode string 'Aladdin:open sesame'"""%sys.argv[0] + -e: encode (default)""" try: - opts, args = getopt.getopt(sys.argv[1:], 'hdeut') + opts, args = getopt.getopt(sys.argv[1:], 'hdeu') except getopt.error as msg: sys.stdout = sys.stderr print(msg) @@ -584,7 +574,6 @@ def main(): if o == '-e': func = encode if o == '-d': func = decode if o == '-u': func = decode - if o == '-t': test(); return if o == '-h': print(usage); return if args and args[0] != '-': with open(args[0], 'rb') as f: @@ -593,15 +582,5 @@ def main(): func(sys.stdin.buffer, sys.stdout.buffer) -def test(): - s0 = b"Aladdin:open sesame" - print(repr(s0)) - s1 = encodebytes(s0) - print(repr(s1)) - s2 = decodebytes(s1) - print(repr(s2)) - assert s0 == s2 - - if __name__ == '__main__': main() From b4ee044fa6035c418b4ccd4fa39d67b4b57dd665 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Sun, 24 Dec 2023 03:58:36 +0800 Subject: [PATCH 197/893] Update test_base64.py from CPython v3.12.0 --- Lib/test/test_base64.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py index 217f294546..fa03fa1d61 100644 --- a/Lib/test/test_base64.py +++ b/Lib/test/test_base64.py @@ -31,6 +31,8 @@ def test_encodebytes(self): b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE" b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0\nNT" b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==\n") + eq(base64.encodebytes(b"Aladdin:open sesame"), + b"QWxhZGRpbjpvcGVuIHNlc2FtZQ==\n") # Non-bytes eq(base64.encodebytes(bytearray(b'abc')), b'YWJj\n') eq(base64.encodebytes(memoryview(b'abc')), b'YWJj\n') @@ -50,6 +52,8 @@ def test_decodebytes(self): b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" b"0123456789!@#0^&*();:<>,. []{}") eq(base64.decodebytes(b''), b'') + eq(base64.decodebytes(b"QWxhZGRpbjpvcGVuIHNlc2FtZQ==\n"), + b"Aladdin:open sesame") # Non-bytes eq(base64.decodebytes(bytearray(b'YWJj\n')), b'abc') eq(base64.decodebytes(memoryview(b'YWJj\n')), b'abc') @@ -762,14 +766,6 @@ def tearDown(self): def get_output(self, *args): return script_helper.assert_python_ok('-m', 'base64', *args).out - def test_encode_decode(self): - output = self.get_output('-t') - self.assertSequenceEqual(output.splitlines(), ( - b"b'Aladdin:open sesame'", - br"b'QWxhZGRpbjpvcGVuIHNlc2FtZQ==\n'", - b"b'Aladdin:open sesame'", - )) - def test_encode_file(self): with open(os_helper.TESTFN, 'wb') as fp: fp.write(b'a\xffb\n') From a00a387735cadd0c2b769312d0628ed80b4e4c16 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Sun, 24 Dec 2023 03:59:38 +0800 Subject: [PATCH 198/893] Update bdb.py from CPython v3.12.0 --- Lib/bdb.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/Lib/bdb.py b/Lib/bdb.py index 75d6113576..0f3eec653b 100644 --- a/Lib/bdb.py +++ b/Lib/bdb.py @@ -570,9 +570,12 @@ def format_stack_entry(self, frame_lineno, lprefix=': '): rv = frame.f_locals['__return__'] s += '->' s += reprlib.repr(rv) - line = linecache.getline(filename, lineno, frame.f_globals) - if line: - s += lprefix + line.strip() + if lineno is not None: + line = linecache.getline(filename, lineno, frame.f_globals) + if line: + s += lprefix + line.strip() + else: + s += f'{lprefix}Warning: lineno is None' return s # The following methods can be called by clients to use @@ -805,15 +808,18 @@ def checkfuncname(b, frame): return True -# Determines if there is an effective (active) breakpoint at this -# line of code. Returns breakpoint number or 0 if none def effective(file, line, frame): - """Determine which breakpoint for this file:line is to be acted upon. + """Return (active breakpoint, delete temporary flag) or (None, None) as + breakpoint to act upon. + + The "active breakpoint" is the first entry in bplist[line, file] (which + must exist) that is enabled, for which checkfuncname is True, and that + has neither a False condition nor a positive ignore count. The flag, + meaning that a temporary breakpoint should be deleted, is False only + when the condiion cannot be evaluated (in which case, ignore count is + ignored). - Called only if we know there is a breakpoint at this location. Return - the breakpoint that was triggered and a boolean that indicates if it is - ok to delete a temporary breakpoint. Return (None, None) if there is no - matching breakpoint. + If no such entry exists, then (None, None) is returned. """ possibles = Breakpoint.bplist[file, line] for b in possibles: From dd0c393b4599a6becdb0c1084002c050434984e6 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Sun, 24 Dec 2023 04:03:20 +0800 Subject: [PATCH 199/893] Update test_bdb.py from CPython v3.12.0 --- Lib/test/test_bdb.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_bdb.py b/Lib/test/test_bdb.py index 70cb096e92..a3abbbb8db 100644 --- a/Lib/test/test_bdb.py +++ b/Lib/test/test_bdb.py @@ -59,6 +59,7 @@ from itertools import islice, repeat from test.support import import_helper from test.support import os_helper +from test.support import patch_list class BdbException(Exception): pass @@ -432,8 +433,9 @@ def __exit__(self, type_=None, value=None, traceback=None): not_empty = '' if self.tracer.set_list: not_empty += 'All paired tuples have not been processed, ' - not_empty += ('the last one was number %d' % + not_empty += ('the last one was number %d\n' % self.tracer.expect_set_no) + not_empty += repr(self.tracer.set_list) # Make a BdbNotExpectedError a unittest failure. if type_ is not None and issubclass(BdbNotExpectedError, type_): @@ -728,6 +730,14 @@ def test_until_in_caller_frame(self): def test_skip(self): # Check that tracing is skipped over the import statement in # 'tfunc_import()'. + + # Remove all but the standard importers. + sys.meta_path[:] = ( + item + for item in sys.meta_path + if item.__module__.startswith('_frozen_importlib') + ) + code = """ def main(): lno = 3 @@ -1224,5 +1234,12 @@ def main(): tracer.runcall(tfunc_import) +class TestRegressions(unittest.TestCase): + def test_format_stack_entry_no_lineno(self): + # See gh-101517 + self.assertIn('Warning: lineno is None', + Bdb().format_stack_entry((sys._getframe(), None))) + + if __name__ == "__main__": unittest.main() From 44438bffbd28a643bc97b700aa8c60934026fe67 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Sun, 24 Dec 2023 04:04:25 +0800 Subject: [PATCH 200/893] Update argparse.py from CPython v3.12.0 --- Lib/argparse.py | 64 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/Lib/argparse.py b/Lib/argparse.py index 7761908861..543d9944f9 100644 --- a/Lib/argparse.py +++ b/Lib/argparse.py @@ -345,21 +345,22 @@ def _format_usage(self, usage, actions, groups, prefix): def get_lines(parts, indent, prefix=None): lines = [] line = [] + indent_length = len(indent) if prefix is not None: line_len = len(prefix) - 1 else: - line_len = len(indent) - 1 + line_len = indent_length - 1 for part in parts: if line_len + 1 + len(part) > text_width and line: lines.append(indent + ' '.join(line)) line = [] - line_len = len(indent) - 1 + line_len = indent_length - 1 line.append(part) line_len += len(part) + 1 if line: lines.append(indent + ' '.join(line)) if prefix is not None: - lines[0] = lines[0][len(indent):] + lines[0] = lines[0][indent_length:] return lines # if prog is short, follow it with optionals or positionals @@ -403,10 +404,18 @@ def _format_actions_usage(self, actions, groups): except ValueError: continue else: - end = start + len(group._group_actions) + group_action_count = len(group._group_actions) + end = start + group_action_count if actions[start:end] == group._group_actions: + + suppressed_actions_count = 0 for action in group._group_actions: group_actions.add(action) + if action.help is SUPPRESS: + suppressed_actions_count += 1 + + exposed_actions_count = group_action_count - suppressed_actions_count + if not group.required: if start in inserts: inserts[start] += ' [' @@ -416,7 +425,7 @@ def _format_actions_usage(self, actions, groups): inserts[end] += ']' else: inserts[end] = ']' - else: + elif exposed_actions_count > 1: if start in inserts: inserts[start] += ' (' else: @@ -490,7 +499,6 @@ def _format_actions_usage(self, actions, groups): text = _re.sub(r'(%s) ' % open, r'\1', text) text = _re.sub(r' (%s)' % close, r'\1', text) text = _re.sub(r'%s *%s' % (open, close), r'', text) - text = _re.sub(r'\(([^|]*)\)', r'\1', text) text = text.strip() # return the text @@ -875,16 +883,19 @@ def __call__(self, parser, namespace, values, option_string=None): raise NotImplementedError(_('.__call__() not defined')) +# FIXME: remove together with `BooleanOptionalAction` deprecated arguments. +_deprecated_default = object() + class BooleanOptionalAction(Action): def __init__(self, option_strings, dest, default=None, - type=None, - choices=None, + type=_deprecated_default, + choices=_deprecated_default, required=False, help=None, - metavar=None): + metavar=_deprecated_default): _option_strings = [] for option_string in option_strings: @@ -894,6 +905,24 @@ def __init__(self, option_string = '--no-' + option_string[2:] _option_strings.append(option_string) + # We need `_deprecated` special value to ban explicit arguments that + # match default value. Like: + # parser.add_argument('-f', action=BooleanOptionalAction, type=int) + for field_name in ('type', 'choices', 'metavar'): + if locals()[field_name] is not _deprecated_default: + warnings._deprecated( + field_name, + "{name!r} is deprecated as of Python 3.12 and will be " + "removed in Python {remove}.", + remove=(3, 14)) + + if type is _deprecated_default: + type = None + if choices is _deprecated_default: + choices = None + if metavar is _deprecated_default: + metavar = None + super().__init__( option_strings=_option_strings, dest=dest, @@ -2165,7 +2194,9 @@ def _read_args_from_files(self, arg_strings): # replace arguments referencing files with the file content else: try: - with open(arg_string[1:]) as args_file: + with open(arg_string[1:], + encoding=_sys.getfilesystemencoding(), + errors=_sys.getfilesystemencodeerrors()) as args_file: arg_strings = [] for arg_line in args_file.read().splitlines(): for arg in self.convert_arg_line_to_args(arg_line): @@ -2479,9 +2510,11 @@ def _get_values(self, action, arg_strings): not action.option_strings): if action.default is not None: value = action.default + self._check_value(action, value) else: + # since arg_strings is always [] at this point + # there is no need to use self._check_value(action, value) value = arg_strings - self._check_value(action, value) # single argument or optional argument produces a single value elif len(arg_strings) == 1 and action.nargs in [None, OPTIONAL]: @@ -2523,7 +2556,6 @@ def _get_value(self, action, arg_string): # ArgumentTypeErrors indicate errors except ArgumentTypeError as err: - name = getattr(action.type, '__name__', repr(action.type)) msg = str(err) raise ArgumentError(action, msg) @@ -2595,9 +2627,11 @@ def print_help(self, file=None): def _print_message(self, message, file=None): if message: - if file is None: - file = _sys.stderr - file.write(message) + file = file or _sys.stderr + try: + file.write(message) + except (AttributeError, OSError): + pass # =============== # Exiting methods From 5700fa39537e21d89b531a495c81c5374caf3822 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Sun, 24 Dec 2023 04:05:39 +0800 Subject: [PATCH 201/893] Update argparse.py from CPython v3.12.0 --- Lib/test/test_argparse.py | 138 ++++++++++++++++++++++++++++++++++---- 1 file changed, 126 insertions(+), 12 deletions(-) diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py index 1acecbb8ab..3a62a16cee 100644 --- a/Lib/test/test_argparse.py +++ b/Lib/test/test_argparse.py @@ -1,5 +1,7 @@ # Author: Steven J. Bethard . +import contextlib +import functools import inspect import io import operator @@ -35,6 +37,35 @@ def getvalue(self): return self.buffer.raw.getvalue().decode('utf-8') +class StdStreamTest(unittest.TestCase): + + def test_skip_invalid_stderr(self): + parser = argparse.ArgumentParser() + with ( + contextlib.redirect_stderr(None), + mock.patch('argparse._sys.exit') + ): + parser.exit(status=0, message='foo') + + def test_skip_invalid_stdout(self): + parser = argparse.ArgumentParser() + for func in ( + parser.print_usage, + parser.print_help, + functools.partial(parser.parse_args, ['-h']) + ): + with ( + self.subTest(func=func), + contextlib.redirect_stdout(None), + # argparse uses stderr as a fallback + StdIOBuffer() as mocked_stderr, + contextlib.redirect_stderr(mocked_stderr), + mock.patch('argparse._sys.exit'), + ): + func() + self.assertRegex(mocked_stderr.getvalue(), r'usage:') + + class TestCase(unittest.TestCase): def setUp(self): @@ -734,6 +765,49 @@ def test_const(self): self.assertIn("got an unexpected keyword argument 'const'", str(cm.exception)) + def test_deprecated_init_kw(self): + # See gh-92248 + parser = argparse.ArgumentParser() + + with self.assertWarns(DeprecationWarning): + parser.add_argument( + '-a', + action=argparse.BooleanOptionalAction, + type=None, + ) + with self.assertWarns(DeprecationWarning): + parser.add_argument( + '-b', + action=argparse.BooleanOptionalAction, + type=bool, + ) + + with self.assertWarns(DeprecationWarning): + parser.add_argument( + '-c', + action=argparse.BooleanOptionalAction, + metavar=None, + ) + with self.assertWarns(DeprecationWarning): + parser.add_argument( + '-d', + action=argparse.BooleanOptionalAction, + metavar='d', + ) + + with self.assertWarns(DeprecationWarning): + parser.add_argument( + '-e', + action=argparse.BooleanOptionalAction, + choices=None, + ) + with self.assertWarns(DeprecationWarning): + parser.add_argument( + '-f', + action=argparse.BooleanOptionalAction, + choices=(), + ) + class TestBooleanOptionalActionRequired(ParserTestCase): """Tests BooleanOptionalAction required""" @@ -1505,14 +1579,15 @@ class TestArgumentsFromFile(TempDirMixin, ParserTestCase): def setUp(self): super(TestArgumentsFromFile, self).setUp() file_texts = [ - ('hello', 'hello world!\n'), - ('recursive', '-a\n' - 'A\n' - '@hello'), - ('invalid', '@no-such-path\n'), + ('hello', os.fsencode(self.hello) + b'\n'), + ('recursive', b'-a\n' + b'A\n' + b'@hello'), + ('invalid', b'@no-such-path\n'), + ('undecodable', self.undecodable + b'\n'), ] for path, text in file_texts: - with open(path, 'w', encoding="utf-8") as file: + with open(path, 'wb') as file: file.write(text) parser_signature = Sig(fromfile_prefix_chars='@') @@ -1522,15 +1597,25 @@ def setUp(self): Sig('y', nargs='+'), ] failures = ['', '-b', 'X', '@invalid', '@missing'] + hello = 'hello world!' + os_helper.FS_NONASCII successes = [ ('X Y', NS(a=None, x='X', y=['Y'])), ('X -a A Y Z', NS(a='A', x='X', y=['Y', 'Z'])), - ('@hello X', NS(a=None, x='hello world!', y=['X'])), - ('X @hello', NS(a=None, x='X', y=['hello world!'])), - ('-a B @recursive Y Z', NS(a='A', x='hello world!', y=['Y', 'Z'])), - ('X @recursive Z -a B', NS(a='B', x='X', y=['hello world!', 'Z'])), + ('@hello X', NS(a=None, x=hello, y=['X'])), + ('X @hello', NS(a=None, x='X', y=[hello])), + ('-a B @recursive Y Z', NS(a='A', x=hello, y=['Y', 'Z'])), + ('X @recursive Z -a B', NS(a='B', x='X', y=[hello, 'Z'])), (["-a", "", "X", "Y"], NS(a='', x='X', y=['Y'])), ] + if os_helper.TESTFN_UNDECODABLE: + undecodable = os_helper.TESTFN_UNDECODABLE.lstrip(b'@') + decoded_undecodable = os.fsdecode(undecodable) + successes += [ + ('@undecodable X', NS(a=None, x=decoded_undecodable, y=['X'])), + ('X @undecodable', NS(a=None, x='X', y=[decoded_undecodable])), + ] + else: + undecodable = b'' class TestArgumentsFromFileConverter(TempDirMixin, ParserTestCase): @@ -1539,10 +1624,10 @@ class TestArgumentsFromFileConverter(TempDirMixin, ParserTestCase): def setUp(self): super(TestArgumentsFromFileConverter, self).setUp() file_texts = [ - ('hello', 'hello world!\n'), + ('hello', b'hello world!\n'), ] for path, text in file_texts: - with open(path, 'w', encoding="utf-8") as file: + with open(path, 'wb') as file: file.write(text) class FromFileConverterArgumentParser(ErrorRaisingArgumentParser): @@ -3753,6 +3838,28 @@ class TestHelpUsage(HelpTestCase): version = '' +class TestHelpUsageWithParentheses(HelpTestCase): + parser_signature = Sig(prog='PROG') + argument_signatures = [ + Sig('positional', metavar='(example) positional'), + Sig('-p', '--optional', metavar='{1 (option A), 2 (option B)}'), + ] + + usage = '''\ + usage: PROG [-h] [-p {1 (option A), 2 (option B)}] (example) positional + ''' + help = usage + '''\ + + positional arguments: + (example) positional + + options: + -h, --help show this help message and exit + -p {1 (option A), 2 (option B)}, --optional {1 (option A), 2 (option B)} + ''' + version = '' + + class TestHelpOnlyUserGroups(HelpTestCase): """Test basic usage messages""" @@ -5219,6 +5326,13 @@ def test_mixed(self): self.assertEqual(NS(v=3, spam=True, badger="B"), args) self.assertEqual(["C", "--foo", "4"], extras) + def test_zero_or_more_optional(self): + parser = argparse.ArgumentParser() + parser.add_argument('x', nargs='*', choices=('x', 'y')) + args = parser.parse_args([]) + self.assertEqual(NS(x=[]), args) + + # =========================== # parse_intermixed_args tests # =========================== From 57f9478d160da8d267ac32800047f172dcfb9e51 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Sun, 24 Dec 2023 04:07:12 +0800 Subject: [PATCH 202/893] Add test_bz2.py from CPython v3.12.0 --- Lib/test/test_bz2.py | 1018 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1018 insertions(+) create mode 100644 Lib/test/test_bz2.py diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py new file mode 100644 index 0000000000..1f0b9adc36 --- /dev/null +++ b/Lib/test/test_bz2.py @@ -0,0 +1,1018 @@ +from test import support +from test.support import bigmemtest, _4G + +import array +import unittest +from io import BytesIO, DEFAULT_BUFFER_SIZE +import os +import pickle +import glob +import tempfile +import pathlib +import random +import shutil +import subprocess +import threading +from test.support import import_helper +from test.support import threading_helper +from test.support.os_helper import unlink +import _compression +import sys + + +# Skip tests if the bz2 module doesn't exist. +bz2 = import_helper.import_module('bz2') +from bz2 import BZ2File, BZ2Compressor, BZ2Decompressor + +has_cmdline_bunzip2 = None + +def ext_decompress(data): + global has_cmdline_bunzip2 + if has_cmdline_bunzip2 is None: + has_cmdline_bunzip2 = bool(shutil.which('bunzip2')) + if has_cmdline_bunzip2: + return subprocess.check_output(['bunzip2'], input=data) + else: + return bz2.decompress(data) + +class BaseTest(unittest.TestCase): + "Base for other testcases." + + TEXT_LINES = [ + b'root:x:0:0:root:/root:/bin/bash\n', + b'bin:x:1:1:bin:/bin:\n', + b'daemon:x:2:2:daemon:/sbin:\n', + b'adm:x:3:4:adm:/var/adm:\n', + b'lp:x:4:7:lp:/var/spool/lpd:\n', + b'sync:x:5:0:sync:/sbin:/bin/sync\n', + b'shutdown:x:6:0:shutdown:/sbin:/sbin/shutdown\n', + b'halt:x:7:0:halt:/sbin:/sbin/halt\n', + b'mail:x:8:12:mail:/var/spool/mail:\n', + b'news:x:9:13:news:/var/spool/news:\n', + b'uucp:x:10:14:uucp:/var/spool/uucp:\n', + b'operator:x:11:0:operator:/root:\n', + b'games:x:12:100:games:/usr/games:\n', + b'gopher:x:13:30:gopher:/usr/lib/gopher-data:\n', + b'ftp:x:14:50:FTP User:/var/ftp:/bin/bash\n', + b'nobody:x:65534:65534:Nobody:/home:\n', + b'postfix:x:100:101:postfix:/var/spool/postfix:\n', + b'niemeyer:x:500:500::/home/niemeyer:/bin/bash\n', + b'postgres:x:101:102:PostgreSQL Server:/var/lib/pgsql:/bin/bash\n', + b'mysql:x:102:103:MySQL server:/var/lib/mysql:/bin/bash\n', + b'www:x:103:104::/var/www:/bin/false\n', + ] + TEXT = b''.join(TEXT_LINES) + DATA = b'BZh91AY&SY.\xc8N\x18\x00\x01>_\x80\x00\x10@\x02\xff\xf0\x01\x07n\x00?\xe7\xff\xe00\x01\x99\xaa\x00\xc0\x03F\x86\x8c#&\x83F\x9a\x03\x06\xa6\xd0\xa6\x93M\x0fQ\xa7\xa8\x06\x804hh\x12$\x11\xa4i4\xf14S\xd2\x88\xe5\xcd9gd6\x0b\n\xe9\x9b\xd5\x8a\x99\xf7\x08.K\x8ev\xfb\xf7xw\xbb\xdf\xa1\x92\xf1\xdd|/";\xa2\xba\x9f\xd5\xb1#A\xb6\xf6\xb3o\xc9\xc5y\\\xebO\xe7\x85\x9a\xbc\xb6f8\x952\xd5\xd7"%\x89>V,\xf7\xa6z\xe2\x9f\xa3\xdf\x11\x11"\xd6E)I\xa9\x13^\xca\xf3r\xd0\x03U\x922\xf26\xec\xb6\xed\x8b\xc3U\x13\x9d\xc5\x170\xa4\xfa^\x92\xacDF\x8a\x97\xd6\x19\xfe\xdd\xb8\xbd\x1a\x9a\x19\xa3\x80ankR\x8b\xe5\xd83]\xa9\xc6\x08\x82f\xf6\xb9"6l$\xb8j@\xc0\x8a\xb0l1..\xbak\x83ls\x15\xbc\xf4\xc1\x13\xbe\xf8E\xb8\x9d\r\xa8\x9dk\x84\xd3n\xfa\xacQ\x07\xb1%y\xaav\xb4\x08\xe0z\x1b\x16\xf5\x04\xe9\xcc\xb9\x08z\x1en7.G\xfc]\xc9\x14\xe1B@\xbb!8`' + EMPTY_DATA = b'BZh9\x17rE8P\x90\x00\x00\x00\x00' + BAD_DATA = b'this is not a valid bzip2 file' + + # Some tests need more than one block of uncompressed data. Since one block + # is at least 100,000 bytes, we gather some data dynamically and compress it. + # Note that this assumes that compression works correctly, so we cannot + # simply use the bigger test data for all tests. + test_size = 0 + BIG_TEXT = bytearray(128*1024) + for fname in glob.glob(os.path.join(glob.escape(os.path.dirname(__file__)), '*.py')): + with open(fname, 'rb') as fh: + test_size += fh.readinto(memoryview(BIG_TEXT)[test_size:]) + if test_size > 128*1024: + break + BIG_DATA = bz2.compress(BIG_TEXT, compresslevel=1) + + def setUp(self): + fd, self.filename = tempfile.mkstemp() + os.close(fd) + + def tearDown(self): + unlink(self.filename) + + +class BZ2FileTest(BaseTest): + "Test the BZ2File class." + + def createTempFile(self, streams=1, suffix=b""): + with open(self.filename, "wb") as f: + f.write(self.DATA * streams) + f.write(suffix) + + def testBadArgs(self): + self.assertRaises(TypeError, BZ2File, 123.456) + self.assertRaises(ValueError, BZ2File, os.devnull, "z") + self.assertRaises(ValueError, BZ2File, os.devnull, "rx") + self.assertRaises(ValueError, BZ2File, os.devnull, "rbt") + self.assertRaises(ValueError, BZ2File, os.devnull, compresslevel=0) + self.assertRaises(ValueError, BZ2File, os.devnull, compresslevel=10) + + # compresslevel is keyword-only + self.assertRaises(TypeError, BZ2File, os.devnull, "r", 3) + + def testRead(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + self.assertRaises(TypeError, bz2f.read, float()) + self.assertEqual(bz2f.read(), self.TEXT) + + def testReadBadFile(self): + self.createTempFile(streams=0, suffix=self.BAD_DATA) + with BZ2File(self.filename) as bz2f: + self.assertRaises(OSError, bz2f.read) + + def testReadMultiStream(self): + self.createTempFile(streams=5) + with BZ2File(self.filename) as bz2f: + self.assertRaises(TypeError, bz2f.read, float()) + self.assertEqual(bz2f.read(), self.TEXT * 5) + + def testReadMonkeyMultiStream(self): + # Test BZ2File.read() on a multi-stream archive where a stream + # boundary coincides with the end of the raw read buffer. + buffer_size = _compression.BUFFER_SIZE + _compression.BUFFER_SIZE = len(self.DATA) + try: + self.createTempFile(streams=5) + with BZ2File(self.filename) as bz2f: + self.assertRaises(TypeError, bz2f.read, float()) + self.assertEqual(bz2f.read(), self.TEXT * 5) + finally: + _compression.BUFFER_SIZE = buffer_size + + def testReadTrailingJunk(self): + self.createTempFile(suffix=self.BAD_DATA) + with BZ2File(self.filename) as bz2f: + self.assertEqual(bz2f.read(), self.TEXT) + + def testReadMultiStreamTrailingJunk(self): + self.createTempFile(streams=5, suffix=self.BAD_DATA) + with BZ2File(self.filename) as bz2f: + self.assertEqual(bz2f.read(), self.TEXT * 5) + + def testRead0(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + self.assertRaises(TypeError, bz2f.read, float()) + self.assertEqual(bz2f.read(0), b"") + + def testReadChunk10(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + text = b'' + while True: + str = bz2f.read(10) + if not str: + break + text += str + self.assertEqual(text, self.TEXT) + + def testReadChunk10MultiStream(self): + self.createTempFile(streams=5) + with BZ2File(self.filename) as bz2f: + text = b'' + while True: + str = bz2f.read(10) + if not str: + break + text += str + self.assertEqual(text, self.TEXT * 5) + + def testRead100(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + self.assertEqual(bz2f.read(100), self.TEXT[:100]) + + def testPeek(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + pdata = bz2f.peek() + self.assertNotEqual(len(pdata), 0) + self.assertTrue(self.TEXT.startswith(pdata)) + self.assertEqual(bz2f.read(), self.TEXT) + + def testReadInto(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + n = 128 + b = bytearray(n) + self.assertEqual(bz2f.readinto(b), n) + self.assertEqual(b, self.TEXT[:n]) + n = len(self.TEXT) - n + b = bytearray(len(self.TEXT)) + self.assertEqual(bz2f.readinto(b), n) + self.assertEqual(b[:n], self.TEXT[-n:]) + + def testReadLine(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + self.assertRaises(TypeError, bz2f.readline, None) + for line in self.TEXT_LINES: + self.assertEqual(bz2f.readline(), line) + + def testReadLineMultiStream(self): + self.createTempFile(streams=5) + with BZ2File(self.filename) as bz2f: + self.assertRaises(TypeError, bz2f.readline, None) + for line in self.TEXT_LINES * 5: + self.assertEqual(bz2f.readline(), line) + + def testReadLines(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + self.assertRaises(TypeError, bz2f.readlines, None) + self.assertEqual(bz2f.readlines(), self.TEXT_LINES) + + def testReadLinesMultiStream(self): + self.createTempFile(streams=5) + with BZ2File(self.filename) as bz2f: + self.assertRaises(TypeError, bz2f.readlines, None) + self.assertEqual(bz2f.readlines(), self.TEXT_LINES * 5) + + def testIterator(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + self.assertEqual(list(iter(bz2f)), self.TEXT_LINES) + + def testIteratorMultiStream(self): + self.createTempFile(streams=5) + with BZ2File(self.filename) as bz2f: + self.assertEqual(list(iter(bz2f)), self.TEXT_LINES * 5) + + def testClosedIteratorDeadlock(self): + # Issue #3309: Iteration on a closed BZ2File should release the lock. + self.createTempFile() + bz2f = BZ2File(self.filename) + bz2f.close() + self.assertRaises(ValueError, next, bz2f) + # This call will deadlock if the above call failed to release the lock. + self.assertRaises(ValueError, bz2f.readlines) + + def testWrite(self): + with BZ2File(self.filename, "w") as bz2f: + self.assertRaises(TypeError, bz2f.write) + bz2f.write(self.TEXT) + with open(self.filename, 'rb') as f: + self.assertEqual(ext_decompress(f.read()), self.TEXT) + + def testWriteChunks10(self): + with BZ2File(self.filename, "w") as bz2f: + n = 0 + while True: + str = self.TEXT[n*10:(n+1)*10] + if not str: + break + bz2f.write(str) + n += 1 + with open(self.filename, 'rb') as f: + self.assertEqual(ext_decompress(f.read()), self.TEXT) + + def testWriteNonDefaultCompressLevel(self): + expected = bz2.compress(self.TEXT, compresslevel=5) + with BZ2File(self.filename, "w", compresslevel=5) as bz2f: + bz2f.write(self.TEXT) + with open(self.filename, "rb") as f: + self.assertEqual(f.read(), expected) + + def testWriteLines(self): + with BZ2File(self.filename, "w") as bz2f: + self.assertRaises(TypeError, bz2f.writelines) + bz2f.writelines(self.TEXT_LINES) + # Issue #1535500: Calling writelines() on a closed BZ2File + # should raise an exception. + self.assertRaises(ValueError, bz2f.writelines, ["a"]) + with open(self.filename, 'rb') as f: + self.assertEqual(ext_decompress(f.read()), self.TEXT) + + def testWriteMethodsOnReadOnlyFile(self): + with BZ2File(self.filename, "w") as bz2f: + bz2f.write(b"abc") + + with BZ2File(self.filename, "r") as bz2f: + self.assertRaises(OSError, bz2f.write, b"a") + self.assertRaises(OSError, bz2f.writelines, [b"a"]) + + def testAppend(self): + with BZ2File(self.filename, "w") as bz2f: + self.assertRaises(TypeError, bz2f.write) + bz2f.write(self.TEXT) + with BZ2File(self.filename, "a") as bz2f: + self.assertRaises(TypeError, bz2f.write) + bz2f.write(self.TEXT) + with open(self.filename, 'rb') as f: + self.assertEqual(ext_decompress(f.read()), self.TEXT * 2) + + def testSeekForward(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + self.assertRaises(TypeError, bz2f.seek) + bz2f.seek(150) + self.assertEqual(bz2f.read(), self.TEXT[150:]) + + def testSeekForwardAcrossStreams(self): + self.createTempFile(streams=2) + with BZ2File(self.filename) as bz2f: + self.assertRaises(TypeError, bz2f.seek) + bz2f.seek(len(self.TEXT) + 150) + self.assertEqual(bz2f.read(), self.TEXT[150:]) + + def testSeekBackwards(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + bz2f.read(500) + bz2f.seek(-150, 1) + self.assertEqual(bz2f.read(), self.TEXT[500-150:]) + + def testSeekBackwardsAcrossStreams(self): + self.createTempFile(streams=2) + with BZ2File(self.filename) as bz2f: + readto = len(self.TEXT) + 100 + while readto > 0: + readto -= len(bz2f.read(readto)) + bz2f.seek(-150, 1) + self.assertEqual(bz2f.read(), self.TEXT[100-150:] + self.TEXT) + + def testSeekBackwardsFromEnd(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + bz2f.seek(-150, 2) + self.assertEqual(bz2f.read(), self.TEXT[len(self.TEXT)-150:]) + + def testSeekBackwardsFromEndAcrossStreams(self): + self.createTempFile(streams=2) + with BZ2File(self.filename) as bz2f: + bz2f.seek(-1000, 2) + self.assertEqual(bz2f.read(), (self.TEXT * 2)[-1000:]) + + def testSeekPostEnd(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + bz2f.seek(150000) + self.assertEqual(bz2f.tell(), len(self.TEXT)) + self.assertEqual(bz2f.read(), b"") + + def testSeekPostEndMultiStream(self): + self.createTempFile(streams=5) + with BZ2File(self.filename) as bz2f: + bz2f.seek(150000) + self.assertEqual(bz2f.tell(), len(self.TEXT) * 5) + self.assertEqual(bz2f.read(), b"") + + def testSeekPostEndTwice(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + bz2f.seek(150000) + bz2f.seek(150000) + self.assertEqual(bz2f.tell(), len(self.TEXT)) + self.assertEqual(bz2f.read(), b"") + + def testSeekPostEndTwiceMultiStream(self): + self.createTempFile(streams=5) + with BZ2File(self.filename) as bz2f: + bz2f.seek(150000) + bz2f.seek(150000) + self.assertEqual(bz2f.tell(), len(self.TEXT) * 5) + self.assertEqual(bz2f.read(), b"") + + def testSeekPreStart(self): + self.createTempFile() + with BZ2File(self.filename) as bz2f: + bz2f.seek(-150) + self.assertEqual(bz2f.tell(), 0) + self.assertEqual(bz2f.read(), self.TEXT) + + def testSeekPreStartMultiStream(self): + self.createTempFile(streams=2) + with BZ2File(self.filename) as bz2f: + bz2f.seek(-150) + self.assertEqual(bz2f.tell(), 0) + self.assertEqual(bz2f.read(), self.TEXT * 2) + + def testFileno(self): + self.createTempFile() + with open(self.filename, 'rb') as rawf: + bz2f = BZ2File(rawf) + try: + self.assertEqual(bz2f.fileno(), rawf.fileno()) + finally: + bz2f.close() + self.assertRaises(ValueError, bz2f.fileno) + + def testSeekable(self): + bz2f = BZ2File(BytesIO(self.DATA)) + try: + self.assertTrue(bz2f.seekable()) + bz2f.read() + self.assertTrue(bz2f.seekable()) + finally: + bz2f.close() + self.assertRaises(ValueError, bz2f.seekable) + + bz2f = BZ2File(BytesIO(), "w") + try: + self.assertFalse(bz2f.seekable()) + finally: + bz2f.close() + self.assertRaises(ValueError, bz2f.seekable) + + src = BytesIO(self.DATA) + src.seekable = lambda: False + bz2f = BZ2File(src) + try: + self.assertFalse(bz2f.seekable()) + finally: + bz2f.close() + self.assertRaises(ValueError, bz2f.seekable) + + def testReadable(self): + bz2f = BZ2File(BytesIO(self.DATA)) + try: + self.assertTrue(bz2f.readable()) + bz2f.read() + self.assertTrue(bz2f.readable()) + finally: + bz2f.close() + self.assertRaises(ValueError, bz2f.readable) + + bz2f = BZ2File(BytesIO(), "w") + try: + self.assertFalse(bz2f.readable()) + finally: + bz2f.close() + self.assertRaises(ValueError, bz2f.readable) + + def testWritable(self): + bz2f = BZ2File(BytesIO(self.DATA)) + try: + self.assertFalse(bz2f.writable()) + bz2f.read() + self.assertFalse(bz2f.writable()) + finally: + bz2f.close() + self.assertRaises(ValueError, bz2f.writable) + + bz2f = BZ2File(BytesIO(), "w") + try: + self.assertTrue(bz2f.writable()) + finally: + bz2f.close() + self.assertRaises(ValueError, bz2f.writable) + + def testOpenDel(self): + self.createTempFile() + for i in range(10000): + o = BZ2File(self.filename) + del o + + def testOpenNonexistent(self): + self.assertRaises(OSError, BZ2File, "/non/existent") + + def testReadlinesNoNewline(self): + # Issue #1191043: readlines() fails on a file containing no newline. + data = b'BZh91AY&SY\xd9b\x89]\x00\x00\x00\x03\x80\x04\x00\x02\x00\x0c\x00 \x00!\x9ah3M\x13<]\xc9\x14\xe1BCe\x8a%t' + with open(self.filename, "wb") as f: + f.write(data) + with BZ2File(self.filename) as bz2f: + lines = bz2f.readlines() + self.assertEqual(lines, [b'Test']) + with BZ2File(self.filename) as bz2f: + xlines = list(bz2f.readlines()) + self.assertEqual(xlines, [b'Test']) + + def testContextProtocol(self): + f = None + with BZ2File(self.filename, "wb") as f: + f.write(b"xxx") + f = BZ2File(self.filename, "rb") + f.close() + try: + with f: + pass + except ValueError: + pass + else: + self.fail("__enter__ on a closed file didn't raise an exception") + try: + with BZ2File(self.filename, "wb") as f: + 1/0 + except ZeroDivisionError: + pass + else: + self.fail("1/0 didn't raise an exception") + + @threading_helper.requires_working_threading() + def testThreading(self): + # Issue #7205: Using a BZ2File from several threads shouldn't deadlock. + data = b"1" * 2**20 + nthreads = 10 + with BZ2File(self.filename, 'wb') as f: + def comp(): + for i in range(5): + f.write(data) + threads = [threading.Thread(target=comp) for i in range(nthreads)] + with threading_helper.start_threads(threads): + pass + + def testMixedIterationAndReads(self): + self.createTempFile() + linelen = len(self.TEXT_LINES[0]) + halflen = linelen // 2 + with BZ2File(self.filename) as bz2f: + bz2f.read(halflen) + self.assertEqual(next(bz2f), self.TEXT_LINES[0][halflen:]) + self.assertEqual(bz2f.read(), self.TEXT[linelen:]) + with BZ2File(self.filename) as bz2f: + bz2f.readline() + self.assertEqual(next(bz2f), self.TEXT_LINES[1]) + self.assertEqual(bz2f.readline(), self.TEXT_LINES[2]) + with BZ2File(self.filename) as bz2f: + bz2f.readlines() + self.assertRaises(StopIteration, next, bz2f) + self.assertEqual(bz2f.readlines(), []) + + def testMultiStreamOrdering(self): + # Test the ordering of streams when reading a multi-stream archive. + data1 = b"foo" * 1000 + data2 = b"bar" * 1000 + with BZ2File(self.filename, "w") as bz2f: + bz2f.write(data1) + with BZ2File(self.filename, "a") as bz2f: + bz2f.write(data2) + with BZ2File(self.filename) as bz2f: + self.assertEqual(bz2f.read(), data1 + data2) + + def testOpenBytesFilename(self): + str_filename = self.filename + try: + bytes_filename = str_filename.encode("ascii") + except UnicodeEncodeError: + self.skipTest("Temporary file name needs to be ASCII") + with BZ2File(bytes_filename, "wb") as f: + f.write(self.DATA) + with BZ2File(bytes_filename, "rb") as f: + self.assertEqual(f.read(), self.DATA) + # Sanity check that we are actually operating on the right file. + with BZ2File(str_filename, "rb") as f: + self.assertEqual(f.read(), self.DATA) + + def testOpenPathLikeFilename(self): + filename = pathlib.Path(self.filename) + with BZ2File(filename, "wb") as f: + f.write(self.DATA) + with BZ2File(filename, "rb") as f: + self.assertEqual(f.read(), self.DATA) + + def testDecompressLimited(self): + """Decompressed data buffering should be limited""" + bomb = bz2.compress(b'\0' * int(2e6), compresslevel=9) + self.assertLess(len(bomb), _compression.BUFFER_SIZE) + + decomp = BZ2File(BytesIO(bomb)) + self.assertEqual(decomp.read(1), b'\0') + max_decomp = 1 + DEFAULT_BUFFER_SIZE + self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp, + "Excessive amount of data was decompressed") + + + # Tests for a BZ2File wrapping another file object: + + def testReadBytesIO(self): + with BytesIO(self.DATA) as bio: + with BZ2File(bio) as bz2f: + self.assertRaises(TypeError, bz2f.read, float()) + self.assertEqual(bz2f.read(), self.TEXT) + self.assertFalse(bio.closed) + + def testPeekBytesIO(self): + with BytesIO(self.DATA) as bio: + with BZ2File(bio) as bz2f: + pdata = bz2f.peek() + self.assertNotEqual(len(pdata), 0) + self.assertTrue(self.TEXT.startswith(pdata)) + self.assertEqual(bz2f.read(), self.TEXT) + + def testWriteBytesIO(self): + with BytesIO() as bio: + with BZ2File(bio, "w") as bz2f: + self.assertRaises(TypeError, bz2f.write) + bz2f.write(self.TEXT) + self.assertEqual(ext_decompress(bio.getvalue()), self.TEXT) + self.assertFalse(bio.closed) + + def testSeekForwardBytesIO(self): + with BytesIO(self.DATA) as bio: + with BZ2File(bio) as bz2f: + self.assertRaises(TypeError, bz2f.seek) + bz2f.seek(150) + self.assertEqual(bz2f.read(), self.TEXT[150:]) + + def testSeekBackwardsBytesIO(self): + with BytesIO(self.DATA) as bio: + with BZ2File(bio) as bz2f: + bz2f.read(500) + bz2f.seek(-150, 1) + self.assertEqual(bz2f.read(), self.TEXT[500-150:]) + + def test_read_truncated(self): + # Drop the eos_magic field (6 bytes) and CRC (4 bytes). + truncated = self.DATA[:-10] + with BZ2File(BytesIO(truncated)) as f: + self.assertRaises(EOFError, f.read) + with BZ2File(BytesIO(truncated)) as f: + self.assertEqual(f.read(len(self.TEXT)), self.TEXT) + self.assertRaises(EOFError, f.read, 1) + # Incomplete 4-byte file header, and block header of at least 146 bits. + for i in range(22): + with BZ2File(BytesIO(truncated[:i])) as f: + self.assertRaises(EOFError, f.read, 1) + + def test_issue44439(self): + q = array.array('Q', [1, 2, 3, 4, 5]) + LENGTH = len(q) * q.itemsize + + with BZ2File(BytesIO(), 'w') as f: + self.assertEqual(f.write(q), LENGTH) + self.assertEqual(f.tell(), LENGTH) + + +class BZ2CompressorTest(BaseTest): + def testCompress(self): + bz2c = BZ2Compressor() + self.assertRaises(TypeError, bz2c.compress) + data = bz2c.compress(self.TEXT) + data += bz2c.flush() + self.assertEqual(ext_decompress(data), self.TEXT) + + def testCompressEmptyString(self): + bz2c = BZ2Compressor() + data = bz2c.compress(b'') + data += bz2c.flush() + self.assertEqual(data, self.EMPTY_DATA) + + def testCompressChunks10(self): + bz2c = BZ2Compressor() + n = 0 + data = b'' + while True: + str = self.TEXT[n*10:(n+1)*10] + if not str: + break + data += bz2c.compress(str) + n += 1 + data += bz2c.flush() + self.assertEqual(ext_decompress(data), self.TEXT) + + @support.skip_if_pgo_task + @bigmemtest(size=_4G + 100, memuse=2) + def testCompress4G(self, size): + # "Test BZ2Compressor.compress()/flush() with >4GiB input" + bz2c = BZ2Compressor() + data = b"x" * size + try: + compressed = bz2c.compress(data) + compressed += bz2c.flush() + finally: + data = None # Release memory + data = bz2.decompress(compressed) + try: + self.assertEqual(len(data), size) + self.assertEqual(len(data.strip(b"x")), 0) + finally: + data = None + + def testPickle(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises(TypeError): + pickle.dumps(BZ2Compressor(), proto) + + +class BZ2DecompressorTest(BaseTest): + def test_Constructor(self): + self.assertRaises(TypeError, BZ2Decompressor, 42) + + def testDecompress(self): + bz2d = BZ2Decompressor() + self.assertRaises(TypeError, bz2d.decompress) + text = bz2d.decompress(self.DATA) + self.assertEqual(text, self.TEXT) + + def testDecompressChunks10(self): + bz2d = BZ2Decompressor() + text = b'' + n = 0 + while True: + str = self.DATA[n*10:(n+1)*10] + if not str: + break + text += bz2d.decompress(str) + n += 1 + self.assertEqual(text, self.TEXT) + + def testDecompressUnusedData(self): + bz2d = BZ2Decompressor() + unused_data = b"this is unused data" + text = bz2d.decompress(self.DATA+unused_data) + self.assertEqual(text, self.TEXT) + self.assertEqual(bz2d.unused_data, unused_data) + + def testEOFError(self): + bz2d = BZ2Decompressor() + text = bz2d.decompress(self.DATA) + self.assertRaises(EOFError, bz2d.decompress, b"anything") + self.assertRaises(EOFError, bz2d.decompress, b"") + + @support.skip_if_pgo_task + @bigmemtest(size=_4G + 100, memuse=3.3) + def testDecompress4G(self, size): + # "Test BZ2Decompressor.decompress() with >4GiB input" + blocksize = min(10 * 1024 * 1024, size) + block = random.randbytes(blocksize) + try: + data = block * ((size-1) // blocksize + 1) + compressed = bz2.compress(data) + bz2d = BZ2Decompressor() + decompressed = bz2d.decompress(compressed) + self.assertTrue(decompressed == data) + finally: + data = None + compressed = None + decompressed = None + + def testPickle(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises(TypeError): + pickle.dumps(BZ2Decompressor(), proto) + + def testDecompressorChunksMaxsize(self): + bzd = BZ2Decompressor() + max_length = 100 + out = [] + + # Feed some input + len_ = len(self.BIG_DATA) - 64 + out.append(bzd.decompress(self.BIG_DATA[:len_], + max_length=max_length)) + self.assertFalse(bzd.needs_input) + self.assertEqual(len(out[-1]), max_length) + + # Retrieve more data without providing more input + out.append(bzd.decompress(b'', max_length=max_length)) + self.assertFalse(bzd.needs_input) + self.assertEqual(len(out[-1]), max_length) + + # Retrieve more data while providing more input + out.append(bzd.decompress(self.BIG_DATA[len_:], + max_length=max_length)) + self.assertLessEqual(len(out[-1]), max_length) + + # Retrieve remaining uncompressed data + while not bzd.eof: + out.append(bzd.decompress(b'', max_length=max_length)) + self.assertLessEqual(len(out[-1]), max_length) + + out = b"".join(out) + self.assertEqual(out, self.BIG_TEXT) + self.assertEqual(bzd.unused_data, b"") + + def test_decompressor_inputbuf_1(self): + # Test reusing input buffer after moving existing + # contents to beginning + bzd = BZ2Decompressor() + out = [] + + # Create input buffer and fill it + self.assertEqual(bzd.decompress(self.DATA[:100], + max_length=0), b'') + + # Retrieve some results, freeing capacity at beginning + # of input buffer + out.append(bzd.decompress(b'', 2)) + + # Add more data that fits into input buffer after + # moving existing data to beginning + out.append(bzd.decompress(self.DATA[100:105], 15)) + + # Decompress rest of data + out.append(bzd.decompress(self.DATA[105:])) + self.assertEqual(b''.join(out), self.TEXT) + + def test_decompressor_inputbuf_2(self): + # Test reusing input buffer by appending data at the + # end right away + bzd = BZ2Decompressor() + out = [] + + # Create input buffer and empty it + self.assertEqual(bzd.decompress(self.DATA[:200], + max_length=0), b'') + out.append(bzd.decompress(b'')) + + # Fill buffer with new data + out.append(bzd.decompress(self.DATA[200:280], 2)) + + # Append some more data, not enough to require resize + out.append(bzd.decompress(self.DATA[280:300], 2)) + + # Decompress rest of data + out.append(bzd.decompress(self.DATA[300:])) + self.assertEqual(b''.join(out), self.TEXT) + + def test_decompressor_inputbuf_3(self): + # Test reusing input buffer after extending it + + bzd = BZ2Decompressor() + out = [] + + # Create almost full input buffer + out.append(bzd.decompress(self.DATA[:200], 5)) + + # Add even more data to it, requiring resize + out.append(bzd.decompress(self.DATA[200:300], 5)) + + # Decompress rest of data + out.append(bzd.decompress(self.DATA[300:])) + self.assertEqual(b''.join(out), self.TEXT) + + def test_failure(self): + bzd = BZ2Decompressor() + self.assertRaises(Exception, bzd.decompress, self.BAD_DATA * 30) + # Previously, a second call could crash due to internal inconsistency + self.assertRaises(Exception, bzd.decompress, self.BAD_DATA * 30) + + @support.refcount_test + def test_refleaks_in___init__(self): + gettotalrefcount = support.get_attribute(sys, 'gettotalrefcount') + bzd = BZ2Decompressor() + refs_before = gettotalrefcount() + for i in range(100): + bzd.__init__() + self.assertAlmostEqual(gettotalrefcount() - refs_before, 0, delta=10) + + def test_uninitialized_BZ2Decompressor_crash(self): + self.assertEqual(BZ2Decompressor.__new__(BZ2Decompressor). + decompress(bytes()), b'') + + +class CompressDecompressTest(BaseTest): + def testCompress(self): + data = bz2.compress(self.TEXT) + self.assertEqual(ext_decompress(data), self.TEXT) + + def testCompressEmptyString(self): + text = bz2.compress(b'') + self.assertEqual(text, self.EMPTY_DATA) + + def testDecompress(self): + text = bz2.decompress(self.DATA) + self.assertEqual(text, self.TEXT) + + def testDecompressEmpty(self): + text = bz2.decompress(b"") + self.assertEqual(text, b"") + + def testDecompressToEmptyString(self): + text = bz2.decompress(self.EMPTY_DATA) + self.assertEqual(text, b'') + + def testDecompressIncomplete(self): + self.assertRaises(ValueError, bz2.decompress, self.DATA[:-10]) + + def testDecompressBadData(self): + self.assertRaises(OSError, bz2.decompress, self.BAD_DATA) + + def testDecompressMultiStream(self): + text = bz2.decompress(self.DATA * 5) + self.assertEqual(text, self.TEXT * 5) + + def testDecompressTrailingJunk(self): + text = bz2.decompress(self.DATA + self.BAD_DATA) + self.assertEqual(text, self.TEXT) + + def testDecompressMultiStreamTrailingJunk(self): + text = bz2.decompress(self.DATA * 5 + self.BAD_DATA) + self.assertEqual(text, self.TEXT * 5) + + +class OpenTest(BaseTest): + "Test the open function." + + def open(self, *args, **kwargs): + return bz2.open(*args, **kwargs) + + def test_binary_modes(self): + for mode in ("wb", "xb"): + if mode == "xb": + unlink(self.filename) + with self.open(self.filename, mode) as f: + f.write(self.TEXT) + with open(self.filename, "rb") as f: + file_data = ext_decompress(f.read()) + self.assertEqual(file_data, self.TEXT) + with self.open(self.filename, "rb") as f: + self.assertEqual(f.read(), self.TEXT) + with self.open(self.filename, "ab") as f: + f.write(self.TEXT) + with open(self.filename, "rb") as f: + file_data = ext_decompress(f.read()) + self.assertEqual(file_data, self.TEXT * 2) + + def test_implicit_binary_modes(self): + # Test implicit binary modes (no "b" or "t" in mode string). + for mode in ("w", "x"): + if mode == "x": + unlink(self.filename) + with self.open(self.filename, mode) as f: + f.write(self.TEXT) + with open(self.filename, "rb") as f: + file_data = ext_decompress(f.read()) + self.assertEqual(file_data, self.TEXT) + with self.open(self.filename, "r") as f: + self.assertEqual(f.read(), self.TEXT) + with self.open(self.filename, "a") as f: + f.write(self.TEXT) + with open(self.filename, "rb") as f: + file_data = ext_decompress(f.read()) + self.assertEqual(file_data, self.TEXT * 2) + + def test_text_modes(self): + text = self.TEXT.decode("ascii") + text_native_eol = text.replace("\n", os.linesep) + for mode in ("wt", "xt"): + if mode == "xt": + unlink(self.filename) + with self.open(self.filename, mode, encoding="ascii") as f: + f.write(text) + with open(self.filename, "rb") as f: + file_data = ext_decompress(f.read()).decode("ascii") + self.assertEqual(file_data, text_native_eol) + with self.open(self.filename, "rt", encoding="ascii") as f: + self.assertEqual(f.read(), text) + with self.open(self.filename, "at", encoding="ascii") as f: + f.write(text) + with open(self.filename, "rb") as f: + file_data = ext_decompress(f.read()).decode("ascii") + self.assertEqual(file_data, text_native_eol * 2) + + def test_x_mode(self): + for mode in ("x", "xb", "xt"): + unlink(self.filename) + encoding = "utf-8" if "t" in mode else None + with self.open(self.filename, mode, encoding=encoding) as f: + pass + with self.assertRaises(FileExistsError): + with self.open(self.filename, mode) as f: + pass + + def test_fileobj(self): + with self.open(BytesIO(self.DATA), "r") as f: + self.assertEqual(f.read(), self.TEXT) + with self.open(BytesIO(self.DATA), "rb") as f: + self.assertEqual(f.read(), self.TEXT) + text = self.TEXT.decode("ascii") + with self.open(BytesIO(self.DATA), "rt", encoding="utf-8") as f: + self.assertEqual(f.read(), text) + + def test_bad_params(self): + # Test invalid parameter combinations. + self.assertRaises(ValueError, + self.open, self.filename, "wbt") + self.assertRaises(ValueError, + self.open, self.filename, "xbt") + self.assertRaises(ValueError, + self.open, self.filename, "rb", encoding="utf-8") + self.assertRaises(ValueError, + self.open, self.filename, "rb", errors="ignore") + self.assertRaises(ValueError, + self.open, self.filename, "rb", newline="\n") + + def test_encoding(self): + # Test non-default encoding. + text = self.TEXT.decode("ascii") + text_native_eol = text.replace("\n", os.linesep) + with self.open(self.filename, "wt", encoding="utf-16-le") as f: + f.write(text) + with open(self.filename, "rb") as f: + file_data = ext_decompress(f.read()).decode("utf-16-le") + self.assertEqual(file_data, text_native_eol) + with self.open(self.filename, "rt", encoding="utf-16-le") as f: + self.assertEqual(f.read(), text) + + def test_encoding_error_handler(self): + # Test with non-default encoding error handler. + with self.open(self.filename, "wb") as f: + f.write(b"foo\xffbar") + with self.open(self.filename, "rt", encoding="ascii", errors="ignore") \ + as f: + self.assertEqual(f.read(), "foobar") + + def test_newline(self): + # Test with explicit newline (universal newline mode disabled). + text = self.TEXT.decode("ascii") + with self.open(self.filename, "wt", encoding="utf-8", newline="\n") as f: + f.write(text) + with self.open(self.filename, "rt", encoding="utf-8", newline="\r") as f: + self.assertEqual(f.readlines(), [text]) + + +def tearDownModule(): + support.reap_children() + + +if __name__ == '__main__': + unittest.main() From f519ffdb184b744d06078adf7d4bd0c460ea144e Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Sun, 24 Dec 2023 04:11:04 +0800 Subject: [PATCH 203/893] Add asynchat.py and asyncore.py from CPython v3.12.0 --- Lib/test/support/asynchat.py | 314 +++++++++++++++++ Lib/test/support/asyncore.py | 649 +++++++++++++++++++++++++++++++++++ 2 files changed, 963 insertions(+) create mode 100644 Lib/test/support/asynchat.py create mode 100644 Lib/test/support/asyncore.py diff --git a/Lib/test/support/asynchat.py b/Lib/test/support/asynchat.py new file mode 100644 index 0000000000..38c47a1fda --- /dev/null +++ b/Lib/test/support/asynchat.py @@ -0,0 +1,314 @@ +# TODO: This module was deprecated and removed from CPython 3.12 +# Now it is a test-only helper. Any attempts to rewrite exising tests that +# are using this module and remove it completely are appreciated! +# See: https://github.com/python/cpython/issues/72719 + +# -*- Mode: Python; tab-width: 4 -*- +# Id: asynchat.py,v 2.26 2000/09/07 22:29:26 rushing Exp +# Author: Sam Rushing + +# ====================================================================== +# Copyright 1996 by Sam Rushing +# +# All Rights Reserved +# +# Permission to use, copy, modify, and distribute this software and +# its documentation for any purpose and without fee is hereby +# granted, provided that the above copyright notice appear in all +# copies and that both that copyright notice and this permission +# notice appear in supporting documentation, and that the name of Sam +# Rushing not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, +# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN +# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR +# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS +# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, +# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# ====================================================================== + +r"""A class supporting chat-style (command/response) protocols. + +This class adds support for 'chat' style protocols - where one side +sends a 'command', and the other sends a response (examples would be +the common internet protocols - smtp, nntp, ftp, etc..). + +The handle_read() method looks at the input stream for the current +'terminator' (usually '\r\n' for single-line responses, '\r\n.\r\n' +for multi-line output), calling self.found_terminator() on its +receipt. + +for example: +Say you build an async nntp client using this class. At the start +of the connection, you'll have self.terminator set to '\r\n', in +order to process the single-line greeting. Just before issuing a +'LIST' command you'll set it to '\r\n.\r\n'. The output of the LIST +command will be accumulated (using your own 'collect_incoming_data' +method) up to the terminator, and then control will be returned to +you - by calling your self.found_terminator() method. +""" + +from collections import deque + +from test.support import asyncore + + +class async_chat(asyncore.dispatcher): + """This is an abstract class. You must derive from this class, and add + the two methods collect_incoming_data() and found_terminator()""" + + # these are overridable defaults + + ac_in_buffer_size = 65536 + ac_out_buffer_size = 65536 + + # we don't want to enable the use of encoding by default, because that is a + # sign of an application bug that we don't want to pass silently + + use_encoding = 0 + encoding = 'latin-1' + + def __init__(self, sock=None, map=None): + # for string terminator matching + self.ac_in_buffer = b'' + + # we use a list here rather than io.BytesIO for a few reasons... + # del lst[:] is faster than bio.truncate(0) + # lst = [] is faster than bio.truncate(0) + self.incoming = [] + + # we toss the use of the "simple producer" and replace it with + # a pure deque, which the original fifo was a wrapping of + self.producer_fifo = deque() + asyncore.dispatcher.__init__(self, sock, map) + + def collect_incoming_data(self, data): + raise NotImplementedError("must be implemented in subclass") + + def _collect_incoming_data(self, data): + self.incoming.append(data) + + def _get_data(self): + d = b''.join(self.incoming) + del self.incoming[:] + return d + + def found_terminator(self): + raise NotImplementedError("must be implemented in subclass") + + def set_terminator(self, term): + """Set the input delimiter. + + Can be a fixed string of any length, an integer, or None. + """ + if isinstance(term, str) and self.use_encoding: + term = bytes(term, self.encoding) + elif isinstance(term, int) and term < 0: + raise ValueError('the number of received bytes must be positive') + self.terminator = term + + def get_terminator(self): + return self.terminator + + # grab some more data from the socket, + # throw it to the collector method, + # check for the terminator, + # if found, transition to the next state. + + def handle_read(self): + + try: + data = self.recv(self.ac_in_buffer_size) + except BlockingIOError: + return + except OSError: + self.handle_error() + return + + if isinstance(data, str) and self.use_encoding: + data = bytes(str, self.encoding) + self.ac_in_buffer = self.ac_in_buffer + data + + # Continue to search for self.terminator in self.ac_in_buffer, + # while calling self.collect_incoming_data. The while loop + # is necessary because we might read several data+terminator + # combos with a single recv(4096). + + while self.ac_in_buffer: + lb = len(self.ac_in_buffer) + terminator = self.get_terminator() + if not terminator: + # no terminator, collect it all + self.collect_incoming_data(self.ac_in_buffer) + self.ac_in_buffer = b'' + elif isinstance(terminator, int): + # numeric terminator + n = terminator + if lb < n: + self.collect_incoming_data(self.ac_in_buffer) + self.ac_in_buffer = b'' + self.terminator = self.terminator - lb + else: + self.collect_incoming_data(self.ac_in_buffer[:n]) + self.ac_in_buffer = self.ac_in_buffer[n:] + self.terminator = 0 + self.found_terminator() + else: + # 3 cases: + # 1) end of buffer matches terminator exactly: + # collect data, transition + # 2) end of buffer matches some prefix: + # collect data to the prefix + # 3) end of buffer does not match any prefix: + # collect data + terminator_len = len(terminator) + index = self.ac_in_buffer.find(terminator) + if index != -1: + # we found the terminator + if index > 0: + # don't bother reporting the empty string + # (source of subtle bugs) + self.collect_incoming_data(self.ac_in_buffer[:index]) + self.ac_in_buffer = self.ac_in_buffer[index+terminator_len:] + # This does the Right Thing if the terminator + # is changed here. + self.found_terminator() + else: + # check for a prefix of the terminator + index = find_prefix_at_end(self.ac_in_buffer, terminator) + if index: + if index != lb: + # we found a prefix, collect up to the prefix + self.collect_incoming_data(self.ac_in_buffer[:-index]) + self.ac_in_buffer = self.ac_in_buffer[-index:] + break + else: + # no prefix, collect it all + self.collect_incoming_data(self.ac_in_buffer) + self.ac_in_buffer = b'' + + def handle_write(self): + self.initiate_send() + + def handle_close(self): + self.close() + + def push(self, data): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + sabs = self.ac_out_buffer_size + if len(data) > sabs: + for i in range(0, len(data), sabs): + self.producer_fifo.append(data[i:i+sabs]) + else: + self.producer_fifo.append(data) + self.initiate_send() + + def push_with_producer(self, producer): + self.producer_fifo.append(producer) + self.initiate_send() + + def readable(self): + "predicate for inclusion in the readable for select()" + # cannot use the old predicate, it violates the claim of the + # set_terminator method. + + # return (len(self.ac_in_buffer) <= self.ac_in_buffer_size) + return 1 + + def writable(self): + "predicate for inclusion in the writable for select()" + return self.producer_fifo or (not self.connected) + + def close_when_done(self): + "automatically close this channel once the outgoing queue is empty" + self.producer_fifo.append(None) + + def initiate_send(self): + while self.producer_fifo and self.connected: + first = self.producer_fifo[0] + # handle empty string/buffer or None entry + if not first: + del self.producer_fifo[0] + if first is None: + self.handle_close() + return + + # handle classic producer behavior + obs = self.ac_out_buffer_size + try: + data = first[:obs] + except TypeError: + data = first.more() + if data: + self.producer_fifo.appendleft(data) + else: + del self.producer_fifo[0] + continue + + if isinstance(data, str) and self.use_encoding: + data = bytes(data, self.encoding) + + # send the data + try: + num_sent = self.send(data) + except OSError: + self.handle_error() + return + + if num_sent: + if num_sent < len(data) or obs < len(first): + self.producer_fifo[0] = first[num_sent:] + else: + del self.producer_fifo[0] + # we tried to send some actual data + return + + def discard_buffers(self): + # Emergencies only! + self.ac_in_buffer = b'' + del self.incoming[:] + self.producer_fifo.clear() + + +class simple_producer: + + def __init__(self, data, buffer_size=512): + self.data = data + self.buffer_size = buffer_size + + def more(self): + if len(self.data) > self.buffer_size: + result = self.data[:self.buffer_size] + self.data = self.data[self.buffer_size:] + return result + else: + result = self.data + self.data = b'' + return result + + +# Given 'haystack', see if any prefix of 'needle' is at its end. This +# assumes an exact match has already been checked. Return the number of +# characters matched. +# for example: +# f_p_a_e("qwerty\r", "\r\n") => 1 +# f_p_a_e("qwertydkjf", "\r\n") => 0 +# f_p_a_e("qwerty\r\n", "\r\n") => + +# this could maybe be made faster with a computed regex? +# [answer: no; circa Python-2.0, Jan 2001] +# new python: 28961/s +# old python: 18307/s +# re: 12820/s +# regex: 14035/s + +def find_prefix_at_end(haystack, needle): + l = len(needle) - 1 + while l and not haystack.endswith(needle[:l]): + l -= 1 + return l diff --git a/Lib/test/support/asyncore.py b/Lib/test/support/asyncore.py new file mode 100644 index 0000000000..b397aca556 --- /dev/null +++ b/Lib/test/support/asyncore.py @@ -0,0 +1,649 @@ +# TODO: This module was deprecated and removed from CPython 3.12 +# Now it is a test-only helper. Any attempts to rewrite exising tests that +# are using this module and remove it completely are appreciated! +# See: https://github.com/python/cpython/issues/72719 + +# -*- Mode: Python -*- +# Id: asyncore.py,v 2.51 2000/09/07 22:29:26 rushing Exp +# Author: Sam Rushing + +# ====================================================================== +# Copyright 1996 by Sam Rushing +# +# All Rights Reserved +# +# Permission to use, copy, modify, and distribute this software and +# its documentation for any purpose and without fee is hereby +# granted, provided that the above copyright notice appear in all +# copies and that both that copyright notice and this permission +# notice appear in supporting documentation, and that the name of Sam +# Rushing not be used in advertising or publicity pertaining to +# distribution of the software without specific, written prior +# permission. +# +# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, +# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN +# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR +# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS +# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, +# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +# ====================================================================== + +"""Basic infrastructure for asynchronous socket service clients and servers. + +There are only two ways to have a program on a single processor do "more +than one thing at a time". Multi-threaded programming is the simplest and +most popular way to do it, but there is another very different technique, +that lets you have nearly all the advantages of multi-threading, without +actually using multiple threads. it's really only practical if your program +is largely I/O bound. If your program is CPU bound, then pre-emptive +scheduled threads are probably what you really need. Network servers are +rarely CPU-bound, however. + +If your operating system supports the select() system call in its I/O +library (and nearly all do), then you can use it to juggle multiple +communication channels at once; doing other work while your I/O is taking +place in the "background." Although this strategy can seem strange and +complex, especially at first, it is in many ways easier to understand and +control than multi-threaded programming. The module documented here solves +many of the difficult problems for you, making the task of building +sophisticated high-performance network servers and clients a snap. +""" + +import select +import socket +import sys +import time +import warnings + +import os +from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, EINVAL, \ + ENOTCONN, ESHUTDOWN, EISCONN, EBADF, ECONNABORTED, EPIPE, EAGAIN, \ + errorcode + + +_DISCONNECTED = frozenset({ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, + EBADF}) + +try: + socket_map +except NameError: + socket_map = {} + +def _strerror(err): + try: + return os.strerror(err) + except (ValueError, OverflowError, NameError): + if err in errorcode: + return errorcode[err] + return "Unknown error %s" %err + +class ExitNow(Exception): + pass + +_reraised_exceptions = (ExitNow, KeyboardInterrupt, SystemExit) + +def read(obj): + try: + obj.handle_read_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + +def write(obj): + try: + obj.handle_write_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + +def _exception(obj): + try: + obj.handle_expt_event() + except _reraised_exceptions: + raise + except: + obj.handle_error() + +def readwrite(obj, flags): + try: + if flags & select.POLLIN: + obj.handle_read_event() + if flags & select.POLLOUT: + obj.handle_write_event() + if flags & select.POLLPRI: + obj.handle_expt_event() + if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL): + obj.handle_close() + except OSError as e: + if e.errno not in _DISCONNECTED: + obj.handle_error() + else: + obj.handle_close() + except _reraised_exceptions: + raise + except: + obj.handle_error() + +def poll(timeout=0.0, map=None): + if map is None: + map = socket_map + if map: + r = []; w = []; e = [] + for fd, obj in list(map.items()): + is_r = obj.readable() + is_w = obj.writable() + if is_r: + r.append(fd) + # accepting sockets should not be writable + if is_w and not obj.accepting: + w.append(fd) + if is_r or is_w: + e.append(fd) + if [] == r == w == e: + time.sleep(timeout) + return + + r, w, e = select.select(r, w, e, timeout) + + for fd in r: + obj = map.get(fd) + if obj is None: + continue + read(obj) + + for fd in w: + obj = map.get(fd) + if obj is None: + continue + write(obj) + + for fd in e: + obj = map.get(fd) + if obj is None: + continue + _exception(obj) + +def poll2(timeout=0.0, map=None): + # Use the poll() support added to the select module in Python 2.0 + if map is None: + map = socket_map + if timeout is not None: + # timeout is in milliseconds + timeout = int(timeout*1000) + pollster = select.poll() + if map: + for fd, obj in list(map.items()): + flags = 0 + if obj.readable(): + flags |= select.POLLIN | select.POLLPRI + # accepting sockets should not be writable + if obj.writable() and not obj.accepting: + flags |= select.POLLOUT + if flags: + pollster.register(fd, flags) + + r = pollster.poll(timeout) + for fd, flags in r: + obj = map.get(fd) + if obj is None: + continue + readwrite(obj, flags) + +poll3 = poll2 # Alias for backward compatibility + +def loop(timeout=30.0, use_poll=False, map=None, count=None): + if map is None: + map = socket_map + + if use_poll and hasattr(select, 'poll'): + poll_fun = poll2 + else: + poll_fun = poll + + if count is None: + while map: + poll_fun(timeout, map) + + else: + while map and count > 0: + poll_fun(timeout, map) + count = count - 1 + +class dispatcher: + + debug = False + connected = False + accepting = False + connecting = False + closing = False + addr = None + ignore_log_types = frozenset({'warning'}) + + def __init__(self, sock=None, map=None): + if map is None: + self._map = socket_map + else: + self._map = map + + self._fileno = None + + if sock: + # Set to nonblocking just to make sure for cases where we + # get a socket from a blocking source. + sock.setblocking(False) + self.set_socket(sock, map) + self.connected = True + # The constructor no longer requires that the socket + # passed be connected. + try: + self.addr = sock.getpeername() + except OSError as err: + if err.errno in (ENOTCONN, EINVAL): + # To handle the case where we got an unconnected + # socket. + self.connected = False + else: + # The socket is broken in some unknown way, alert + # the user and remove it from the map (to prevent + # polling of broken sockets). + self.del_channel(map) + raise + else: + self.socket = None + + def __repr__(self): + status = [self.__class__.__module__+"."+self.__class__.__qualname__] + if self.accepting and self.addr: + status.append('listening') + elif self.connected: + status.append('connected') + if self.addr is not None: + try: + status.append('%s:%d' % self.addr) + except TypeError: + status.append(repr(self.addr)) + return '<%s at %#x>' % (' '.join(status), id(self)) + + def add_channel(self, map=None): + #self.log_info('adding channel %s' % self) + if map is None: + map = self._map + map[self._fileno] = self + + def del_channel(self, map=None): + fd = self._fileno + if map is None: + map = self._map + if fd in map: + #self.log_info('closing channel %d:%s' % (fd, self)) + del map[fd] + self._fileno = None + + def create_socket(self, family=socket.AF_INET, type=socket.SOCK_STREAM): + self.family_and_type = family, type + sock = socket.socket(family, type) + sock.setblocking(False) + self.set_socket(sock) + + def set_socket(self, sock, map=None): + self.socket = sock + self._fileno = sock.fileno() + self.add_channel(map) + + def set_reuse_addr(self): + # try to re-use a server port if possible + try: + self.socket.setsockopt( + socket.SOL_SOCKET, socket.SO_REUSEADDR, + self.socket.getsockopt(socket.SOL_SOCKET, + socket.SO_REUSEADDR) | 1 + ) + except OSError: + pass + + # ================================================== + # predicates for select() + # these are used as filters for the lists of sockets + # to pass to select(). + # ================================================== + + def readable(self): + return True + + def writable(self): + return True + + # ================================================== + # socket object methods. + # ================================================== + + def listen(self, num): + self.accepting = True + if os.name == 'nt' and num > 5: + num = 5 + return self.socket.listen(num) + + def bind(self, addr): + self.addr = addr + return self.socket.bind(addr) + + def connect(self, address): + self.connected = False + self.connecting = True + err = self.socket.connect_ex(address) + if err in (EINPROGRESS, EALREADY, EWOULDBLOCK) \ + or err == EINVAL and os.name == 'nt': + self.addr = address + return + if err in (0, EISCONN): + self.addr = address + self.handle_connect_event() + else: + raise OSError(err, errorcode[err]) + + def accept(self): + # XXX can return either an address pair or None + try: + conn, addr = self.socket.accept() + except TypeError: + return None + except OSError as why: + if why.errno in (EWOULDBLOCK, ECONNABORTED, EAGAIN): + return None + else: + raise + else: + return conn, addr + + def send(self, data): + try: + result = self.socket.send(data) + return result + except OSError as why: + if why.errno == EWOULDBLOCK: + return 0 + elif why.errno in _DISCONNECTED: + self.handle_close() + return 0 + else: + raise + + def recv(self, buffer_size): + try: + data = self.socket.recv(buffer_size) + if not data: + # a closed connection is indicated by signaling + # a read condition, and having recv() return 0. + self.handle_close() + return b'' + else: + return data + except OSError as why: + # winsock sometimes raises ENOTCONN + if why.errno in _DISCONNECTED: + self.handle_close() + return b'' + else: + raise + + def close(self): + self.connected = False + self.accepting = False + self.connecting = False + self.del_channel() + if self.socket is not None: + try: + self.socket.close() + except OSError as why: + if why.errno not in (ENOTCONN, EBADF): + raise + + # log and log_info may be overridden to provide more sophisticated + # logging and warning methods. In general, log is for 'hit' logging + # and 'log_info' is for informational, warning and error logging. + + def log(self, message): + sys.stderr.write('log: %s\n' % str(message)) + + def log_info(self, message, type='info'): + if type not in self.ignore_log_types: + print('%s: %s' % (type, message)) + + def handle_read_event(self): + if self.accepting: + # accepting sockets are never connected, they "spawn" new + # sockets that are connected + self.handle_accept() + elif not self.connected: + if self.connecting: + self.handle_connect_event() + self.handle_read() + else: + self.handle_read() + + def handle_connect_event(self): + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise OSError(err, _strerror(err)) + self.handle_connect() + self.connected = True + self.connecting = False + + def handle_write_event(self): + if self.accepting: + # Accepting sockets shouldn't get a write event. + # We will pretend it didn't happen. + return + + if not self.connected: + if self.connecting: + self.handle_connect_event() + self.handle_write() + + def handle_expt_event(self): + # handle_expt_event() is called if there might be an error on the + # socket, or if there is OOB data + # check for the error condition first + err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # we can get here when select.select() says that there is an + # exceptional condition on the socket + # since there is an error, we'll go ahead and close the socket + # like we would in a subclassed handle_read() that received no + # data + self.handle_close() + else: + self.handle_expt() + + def handle_error(self): + nil, t, v, tbinfo = compact_traceback() + + # sometimes a user repr method will crash. + try: + self_repr = repr(self) + except: + self_repr = '<__repr__(self) failed for object at %0x>' % id(self) + + self.log_info( + 'uncaptured python exception, closing channel %s (%s:%s %s)' % ( + self_repr, + t, + v, + tbinfo + ), + 'error' + ) + self.handle_close() + + def handle_expt(self): + self.log_info('unhandled incoming priority event', 'warning') + + def handle_read(self): + self.log_info('unhandled read event', 'warning') + + def handle_write(self): + self.log_info('unhandled write event', 'warning') + + def handle_connect(self): + self.log_info('unhandled connect event', 'warning') + + def handle_accept(self): + pair = self.accept() + if pair is not None: + self.handle_accepted(*pair) + + def handle_accepted(self, sock, addr): + sock.close() + self.log_info('unhandled accepted event', 'warning') + + def handle_close(self): + self.log_info('unhandled close event', 'warning') + self.close() + +# --------------------------------------------------------------------------- +# adds simple buffered output capability, useful for simple clients. +# [for more sophisticated usage use asynchat.async_chat] +# --------------------------------------------------------------------------- + +class dispatcher_with_send(dispatcher): + + def __init__(self, sock=None, map=None): + dispatcher.__init__(self, sock, map) + self.out_buffer = b'' + + def initiate_send(self): + num_sent = 0 + num_sent = dispatcher.send(self, self.out_buffer[:65536]) + self.out_buffer = self.out_buffer[num_sent:] + + def handle_write(self): + self.initiate_send() + + def writable(self): + return (not self.connected) or len(self.out_buffer) + + def send(self, data): + if self.debug: + self.log_info('sending %s' % repr(data)) + self.out_buffer = self.out_buffer + data + self.initiate_send() + +# --------------------------------------------------------------------------- +# used for debugging. +# --------------------------------------------------------------------------- + +def compact_traceback(): + exc = sys.exception() + tb = exc.__traceback__ + if not tb: # Must have a traceback + raise AssertionError("traceback does not exist") + tbinfo = [] + while tb: + tbinfo.append(( + tb.tb_frame.f_code.co_filename, + tb.tb_frame.f_code.co_name, + str(tb.tb_lineno) + )) + tb = tb.tb_next + + # just to be safe + del tb + + file, function, line = tbinfo[-1] + info = ' '.join(['[%s|%s|%s]' % x for x in tbinfo]) + return (file, function, line), type(exc), exc, info + +def close_all(map=None, ignore_all=False): + if map is None: + map = socket_map + for x in list(map.values()): + try: + x.close() + except OSError as x: + if x.errno == EBADF: + pass + elif not ignore_all: + raise + except _reraised_exceptions: + raise + except: + if not ignore_all: + raise + map.clear() + +# Asynchronous File I/O: +# +# After a little research (reading man pages on various unixen, and +# digging through the linux kernel), I've determined that select() +# isn't meant for doing asynchronous file i/o. +# Heartening, though - reading linux/mm/filemap.c shows that linux +# supports asynchronous read-ahead. So _MOST_ of the time, the data +# will be sitting in memory for us already when we go to read it. +# +# What other OS's (besides NT) support async file i/o? [VMS?] +# +# Regardless, this is useful for pipes, and stdin/stdout... + +if os.name == 'posix': + class file_wrapper: + # Here we override just enough to make a file + # look like a socket for the purposes of asyncore. + # The passed fd is automatically os.dup()'d + + def __init__(self, fd): + self.fd = os.dup(fd) + + def __del__(self): + if self.fd >= 0: + warnings.warn("unclosed file %r" % self, ResourceWarning, + source=self) + self.close() + + def recv(self, *args): + return os.read(self.fd, *args) + + def send(self, *args): + return os.write(self.fd, *args) + + def getsockopt(self, level, optname, buflen=None): + if (level == socket.SOL_SOCKET and + optname == socket.SO_ERROR and + not buflen): + return 0 + raise NotImplementedError("Only asyncore specific behaviour " + "implemented.") + + read = recv + write = send + + def close(self): + if self.fd < 0: + return + fd = self.fd + self.fd = -1 + os.close(fd) + + def fileno(self): + return self.fd + + class file_dispatcher(dispatcher): + + def __init__(self, fd, map=None): + dispatcher.__init__(self, None, map) + self.connected = True + try: + fd = fd.fileno() + except AttributeError: + pass + self.set_file(fd) + # set it to non-blocking mode + os.set_blocking(fd, False) + + def set_file(self, fd): + self.socket = file_wrapper(fd) + self._fileno = self.socket.fileno() + self.add_channel() From de626b8627d2af71291a691bbf01204a8a05a880 Mon Sep 17 00:00:00 2001 From: NakanoMiku Date: Mon, 25 Dec 2023 15:19:28 +0800 Subject: [PATCH 204/893] Update test_file.py from CPython v3.12.0 --- Lib/test/test_file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/test/test_file.py b/Lib/test/test_file.py index 247c9675e6..d998af936a 100644 --- a/Lib/test/test_file.py +++ b/Lib/test/test_file.py @@ -217,7 +217,7 @@ def testSetBufferSize(self): self._checkBufferSize(1) def testTruncateOnWindows(self): - # SF bug + # SF bug # "file.truncate fault on windows" f = self.open(TESTFN, 'wb') From 727f97fd489ae235fde8986d60ac1cfaa1ba4324 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Thu, 28 Dec 2023 01:23:35 +0900 Subject: [PATCH 205/893] Fix malachite-bigint version --- Cargo.lock | 16 ++++++++-------- Cargo.toml | 12 ++++++------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c0c59f1552..4edd2d644b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1257,9 +1257,9 @@ dependencies = [ [[package]] name = "malachite-bigint" -version = "0.1.1" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76c3eca3b5df299486144c8423c45c24bdf9e82e2452c8a1eeda547c4d8b5d41" +checksum = "17703a19c80bbdd0b7919f0f104f3b0597f7de4fc4e90a477c15366a5ba03faa" dependencies = [ "derive_more", "malachite", @@ -1946,7 +1946,7 @@ dependencies = [ [[package]] name = "rustpython-ast" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" +source = "git+https://github.com/RustPython/Parser.git?rev=29c4728dbedc7e69cc2560b9b34058bbba9b1303#29c4728dbedc7e69cc2560b9b34058bbba9b1303" dependencies = [ "is-macro", "malachite-bigint", @@ -2058,7 +2058,7 @@ dependencies = [ [[package]] name = "rustpython-format" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" +source = "git+https://github.com/RustPython/Parser.git?rev=29c4728dbedc7e69cc2560b9b34058bbba9b1303#29c4728dbedc7e69cc2560b9b34058bbba9b1303" dependencies = [ "bitflags 2.4.0", "itertools 0.11.0", @@ -2085,7 +2085,7 @@ dependencies = [ [[package]] name = "rustpython-literal" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" +source = "git+https://github.com/RustPython/Parser.git?rev=29c4728dbedc7e69cc2560b9b34058bbba9b1303#29c4728dbedc7e69cc2560b9b34058bbba9b1303" dependencies = [ "hexf-parse", "is-macro", @@ -2097,7 +2097,7 @@ dependencies = [ [[package]] name = "rustpython-parser" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" +source = "git+https://github.com/RustPython/Parser.git?rev=29c4728dbedc7e69cc2560b9b34058bbba9b1303#29c4728dbedc7e69cc2560b9b34058bbba9b1303" dependencies = [ "anyhow", "is-macro", @@ -2120,7 +2120,7 @@ dependencies = [ [[package]] name = "rustpython-parser-core" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" +source = "git+https://github.com/RustPython/Parser.git?rev=29c4728dbedc7e69cc2560b9b34058bbba9b1303#29c4728dbedc7e69cc2560b9b34058bbba9b1303" dependencies = [ "is-macro", "memchr", @@ -2130,7 +2130,7 @@ dependencies = [ [[package]] name = "rustpython-parser-vendored" version = "0.3.1" -source = "git+https://github.com/RustPython/Parser.git?rev=52edf4525ec300f2b69670f3991784bbcf140564#52edf4525ec300f2b69670f3991784bbcf140564" +source = "git+https://github.com/RustPython/Parser.git?rev=29c4728dbedc7e69cc2560b9b34058bbba9b1303#29c4728dbedc7e69cc2560b9b34058bbba9b1303" dependencies = [ "memchr", "once_cell", diff --git a/Cargo.toml b/Cargo.toml index 680f090cc4..a91f2ea05f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,11 +29,11 @@ rustpython-pylib = { path = "pylib", version = "0.3.0" } rustpython-stdlib = { path = "stdlib", default-features = false, version = "0.3.0" } rustpython-doc = { git = "https://github.com/RustPython/__doc__", tag = "0.3.0", version = "0.3.0" } -rustpython-literal = { git = "https://github.com/RustPython/Parser.git", rev = "52edf4525ec300f2b69670f3991784bbcf140564" } -rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", rev = "52edf4525ec300f2b69670f3991784bbcf140564" } -rustpython-parser = { git = "https://github.com/RustPython/Parser.git", rev = "52edf4525ec300f2b69670f3991784bbcf140564" } -rustpython-ast = { git = "https://github.com/RustPython/Parser.git", rev = "52edf4525ec300f2b69670f3991784bbcf140564" } -rustpython-format = { git = "https://github.com/RustPython/Parser.git", rev = "52edf4525ec300f2b69670f3991784bbcf140564" } +rustpython-literal = { git = "https://github.com/RustPython/Parser.git", rev = "29c4728dbedc7e69cc2560b9b34058bbba9b1303" } +rustpython-parser-core = { git = "https://github.com/RustPython/Parser.git", rev = "29c4728dbedc7e69cc2560b9b34058bbba9b1303" } +rustpython-parser = { git = "https://github.com/RustPython/Parser.git", rev = "29c4728dbedc7e69cc2560b9b34058bbba9b1303" } +rustpython-ast = { git = "https://github.com/RustPython/Parser.git", rev = "29c4728dbedc7e69cc2560b9b34058bbba9b1303" } +rustpython-format = { git = "https://github.com/RustPython/Parser.git", rev = "29c4728dbedc7e69cc2560b9b34058bbba9b1303" } # rustpython-literal = { path = "../RustPython-parser/literal" } # rustpython-parser-core = { path = "../RustPython-parser/core" } # rustpython-parser = { path = "../RustPython-parser/parser" } @@ -59,7 +59,7 @@ is-macro = "0.3.0" libc = "0.2.133" log = "0.4.16" nix = "0.26" -malachite-bigint = "0.1.1" +malachite-bigint = "0.2.0" malachite-q = "0.4.4" malachite-base = "0.4.4" num-complex = "0.4.0" From 16a3edd4328e6536165b55d1dee5f66ddd5d72b3 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Thu, 28 Dec 2023 02:34:22 +0900 Subject: [PATCH 206/893] Fix wasi CI --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e616f6d10c..4f31ba74c1 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -415,4 +415,4 @@ jobs: - name: build rustpython run: cargo build --release --target wasm32-wasi --features freeze-stdlib,stdlib --verbose - name: run snippets - run: wasmer run --dir . target/wasm32-wasi/release/rustpython.wasm -- extra_tests/snippets/stdlib_random.py + run: wasmer run --dir `pwd` target/wasm32-wasi/release/rustpython.wasm -- `pwd`/extra_tests/snippets/stdlib_random.py From 459cad8407cc42c7db9041667e49065991f9c5e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Moskal?= Date: Thu, 28 Dec 2023 05:08:08 +0100 Subject: [PATCH 207/893] update indexmap to 1.9.3 (#5128) --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4edd2d644b..7bd76c9783 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -983,9 +983,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.2" +version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885e79c1fc4b10f0e172c475f458b7f7b93061064d98c3293e98c5ba0c8b399" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown", diff --git a/Cargo.toml b/Cargo.toml index a91f2ea05f..35a03003ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,7 +52,7 @@ crossbeam-utils = "0.8.16" flame = "0.2.2" glob = "0.3" hex = "0.4.3" -indexmap = "1.8.1" +indexmap = { version = "1.9.3", features = ["std"] } insta = "1.33.0" itertools = "0.11.0" is-macro = "0.3.0" From fc91cd8bc7bf52b54b8f0cd0a1f29f95aa7d3e4b Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Thu, 28 Dec 2023 21:51:26 +0900 Subject: [PATCH 208/893] clean up winapi features (#5141) --- stdlib/Cargo.toml | 5 ++--- vm/Cargo.toml | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index 1cd6f09b5e..cd04f47b83 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -112,9 +112,8 @@ widestring = { workspace = true } [target.'cfg(windows)'.dependencies.winapi] version = "0.3.9" features = [ - "winsock2", "handleapi", "ws2def", "std", "winbase", "wincrypt", "fileapi", "processenv", - "namedpipeapi", "winnt", "processthreadsapi", "errhandlingapi", "winuser", "synchapi", "wincon", - "impl-default", "vcruntime", "ifdef", "netioapi" + "winsock2", "ws2def", "std", "wincrypt", "fileapi", + "impl-default", "vcruntime", "ifdef", "netioapi", "profileapi", ] [target.'cfg(target_os = "macos")'.dependencies] diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 7ba9a44f14..c8fd0174ca 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -139,9 +139,9 @@ features = [ [target.'cfg(windows)'.dependencies.winapi] version = "0.3.9" features = [ - "winsock2", "handleapi", "ws2def", "std", "winbase", "wincrypt", "fileapi", "processenv", - "namedpipeapi", "winnt", "processthreadsapi", "errhandlingapi", "winuser", "synchapi", "wincon", - "impl-default", "vcruntime", "ifdef", "netioapi", "memoryapi", "profileapi", "sysinfoapi" + "winsock2", "handleapi", "std", "winbase", "processenv", + "winnt", "processthreadsapi", "errhandlingapi", "wincon", + "impl-default", "vcruntime", "sysinfoapi", ] [target.'cfg(target_arch = "wasm32")'.dependencies] From adf0788895d510efb3df390e0ff83defd7881821 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Thu, 28 Dec 2023 21:32:47 +0900 Subject: [PATCH 209/893] bump up windows{-sys} --- Cargo.lock | 80 ++++++++++++++++++++++++++++++++++++++++++++++----- vm/Cargo.toml | 6 ++-- 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7bd76c9783..3ce8580250 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2289,7 +2289,7 @@ dependencies = [ "widestring", "winapi", "windows", - "windows-sys 0.48.0", + "windows-sys 0.52.0", "winreg", ] @@ -3159,21 +3159,21 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows" -version = "0.51.1" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca229916c5ee38c2f2bc1e9d8f04df975b4bd93f9955dc69fabb5d91270045c9" +checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ "windows-core", - "windows-targets 0.48.5", + "windows-targets 0.52.0", ] [[package]] name = "windows-core" -version = "0.51.1" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.48.5", + "windows-targets 0.52.0", ] [[package]] @@ -3222,6 +3222,15 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", +] + [[package]] name = "windows-targets" version = "0.42.1" @@ -3252,6 +3261,21 @@ dependencies = [ "windows_x86_64_msvc 0.48.5", ] +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.42.1" @@ -3264,6 +3288,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.36.1" @@ -3282,6 +3312,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.36.1" @@ -3300,6 +3336,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.36.1" @@ -3318,6 +3360,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.36.1" @@ -3336,6 +3384,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.42.1" @@ -3348,6 +3402,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.36.1" @@ -3366,6 +3426,12 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + [[package]] name = "winreg" version = "0.10.1" diff --git a/vm/Cargo.toml b/vm/Cargo.toml index c8fd0174ca..341603b83d 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -113,7 +113,7 @@ widestring = { workspace = true } winreg = "0.10.1" [target.'cfg(windows)'.dependencies.windows] -version = "0.51.1" +version = "0.52.0" features = [ "Win32_Foundation", "Win32_System_LibraryLoader", @@ -122,14 +122,16 @@ features = [ ] [target.'cfg(windows)'.dependencies.windows-sys] -version = "0.48.0" +version = "0.52.0" features = [ "Win32_Foundation", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_LibraryLoader", + "Win32_System_Memory", "Win32_System_Pipes", + "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_UI_Shell", From 7513017e21510512486b90f0cec43b51b0b88c0e Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Thu, 28 Dec 2023 21:41:42 +0900 Subject: [PATCH 210/893] replace sysinfoapi to windows-sys --- vm/src/stdlib/sys.rs | 15 ++++++--------- vm/src/stdlib/time.rs | 2 +- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/vm/src/stdlib/sys.rs b/vm/src/stdlib/sys.rs index 65d50eaa15..4c8b6a060f 100644 --- a/vm/src/stdlib/sys.rs +++ b/vm/src/stdlib/sys.rs @@ -487,18 +487,15 @@ mod sys { fn getwindowsversion(vm: &VirtualMachine) -> PyResult { use std::ffi::OsString; use std::os::windows::ffi::OsStringExt; - use winapi::um::{ - sysinfoapi::GetVersionExW, - winnt::{LPOSVERSIONINFOEXW, LPOSVERSIONINFOW, OSVERSIONINFOEXW}, + use windows_sys::Win32::System::SystemInformation::{ + GetVersionExW, OSVERSIONINFOEXW, OSVERSIONINFOW, }; - let mut version = OSVERSIONINFOEXW { - dwOSVersionInfoSize: std::mem::size_of::() as u32, - ..OSVERSIONINFOEXW::default() - }; + let mut version: OSVERSIONINFOEXW = unsafe { std::mem::zeroed() }; + version.dwOSVersionInfoSize = std::mem::size_of::() as u32; let result = unsafe { - let osvi = &mut version as LPOSVERSIONINFOEXW as LPOSVERSIONINFOW; - // SAFETY: GetVersionExW accepts a pointer of OSVERSIONINFOW, but winapi crate's type currently doesn't allow to do so. + let osvi = &mut version as *mut OSVERSIONINFOEXW as *mut OSVERSIONINFOW; + // SAFETY: GetVersionExW accepts a pointer of OSVERSIONINFOW, but windows-sys crate's type currently doesn't allow to do so. // https://docs.microsoft.com/en-us/windows/win32/api/sysinfoapi/nf-sysinfoapi-getversionexw#parameters GetVersionExW(osvi) }; diff --git a/vm/src/stdlib/time.rs b/vm/src/stdlib/time.rs index 3df9a04c12..269dce44ab 100644 --- a/vm/src/stdlib/time.rs +++ b/vm/src/stdlib/time.rs @@ -633,7 +633,7 @@ mod platform { GetCurrentProcess, GetCurrentThread, GetProcessTimes, GetThreadTimes, }; use winapi::um::profileapi::{QueryPerformanceCounter, QueryPerformanceFrequency}; - use winapi::um::sysinfoapi::{GetSystemTimeAdjustment, GetTickCount64}; + use windows_sys::Win32::System::SystemInformation::{GetSystemTimeAdjustment, GetTickCount64}; fn u64_from_filetime(time: FILETIME) -> u64 { unsafe { From 8a84a479f1ad17fb3bf2fafa43113c0128be5327 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Thu, 28 Dec 2023 21:51:16 +0900 Subject: [PATCH 211/893] remove processthreadsapi --- vm/Cargo.toml | 4 ++-- vm/src/stdlib/nt.rs | 11 ++++++----- vm/src/stdlib/os.rs | 25 +++++++++++++++++-------- vm/src/stdlib/time.rs | 11 ++++++----- 4 files changed, 31 insertions(+), 20 deletions(-) diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 341603b83d..5dd559d41d 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -142,8 +142,8 @@ features = [ version = "0.3.9" features = [ "winsock2", "handleapi", "std", "winbase", "processenv", - "winnt", "processthreadsapi", "errhandlingapi", "wincon", - "impl-default", "vcruntime", "sysinfoapi", + "winnt", "errhandlingapi", "wincon", + "impl-default", "vcruntime", ] [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/vm/src/stdlib/nt.rs b/vm/src/stdlib/nt.rs index 713648933d..c8f8c3f255 100644 --- a/vm/src/stdlib/nt.rs +++ b/vm/src/stdlib/nt.rs @@ -136,7 +136,8 @@ pub(crate) mod module { #[pyfunction] fn kill(pid: i32, sig: isize, vm: &VirtualMachine) -> PyResult<()> { { - use um::{handleapi, processthreadsapi, wincon, winnt}; + use um::{wincon, winnt}; + use windows_sys::Win32::{Foundation::CloseHandle, System::Threading}; let sig = sig as u32; let pid = pid as u32; @@ -146,13 +147,13 @@ pub(crate) mod module { return res; } - let h = unsafe { processthreadsapi::OpenProcess(winnt::PROCESS_ALL_ACCESS, 0, pid) }; - if h.is_null() { + let h = unsafe { Threading::OpenProcess(winnt::PROCESS_ALL_ACCESS, 0, pid) }; + if h == 0 { return Err(errno_err(vm)); } - let ret = unsafe { processthreadsapi::TerminateProcess(h, sig) }; + let ret = unsafe { Threading::TerminateProcess(h, sig) }; let res = if ret == 0 { Err(errno_err(vm)) } else { Ok(()) }; - unsafe { handleapi::CloseHandle(h) }; + unsafe { CloseHandle(h) }; res } } diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 8164e5ea48..8319de732a 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -1414,19 +1414,28 @@ pub(super) mod _os { fn times(vm: &VirtualMachine) -> PyResult { #[cfg(windows)] { - use winapi::shared::minwindef::FILETIME; - use winapi::um::processthreadsapi::{GetCurrentProcess, GetProcessTimes}; + use std::mem::MaybeUninit; + use windows_sys::Win32::{Foundation::FILETIME, System::Threading}; - let mut _create = FILETIME::default(); - let mut _exit = FILETIME::default(); - let mut kernel = FILETIME::default(); - let mut user = FILETIME::default(); + let mut _create = MaybeUninit::::uninit(); + let mut _exit = MaybeUninit::::uninit(); + let mut kernel = MaybeUninit::::uninit(); + let mut user = MaybeUninit::::uninit(); unsafe { - let h_proc = GetCurrentProcess(); - GetProcessTimes(h_proc, &mut _create, &mut _exit, &mut kernel, &mut user); + let h_proc = Threading::GetCurrentProcess(); + Threading::GetProcessTimes( + h_proc, + _create.as_mut_ptr(), + _exit.as_mut_ptr(), + kernel.as_mut_ptr(), + user.as_mut_ptr(), + ); } + let kernel = unsafe { kernel.assume_init() }; + let user = unsafe { user.assume_init() }; + let times_result = TimesResult { user: user.dwHighDateTime as f64 * 429.4967296 + user.dwLowDateTime as f64 * 1e-7, system: kernel.dwHighDateTime as f64 * 429.4967296 diff --git a/vm/src/stdlib/time.rs b/vm/src/stdlib/time.rs index 269dce44ab..883ef3342f 100644 --- a/vm/src/stdlib/time.rs +++ b/vm/src/stdlib/time.rs @@ -628,12 +628,13 @@ mod platform { PyRef, PyResult, VirtualMachine, }; use std::time::Duration; - use winapi::shared::{minwindef::FILETIME, ntdef::ULARGE_INTEGER}; - use winapi::um::processthreadsapi::{ - GetCurrentProcess, GetCurrentThread, GetProcessTimes, GetThreadTimes, - }; + use winapi::shared::ntdef::ULARGE_INTEGER; use winapi::um::profileapi::{QueryPerformanceCounter, QueryPerformanceFrequency}; - use windows_sys::Win32::System::SystemInformation::{GetSystemTimeAdjustment, GetTickCount64}; + use windows_sys::Win32::{ + Foundation::FILETIME, + System::SystemInformation::{GetSystemTimeAdjustment, GetTickCount64}, + System::Threading::{GetCurrentProcess, GetCurrentThread, GetProcessTimes, GetThreadTimes}, + }; fn u64_from_filetime(time: FILETIME) -> u64 { unsafe { From 6df97329658c0493b713705bb8e19b4e8a3a4d8e Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Thu, 28 Dec 2023 22:02:10 +0900 Subject: [PATCH 212/893] replace wincon to windows-sys --- vm/Cargo.toml | 4 +-- vm/src/stdlib/nt.rs | 63 ++++++++++++++++++++++----------------------- vm/src/windows.rs | 4 +-- 3 files changed, 35 insertions(+), 36 deletions(-) diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 5dd559d41d..a157bcf735 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -141,8 +141,8 @@ features = [ [target.'cfg(windows)'.dependencies.winapi] version = "0.3.9" features = [ - "winsock2", "handleapi", "std", "winbase", "processenv", - "winnt", "errhandlingapi", "wincon", + "winsock2", "handleapi", "std", "winbase", + "winnt", "errhandlingapi", "impl-default", "vcruntime", ] diff --git a/vm/src/stdlib/nt.rs b/vm/src/stdlib/nt.rs index c8f8c3f255..2f78f67ad3 100644 --- a/vm/src/stdlib/nt.rs +++ b/vm/src/stdlib/nt.rs @@ -10,8 +10,10 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { #[pymodule(name = "nt", with(super::os::_os))] pub(crate) mod module { + #[cfg(target_env = "msvc")] + use crate::builtins::PyListRef; use crate::{ - builtins::{PyStrRef, PyTupleRef}, + builtins::{PyDictRef, PyStrRef, PyTupleRef}, common::{crt_fd::Fd, os::errno, suppress_iph}, convert::ToPyException, function::Either, @@ -23,13 +25,14 @@ pub(crate) mod module { }; use std::{ env, fs, io, + mem::MaybeUninit, os::windows::ffi::{OsStrExt, OsStringExt}, }; - - use crate::builtins::PyDictRef; - #[cfg(target_env = "msvc")] - use crate::builtins::PyListRef; use winapi::{um, vc::vcruntime::intptr_t}; + use windows_sys::Win32::{ + Foundation::{CloseHandle, INVALID_HANDLE_VALUE}, + System::{Console, Threading}, + }; #[pyattr] use libc::{O_BINARY, O_TEMPORARY}; @@ -135,27 +138,23 @@ pub(crate) mod module { #[pyfunction] fn kill(pid: i32, sig: isize, vm: &VirtualMachine) -> PyResult<()> { - { - use um::{wincon, winnt}; - use windows_sys::Win32::{Foundation::CloseHandle, System::Threading}; - let sig = sig as u32; - let pid = pid as u32; - - if sig == wincon::CTRL_C_EVENT || sig == wincon::CTRL_BREAK_EVENT { - let ret = unsafe { wincon::GenerateConsoleCtrlEvent(sig, pid) }; - let res = if ret == 0 { Err(errno_err(vm)) } else { Ok(()) }; - return res; - } + let sig = sig as u32; + let pid = pid as u32; - let h = unsafe { Threading::OpenProcess(winnt::PROCESS_ALL_ACCESS, 0, pid) }; - if h == 0 { - return Err(errno_err(vm)); - } - let ret = unsafe { Threading::TerminateProcess(h, sig) }; + if sig == Console::CTRL_C_EVENT || sig == Console::CTRL_BREAK_EVENT { + let ret = unsafe { Console::GenerateConsoleCtrlEvent(sig, pid) }; let res = if ret == 0 { Err(errno_err(vm)) } else { Ok(()) }; - unsafe { CloseHandle(h) }; - res + return res; } + + let h = unsafe { Threading::OpenProcess(Threading::PROCESS_ALL_ACCESS, 0, pid) }; + if h == 0 { + return Err(errno_err(vm)); + } + let ret = unsafe { Threading::TerminateProcess(h, sig) }; + let res = if ret == 0 { Err(errno_err(vm)) } else { Ok(()) }; + unsafe { CloseHandle(h) }; + res } #[pyfunction] @@ -164,22 +163,22 @@ pub(crate) mod module { vm: &VirtualMachine, ) -> PyResult<_os::PyTerminalSize> { let (columns, lines) = { - use um::{handleapi, processenv, winbase, wincon}; let stdhandle = match fd { - OptionalArg::Present(0) => winbase::STD_INPUT_HANDLE, - OptionalArg::Present(1) | OptionalArg::Missing => winbase::STD_OUTPUT_HANDLE, - OptionalArg::Present(2) => winbase::STD_ERROR_HANDLE, + OptionalArg::Present(0) => Console::STD_INPUT_HANDLE, + OptionalArg::Present(1) | OptionalArg::Missing => Console::STD_OUTPUT_HANDLE, + OptionalArg::Present(2) => Console::STD_ERROR_HANDLE, _ => return Err(vm.new_value_error("bad file descriptor".to_owned())), }; - let h = unsafe { processenv::GetStdHandle(stdhandle) }; - if h.is_null() { + let h = unsafe { Console::GetStdHandle(stdhandle) }; + if h == 0 { return Err(vm.new_os_error("handle cannot be retrieved".to_owned())); } - if h == handleapi::INVALID_HANDLE_VALUE { + if h == INVALID_HANDLE_VALUE { return Err(errno_err(vm)); } - let mut csbi = wincon::CONSOLE_SCREEN_BUFFER_INFO::default(); - let ret = unsafe { wincon::GetConsoleScreenBufferInfo(h, &mut csbi) }; + let mut csbi = MaybeUninit::uninit(); + let ret = unsafe { Console::GetConsoleScreenBufferInfo(h, csbi.as_mut_ptr()) }; + let csbi = unsafe { csbi.assume_init() }; if ret == 0 { return Err(errno_err(vm)); } diff --git a/vm/src/windows.rs b/vm/src/windows.rs index 9216f839fe..e749241c07 100644 --- a/vm/src/windows.rs +++ b/vm/src/windows.rs @@ -4,7 +4,7 @@ use crate::{ PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; use windows::Win32::Foundation::HANDLE; -use windows_sys::Win32::Foundation::{BOOL, HANDLE as RAW_HANDLE}; +use windows_sys::Win32::Foundation::{BOOL, HANDLE as RAW_HANDLE, INVALID_HANDLE_VALUE}; pub(crate) trait WindowsSysResultValue { type Ok: ToPyObject; @@ -15,7 +15,7 @@ pub(crate) trait WindowsSysResultValue { impl WindowsSysResultValue for RAW_HANDLE { type Ok = HANDLE; fn is_err(&self) -> bool { - *self == windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE + *self == INVALID_HANDLE_VALUE } fn into_ok(self) -> Self::Ok { HANDLE(self) From ee128eac7c67e7d74fc557ee77a46f2f182d6f4e Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Thu, 28 Dec 2023 22:10:50 +0900 Subject: [PATCH 213/893] replace errorhandling to windows-sys --- vm/Cargo.toml | 5 +++-- vm/src/stdlib/msvcrt.rs | 10 +++++----- vm/src/stdlib/nt.rs | 34 ++++++++++++++++++---------------- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/vm/Cargo.toml b/vm/Cargo.toml index a157bcf735..f7a38d6700 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -128,6 +128,7 @@ features = [ "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Console", + "Win32_System_Diagnostics_Debug", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Pipes", @@ -142,8 +143,8 @@ features = [ version = "0.3.9" features = [ "winsock2", "handleapi", "std", "winbase", - "winnt", "errhandlingapi", - "impl-default", "vcruntime", + "winnt", + "impl-default", ] [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/vm/src/stdlib/msvcrt.rs b/vm/src/stdlib/msvcrt.rs index b6f933aea2..fc6ed24d35 100644 --- a/vm/src/stdlib/msvcrt.rs +++ b/vm/src/stdlib/msvcrt.rs @@ -9,9 +9,9 @@ mod msvcrt { PyRef, PyResult, VirtualMachine, }; use itertools::Itertools; - use winapi::{ - shared::minwindef::UINT, - um::{handleapi::INVALID_HANDLE_VALUE, winnt::HANDLE}, + use windows_sys::Win32::{ + Foundation::{HANDLE, INVALID_HANDLE_VALUE}, + System::Diagnostics::Debug, }; #[pyattr] @@ -111,7 +111,7 @@ mod msvcrt { #[allow(non_snake_case)] #[pyfunction] - fn SetErrorMode(mode: UINT, _: &VirtualMachine) -> UINT { - unsafe { suppress_iph!(winapi::um::errhandlingapi::SetErrorMode(mode)) } + fn SetErrorMode(mode: Debug::THREAD_ERROR_MODE, _: &VirtualMachine) -> u32 { + unsafe { suppress_iph!(Debug::SetErrorMode(mode)) } } } diff --git a/vm/src/stdlib/nt.rs b/vm/src/stdlib/nt.rs index 2f78f67ad3..37064d20cd 100644 --- a/vm/src/stdlib/nt.rs +++ b/vm/src/stdlib/nt.rs @@ -23,14 +23,16 @@ pub(crate) mod module { }, PyResult, TryFromObject, VirtualMachine, }; + use libc::intptr_t; use std::{ env, fs, io, mem::MaybeUninit, os::windows::ffi::{OsStrExt, OsStringExt}, }; - use winapi::{um, vc::vcruntime::intptr_t}; + use winapi::um; use windows_sys::Win32::{ Foundation::{CloseHandle, INVALID_HANDLE_VALUE}, + Storage::FileSystem, System::{Console, Threading}, }; @@ -39,9 +41,9 @@ pub(crate) mod module { #[pyfunction] pub(super) fn access(path: OsPath, mode: u8, vm: &VirtualMachine) -> PyResult { - use um::{fileapi, winnt}; - let attr = unsafe { fileapi::GetFileAttributesW(path.to_widecstring(vm)?.as_ptr()) }; - Ok(attr != fileapi::INVALID_FILE_ATTRIBUTES + use um::winnt; + let attr = unsafe { FileSystem::GetFileAttributesW(path.to_widecstring(vm)?.as_ptr()) }; + Ok(attr != FileSystem::INVALID_FILE_ATTRIBUTES && (mode & 2 == 0 || attr & winnt::FILE_ATTRIBUTE_READONLY == 0 || attr & winnt::FILE_ATTRIBUTE_DIRECTORY != 0)) @@ -252,7 +254,7 @@ pub(crate) mod module { let wpath = path.to_widecstring(vm)?; let mut buffer = vec![0u16; winapi::shared::minwindef::MAX_PATH]; let ret = unsafe { - um::fileapi::GetFullPathNameW( + FileSystem::GetFullPathNameW( wpath.as_ptr(), buffer.len() as _, buffer.as_mut_ptr(), @@ -265,7 +267,7 @@ pub(crate) mod module { if ret as usize > buffer.len() { buffer.resize(ret as usize, 0); let ret = unsafe { - um::fileapi::GetFullPathNameW( + FileSystem::GetFullPathNameW( wpath.as_ptr(), buffer.len() as _, buffer.as_mut_ptr(), @@ -286,7 +288,7 @@ pub(crate) mod module { let buflen = std::cmp::max(wide.len(), winapi::shared::minwindef::MAX_PATH); let mut buffer = vec![0u16; buflen]; let ret = unsafe { - um::fileapi::GetVolumePathNameW(wide.as_ptr(), buffer.as_mut_ptr(), buflen as _) + FileSystem::GetVolumePathNameW(wide.as_ptr(), buffer.as_mut_ptr(), buflen as _) }; if ret == 0 { return Err(errno_err(vm)); @@ -335,17 +337,17 @@ pub(crate) mod module { #[pyfunction] fn _getdiskusage(path: OsPath, vm: &VirtualMachine) -> PyResult<(u64, u64)> { - use um::fileapi::GetDiskFreeSpaceExW; - use winapi::shared::{ntdef::ULARGE_INTEGER, winerror}; + use winapi::shared::winerror; + use FileSystem::GetDiskFreeSpaceExW; let wpath = path.to_widecstring(vm)?; - let mut _free_to_me = ULARGE_INTEGER::default(); - let mut total = ULARGE_INTEGER::default(); - let mut free = ULARGE_INTEGER::default(); + let mut _free_to_me: u64 = 0; + let mut total: u64 = 0; + let mut free: u64 = 0; let ret = unsafe { GetDiskFreeSpaceExW(wpath.as_ptr(), &mut _free_to_me, &mut total, &mut free) }; if ret != 0 { - return Ok(unsafe { (*total.QuadPart(), *free.QuadPart()) }); + return Ok((total, free)); } let err = io::Error::last_os_error(); if err.raw_os_error() == Some(winerror::ERROR_DIRECTORY as i32) { @@ -359,7 +361,7 @@ pub(crate) mod module { return if ret == 0 { Err(errno_err(vm)) } else { - Ok(unsafe { (*total.QuadPart(), *free.QuadPart()) }) + Ok((total, free)) }; } } @@ -367,7 +369,7 @@ pub(crate) mod module { } #[pyfunction] - fn get_handle_inheritable(handle: intptr_t, vm: &VirtualMachine) -> PyResult { + fn get_handle_inheritable(handle: isize, vm: &VirtualMachine) -> PyResult { let mut flags = 0; if unsafe { um::handleapi::GetHandleInformation(handle as _, &mut flags) } == 0 { Err(errno_err(vm)) @@ -376,7 +378,7 @@ pub(crate) mod module { } } - pub fn raw_set_handle_inheritable(handle: intptr_t, inheritable: bool) -> io::Result<()> { + pub fn raw_set_handle_inheritable(handle: isize, inheritable: bool) -> io::Result<()> { use um::winbase::HANDLE_FLAG_INHERIT; let flags = if inheritable { HANDLE_FLAG_INHERIT } else { 0 }; let res = From d01909a5248fa59094f121ec21407c2bcd60e791 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Thu, 28 Dec 2023 22:28:10 +0900 Subject: [PATCH 214/893] replace handleapi to windows-sys --- vm/Cargo.toml | 2 +- vm/src/stdlib/nt.rs | 32 +++++++++++++++++--------------- vm/src/stdlib/os.rs | 20 ++++++++++---------- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/vm/Cargo.toml b/vm/Cargo.toml index f7a38d6700..cedbeaf558 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -142,7 +142,7 @@ features = [ [target.'cfg(windows)'.dependencies.winapi] version = "0.3.9" features = [ - "winsock2", "handleapi", "std", "winbase", + "winsock2", "std", "winbase", "winnt", "impl-default", ] diff --git a/vm/src/stdlib/nt.rs b/vm/src/stdlib/nt.rs index 37064d20cd..dbc7032d40 100644 --- a/vm/src/stdlib/nt.rs +++ b/vm/src/stdlib/nt.rs @@ -29,9 +29,8 @@ pub(crate) mod module { mem::MaybeUninit, os::windows::ffi::{OsStrExt, OsStringExt}, }; - use winapi::um; use windows_sys::Win32::{ - Foundation::{CloseHandle, INVALID_HANDLE_VALUE}, + Foundation::{self, INVALID_HANDLE_VALUE}, Storage::FileSystem, System::{Console, Threading}, }; @@ -41,12 +40,11 @@ pub(crate) mod module { #[pyfunction] pub(super) fn access(path: OsPath, mode: u8, vm: &VirtualMachine) -> PyResult { - use um::winnt; let attr = unsafe { FileSystem::GetFileAttributesW(path.to_widecstring(vm)?.as_ptr()) }; Ok(attr != FileSystem::INVALID_FILE_ATTRIBUTES && (mode & 2 == 0 - || attr & winnt::FILE_ATTRIBUTE_READONLY == 0 - || attr & winnt::FILE_ATTRIBUTE_DIRECTORY != 0)) + || attr & FileSystem::FILE_ATTRIBUTE_READONLY == 0 + || attr & FileSystem::FILE_ATTRIBUTE_DIRECTORY != 0)) } #[derive(FromArgs)] @@ -155,7 +153,7 @@ pub(crate) mod module { } let ret = unsafe { Threading::TerminateProcess(h, sig) }; let res = if ret == 0 { Err(errno_err(vm)) } else { Ok(()) }; - unsafe { CloseHandle(h) }; + unsafe { Foundation::CloseHandle(h) }; res } @@ -369,20 +367,24 @@ pub(crate) mod module { } #[pyfunction] - fn get_handle_inheritable(handle: isize, vm: &VirtualMachine) -> PyResult { + fn get_handle_inheritable(handle: intptr_t, vm: &VirtualMachine) -> PyResult { let mut flags = 0; - if unsafe { um::handleapi::GetHandleInformation(handle as _, &mut flags) } == 0 { + if unsafe { Foundation::GetHandleInformation(handle as _, &mut flags) } == 0 { Err(errno_err(vm)) } else { - Ok(flags & um::winbase::HANDLE_FLAG_INHERIT != 0) + Ok(flags & Foundation::HANDLE_FLAG_INHERIT != 0) } } - pub fn raw_set_handle_inheritable(handle: isize, inheritable: bool) -> io::Result<()> { - use um::winbase::HANDLE_FLAG_INHERIT; - let flags = if inheritable { HANDLE_FLAG_INHERIT } else { 0 }; - let res = - unsafe { um::handleapi::SetHandleInformation(handle as _, HANDLE_FLAG_INHERIT, flags) }; + pub fn raw_set_handle_inheritable(handle: intptr_t, inheritable: bool) -> io::Result<()> { + let flags = if inheritable { + Foundation::HANDLE_FLAG_INHERIT + } else { + 0 + }; + let res = unsafe { + Foundation::SetHandleInformation(handle as _, Foundation::HANDLE_FLAG_INHERIT, flags) + }; if res == 0 { Err(errno()) } else { @@ -410,7 +412,7 @@ pub(crate) mod module { let [] = dir_fd.0; let _ = mode; let wide = path.to_widecstring(vm)?; - let res = unsafe { um::fileapi::CreateDirectoryW(wide.as_ptr(), std::ptr::null_mut()) }; + let res = unsafe { FileSystem::CreateDirectoryW(wide.as_ptr(), std::ptr::null_mut()) }; if res == 0 { return Err(errno_err(vm)); } diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 8319de732a..8a1ad20587 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -1200,17 +1200,18 @@ pub(super) mod _os { let res = unsafe { suppress_iph!(libc::lseek(fd, position, how)) }; #[cfg(windows)] let res = unsafe { - use winapi::um::{fileapi, winnt}; + use winapi::um::winnt; + use windows_sys::Win32::Storage::FileSystem; let handle = Fd(fd).to_raw_handle().map_err(|e| e.into_pyexception(vm))?; let mut li = winnt::LARGE_INTEGER::default(); *li.QuadPart_mut() = position; - let ret = fileapi::SetFilePointer( - handle, + let ret = FileSystem::SetFilePointer( + handle as _, li.u().LowPart as _, &mut li.u_mut().HighPart, how as _, ); - if ret == fileapi::INVALID_SET_FILE_POINTER { + if ret == FileSystem::INVALID_SET_FILE_POINTER { -1 } else { li.u_mut().LowPart = ret; @@ -1357,10 +1358,8 @@ pub(super) mod _os { #[cfg(windows)] { use std::{fs::OpenOptions, os::windows::prelude::*}; - use winapi::{ - shared::minwindef::{DWORD, FILETIME}, - um::fileapi::SetFileTime, - }; + use winapi::shared::minwindef::DWORD; + use windows_sys::Win32::{Foundation::FILETIME, Storage::FileSystem}; let [] = dir_fd.0; @@ -1382,8 +1381,9 @@ pub(super) mod _os { .open(path) .map_err(|err| err.into_pyexception(vm))?; - let ret = - unsafe { SetFileTime(f.as_raw_handle() as _, std::ptr::null(), &acc, &modif) }; + let ret = unsafe { + FileSystem::SetFileTime(f.as_raw_handle() as _, std::ptr::null(), &acc, &modif) + }; if ret == 0 { Err(io::Error::last_os_error().into_pyexception(vm)) From cccfb0883560422ec612c42121299a17385065b7 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Thu, 28 Dec 2023 22:40:49 +0900 Subject: [PATCH 215/893] replace winbase winnt to windows-sys --- vm/Cargo.toml | 8 +++----- vm/src/stdlib/msvcrt.rs | 2 +- vm/src/stdlib/nt.rs | 2 +- vm/src/stdlib/os.rs | 2 +- vm/src/stdlib/signal.rs | 21 ++++++++++++++------- vm/src/stdlib/winreg.rs | 4 ++-- 6 files changed, 22 insertions(+), 17 deletions(-) diff --git a/vm/Cargo.toml b/vm/Cargo.toml index cedbeaf558..7bdf3617ed 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -125,6 +125,7 @@ features = [ version = "0.52.0" features = [ "Win32_Foundation", + "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Console", @@ -132,6 +133,7 @@ features = [ "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Pipes", + "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", @@ -141,11 +143,7 @@ features = [ [target.'cfg(windows)'.dependencies.winapi] version = "0.3.9" -features = [ - "winsock2", "std", "winbase", - "winnt", - "impl-default", -] +features = ["std", "winnt"] [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen = "0.2.80" diff --git a/vm/src/stdlib/msvcrt.rs b/vm/src/stdlib/msvcrt.rs index fc6ed24d35..03ddb44f22 100644 --- a/vm/src/stdlib/msvcrt.rs +++ b/vm/src/stdlib/msvcrt.rs @@ -15,7 +15,7 @@ mod msvcrt { }; #[pyattr] - use winapi::um::winbase::{ + use windows_sys::Win32::System::Diagnostics::Debug::{ SEM_FAILCRITICALERRORS, SEM_NOALIGNMENTFAULTEXCEPT, SEM_NOGPFAULTERRORBOX, SEM_NOOPENFILEERRORBOX, }; diff --git a/vm/src/stdlib/nt.rs b/vm/src/stdlib/nt.rs index dbc7032d40..e55aed883c 100644 --- a/vm/src/stdlib/nt.rs +++ b/vm/src/stdlib/nt.rs @@ -428,6 +428,6 @@ pub fn init_winsock() { static WSA_INIT: parking_lot::Once = parking_lot::Once::new(); WSA_INIT.call_once(|| unsafe { let mut wsa_data = std::mem::MaybeUninit::uninit(); - let _ = winapi::um::winsock2::WSAStartup(0x0101, wsa_data.as_mut_ptr()); + let _ = windows_sys::Win32::Networking::WinSock::WSAStartup(0x0101, wsa_data.as_mut_ptr()); }) } diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 8a1ad20587..94b1938e43 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -1377,7 +1377,7 @@ pub(super) mod _os { let f = OpenOptions::new() .write(true) - .custom_flags(winapi::um::winbase::FILE_FLAG_BACKUP_SEMANTICS) + .custom_flags(windows_sys::Win32::Storage::FileSystem::FILE_FLAG_BACKUP_SEMANTICS) .open(path) .map_err(|err| err.into_pyexception(vm))?; diff --git a/vm/src/stdlib/signal.rs b/vm/src/stdlib/signal.rs index b30428546c..4a698a7633 100644 --- a/vm/src/stdlib/signal.rs +++ b/vm/src/stdlib/signal.rs @@ -19,7 +19,6 @@ pub(crate) mod _signal { cfg_if::cfg_if! { if #[cfg(windows)] { - use winapi::um::winsock2; type WakeupFd = libc::SOCKET; const INVALID_WAKEUP: WakeupFd = (-1isize) as usize; static WAKEUP: atomic::AtomicUsize = atomic::AtomicUsize::new(INVALID_WAKEUP); @@ -200,14 +199,16 @@ pub(crate) mod _signal { #[cfg(windows)] let is_socket = if fd != INVALID_WAKEUP { + use windows_sys::Win32::Networking::WinSock; + crate::stdlib::nt::init_winsock(); let mut res = 0i32; let mut res_size = std::mem::size_of::() as i32; let res = unsafe { - winsock2::getsockopt( + WinSock::getsockopt( fd, - winsock2::SOL_SOCKET, - winsock2::SO_ERROR, + WinSock::SOL_SOCKET, + WinSock::SO_ERROR, &mut res as *mut i32 as *mut _, &mut res_size, ) @@ -217,7 +218,7 @@ pub(crate) mod _signal { if !is_socket { let err = std::io::Error::last_os_error(); // if getsockopt failed for some other reason, throw - if err.raw_os_error() != Some(winsock2::WSAENOTSOCK) { + if err.raw_os_error() != Some(WinSock::WSAENOTSOCK) { return Err(err.into_pyexception(vm)); } } @@ -263,8 +264,14 @@ pub(crate) mod _signal { let sigbyte = signum as u8; #[cfg(windows)] if WAKEUP_IS_SOCKET.load(Ordering::Relaxed) { - let _res = - unsafe { winsock2::send(wakeup_fd, &sigbyte as *const u8 as *const _, 1, 0) }; + let _res = unsafe { + windows_sys::Win32::Networking::WinSock::send( + wakeup_fd, + &sigbyte as *const u8 as *const _, + 1, + 0, + ) + }; return; } let _res = unsafe { libc::write(wakeup_fd as _, &sigbyte as *const u8 as *const _, 1) }; diff --git a/vm/src/stdlib/winreg.rs b/vm/src/stdlib/winreg.rs index f66dd30c78..b31694f9e6 100644 --- a/vm/src/stdlib/winreg.rs +++ b/vm/src/stdlib/winreg.rs @@ -38,14 +38,14 @@ mod winreg { // access rights #[pyattr] - pub use winapi::um::winnt::{ + pub use windows_sys::Win32::System::Registry::{ KEY_ALL_ACCESS, KEY_CREATE_LINK, KEY_CREATE_SUB_KEY, KEY_ENUMERATE_SUB_KEYS, KEY_EXECUTE, KEY_NOTIFY, KEY_QUERY_VALUE, KEY_READ, KEY_SET_VALUE, KEY_WOW64_32KEY, KEY_WOW64_64KEY, KEY_WRITE, }; // value types #[pyattr] - pub use winapi::um::winnt::{ + pub use windows_sys::Win32::System::Registry::{ REG_BINARY, REG_DWORD, REG_DWORD_BIG_ENDIAN, REG_DWORD_LITTLE_ENDIAN, REG_EXPAND_SZ, REG_FULL_RESOURCE_DESCRIPTOR, REG_LINK, REG_MULTI_SZ, REG_NONE, REG_QWORD, REG_QWORD_LITTLE_ENDIAN, REG_RESOURCE_LIST, REG_RESOURCE_REQUIREMENTS_LIST, REG_SZ, From 4729ca3af0ba63644ddf88b1cd0a5137394e9bee Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Thu, 28 Dec 2023 23:05:41 +0900 Subject: [PATCH 216/893] Drop winapi from rustpython-vm --- Cargo.lock | 1 - vm/Cargo.toml | 5 +---- vm/src/stdlib/errno.rs | 25 ++++++++++++++++++++++--- vm/src/stdlib/nt.rs | 7 +++---- vm/src/stdlib/os.rs | 19 +++++++++---------- vm/src/stdlib/time.rs | 29 +++++++++-------------------- vm/src/stdlib/winreg.rs | 6 +++--- 7 files changed, 47 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3ce8580250..2d4c8a6ad5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2287,7 +2287,6 @@ dependencies = [ "wasm-bindgen", "which", "widestring", - "winapi", "windows", "windows-sys 0.52.0", "winreg", diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 7bdf3617ed..f061f54a85 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -132,6 +132,7 @@ features = [ "Win32_System_Diagnostics_Debug", "Win32_System_LibraryLoader", "Win32_System_Memory", + "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", @@ -141,10 +142,6 @@ features = [ "Win32_UI_WindowsAndMessaging", ] -[target.'cfg(windows)'.dependencies.winapi] -version = "0.3.9" -features = ["std", "winnt"] - [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen = "0.2.80" diff --git a/vm/src/stdlib/errno.rs b/vm/src/stdlib/errno.rs index fe36b4965e..a142d68a34 100644 --- a/vm/src/stdlib/errno.rs +++ b/vm/src/stdlib/errno.rs @@ -24,12 +24,31 @@ pub fn make_module(vm: &VirtualMachine) -> PyRef { pub mod errors { pub use libc::*; #[cfg(windows)] - pub use winapi::shared::winerror::*; + pub use windows_sys::Win32::{ + Foundation::*, + Networking::WinSock::{ + WSABASEERR, WSADESCRIPTION_LEN, WSAEACCES, WSAEADDRINUSE, WSAEADDRNOTAVAIL, + WSAEAFNOSUPPORT, WSAEALREADY, WSAEBADF, WSAECANCELLED, WSAECONNABORTED, + WSAECONNREFUSED, WSAECONNRESET, WSAEDESTADDRREQ, WSAEDISCON, WSAEDQUOT, WSAEFAULT, + WSAEHOSTDOWN, WSAEHOSTUNREACH, WSAEINPROGRESS, WSAEINTR, WSAEINVAL, + WSAEINVALIDPROCTABLE, WSAEINVALIDPROVIDER, WSAEISCONN, WSAELOOP, WSAEMFILE, + WSAEMSGSIZE, WSAENAMETOOLONG, WSAENETDOWN, WSAENETRESET, WSAENETUNREACH, WSAENOBUFS, + WSAENOMORE, WSAENOPROTOOPT, WSAENOTCONN, WSAENOTEMPTY, WSAENOTSOCK, WSAEOPNOTSUPP, + WSAEPFNOSUPPORT, WSAEPROCLIM, WSAEPROTONOSUPPORT, WSAEPROTOTYPE, + WSAEPROVIDERFAILEDINIT, WSAEREFUSED, WSAEREMOTE, WSAESHUTDOWN, WSAESOCKTNOSUPPORT, + WSAESTALE, WSAETIMEDOUT, WSAETOOMANYREFS, WSAEUSERS, WSAEWOULDBLOCK, WSAHOST_NOT_FOUND, + WSAID_ACCEPTEX, WSAID_CONNECTEX, WSAID_DISCONNECTEX, WSAID_GETACCEPTEXSOCKADDRS, + WSAID_TRANSMITFILE, WSAID_TRANSMITPACKETS, WSAID_WSAPOLL, WSAID_WSARECVMSG, + WSANOTINITIALISED, WSANO_DATA, WSANO_RECOVERY, WSAPROTOCOL_LEN, WSASERVICE_NOT_FOUND, + WSASYSCALLFAILURE, WSASYSNOTREADY, WSASYS_STATUS_LEN, WSATRY_AGAIN, WSATYPE_NOT_FOUND, + WSAVERNOTSUPPORTED, + }, + }; #[cfg(windows)] macro_rules! reexport_wsa { ($($errname:ident),*$(,)?) => { paste::paste! { - $(pub const $errname: i32 = [] as i32;)* + $(pub const $errname: i32 = windows_sys::Win32::Networking::WinSock:: [] as i32;)* } } } @@ -43,7 +62,7 @@ pub mod errors { // TODO: EBADF should be here once winerrs are translated to errnos but it messes up some things atm } #[cfg(windows)] - pub const WSAHOS: i32 = WSAHOST_NOT_FOUND as i32; + pub const WSAHOS: i32 = WSAHOST_NOT_FOUND; } #[cfg(any(unix, windows, target_os = "wasi"))] diff --git a/vm/src/stdlib/nt.rs b/vm/src/stdlib/nt.rs index e55aed883c..206faa82e1 100644 --- a/vm/src/stdlib/nt.rs +++ b/vm/src/stdlib/nt.rs @@ -250,7 +250,7 @@ pub(crate) mod module { #[pyfunction] fn _getfullpathname(path: OsPath, vm: &VirtualMachine) -> PyResult { let wpath = path.to_widecstring(vm)?; - let mut buffer = vec![0u16; winapi::shared::minwindef::MAX_PATH]; + let mut buffer = vec![0u16; Foundation::MAX_PATH as usize]; let ret = unsafe { FileSystem::GetFullPathNameW( wpath.as_ptr(), @@ -283,7 +283,7 @@ pub(crate) mod module { #[pyfunction] fn _getvolumepathname(path: OsPath, vm: &VirtualMachine) -> PyResult { let wide = path.to_widecstring(vm)?; - let buflen = std::cmp::max(wide.len(), winapi::shared::minwindef::MAX_PATH); + let buflen = std::cmp::max(wide.len(), Foundation::MAX_PATH as usize); let mut buffer = vec![0u16; buflen]; let ret = unsafe { FileSystem::GetVolumePathNameW(wide.as_ptr(), buffer.as_mut_ptr(), buflen as _) @@ -335,7 +335,6 @@ pub(crate) mod module { #[pyfunction] fn _getdiskusage(path: OsPath, vm: &VirtualMachine) -> PyResult<(u64, u64)> { - use winapi::shared::winerror; use FileSystem::GetDiskFreeSpaceExW; let wpath = path.to_widecstring(vm)?; @@ -348,7 +347,7 @@ pub(crate) mod module { return Ok((total, free)); } let err = io::Error::last_os_error(); - if err.raw_os_error() == Some(winerror::ERROR_DIRECTORY as i32) { + if err.raw_os_error() == Some(Foundation::ERROR_DIRECTORY as i32) { if let Some(parent) = path.as_ref().parent() { let parent = widestring::WideCString::from_os_str(parent).unwrap(); diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 94b1938e43..376c18fb3a 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -1200,22 +1200,20 @@ pub(super) mod _os { let res = unsafe { suppress_iph!(libc::lseek(fd, position, how)) }; #[cfg(windows)] let res = unsafe { - use winapi::um::winnt; use windows_sys::Win32::Storage::FileSystem; let handle = Fd(fd).to_raw_handle().map_err(|e| e.into_pyexception(vm))?; - let mut li = winnt::LARGE_INTEGER::default(); - *li.QuadPart_mut() = position; + let mut distance_to_move: [i32; 2] = std::mem::transmute(position); let ret = FileSystem::SetFilePointer( handle as _, - li.u().LowPart as _, - &mut li.u_mut().HighPart, + distance_to_move[0], + &mut distance_to_move[1], how as _, ); if ret == FileSystem::INVALID_SET_FILE_POINTER { -1 } else { - li.u_mut().LowPart = ret; - *li.QuadPart() + distance_to_move[0] = ret as _; + std::mem::transmute(distance_to_move) } }; if res < 0 { @@ -1358,7 +1356,7 @@ pub(super) mod _os { #[cfg(windows)] { use std::{fs::OpenOptions, os::windows::prelude::*}; - use winapi::shared::minwindef::DWORD; + type DWORD = u32; use windows_sys::Win32::{Foundation::FILETIME, Storage::FileSystem}; let [] = dir_fd.0; @@ -1605,9 +1603,10 @@ pub(super) mod _os { if #[cfg(any(target_os = "android", target_os = "redox"))] { Ok(Some("UTF-8".to_owned())) } else if #[cfg(windows)] { + use windows_sys::Win32::System::Console; let cp = match fd { - 0 => unsafe { winapi::um::consoleapi::GetConsoleCP() }, - 1 | 2 => unsafe { winapi::um::consoleapi::GetConsoleOutputCP() }, + 0 => unsafe { Console::GetConsoleCP() }, + 1 | 2 => unsafe { Console::GetConsoleOutputCP() }, _ => 0, }; diff --git a/vm/src/stdlib/time.rs b/vm/src/stdlib/time.rs index 883ef3342f..9717a69665 100644 --- a/vm/src/stdlib/time.rs +++ b/vm/src/stdlib/time.rs @@ -628,43 +628,33 @@ mod platform { PyRef, PyResult, VirtualMachine, }; use std::time::Duration; - use winapi::shared::ntdef::ULARGE_INTEGER; - use winapi::um::profileapi::{QueryPerformanceCounter, QueryPerformanceFrequency}; use windows_sys::Win32::{ Foundation::FILETIME, + System::Performance::{QueryPerformanceCounter, QueryPerformanceFrequency}, System::SystemInformation::{GetSystemTimeAdjustment, GetTickCount64}, System::Threading::{GetCurrentProcess, GetCurrentThread, GetProcessTimes, GetThreadTimes}, }; fn u64_from_filetime(time: FILETIME) -> u64 { - unsafe { - let mut large = std::mem::MaybeUninit::::uninit(); - { - let m = (*large.as_mut_ptr()).u_mut(); - m.LowPart = time.dwLowDateTime; - m.HighPart = time.dwHighDateTime; - } - let large = large.assume_init(); - *large.QuadPart() - } + let large: [u32; 2] = [time.dwLowDateTime, time.dwHighDateTime]; + unsafe { std::mem::transmute(large) } } fn win_perf_counter_frequency(vm: &VirtualMachine) -> PyResult { - let freq = unsafe { + let frequency = unsafe { let mut freq = std::mem::MaybeUninit::uninit(); if QueryPerformanceFrequency(freq.as_mut_ptr()) == 0 { return Err(errno_err(vm)); } freq.assume_init() }; - let frequency = unsafe { freq.QuadPart() }; - if *frequency < 1 { + if frequency < 1 { Err(vm.new_runtime_error("invalid QueryPerformanceFrequency".to_owned())) - } else if *frequency > i64::MAX / SEC_TO_NS { + } else if frequency > i64::MAX / SEC_TO_NS { Err(vm.new_overflow_error("QueryPerformanceFrequency is too large".to_owned())) } else { - Ok(*frequency) + Ok(frequency) } } @@ -678,15 +668,14 @@ mod platform { } pub(super) fn get_perf_time(vm: &VirtualMachine) -> PyResult { - let now = unsafe { + let ticks = unsafe { let mut performance_count = std::mem::MaybeUninit::uninit(); QueryPerformanceCounter(performance_count.as_mut_ptr()); performance_count.assume_init() }; - let ticks = unsafe { now.QuadPart() }; Ok(Duration::from_nanos(time_muldiv( - *ticks, + ticks, SEC_TO_NS, global_frequency(vm)?, ))) diff --git a/vm/src/stdlib/winreg.rs b/vm/src/stdlib/winreg.rs index b31694f9e6..b992de9cb9 100644 --- a/vm/src/stdlib/winreg.rs +++ b/vm/src/stdlib/winreg.rs @@ -34,7 +34,7 @@ mod winreg { }; use ::winreg::{enums::RegType, RegKey, RegValue}; use std::{ffi::OsStr, io}; - use winapi::shared::winerror; + use windows_sys::Win32::Foundation; // access rights #[pyattr] @@ -201,7 +201,7 @@ mod winreg { key.with_key(|k| k.enum_keys().nth(index as usize)) .unwrap_or_else(|| { Err(io::Error::from_raw_os_error( - winerror::ERROR_NO_MORE_ITEMS as i32, + Foundation::ERROR_NO_MORE_ITEMS as i32, )) }) .map_err(|e| e.to_pyexception(vm)) @@ -217,7 +217,7 @@ mod winreg { .with_key(|k| k.enum_values().nth(index as usize)) .unwrap_or_else(|| { Err(io::Error::from_raw_os_error( - winerror::ERROR_NO_MORE_ITEMS as i32, + Foundation::ERROR_NO_MORE_ITEMS as i32, )) }) .map_err(|e| e.to_pyexception(vm))?; From 756088e7fb204e008fcddb14be82b4bb64efd640 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Thu, 28 Dec 2023 23:39:45 +0900 Subject: [PATCH 217/893] more winapi to windows-sys --- Cargo.lock | 1 + stdlib/Cargo.toml | 12 ++++++-- stdlib/src/multiprocessing.rs | 8 +++--- stdlib/src/select.rs | 10 +++---- stdlib/src/socket.rs | 52 +++++++++++++++++++++++++++-------- stdlib/src/ssl.rs | 8 +++--- 6 files changed, 64 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2d4c8a6ad5..d8fbbc5944 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2213,6 +2213,7 @@ dependencies = [ "uuid", "widestring", "winapi", + "windows-sys 0.52.0", "xml-rs", ] diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index cd04f47b83..50193ee0f8 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -112,8 +112,16 @@ widestring = { workspace = true } [target.'cfg(windows)'.dependencies.winapi] version = "0.3.9" features = [ - "winsock2", "ws2def", "std", "wincrypt", "fileapi", - "impl-default", "vcruntime", "ifdef", "netioapi", "profileapi", + "winsock2", "ifdef", "netioapi", +] + +[target.'cfg(windows)'.dependencies.windows-sys] +version = "0.52.0" +features = [ + "Win32_Networking_WinSock", + "Win32_NetworkManagement_IpHelper", + "Win32_NetworkManagement_Ndis", + "Win32_Security_Cryptography", ] [target.'cfg(target_os = "macos")'.dependencies] diff --git a/stdlib/src/multiprocessing.rs b/stdlib/src/multiprocessing.rs index f508f9a5d5..a6d902eb63 100644 --- a/stdlib/src/multiprocessing.rs +++ b/stdlib/src/multiprocessing.rs @@ -4,11 +4,11 @@ pub(crate) use _multiprocessing::make_module; #[pymodule] mod _multiprocessing { use crate::vm::{function::ArgBytesLike, stdlib::os, PyResult, VirtualMachine}; - use winapi::um::winsock2::{self, SOCKET}; + use windows_sys::Win32::Networking::WinSock::{self, SOCKET}; #[pyfunction] fn closesocket(socket: usize, vm: &VirtualMachine) -> PyResult<()> { - let res = unsafe { winsock2::closesocket(socket as SOCKET) }; + let res = unsafe { WinSock::closesocket(socket as SOCKET) }; if res == 0 { Err(os::errno_err(vm)) } else { @@ -20,7 +20,7 @@ mod _multiprocessing { fn recv(socket: usize, size: usize, vm: &VirtualMachine) -> PyResult { let mut buf = vec![0; size]; let nread = - unsafe { winsock2::recv(socket as SOCKET, buf.as_mut_ptr() as *mut _, size as i32, 0) }; + unsafe { WinSock::recv(socket as SOCKET, buf.as_mut_ptr() as *mut _, size as i32, 0) }; if nread < 0 { Err(os::errno_err(vm)) } else { @@ -31,7 +31,7 @@ mod _multiprocessing { #[pyfunction] fn send(socket: usize, buf: ArgBytesLike, vm: &VirtualMachine) -> PyResult { let ret = buf.with_ref(|b| unsafe { - winsock2::send(socket as SOCKET, b.as_ptr() as *const _, b.len() as i32, 0) + WinSock::send(socket as SOCKET, b.as_ptr() as *const _, b.len() as i32, 0) }); if ret < 0 { Err(os::errno_err(vm)) diff --git a/stdlib/src/select.rs b/stdlib/src/select.rs index 586305b1de..48705f4c9b 100644 --- a/stdlib/src/select.rs +++ b/stdlib/src/select.rs @@ -30,8 +30,8 @@ mod platform { #[allow(non_snake_case)] #[cfg(windows)] mod platform { - use winapi::um::winsock2; - pub use winsock2::{fd_set, select, timeval, FD_SETSIZE, SOCKET as RawFd}; + use windows_sys::Win32::Networking::WinSock; + pub use WinSock::{select, FD_SET as fd_set, FD_SETSIZE, SOCKET as RawFd, TIMEVAL as timeval}; // based off winsock2.h: https://gist.github.com/piscisaureus/906386#file-winsock2-h-L128-L141 @@ -45,7 +45,7 @@ mod platform { slot = slot.add(1); } // slot == &fd_array[fd_count] at this point - if fd_count < FD_SETSIZE as u32 { + if fd_count < FD_SETSIZE { *slot = fd as RawFd; (*set).fd_count += 1; } @@ -56,12 +56,12 @@ mod platform { } pub unsafe fn FD_ISSET(fd: RawFd, set: *mut fd_set) -> bool { - use winapi::um::winsock2::__WSAFDIsSet; + use WinSock::__WSAFDIsSet; __WSAFDIsSet(fd as _, set) != 0 } pub fn check_err(x: i32) -> bool { - x == winsock2::SOCKET_ERROR + x == WinSock::SOCKET_ERROR } } diff --git a/stdlib/src/socket.rs b/stdlib/src/socket.rs index b417a1739c..682cdd9654 100644 --- a/stdlib/src/socket.rs +++ b/stdlib/src/socket.rs @@ -34,18 +34,46 @@ mod _socket { use libc as c; #[cfg(windows)] mod c { - pub use winapi::shared::ifdef::IF_MAX_STRING_SIZE as IF_NAMESIZE; - pub use winapi::shared::mstcpip::*; pub use winapi::shared::netioapi::{if_indextoname, if_nametoindex}; - pub use winapi::shared::ws2def::*; - pub use winapi::shared::ws2ipdef::*; + pub use winapi::shared::ws2def::{ + INADDR_ANY, INADDR_BROADCAST, INADDR_LOOPBACK, INADDR_NONE, + }; pub use winapi::um::winsock2::{ - IPPORT_RESERVED, SD_BOTH as SHUT_RDWR, SD_RECEIVE as SHUT_RD, SD_SEND as SHUT_WR, - SOCK_DGRAM, SOCK_RAW, SOCK_RDM, SOCK_SEQPACKET, SOCK_STREAM, SOL_SOCKET, SO_BROADCAST, - SO_ERROR, SO_EXCLUSIVEADDRUSE, SO_LINGER, SO_OOBINLINE, SO_REUSEADDR, SO_TYPE, - SO_USELOOPBACK, *, + getprotobyname, getservbyname, getservbyport, getsockopt, setsockopt, + SO_EXCLUSIVEADDRUSE, + }; + pub use winapi::um::ws2tcpip::{ + EAI_AGAIN, EAI_BADFLAGS, EAI_FAIL, EAI_FAMILY, EAI_MEMORY, EAI_NODATA, EAI_NONAME, + EAI_SERVICE, EAI_SOCKTYPE, + }; + pub use windows_sys::Win32::Networking::WinSock::{ + AF_DECnet, AF_APPLETALK, AF_IPX, AF_LINK, AI_ADDRCONFIG, AI_ALL, AI_CANONNAME, + AI_NUMERICSERV, AI_V4MAPPED, IPPORT_RESERVED, IPPROTO_AH, IPPROTO_DSTOPTS, IPPROTO_EGP, + IPPROTO_ESP, IPPROTO_FRAGMENT, IPPROTO_GGP, IPPROTO_HOPOPTS, IPPROTO_ICMP, + IPPROTO_ICMPV6, IPPROTO_IDP, IPPROTO_IGMP, IPPROTO_IP, IPPROTO_IP as IPPROTO_IPIP, + IPPROTO_IPV4, IPPROTO_IPV6, IPPROTO_ND, IPPROTO_NONE, IPPROTO_PIM, IPPROTO_PUP, + IPPROTO_RAW, IPPROTO_ROUTING, IPPROTO_TCP, IPPROTO_UDP, IPV6_CHECKSUM, IPV6_DONTFRAG, + IPV6_HOPLIMIT, IPV6_HOPOPTS, IPV6_JOIN_GROUP, IPV6_LEAVE_GROUP, IPV6_MULTICAST_HOPS, + IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, IPV6_PKTINFO, IPV6_RECVRTHDR, IPV6_RECVTCLASS, + IPV6_RTHDR, IPV6_TCLASS, IPV6_UNICAST_HOPS, IPV6_V6ONLY, IP_ADD_MEMBERSHIP, + IP_DROP_MEMBERSHIP, IP_HDRINCL, IP_MULTICAST_IF, IP_MULTICAST_LOOP, IP_MULTICAST_TTL, + IP_OPTIONS, IP_RECVDSTADDR, IP_TOS, IP_TTL, MSG_BCAST, MSG_CTRUNC, MSG_DONTROUTE, + MSG_MCAST, MSG_OOB, MSG_PEEK, MSG_TRUNC, MSG_WAITALL, NI_DGRAM, NI_MAXHOST, NI_MAXSERV, + NI_NAMEREQD, NI_NOFQDN, NI_NUMERICHOST, NI_NUMERICSERV, RCVALL_IPLEVEL, RCVALL_OFF, + RCVALL_ON, RCVALL_SOCKETLEVELONLY, SD_BOTH as SHUT_RDWR, SD_RECEIVE as SHUT_RD, + SD_SEND as SHUT_WR, SIO_KEEPALIVE_VALS, SIO_LOOPBACK_FAST_PATH, SIO_RCVALL, SOCK_DGRAM, + SOCK_RAW, SOCK_RDM, SOCK_SEQPACKET, SOCK_STREAM, SOL_SOCKET, SOMAXCONN, SO_BROADCAST, + SO_ERROR, SO_LINGER, SO_OOBINLINE, SO_REUSEADDR, SO_TYPE, SO_USELOOPBACK, TCP_NODELAY, + WSAEBADF, WSAECONNRESET, WSAENOTSOCK, WSAEWOULDBLOCK, }; - pub use winapi::um::ws2tcpip::*; + pub const IF_NAMESIZE: usize = + windows_sys::Win32::NetworkManagement::Ndis::IF_MAX_STRING_SIZE as _; + pub const AF_UNSPEC: i32 = windows_sys::Win32::Networking::WinSock::AF_UNSPEC as _; + pub const AF_INET: i32 = windows_sys::Win32::Networking::WinSock::AF_INET as _; + pub const AF_INET6: i32 = windows_sys::Win32::Networking::WinSock::AF_INET6 as _; + pub const AI_PASSIVE: i32 = windows_sys::Win32::Networking::WinSock::AI_PASSIVE as _; + pub const AI_NUMERICHOST: i32 = + windows_sys::Win32::Networking::WinSock::AI_NUMERICHOST as _; } // constants #[pyattr(name = "has_ipv6")] @@ -658,7 +686,7 @@ mod _socket { #[cfg(windows)] #[pyattr] - use winapi::shared::ws2def::{ + use windows_sys::Win32::Networking::WinSock::{ IPPROTO_CBT, IPPROTO_ICLFXBM, IPPROTO_IGP, IPPROTO_L2TP, IPPROTO_PGM, IPPROTO_RDP, IPPROTO_SCTP, IPPROTO_ST, }; @@ -2216,7 +2244,7 @@ mod _socket { } #[cfg(windows)] { - winapi::um::winsock2::INVALID_SOCKET as RawSocket + windows_sys::Win32::Networking::WinSock::INVALID_SOCKET as RawSocket } }; @@ -2329,7 +2357,7 @@ mod _socket { #[cfg(unix)] use libc::close; #[cfg(windows)] - use winapi::um::winsock2::closesocket as close; + use windows_sys::Win32::Networking::WinSock::closesocket as close; let ret = unsafe { close(x as _) }; if ret < 0 { let err = crate::common::os::errno(); diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index 746bc53911..d050f0ed3b 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -1453,7 +1453,7 @@ mod windows { #[pyfunction] fn enum_certificates(store_name: PyStrRef, vm: &VirtualMachine) -> PyResult> { use schannel::{cert_context::ValidUses, cert_store::CertStore, RawPointer}; - use winapi::um::wincrypt; + use windows_sys::Win32::Security::Cryptography; // TODO: check every store for it, not just 2 of them: // https://github.com/python/cpython/blob/3.8/Modules/_ssl.c#L5603-L5610 @@ -1465,12 +1465,12 @@ mod windows { let certs = stores.iter().flat_map(|s| s.certs()).map(|c| { let cert = vm.ctx.new_bytes(c.to_der().to_owned()); let enc_type = unsafe { - let ptr = c.as_ptr() as wincrypt::PCCERT_CONTEXT; + let ptr = c.as_ptr() as *const Cryptography::CERT_CONTEXT; (*ptr).dwCertEncodingType }; let enc_type = match enc_type { - wincrypt::X509_ASN_ENCODING => vm.new_pyobj(ascii!("x509_asn")), - wincrypt::PKCS_7_ASN_ENCODING => vm.new_pyobj(ascii!("pkcs_7_asn")), + Cryptography::X509_ASN_ENCODING => vm.new_pyobj(ascii!("x509_asn")), + Cryptography::PKCS_7_ASN_ENCODING => vm.new_pyobj(ascii!("pkcs_7_asn")), other => vm.new_pyobj(other), }; let usage: PyObjectRef = match c.valid_uses()? { From a309cb5d2ca72a324b3be3727c19cc2f1ad0a0a4 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sat, 30 Dec 2023 04:09:02 +0900 Subject: [PATCH 218/893] Fix 1.75 clippy warnings --- common/src/hash.rs | 4 +--- compiler/codegen/src/compile.rs | 5 ++--- stdlib/src/socket.rs | 2 +- vm/src/builtins/object.rs | 4 +++- vm/src/exceptions.rs | 4 ++-- vm/src/warn.rs | 20 ++++++++------------ 6 files changed, 17 insertions(+), 22 deletions(-) diff --git a/common/src/hash.rs b/common/src/hash.rs index 6169003ab1..f514dac326 100644 --- a/common/src/hash.rs +++ b/common/src/hash.rs @@ -59,9 +59,7 @@ impl HashSecret { impl HashSecret { pub fn hash_value(&self, data: &T) -> PyHash { - let mut hasher = self.build_hasher(); - data.hash(&mut hasher); - fix_sentinel(mod_int(hasher.finish() as PyHash)) + fix_sentinel(mod_int(self.hash_one(data) as _)) } pub fn hash_iter<'a, T: 'a, I, F, E>(&self, iter: I, hashf: F) -> Result diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index 217c2dc02b..e5976047a3 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -507,9 +507,8 @@ impl Compiler { SymbolScope::Cell => { cache = &mut info.cellvar_cache; NameOpType::Deref - } - // // TODO: is this right? - // SymbolScope::Unknown => NameOpType::Global, + } // TODO: is this right? + // SymbolScope::Unknown => NameOpType::Global, }; if NameUsage::Load == usage && name == "__debug__" { diff --git a/stdlib/src/socket.rs b/stdlib/src/socket.rs index b417a1739c..fd9ad74fda 100644 --- a/stdlib/src/socket.rs +++ b/stdlib/src/socket.rs @@ -54,7 +54,7 @@ mod _socket { // put IPPROTO_MAX later use c::{ AF_INET, AF_INET6, AF_UNSPEC, INADDR_ANY, INADDR_LOOPBACK, INADDR_NONE, IPPROTO_ICMP, - IPPROTO_ICMPV6, IPPROTO_IP, IPPROTO_IP as IPPROTO_IPIP, IPPROTO_IPV6, IPPROTO_TCP, + IPPROTO_ICMPV6, IPPROTO_IP, IPPROTO_IPIP, IPPROTO_IPV6, IPPROTO_TCP, IPPROTO_TCP as SOL_TCP, IPPROTO_UDP, MSG_CTRUNC, MSG_DONTROUTE, MSG_OOB, MSG_PEEK, MSG_TRUNC, MSG_WAITALL, NI_DGRAM, NI_MAXHOST, NI_NAMEREQD, NI_NOFQDN, NI_NUMERICHOST, NI_NUMERICSERV, SHUT_RD, SHUT_RDWR, SHUT_WR, SOCK_DGRAM, SOCK_STREAM, SOL_SOCKET, diff --git a/vm/src/builtins/object.rs b/vm/src/builtins/object.rs index 351e559df6..efe6aa980f 100644 --- a/vm/src/builtins/object.rs +++ b/vm/src/builtins/object.rs @@ -61,7 +61,9 @@ impl Constructor for PyBaseObject { name, methods ))); } - _ => unreachable!("unimplemented_abstract_method_count is always positive"), + // TODO: remove `allow` when redox build doesn't complain about it + #[allow(unreachable_patterns)] + _ => unreachable!(), } } } diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index 19b035d980..d83e7e48d8 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -779,7 +779,7 @@ impl ExceptionZoo { let errno_getter = ctx.new_readonly_getset("errno", excs.os_error, |exc: PyBaseExceptionRef| { let args = exc.args(); - args.get(0) + args.first() .filter(|_| args.len() > 1 && args.len() <= 5) .cloned() }); @@ -1116,7 +1116,7 @@ pub(super) mod types { args: ::rustpython_vm::function::FuncArgs, vm: &::rustpython_vm::VirtualMachine, ) -> ::rustpython_vm::PyResult<()> { - zelf.set_attr("value", vm.unwrap_or_none(args.args.get(0).cloned()), vm)?; + zelf.set_attr("value", vm.unwrap_or_none(args.args.first().cloned()), vm)?; Ok(()) } } diff --git a/vm/src/warn.rs b/vm/src/warn.rs index ba45714853..d0acccbf29 100644 --- a/vm/src/warn.rs +++ b/vm/src/warn.rs @@ -87,17 +87,13 @@ pub fn warn( } fn get_default_action(vm: &VirtualMachine) -> PyResult { - vm.state - .warnings - .default_action - .clone() - .try_into() - .map_err(|_| { - vm.new_value_error(format!( - "_warnings.defaultaction must be a string, not '{}'", - vm.state.warnings.default_action - )) - }) + Ok(vm.state.warnings.default_action.clone().into()) + // .map_err(|_| { + // vm.new_value_error(format!( + // "_warnings.defaultaction must be a string, not '{}'", + // vm.state.warnings.default_action + // )) + // }) } fn get_filter( @@ -125,7 +121,7 @@ fn get_filter( .ok_or_else(|| vm.new_value_error(format!("_warnings.filters item {i} isn't a 5-tuple")))?; /* Python code: action, msg, cat, mod, ln = item */ - let action = if let Some(action) = tmp_item.get(0) { + let action = if let Some(action) = tmp_item.first() { action.str(vm).map(|action| action.into_object()) } else { Err(vm.new_type_error("action must be a string".to_string())) From 506c8a633ee6c2b361f7f7847678db184d5462b2 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sat, 30 Dec 2023 12:17:22 +0900 Subject: [PATCH 219/893] Fix redox and nightly --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- compiler/codegen/src/symboltable.rs | 4 ++-- stdlib/src/resource.rs | 4 ++-- stdlib/src/socket.rs | 14 +++++++------- vm/src/builtins/type.rs | 4 ++-- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7bd76c9783..e81b3ce974 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1120,9 +1120,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.141" +version = "0.2.151" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" +checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" [[package]] name = "libffi" diff --git a/Cargo.toml b/Cargo.toml index 35a03003ac..71e9d221dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,7 +56,7 @@ indexmap = { version = "1.9.3", features = ["std"] } insta = "1.33.0" itertools = "0.11.0" is-macro = "0.3.0" -libc = "0.2.133" +libc = "0.2.151" log = "0.4.16" nix = "0.26" malachite-bigint = "0.2.0" diff --git a/compiler/codegen/src/symboltable.rs b/compiler/codegen/src/symboltable.rs index 5283abad53..75e327fe34 100644 --- a/compiler/codegen/src/symboltable.rs +++ b/compiler/codegen/src/symboltable.rs @@ -230,10 +230,10 @@ mod stack { res.unwrap_or_else(|x| panic::resume_unwind(x)) } - pub fn iter(&self) -> impl Iterator + DoubleEndedIterator + '_ { + pub fn iter(&self) -> impl DoubleEndedIterator + '_ { self.as_ref().iter().copied() } - pub fn iter_mut(&mut self) -> impl Iterator + DoubleEndedIterator + '_ { + pub fn iter_mut(&mut self) -> impl DoubleEndedIterator + '_ { self.as_mut().iter_mut().map(|x| &mut **x) } // pub fn top(&self) -> Option<&T> { diff --git a/stdlib/src/resource.rs b/stdlib/src/resource.rs index 85ba5d7be2..075e191284 100644 --- a/stdlib/src/resource.rs +++ b/stdlib/src/resource.rs @@ -11,10 +11,10 @@ mod resource { use std::{io, mem}; cfg_if::cfg_if! { - if #[cfg(any(target_os = "linux", target_os = "android"))] { + if #[cfg(target_os = "android")] { use libc::RLIM_NLIMITS; } else { - // in bsd-ish platforms, this constant isn't abi-stable across os versions, so we just + // This constant isn't abi-stable across os versions, so we just // pick a high number so we don't get false positive ValueErrors and just bubble up the // EINVAL that get/setrlimit return on an invalid resource const RLIM_NLIMITS: i32 = 256; diff --git a/stdlib/src/socket.rs b/stdlib/src/socket.rs index fd9ad74fda..d0642f19e9 100644 --- a/stdlib/src/socket.rs +++ b/stdlib/src/socket.rs @@ -54,19 +54,19 @@ mod _socket { // put IPPROTO_MAX later use c::{ AF_INET, AF_INET6, AF_UNSPEC, INADDR_ANY, INADDR_LOOPBACK, INADDR_NONE, IPPROTO_ICMP, - IPPROTO_ICMPV6, IPPROTO_IP, IPPROTO_IPIP, IPPROTO_IPV6, IPPROTO_TCP, - IPPROTO_TCP as SOL_TCP, IPPROTO_UDP, MSG_CTRUNC, MSG_DONTROUTE, MSG_OOB, MSG_PEEK, - MSG_TRUNC, MSG_WAITALL, NI_DGRAM, NI_MAXHOST, NI_NAMEREQD, NI_NOFQDN, NI_NUMERICHOST, - NI_NUMERICSERV, SHUT_RD, SHUT_RDWR, SHUT_WR, SOCK_DGRAM, SOCK_STREAM, SOL_SOCKET, - SO_BROADCAST, SO_ERROR, SO_LINGER, SO_OOBINLINE, SO_REUSEADDR, SO_TYPE, TCP_NODELAY, + IPPROTO_ICMPV6, IPPROTO_IP, IPPROTO_IPV6, IPPROTO_TCP, IPPROTO_TCP as SOL_TCP, IPPROTO_UDP, + MSG_CTRUNC, MSG_DONTROUTE, MSG_OOB, MSG_PEEK, MSG_TRUNC, MSG_WAITALL, NI_DGRAM, NI_MAXHOST, + NI_NAMEREQD, NI_NOFQDN, NI_NUMERICHOST, NI_NUMERICSERV, SHUT_RD, SHUT_RDWR, SHUT_WR, + SOCK_DGRAM, SOCK_STREAM, SOL_SOCKET, SO_BROADCAST, SO_ERROR, SO_LINGER, SO_OOBINLINE, + SO_REUSEADDR, SO_TYPE, TCP_NODELAY, }; #[cfg(not(target_os = "redox"))] #[pyattr] use c::{ AF_DECnet, AF_APPLETALK, AF_IPX, IPPROTO_AH, IPPROTO_DSTOPTS, IPPROTO_EGP, IPPROTO_ESP, - IPPROTO_FRAGMENT, IPPROTO_HOPOPTS, IPPROTO_IDP, IPPROTO_IGMP, IPPROTO_NONE, IPPROTO_PIM, - IPPROTO_PUP, IPPROTO_RAW, IPPROTO_ROUTING, + IPPROTO_FRAGMENT, IPPROTO_HOPOPTS, IPPROTO_IDP, IPPROTO_IGMP, IPPROTO_IPIP, IPPROTO_NONE, + IPPROTO_PIM, IPPROTO_PUP, IPPROTO_RAW, IPPROTO_ROUTING, }; #[cfg(unix)] diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index e0da4f7dc4..6e1ddebb25 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -299,7 +299,7 @@ impl PyType { } } - pub fn iter_mro(&self) -> impl Iterator + DoubleEndedIterator { + pub fn iter_mro(&self) -> impl DoubleEndedIterator { std::iter::once(self).chain(self.mro.iter().map(|cls| -> &PyType { cls })) } @@ -420,7 +420,7 @@ impl Py { self.as_object().is(cls.borrow()) || self.mro.iter().any(|c| c.is(cls.borrow())) } - pub fn iter_mro(&self) -> impl Iterator> + DoubleEndedIterator { + pub fn iter_mro(&self) -> impl DoubleEndedIterator> { std::iter::once(self).chain(self.mro.iter().map(|x| x.deref())) } From 32f662ae806848dd4fad7c38dc5398636b6222f0 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Sat, 30 Dec 2023 17:14:08 +0900 Subject: [PATCH 220/893] Bump openssl from 0.10.55 to 0.10.62 --- Cargo.lock | 14 +++++++------- stdlib/Cargo.toml | 2 +- stdlib/src/ssl.rs | 3 +-- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 67adca537c..140f078ff5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1480,11 +1480,11 @@ checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" [[package]] name = "openssl" -version = "0.10.55" +version = "0.10.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" +checksum = "8cde4d2d9200ad5909f8dac647e29482e07c3a35de8a13fce7c9c7747ad9f671" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.0", "cfg-if", "foreign-types", "libc", @@ -1512,18 +1512,18 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-src" -version = "111.25.0+1.1.1t" +version = "300.2.1+3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3173cd3626c43e3854b1b727422a276e568d9ec5fe8cec197822cf52cfb743d6" +checksum = "3fe476c29791a5ca0d1273c697e96085bbabbbea2ef7afd5617e78a4b40332d3" dependencies = [ "cc", ] [[package]] name = "openssl-sys" -version = "0.9.90" +version = "0.9.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" +checksum = "c1665caf8ab2dc9aef43d1c0023bd904633a6a05cb30b0ad59bec2ae986e57a7" dependencies = [ "cc", "libc", diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index 50193ee0f8..7c394804b1 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -96,7 +96,7 @@ termios = "0.3.3" gethostname = "0.2.3" socket2 = { version = "0.4.4", features = ["all"] } dns-lookup = "1.0.8" -openssl = { version = "0.10.55", optional = true } +openssl = { version = "0.10.62", optional = true } openssl-sys = { version = "0.9.80", optional = true } openssl-probe = { version = "0.1.5", optional = true } foreign-types-shared = { version = "0.1.1", optional = true } diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index d050f0ed3b..cf3a802038 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -703,9 +703,8 @@ mod _ssl { let certs = self .ctx() .cert_store() - .objects() + .all_certificates() .iter() - .filter_map(|obj| obj.x509()) .map(|cert| cert_to_py(vm, cert, binary_form)) .collect::, _>>()?; Ok(certs) From 1ab133dae8c47978d59fd24196b53850f884af65 Mon Sep 17 00:00:00 2001 From: Evan Rittenhouse Date: Mon, 8 Jan 2024 00:03:57 -0600 Subject: [PATCH 221/893] None.__ne__(None) should be NotImplemented (#5124) --- extra_tests/snippets/builtin_none.py | 2 +- vm/src/builtins/singletons.rs | 29 +++++++++++++++++++++++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/extra_tests/snippets/builtin_none.py b/extra_tests/snippets/builtin_none.py index c75f04ea73..c8a07f518d 100644 --- a/extra_tests/snippets/builtin_none.py +++ b/extra_tests/snippets/builtin_none.py @@ -22,4 +22,4 @@ def none2(): assert None.__eq__(3) is NotImplemented assert None.__ne__(3) is NotImplemented assert None.__eq__(None) is True -# assert None.__ne__(None) is False # changed in 3.12 +assert None.__ne__(None) is NotImplemented diff --git a/vm/src/builtins/singletons.rs b/vm/src/builtins/singletons.rs index 65b171a262..d4487b586d 100644 --- a/vm/src/builtins/singletons.rs +++ b/vm/src/builtins/singletons.rs @@ -2,9 +2,10 @@ use super::{PyStrRef, PyType, PyTypeRef}; use crate::{ class::PyClassImpl, convert::ToPyObject, + function::{PyArithmeticValue, PyComparisonValue}, protocol::PyNumberMethods, - types::{AsNumber, Constructor, Representable}, - Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, + types::{AsNumber, Comparable, Constructor, PyComparisonOp, Representable}, + Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, }; #[pyclass(module = false, name = "NoneType")] @@ -42,7 +43,7 @@ impl Constructor for PyNone { } } -#[pyclass(with(Constructor, AsNumber, Representable))] +#[pyclass(with(Constructor, Comparable, AsNumber, Representable))] impl PyNone { #[pymethod(magic)] fn bool(&self) -> bool { @@ -72,6 +73,28 @@ impl AsNumber for PyNone { } } +impl Comparable for PyNone { + fn cmp( + _zelf: &Py, + other: &PyObject, + op: PyComparisonOp, + vm: &VirtualMachine, + ) -> PyResult { + let value = match op { + PyComparisonOp::Eq => { + if vm.is_none(other) { + PyArithmeticValue::Implemented(true) + } else { + PyArithmeticValue::NotImplemented + } + } + _ => PyComparisonValue::NotImplemented, + }; + + Ok(value) + } +} + #[pyclass(module = false, name = "NotImplementedType")] #[derive(Debug)] pub struct PyNotImplemented; From 602015fca14edcc2d9a98f1dc8923a7bb480ab17 Mon Sep 17 00:00:00 2001 From: Noa Date: Tue, 2 Jan 2024 17:32:49 -0600 Subject: [PATCH 222/893] Update nix and socket2 --- Cargo.lock | 138 ++++++++++++++++------------------------- Cargo.toml | 5 +- stdlib/Cargo.toml | 4 +- stdlib/src/socket.rs | 30 ++++----- vm/src/stdlib/posix.rs | 37 ++++++++--- 5 files changed, 97 insertions(+), 117 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 140f078ff5..481ed44d18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -257,13 +257,11 @@ dependencies = [ [[package]] name = "clipboard-win" -version = "4.5.0" +version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7191c27c2357d9b7ef96baac1773290d4ca63b24205b82a3fd8a0637afcf0362" +checksum = "c57002a5d9be777c1ef967e33674dac9ebd310d8893e4e3437b14d5f0f6372cc" dependencies = [ "error-code", - "str-buf", - "winapi", ] [[package]] @@ -673,14 +671,14 @@ dependencies = [ [[package]] name = "dns-lookup" -version = "1.0.8" +version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53ecafc952c4528d9b51a458d1a8904b81783feff9fde08ab6ed2545ff396872" +checksum = "e5766087c2235fec47fafa4cfecc81e494ee679d0fd4a59887ea0919bfb0e4fc" dependencies = [ "cfg-if", "libc", "socket2", - "winapi", + "windows-sys 0.48.0", ] [[package]] @@ -732,34 +730,19 @@ dependencies = [ [[package]] name = "errno" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" -dependencies = [ - "errno-dragonfly", - "libc", - "windows-sys 0.48.0", -] - -[[package]] -name = "errno-dragonfly" -version = "0.1.2" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ - "cc", "libc", + "windows-sys 0.52.0", ] [[package]] name = "error-code" -version = "2.3.1" +version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64f18991e7bf11e7ffee451b5318b5c1a73c52d0d0ada6e5a3017c8c1ced6a21" -dependencies = [ - "libc", - "str-buf", -] +checksum = "281e452d3bad4005426416cdba5ccfd4f5c1280e10099e21db27f7c1c28347fc" [[package]] name = "exitcode" @@ -769,13 +752,13 @@ checksum = "de853764b47027c2e862a995c34978ffa63c1501f2e15f987ba11bd4f9bba193" [[package]] name = "fd-lock" -version = "3.0.12" +version = "4.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39ae6b3d9530211fb3b12a95374b8b0823be812f53d09e18c5675c0146b09642" +checksum = "7e5768da2206272c81ef0b5e951a41862938a6070da63bcea197899942d3b947" dependencies = [ "cfg-if", "rustix", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -939,12 +922,6 @@ dependencies = [ "libc", ] -[[package]] -name = "hermit-abi" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" - [[package]] name = "hex" version = "0.4.3" @@ -957,6 +934,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df" +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "iana-time-zone" version = "0.1.53" @@ -1004,17 +990,6 @@ dependencies = [ "yaml-rust", ] -[[package]] -name = "io-lifetimes" -version = "1.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" -dependencies = [ - "hermit-abi 0.3.1", - "libc", - "windows-sys 0.48.0", -] - [[package]] name = "is-macro" version = "0.3.0" @@ -1183,9 +1158,9 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.3.1" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d59d8c75012853d2e872fb56bc8a2e53718e2cafe1a4c823143141c6d90c322f" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "lock_api" @@ -1217,9 +1192,9 @@ dependencies = [ [[package]] name = "mac_address" -version = "1.1.4" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b238e3235c8382b7653c6408ed1b08dd379bdb9fdf990fb0bbae3db2cc0ae963" +checksum = "4863ee94f19ed315bf3bc00299338d857d4b5bc856af375cc97d237382ad3856" dependencies = [ "nix 0.23.2", "winapi", @@ -1344,6 +1319,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + [[package]] name = "miniz_oxide" version = "0.6.2" @@ -1386,16 +1370,14 @@ dependencies = [ [[package]] name = "nix" -version = "0.26.2" +version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" +checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.0", "cfg-if", "libc", - "memoffset 0.7.1", - "pin-utils", - "static_assertions", + "memoffset 0.9.0", ] [[package]] @@ -1615,12 +1597,6 @@ dependencies = [ "siphasher", ] -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - [[package]] name = "pkg-config" version = "0.3.26" @@ -1907,16 +1883,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.37.11" +version = "0.38.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85597d61f83914ddeba6a47b3b8ffe7365107221c2e557ed94426489fefb5f77" +checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.0", "errno", - "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -2176,7 +2151,7 @@ dependencies = [ "memchr", "memmap2", "mt19937", - "nix 0.26.2", + "nix 0.27.1", "num-complex", "num-integer", "num-traits", @@ -2245,7 +2220,7 @@ dependencies = [ "malachite-bigint", "memchr", "memoffset 0.6.5", - "nix 0.26.2", + "nix 0.27.1", "num-complex", "num-integer", "num-traits", @@ -2319,21 +2294,20 @@ checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70" [[package]] name = "rustyline" -version = "11.0.0" +version = "13.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfc8644681285d1fb67a467fb3021bfea306b99b4146b166a1fe3ada965eece" +checksum = "02a2d683a4ac90aeef5b1013933f6d977bd37d51ff3f4dad829d4931a7e6be86" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.0", "cfg-if", "clipboard-win", - "dirs-next", "fd-lock", + "home", "libc", "log", "memchr", - "nix 0.26.2", + "nix 0.27.1", "radix_trie", - "scopeguard", "unicode-segmentation", "unicode-width", "utf8parse", @@ -2493,12 +2467,12 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "socket2" -version = "0.4.7" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "winapi", + "windows-sys 0.48.0", ] [[package]] @@ -2518,12 +2492,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "str-buf" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e08d8363704e6c71fc928674353e6b7c23dcea9d82d7012c8faf2a3a025f8d0" - [[package]] name = "strsim" version = "0.8.0" diff --git a/Cargo.toml b/Cargo.toml index 71e9d221dd..b7b3db3dc8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,7 @@ itertools = "0.11.0" is-macro = "0.3.0" libc = "0.2.151" log = "0.4.16" -nix = "0.26" +nix = { version = "0.27", features = ["fs", "user", "process", "term", "time", "signal", "ioctl", "socket", "sched", "zerocopy", "dir", "hostname", "net", "poll"] } malachite-bigint = "0.2.0" malachite-q = "0.4.4" malachite-base = "0.4.4" @@ -70,7 +70,7 @@ once_cell = "1.18" parking_lot = "0.12.1" paste = "1.0.7" rand = "0.8.5" -rustyline = "11" +rustyline = "13" serde = { version = "1.0.133", default-features = false } schannel = "0.1.22" static_assertions = "1.1" @@ -152,5 +152,4 @@ lto = "thin" [patch.crates-io] # REDOX START, Uncomment when you want to compile/check with redoxer -# nix = { git = "https://github.com/coolreader18/nix", branch = "0.26.2-redox" } # REDOX END diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index 7c394804b1..77b79cda3d 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -94,8 +94,8 @@ termios = "0.3.3" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] gethostname = "0.2.3" -socket2 = { version = "0.4.4", features = ["all"] } -dns-lookup = "1.0.8" +socket2 = { version = "0.5.4", features = ["all"] } +dns-lookup = "2" openssl = { version = "0.10.62", optional = true } openssl-sys = { version = "0.9.80", optional = true } openssl-probe = { version = "0.1.5", optional = true } diff --git a/stdlib/src/socket.rs b/stdlib/src/socket.rs index c6624df546..12231bac99 100644 --- a/stdlib/src/socket.rs +++ b/stdlib/src/socket.rs @@ -21,7 +21,7 @@ mod _socket { }; use crossbeam_utils::atomic::AtomicCell; use num_traits::ToPrimitive; - use socket2::{Domain, Protocol, Socket, Type as SocketType}; + use socket2::Socket; use std::{ ffi, io::{self, Read, Write}, @@ -1117,11 +1117,7 @@ mod _socket { if proto == -1 { proto = 0 } - sock = Socket::new( - Domain::from(family), - SocketType::from(socket_kind), - Some(Protocol::from(proto)), - )?; + sock = Socket::new(family.into(), socket_kind.into(), Some(proto.into()))?; }; Ok(zelf.init_inner(family, socket_kind, proto, sock)?) } @@ -1195,7 +1191,7 @@ mod _socket { let mut buf = buf.borrow_buf_mut(); let buf = &mut *buf; self.sock_op(vm, SelectKind::Read, || { - sock.recv_with_flags(slice_as_uninit(buf), flags) + sock.recv_with_flags(unsafe { slice_as_uninit(buf) }, flags) }) } @@ -1245,7 +1241,7 @@ mod _socket { let flags = flags.unwrap_or(0); let sock = self.sock()?; let (n, addr) = self.sock_op(vm, SelectKind::Read, || { - sock.recv_from_with_flags(slice_as_uninit(buf), flags) + sock.recv_from_with_flags(unsafe { slice_as_uninit(buf) }, flags) })?; Ok((n, get_addr_tuple(&addr, vm))) } @@ -1581,16 +1577,13 @@ mod _socket { return get_ip_addr_tuple(&addr, vm); } #[cfg(unix)] - use nix::sys::socket::{SockaddrLike, UnixAddr}; - #[cfg(unix)] - if let Some(unix_addr) = unsafe { UnixAddr::from_raw(addr.as_ptr(), Some(addr.len())) } { + if addr.is_unix() { use std::os::unix::ffi::OsStrExt; - #[cfg(any(target_os = "android", target_os = "linux"))] - if let Some(abstractpath) = unix_addr.as_abstract() { + if let Some(abstractpath) = addr.as_abstract_namespace() { return vm.ctx.new_bytes([b"\0", abstractpath].concat()).into(); } // necessary on macos - let path = ffi::OsStr::as_bytes(unix_addr.path().unwrap_or("".as_ref()).as_ref()); + let path = ffi::OsStr::as_bytes(addr.as_pathname().unwrap_or("".as_ref()).as_ref()); let nul_pos = memchr::memchr(b'\0', path).unwrap_or(path.len()); let path = ffi::OsStr::from_bytes(&path[..nul_pos]); return vm.ctx.new_str(path.to_string_lossy()).into(); @@ -1678,8 +1671,8 @@ mod _socket { Ok(s.to_string_lossy().into_owned()) } - fn slice_as_uninit(v: &mut [T]) -> &mut [MaybeUninit] { - unsafe { &mut *(v as *mut [T] as *mut [MaybeUninit]) } + unsafe fn slice_as_uninit(v: &mut [T]) -> &mut [MaybeUninit] { + &mut *(v as *mut [T] as *mut [MaybeUninit]) } enum IoOrPyException { @@ -1733,7 +1726,6 @@ mod _socket { kind: SelectKind, interval: Option, ) -> io::Result { - let fd = sock_fileno(sock); #[cfg(unix)] { use nix::poll::*; @@ -1742,7 +1734,7 @@ mod _socket { SelectKind::Write => PollFlags::POLLOUT, SelectKind::Connect => PollFlags::POLLOUT | PollFlags::POLLERR, }; - let mut pollfd = [PollFd::new(fd, events)]; + let mut pollfd = [PollFd::new(sock, events)]; let timeout = match interval { Some(d) => d.as_millis() as _, None => -1, @@ -1754,6 +1746,8 @@ mod _socket { { use crate::select; + let fd = sock_fileno(sock); + let mut reads = select::FdSet::new(); let mut writes = select::FdSet::new(); let mut errs = select::FdSet::new(); diff --git a/vm/src/stdlib/posix.rs b/vm/src/stdlib/posix.rs index 1adf0006ef..4351acf86b 100644 --- a/vm/src/stdlib/posix.rs +++ b/vm/src/stdlib/posix.rs @@ -41,7 +41,7 @@ pub mod module { env, ffi::{CStr, CString}, fs, io, - os::unix::io::RawFd, + os::fd::{AsRawFd, BorrowedFd, IntoRawFd, OwnedFd, RawFd}, }; use strum_macros::{EnumIter, EnumString}; @@ -162,6 +162,24 @@ pub mod module { #[pyattr] const _COPYFILE_DATA: u32 = 1 << 3; + impl TryFromObject for BorrowedFd<'_> { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + let fd = i32::try_from_object(vm, obj)?; + if fd == -1 { + return Err(io::Error::from_raw_os_error(libc::EBADF).into_pyexception(vm)); + } + // SAFETY: none, really. but, python's os api of passing around file descriptors + // everywhere isn't really io-safe anyway, so, this is passed to the user. + Ok(unsafe { BorrowedFd::borrow_raw(fd) }) + } + } + + impl ToPyObject for OwnedFd { + fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { + self.into_raw_fd().to_pyobject(vm) + } + } + // Flags for os_access bitflags! { #[derive(Copy, Clone, Debug, PartialEq)] @@ -1196,10 +1214,11 @@ pub mod module { #[cfg(not(target_os = "redox"))] #[pyfunction] - fn openpty(vm: &VirtualMachine) -> PyResult<(i32, i32)> { + fn openpty(vm: &VirtualMachine) -> PyResult<(OwnedFd, OwnedFd)> { let r = nix::pty::openpty(None, None).map_err(|err| err.into_pyexception(vm))?; - for fd in &[r.master, r.slave] { - super::raw_set_inheritable(*fd, false).map_err(|e| e.into_pyexception(vm))?; + for fd in [&r.master, &r.slave] { + super::raw_set_inheritable(fd.as_raw_fd(), false) + .map_err(|e| e.into_pyexception(vm))?; } Ok((r.master, r.slave)) } @@ -2013,9 +2032,9 @@ pub mod module { #[cfg(any(target_os = "linux", target_os = "macos"))] #[derive(FromArgs)] - struct SendFileArgs { - out_fd: i32, - in_fd: i32, + struct SendFileArgs<'fd> { + out_fd: BorrowedFd<'fd>, + in_fd: BorrowedFd<'fd>, offset: crate::common::crt_fd::Offset, count: i64, #[cfg(target_os = "macos")] @@ -2033,7 +2052,7 @@ pub mod module { #[cfg(target_os = "linux")] #[pyfunction] - fn sendfile(args: SendFileArgs, vm: &VirtualMachine) -> PyResult { + fn sendfile(args: SendFileArgs<'_>, vm: &VirtualMachine) -> PyResult { let mut file_offset = args.offset; let res = nix::sys::sendfile::sendfile( @@ -2062,7 +2081,7 @@ pub mod module { #[cfg(target_os = "macos")] #[pyfunction] - fn sendfile(args: SendFileArgs, vm: &VirtualMachine) -> PyResult { + fn sendfile(args: SendFileArgs<'_>, vm: &VirtualMachine) -> PyResult { let headers = _extract_vec_bytes(args.headers, vm)?; let count = headers .as_ref() From 9cc571be95922d6f9d0b2a48c5fb54b3f66723f0 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Tue, 9 Jan 2024 20:53:40 +0900 Subject: [PATCH 223/893] Fix windows stdlib build --- stdlib/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index 77b79cda3d..d077603c4b 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -112,7 +112,7 @@ widestring = { workspace = true } [target.'cfg(windows)'.dependencies.winapi] version = "0.3.9" features = [ - "winsock2", "ifdef", "netioapi", + "winsock2", "ifdef", "netioapi", "ws2tcpip", ] [target.'cfg(windows)'.dependencies.windows-sys] From 28f0fa48a4bb069c758df9789e478210eda38977 Mon Sep 17 00:00:00 2001 From: NakanoMiku <91249276+NakanoMiku39@users.noreply.github.com> Date: Thu, 11 Jan 2024 16:48:56 +0800 Subject: [PATCH 224/893] Fix abc error messages (#5140) Co-authored-by: Jeong, YunWon --- Lib/test/test_abc.py | 12 ------------ Lib/test/test_dataclasses.py | 1 + vm/src/builtins/object.rs | 6 +++--- 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/Lib/test/test_abc.py b/Lib/test/test_abc.py index d912954a41..ac46ea67bb 100644 --- a/Lib/test/test_abc.py +++ b/Lib/test/test_abc.py @@ -149,8 +149,6 @@ def foo(): return 4 self.assertEqual(D.foo(), 4) self.assertEqual(D().foo(), 4) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_object_new_with_one_abstractmethod(self): class C(metaclass=abc_ABCMeta): @abc.abstractmethod @@ -159,8 +157,6 @@ def method_one(self): msg = r"class C without an implementation for abstract method 'method_one'" self.assertRaisesRegex(TypeError, msg, C) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_object_new_with_many_abstractmethods(self): class C(metaclass=abc_ABCMeta): @abc.abstractmethod @@ -526,8 +522,6 @@ def foo(self): self.assertEqual(A.__abstractmethods__, set()) A() - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_update_new_abstractmethods(self): class A(metaclass=abc_ABCMeta): @abc.abstractmethod @@ -544,8 +538,6 @@ def updated_foo(self): msg = "class A without an implementation for abstract methods 'bar', 'foo'" self.assertRaisesRegex(TypeError, msg, A) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_update_implementation(self): class A(metaclass=abc_ABCMeta): @abc.abstractmethod @@ -597,8 +589,6 @@ def updated_foo(self): A() self.assertFalse(hasattr(A, '__abstractmethods__')) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_update_del_implementation(self): class A(metaclass=abc_ABCMeta): @abc.abstractmethod @@ -618,8 +608,6 @@ def foo(self): msg = "class B without an implementation for abstract method 'foo'" self.assertRaisesRegex(TypeError, msg, B) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_update_layered_implementation(self): class A(metaclass=abc_ABCMeta): @abc.abstractmethod diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index 33bd9d0cb6..2e51e43ae8 100644 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -3676,6 +3676,7 @@ class Date(Ordered): self.assertFalse(inspect.isabstract(Date)) self.assertGreater(Date(2020,12,25), Date(2020,8,31)) + @unittest.expectedFailure def test_maintain_abc(self): class A(abc.ABC): @abc.abstractmethod diff --git a/vm/src/builtins/object.rs b/vm/src/builtins/object.rs index efe6aa980f..3b3d1ab365 100644 --- a/vm/src/builtins/object.rs +++ b/vm/src/builtins/object.rs @@ -41,7 +41,7 @@ impl Constructor for PyBaseObject { if let Some(unimplemented_abstract_method_count) = abs_methods.length_opt(vm) { let methods: Vec = abs_methods.try_to_value(vm)?; let methods: String = - Itertools::intersperse(methods.iter().map(|name| name.as_str()), ", ") + Itertools::intersperse(methods.iter().map(|name| name.as_str()), "', '") .collect(); let unimplemented_abstract_method_count = unimplemented_abstract_method_count?; @@ -51,13 +51,13 @@ impl Constructor for PyBaseObject { 0 => {} 1 => { return Err(vm.new_type_error(format!( - "Can't instantiate abstract class {} with abstract method {}", + "class {} without an implementation for abstract method '{}'", name, methods ))); } 2.. => { return Err(vm.new_type_error(format!( - "Can't instantiate abstract class {} with abstract methods {}", + "class {} without an implementation for abstract methods '{}'", name, methods ))); } From aaae5662311850646bcf81097de5fbdf0e718360 Mon Sep 17 00:00:00 2001 From: "kenny the :/" Date: Fri, 12 Jan 2024 20:16:53 +0800 Subject: [PATCH 225/893] Raise error on power with negative number (#5143) --- Lib/test/test_complex.py | 2 -- vm/src/builtins/complex.rs | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/Lib/test/test_complex.py b/Lib/test/test_complex.py index 72999c2f7d..b26dd00d8f 100644 --- a/Lib/test/test_complex.py +++ b/Lib/test/test_complex.py @@ -236,8 +236,6 @@ def test_divmod_zero_division(self): for a, b in ZERO_DIVISION: self.assertRaises(TypeError, divmod, a, b) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_pow(self): self.assertAlmostEqual(pow(1+1j, 0+0j), 1.0) self.assertAlmostEqual(pow(0+0j, 2+0j), 0.0) diff --git a/vm/src/builtins/complex.rs b/vm/src/builtins/complex.rs index 609dcb9b6c..4a3125c138 100644 --- a/vm/src/builtins/complex.rs +++ b/vm/src/builtins/complex.rs @@ -103,7 +103,7 @@ fn inner_div(v1: Complex64, v2: Complex64, vm: &VirtualMachine) -> PyResult PyResult { if v1.is_zero() { - return if v2.im != 0.0 { + return if v2.re < 0.0 || v2.im != 0.0 { let msg = format!("{v1} cannot be raised to a negative or complex power"); Err(vm.new_zero_division_error(msg)) } else if v2.is_zero() { From a777d22a537bdda61a7df5f4c6af6369de1cadc6 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 8 Dec 2023 22:36:46 +0200 Subject: [PATCH 226/893] fix _count --- src/constants.rs | 1 + src/engine.rs | 23 +++++++++++++---------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index f3962b339a..0d5bb41939 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -101,6 +101,7 @@ pub enum SreCatCode { UNI_NOT_LINEBREAK = 17, } bitflags! { + #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct SreFlag: u16 { const TEMPLATE = 1; const IGNORECASE = 2; diff --git a/src/engine.rs b/src/engine.rs index 7334516c2f..f49560c2c6 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -672,7 +672,9 @@ fn op_min_repeat_one( ctx.count = if min_count == 0 { 0 } else { - let count = _count(req, state, ctx, min_count); + let mut next_ctx = *ctx; + next_ctx.skip_code(4); + let count = _count(req, state, next_ctx, min_count); if count < min_count { return ctx.failure(); } @@ -713,7 +715,9 @@ fn op_min_repeat_one( state.string_position = ctx.string_position; - if _count(req, state, ctx, 1) == 0 { + let mut next_ctx = *ctx; + next_ctx.skip_code(4); + if _count(req, state, next_ctx, 1) == 0 { state.marks.pop_discard(); return ctx.failure(); } @@ -741,7 +745,9 @@ fn op_repeat_one(req: &Request, state: &mut State, ctx: &mut state.string_position = ctx.string_position; - let count = _count(req, state, ctx, max_count); + let mut next_ctx = *ctx; + next_ctx.skip_code(4); + let count = _count(req, state, next_ctx, max_count); ctx.skip_char(req, count); if count < min_count { return ctx.failure(); @@ -1428,17 +1434,16 @@ fn charset(set: &[u32], ch: u32) -> bool { fn _count( req: &Request, state: &mut State, - ctx: &MatchContext, + mut ctx: MatchContext, max_count: usize, ) -> usize { - let mut ctx = *ctx; let max_count = std::cmp::min(max_count, ctx.remaining_chars(req)); let end = ctx.string_position + max_count; let opcode = SreOpcode::try_from(ctx.peek_code(req, 0)).unwrap(); match opcode { SreOpcode::ANY => { - while !ctx.string_position < end && !ctx.at_linebreak(req) { + while ctx.string_position < end && !ctx.at_linebreak(req) { ctx.skip_char(req, 1); } } @@ -1446,8 +1451,7 @@ fn _count( ctx.skip_char(req, max_count); } SreOpcode::IN => { - while !ctx.string_position < end && charset(&ctx.pattern(req)[2..], ctx.peek_char(req)) - { + while ctx.string_position < end && charset(&ctx.pattern(req)[2..], ctx.peek_char(req)) { ctx.skip_char(req, 1); } } @@ -1483,7 +1487,6 @@ fn _count( /* General case */ let mut count = 0; - ctx.skip_code(4); let reset_position = ctx.code_position; while count < max_count { @@ -1511,7 +1514,7 @@ fn general_count_literal bool>( mut f: F, ) { let ch = ctx.peek_code(req, 1); - while !ctx.string_position < end && f(ch, ctx.peek_char(req)) { + while ctx.string_position < end && f(ch, ctx.peek_char(req)) { ctx.skip_char(req, 1); } } From d73cc5f58c94e2589efadf1b21f4b5a811869682 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 8 Dec 2023 22:42:38 +0200 Subject: [PATCH 227/893] update version and dependency --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 30e403b54c..de1d68cf6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sre-engine" -version = "0.4.2" +version = "0.4.3" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" repository = "https://github.com/RustPython/sre-engine" @@ -10,6 +10,6 @@ keywords = ["regex"] include = ["LICENSE", "src/**/*.rs"] [dependencies] -num_enum = "0.5.9" +num_enum = "0.7" bitflags = "2" optional = "0.5" From 9070e12e0df3ece6949dc3d9d650fa2d37d8eb54 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 10 Dec 2023 20:50:43 +0200 Subject: [PATCH 228/893] refactor _match with nest loop --- benches/benches.rs | 28 +- src/engine.rs | 894 +++++++++++++++++++++++++++++++++------------ tests/tests.rs | 30 +- 3 files changed, 677 insertions(+), 275 deletions(-) diff --git a/benches/benches.rs b/benches/benches.rs index 604cf91f42..fe470d023c 100644 --- a/benches/benches.rs +++ b/benches/benches.rs @@ -10,10 +10,7 @@ struct Pattern { } impl Pattern { - fn state<'a, S: engine::StrDrive>( - &self, - string: S, - ) -> (engine::Request<'a, S>, engine::State) { + fn state<'a, S: engine::StrDrive>(&self, string: S) -> (engine::Request<'a, S>, engine::State) { self.state_range(string, 0..usize::MAX) } @@ -21,7 +18,7 @@ impl Pattern { &self, string: S, range: std::ops::Range, - ) -> (engine::Request<'a, S>, engine::State) { + ) -> (engine::Request<'a, S>, engine::State) { let req = engine::Request::new(string, range.start, range.end, self.code, false); let state = engine::State::default(); (req, state) @@ -93,29 +90,22 @@ fn benchmarks(b: &mut Bencher) { b.iter(move || { for (p, s) in &tests { let (req, mut state) = p.state(s.clone()); - state.search(req); - assert!(state.has_matched); + assert!(state.search(req)); let (req, mut state) = p.state(s.clone()); - state.pymatch(req); - assert!(state.has_matched); + assert!(state.pymatch(&req)); let (mut req, mut state) = p.state(s.clone()); req.match_all = true; - state.pymatch(req); - assert!(state.has_matched); + assert!(state.pymatch(&req)); let s2 = format!("{}{}{}", " ".repeat(10000), s, " ".repeat(10000)); let (req, mut state) = p.state_range(s2.as_str(), 0..usize::MAX); - state.search(req); - assert!(state.has_matched); + assert!(state.search(req)); let (req, mut state) = p.state_range(s2.as_str(), 10000..usize::MAX); - state.pymatch(req); - assert!(state.has_matched); + assert!(state.pymatch(&req)); let (req, mut state) = p.state_range(s2.as_str(), 10000..10000 + s.len()); - state.pymatch(req); - assert!(state.has_matched); + assert!(state.pymatch(&req)); let (mut req, mut state) = p.state_range(s2.as_str(), 10000..10000 + s.len()); req.match_all = true; - state.pymatch(req); - assert!(state.has_matched); + assert!(state.pymatch(&req)); } }) } diff --git a/src/engine.rs b/src/engine.rs index f49560c2c6..7474f29013 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -6,14 +6,13 @@ use super::constants::{SreAtCode, SreCatCode, SreOpcode}; use super::MAXREPEAT; use optional::Optioned; use std::convert::TryFrom; -use std::ops::Deref; const fn is_py_ascii_whitespace(b: u8) -> bool { matches!(b, b'\t' | b'\n' | b'\x0C' | b'\r' | b' ' | b'\x0B') } #[derive(Debug, Clone, Copy)] -pub struct Request<'a, S: StrDrive> { +pub struct Request<'a, S> { pub string: S, pub start: usize, pub end: usize, @@ -61,14 +60,6 @@ impl Default for Marks { } } -impl Deref for Marks { - type Target = Vec>; - - fn deref(&self) -> &Self::Target { - &self.marks - } -} - impl Marks { pub fn get(&self, group_index: usize) -> (Optioned, Optioned) { let marks_index = 2 * group_index; @@ -83,6 +74,10 @@ impl Marks { self.last_index } + pub fn raw(&self) -> &[Optioned] { + self.marks.as_slice() + } + fn set(&mut self, mark_nr: usize, position: usize) { if mark_nr & 1 != 0 { self.last_index = mark_nr as isize / 2 + 1; @@ -120,70 +115,23 @@ impl Marks { } } -#[derive(Debug)] -pub struct State { - pub marks: Marks, - context_stack: Vec>, - repeat_stack: Vec, +#[derive(Debug, Default)] +pub struct State { pub start: usize, + pub marks: Marks, pub string_position: usize, - next_context: Option>, - popped_has_matched: bool, - pub has_matched: bool, -} - -impl Default for State { - fn default() -> Self { - Self { - marks: Marks::default(), - context_stack: Vec::new(), - repeat_stack: Vec::new(), - start: 0, - string_position: 0, - next_context: None, - popped_has_matched: false, - has_matched: false, - } - } + repeat_stack: Vec, } -impl State { +impl State { pub fn reset(&mut self, start: usize) { self.marks.clear(); - self.context_stack.clear(); self.repeat_stack.clear(); self.start = start; self.string_position = start; - self.next_context = None; - self.popped_has_matched = false; - self.has_matched = false; - } - - fn _match(&mut self, req: &mut Request) { - while let Some(mut ctx) = self.context_stack.pop() { - if let Some(handler) = ctx.handler.take() { - handler(req, self, &mut ctx); - } else if ctx.remaining_codes(req) > 0 { - let code = ctx.peek_code(req, 0); - let code = SreOpcode::try_from(code).unwrap(); - dispatch(req, self, &mut ctx, code); - } else { - ctx.failure(); - } - - if let Some(has_matched) = ctx.has_matched { - self.popped_has_matched = has_matched; - } else { - self.context_stack.push(ctx); - if let Some(next_ctx) = self.next_context.take() { - self.context_stack.push(next_ctx); - } - } - } - self.has_matched = self.popped_has_matched; } - pub fn pymatch(&mut self, mut req: Request) { + pub fn pymatch(&mut self, req: &Request) -> bool { self.start = req.start; self.string_position = req.start; @@ -191,24 +139,20 @@ impl State { string_position: req.start, string_offset: req.string.offset(0, req.start), code_position: 0, - has_matched: None, toplevel: true, - handler: None, + jump: Jump::OpCode, repeat_ctx_id: usize::MAX, count: -1, }; - self.context_stack.push(ctx); - - self._match(&mut req); + _match(&req, self, ctx) } - pub fn search(&mut self, mut req: Request) { + pub fn search(&mut self, mut req: Request) -> bool { self.start = req.start; self.string_position = req.start; - // TODO: optimize by op info and skip prefix if req.start > req.end { - return; + return false; } let mut end = req.end; @@ -219,9 +163,8 @@ impl State { string_position: req.start, string_offset: start_offset, code_position: 0, - has_matched: None, toplevel: true, - handler: None, + jump: Jump::OpCode, repeat_ctx_id: usize::MAX, count: -1, }; @@ -229,11 +172,10 @@ impl State { if ctx.peek_code(&req, 0) == SreOpcode::INFO as u32 { /* optimization info block */ /* <1=skip> <2=flags> <3=min> <4=max> <5=prefix info> */ - let req = &mut req; - let min = ctx.peek_code(req, 3) as usize; + let min = ctx.peek_code(&req, 3) as usize; - if ctx.remaining_chars(req) < min { - return; + if ctx.remaining_chars(&req) < min { + return false; } if min > 1 { @@ -249,42 +191,44 @@ impl State { } } - let flags = SreInfo::from_bits_truncate(ctx.peek_code(req, 2)); + let flags = SreInfo::from_bits_truncate(ctx.peek_code(&req, 2)); if flags.contains(SreInfo::PREFIX) { if flags.contains(SreInfo::LITERAL) { - search_info_literal::(req, self, ctx); + return search_info_literal::(&mut req, self, ctx); } else { - search_info_literal::(req, self, ctx); + return search_info_literal::(&mut req, self, ctx); } - return; } else if flags.contains(SreInfo::CHARSET) { - return search_info_charset(req, self, ctx); + return search_info_charset(&mut req, self, ctx); } // fallback to general search } - self.context_stack.push(ctx); - self._match(&mut req); + if _match(&req, self, ctx) { + return true; + } req.must_advance = false; ctx.toplevel = false; - while !self.has_matched && req.start < end { + while req.start < end { req.start += 1; start_offset = req.string.offset(start_offset, 1); self.reset(req.start); ctx.string_position = req.start; ctx.string_offset = start_offset; - self.context_stack.push(ctx); - self._match(&mut req); + if _match(&req, self, ctx) { + return true; + } } + false } } pub struct SearchIter<'a, S: StrDrive> { pub req: Request<'a, S>, - pub state: State, + pub state: State, } impl<'a, S: StrDrive> Iterator for SearchIter<'a, S> { @@ -296,8 +240,7 @@ impl<'a, S: StrDrive> Iterator for SearchIter<'a, S> { } self.state.reset(self.req.start); - self.state.search(self.req); - if !self.state.has_matched { + if !self.state.search(self.req) { return None; } @@ -308,10 +251,537 @@ impl<'a, S: StrDrive> Iterator for SearchIter<'a, S> { } } +#[derive(Debug, Clone, Copy)] +enum Jump { + OpCode, + Assert1, + AssertNot1, + Branch1, + Branch2, + Repeat1, + UntilBacktrace, + MaxUntil2, + MaxUntil3, + MinUntil1, + RepeatOne1, + RepeatOne2, + MinRepeatOne1, + MinRepeatOne2, +} + +fn _match(req: &Request, state: &mut State, ctx: MatchContext) -> bool { + let mut context_stack = vec![ctx]; + let mut popped_result = false; + + 'coro: loop { + let Some(mut ctx) = context_stack.pop() else { + break; + }; + + popped_result = 'result: loop { + let yield_ = 'context: loop { + match ctx.jump { + Jump::OpCode => {} + Jump::Assert1 => { + if popped_result { + ctx.skip_code_from(req, 1); + } else { + break 'result false; + } + } + Jump::AssertNot1 => { + if popped_result { + break 'result false; + } + ctx.skip_code_from(req, 1); + } + Jump::Branch1 => { + let branch_offset = ctx.count as usize; + let next_length = ctx.peek_code(req, branch_offset) as isize; + if next_length == 0 { + state.marks.pop_discard(); + break 'result false; + } + state.string_position = ctx.string_position; + let next_ctx = ctx.next_offset(branch_offset + 1, Jump::Branch2); + ctx.count += next_length; + break 'context next_ctx; + } + Jump::Branch2 => { + if popped_result { + break 'result true; + } + state.marks.pop_keep(); + ctx.jump = Jump::Branch1; + continue 'context; + } + Jump::Repeat1 => { + state.repeat_stack.pop(); + break 'result popped_result; + } + Jump::UntilBacktrace => { + if !popped_result { + state.repeat_stack[ctx.repeat_ctx_id].count -= 1; + state.string_position = ctx.string_position; + } + break 'result popped_result; + } + Jump::MaxUntil2 => { + let save_last_position = ctx.count as usize; + let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; + repeat_ctx.last_position = save_last_position; + + if popped_result { + state.marks.pop_discard(); + break 'result true; + } + + state.marks.pop(); + repeat_ctx.count -= 1; + state.string_position = ctx.string_position; + + /* cannot match more repeated items here. make sure the + tail matches */ + let mut next_ctx = ctx.next_offset(1, Jump::MaxUntil3); + next_ctx.repeat_ctx_id = repeat_ctx.prev_id; + break 'context next_ctx; + } + Jump::MaxUntil3 => { + if !popped_result { + state.string_position = ctx.string_position; + } + break 'result popped_result; + } + Jump::MinUntil1 => { + if popped_result { + break 'result true; + } + ctx.repeat_ctx_id = ctx.count as usize; + let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; + state.string_position = ctx.string_position; + state.marks.pop(); + + // match more until tail matches + if repeat_ctx.count as usize >= repeat_ctx.max_count + && repeat_ctx.max_count != MAXREPEAT + || state.string_position == repeat_ctx.last_position + { + repeat_ctx.count -= 1; + break 'result false; + } + + /* zero-width match protection */ + repeat_ctx.last_position = state.string_position; + + break 'context ctx + .next_at(repeat_ctx.code_position + 4, Jump::UntilBacktrace); + } + Jump::RepeatOne1 => { + let min_count = ctx.peek_code(req, 2) as isize; + let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); + if next_code == SreOpcode::LITERAL as u32 { + // Special case: Tail starts with a literal. Skip positions where + // the rest of the pattern cannot possibly match. + let c = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 2); + while ctx.at_end(req) || ctx.peek_char(req) != c { + if ctx.count <= min_count { + state.marks.pop_discard(); + break 'result false; + } + ctx.back_skip_char(req, 1); + ctx.count -= 1; + } + } + + state.string_position = ctx.string_position; + // General case: backtracking + break 'context ctx.next_peek_from(1, req, Jump::RepeatOne2); + } + Jump::RepeatOne2 => { + if popped_result { + break 'result true; + } + + let min_count = ctx.peek_code(req, 2) as isize; + if ctx.count <= min_count { + state.marks.pop_discard(); + break 'result false; + } + + ctx.back_skip_char(req, 1); + ctx.count -= 1; + + state.marks.pop_keep(); + ctx.jump = Jump::RepeatOne1; + continue 'context; + } + Jump::MinRepeatOne1 => { + let max_count = ctx.peek_code(req, 3) as usize; + if max_count == MAXREPEAT || ctx.count as usize <= max_count { + state.string_position = ctx.string_position; + break 'context ctx.next_peek_from(1, req, Jump::MinRepeatOne2); + } else { + state.marks.pop_discard(); + break 'result false; + } + } + Jump::MinRepeatOne2 => { + if popped_result { + break 'result true; + } + + state.string_position = ctx.string_position; + + let mut count_ctx = ctx; + count_ctx.skip_code(4); + if _count(req, state, count_ctx, 1) == 0 { + state.marks.pop_discard(); + break 'result false; + } + + ctx.skip_char(req, 1); + ctx.count += 1; + state.marks.pop_keep(); + ctx.jump = Jump::MinRepeatOne1; + continue 'context; + } + } + ctx.jump = Jump::OpCode; + + loop { + macro_rules! general_op_literal { + ($f:expr) => {{ + if ctx.at_end(req) || !$f(ctx.peek_code(req, 1), ctx.peek_char(req)) { + break 'result false; + } + ctx.skip_code(2); + ctx.skip_char(req, 1); + }}; + } + + macro_rules! general_op_in { + ($f:expr) => {{ + if ctx.at_end(req) || !$f(&ctx.pattern(req)[2..], ctx.peek_char(req)) { + break 'result false; + } + ctx.skip_code_from(req, 1); + ctx.skip_char(req, 1); + }}; + } + + macro_rules! general_op_groupref { + ($f:expr) => {{ + let (group_start, group_end) = + state.marks.get(ctx.peek_code(req, 1) as usize); + let (group_start, group_end) = if group_start.is_some() + && group_end.is_some() + && group_start.unpack() <= group_end.unpack() + { + (group_start.unpack(), group_end.unpack()) + } else { + break 'result false; + }; + + let mut gctx = MatchContext { + string_position: group_start, + string_offset: req.string.offset(0, group_start), + ..ctx + }; + + for _ in group_start..group_end { + if ctx.at_end(req) + || $f(ctx.peek_char(req)) != $f(gctx.peek_char(req)) + { + break 'result false; + } + ctx.skip_char(req, 1); + gctx.skip_char(req, 1); + } + + ctx.skip_code(2); + }}; + } + + if ctx.remaining_codes(req) == 0 { + break 'result false; + } + let opcode = ctx.peek_code(req, 0); + let opcode = SreOpcode::try_from(opcode).unwrap(); + + match opcode { + SreOpcode::FAILURE => break 'result false, + SreOpcode::SUCCESS => { + if ctx.can_success(req) { + state.string_position = ctx.string_position; + break 'result true; + } + break 'result false; + } + SreOpcode::ANY => { + if ctx.at_end(req) || ctx.at_linebreak(req) { + break 'result false; + } + ctx.skip_code(1); + ctx.skip_char(req, 1); + } + SreOpcode::ANY_ALL => { + if ctx.at_end(req) { + break 'result false; + } + ctx.skip_code(1); + ctx.skip_char(req, 1); + } + SreOpcode::ASSERT => { + let back = ctx.peek_code(req, 2) as usize; + if ctx.string_position < back { + break 'result false; + } + + let mut next_ctx = ctx.next_offset(3, Jump::Assert1); + next_ctx.toplevel = false; + next_ctx.back_skip_char(req, back); + state.string_position = next_ctx.string_position; + break 'context next_ctx; + } + SreOpcode::ASSERT_NOT => { + let back = ctx.peek_code(req, 2) as usize; + if ctx.string_position < back { + ctx.skip_code_from(req, 1); + continue; + } + + let mut next_ctx = ctx.next_offset(3, Jump::AssertNot1); + next_ctx.toplevel = false; + next_ctx.back_skip_char(req, back); + state.string_position = next_ctx.string_position; + break 'context next_ctx; + } + SreOpcode::AT => { + let atcode = SreAtCode::try_from(ctx.peek_code(req, 1)).unwrap(); + if at(req, &ctx, atcode) { + ctx.skip_code(2); + } else { + break 'result false; + } + } + SreOpcode::BRANCH => { + state.marks.push(); + ctx.count = 1; + ctx.jump = Jump::Branch1; + continue 'context; + } + SreOpcode::CATEGORY => { + let catcode = SreCatCode::try_from(ctx.peek_code(req, 1)).unwrap(); + if ctx.at_end(req) || !category(catcode, ctx.peek_char(req)) { + break 'result false; + } + ctx.skip_code(2); + ctx.skip_char(req, 1); + } + SreOpcode::IN => general_op_in!(charset), + SreOpcode::IN_IGNORE => { + general_op_in!(|set, c| charset(set, lower_ascii(c))) + } + SreOpcode::IN_UNI_IGNORE => { + general_op_in!(|set, c| charset(set, lower_unicode(c))) + } + SreOpcode::IN_LOC_IGNORE => general_op_in!(charset_loc_ignore), + SreOpcode::INFO => { + let min = ctx.peek_code(req, 3) as usize; + if ctx.remaining_chars(req) < min { + break 'result false; + } + ctx.skip_code_from(req, 1); + } + SreOpcode::MARK => { + state + .marks + .set(ctx.peek_code(req, 1) as usize, ctx.string_position); + ctx.skip_code(2); + } + SreOpcode::JUMP => ctx.skip_code_from(req, 1), + SreOpcode::REPEAT => { + let repeat_ctx = RepeatContext { + count: -1, + min_count: ctx.peek_code(req, 2) as usize, + max_count: ctx.peek_code(req, 3) as usize, + code_position: ctx.code_position, + last_position: std::usize::MAX, + prev_id: ctx.repeat_ctx_id, + }; + state.repeat_stack.push(repeat_ctx); + let repeat_ctx_id = state.repeat_stack.len() - 1; + state.string_position = ctx.string_position; + let mut next_ctx = ctx.next_peek_from(1, req, Jump::Repeat1); + next_ctx.repeat_ctx_id = repeat_ctx_id; + break 'context next_ctx; + } + SreOpcode::MAX_UNTIL => { + let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; + state.string_position = ctx.string_position; + repeat_ctx.count += 1; + + if (repeat_ctx.count as usize) < repeat_ctx.min_count { + // not enough matches + break 'context ctx + .next_at(repeat_ctx.code_position + 4, Jump::UntilBacktrace); + } + + if ((repeat_ctx.count as usize) < repeat_ctx.max_count + || repeat_ctx.max_count == MAXREPEAT) + && state.string_position != repeat_ctx.last_position + { + /* we may have enough matches, but if we can + match another item, do so */ + state.marks.push(); + ctx.count = repeat_ctx.last_position as isize; + repeat_ctx.last_position = state.string_position; + + break 'context ctx + .next_at(repeat_ctx.code_position + 4, Jump::MaxUntil2); + } + + /* cannot match more repeated items here. make sure the + tail matches */ + let mut next_ctx = ctx.next_offset(1, Jump::MaxUntil3); + next_ctx.repeat_ctx_id = repeat_ctx.prev_id; + break 'context next_ctx; + } + SreOpcode::MIN_UNTIL => { + let repeat_ctx = state.repeat_stack.last_mut().unwrap(); + state.string_position = ctx.string_position; + repeat_ctx.count += 1; + + if (repeat_ctx.count as usize) < repeat_ctx.min_count { + // not enough matches + break 'context ctx + .next_at(repeat_ctx.code_position + 4, Jump::UntilBacktrace); + } + + state.marks.push(); + ctx.count = ctx.repeat_ctx_id as isize; + let mut next_ctx = ctx.next_offset(1, Jump::MinUntil1); + next_ctx.repeat_ctx_id = repeat_ctx.prev_id; + break 'context next_ctx; + } + SreOpcode::REPEAT_ONE => { + let min_count = ctx.peek_code(req, 2) as usize; + let max_count = ctx.peek_code(req, 3) as usize; + + if ctx.remaining_chars(req) < min_count { + break 'result false; + } + + state.string_position = ctx.string_position; + + let mut next_ctx = ctx; + next_ctx.skip_code(4); + let count = _count(req, state, next_ctx, max_count); + ctx.skip_char(req, count); + if count < min_count { + break 'result false; + } + + let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); + if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(req) { + // tail is empty. we're finished + state.string_position = ctx.string_position; + break 'result true; + } + + state.marks.push(); + ctx.count = count as isize; + ctx.jump = Jump::RepeatOne1; + continue 'context; + } + SreOpcode::MIN_REPEAT_ONE => { + let min_count = ctx.peek_code(req, 2) as usize; + if ctx.remaining_chars(req) < min_count { + break 'result false; + } + + state.string_position = ctx.string_position; + ctx.count = if min_count == 0 { + 0 + } else { + let mut count_ctx = ctx; + count_ctx.skip_code(4); + let count = _count(req, state, count_ctx, min_count); + if count < min_count { + break 'result false; + } + ctx.skip_char(req, count); + count as isize + }; + + let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); + if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(req) { + // tail is empty. we're finished + state.string_position = ctx.string_position; + break 'result true; + } + + state.marks.push(); + ctx.jump = Jump::MinRepeatOne1; + continue 'context; + } + SreOpcode::LITERAL => general_op_literal!(|code, c| code == c), + SreOpcode::NOT_LITERAL => general_op_literal!(|code, c| code != c), + SreOpcode::LITERAL_IGNORE => { + general_op_literal!(|code, c| code == lower_ascii(c)) + } + SreOpcode::NOT_LITERAL_IGNORE => { + general_op_literal!(|code, c| code != lower_ascii(c)) + } + SreOpcode::LITERAL_UNI_IGNORE => { + general_op_literal!(|code, c| code == lower_unicode(c)) + } + SreOpcode::NOT_LITERAL_UNI_IGNORE => { + general_op_literal!(|code, c| code != lower_unicode(c)) + } + SreOpcode::LITERAL_LOC_IGNORE => general_op_literal!(char_loc_ignore), + SreOpcode::NOT_LITERAL_LOC_IGNORE => { + general_op_literal!(|code, c| !char_loc_ignore(code, c)) + } + SreOpcode::GROUPREF => general_op_groupref!(|x| x), + SreOpcode::GROUPREF_IGNORE => general_op_groupref!(lower_ascii), + SreOpcode::GROUPREF_LOC_IGNORE => general_op_groupref!(lower_locate), + SreOpcode::GROUPREF_UNI_IGNORE => general_op_groupref!(lower_unicode), + SreOpcode::GROUPREF_EXISTS => { + let (group_start, group_end) = + state.marks.get(ctx.peek_code(req, 1) as usize); + if group_start.is_some() + && group_end.is_some() + && group_start.unpack() <= group_end.unpack() + { + ctx.skip_code(3); + } else { + ctx.skip_code_from(req, 2) + } + } + SreOpcode::CALL => todo!(), + SreOpcode::CHARSET => todo!(), + SreOpcode::BIGCHARSET => todo!(), + SreOpcode::NEGATE => todo!(), + SreOpcode::RANGE => todo!(), + SreOpcode::RANGE_UNI_IGNORE => todo!(), + SreOpcode::SUBPATTERN => todo!(), + } + } + }; + context_stack.push(ctx); + context_stack.push(yield_); + continue 'coro; + }; + } + popped_result +} + +/* fn dispatch( req: &Request, - state: &mut State, - ctx: &mut MatchContext, + state: &mut State, + ctx: &mut MatchContext, opcode: SreOpcode, ) { match opcode { @@ -422,12 +892,13 @@ fn dispatch( _ => unreachable!("unexpected opcode"), } } +*/ fn search_info_literal( req: &mut Request, - state: &mut State, - mut ctx: MatchContext, -) { + state: &mut State, + mut ctx: MatchContext, +) -> bool { /* pattern starts with a known prefix */ /* */ let len = ctx.peek_code(req, 5) as usize; @@ -450,7 +921,7 @@ fn search_info_literal( while ctx.peek_char(req) != c { ctx.skip_char(req, 1); if ctx.at_end(req) { - return; + return false; } } @@ -460,18 +931,14 @@ fn search_info_literal( // literal only if LITERAL { - state.has_matched = true; - return; + return true; } let mut next_ctx = ctx; next_ctx.skip_char(req, skip); - state.context_stack.push(next_ctx); - state._match(req); - - if state.has_matched { - return; + if _match(req, state, next_ctx) { + return true; } ctx.skip_char(req, 1); @@ -483,12 +950,12 @@ fn search_info_literal( while ctx.peek_char(req) != c { ctx.skip_char(req, 1); if ctx.at_end(req) { - return; + return false; } } ctx.skip_char(req, 1); if ctx.at_end(req) { - return; + return false; } let mut i = 1; @@ -498,7 +965,7 @@ fn search_info_literal( if i != len { ctx.skip_char(req, 1); if ctx.at_end(req) { - return; + return false; } continue; } @@ -509,8 +976,7 @@ fn search_info_literal( // literal only if LITERAL { - state.has_matched = true; - return; + return true; } let mut next_ctx = ctx; @@ -521,16 +987,13 @@ fn search_info_literal( next_ctx.string_offset = req.string.offset(0, state.string_position); } - state.context_stack.push(next_ctx); - state._match(req); - - if state.has_matched { - return; + if _match(req, state, next_ctx) { + return true; } ctx.skip_char(req, 1); if ctx.at_end(req) { - return; + return false; } state.marks.clear(); } @@ -542,13 +1005,14 @@ fn search_info_literal( } } } + false } fn search_info_charset( req: &mut Request, - state: &mut State, - mut ctx: MatchContext, -) { + state: &mut State, + mut ctx: MatchContext, +) -> bool { let set = &ctx.pattern(req)[5..]; ctx.skip_code_from(req, 1); @@ -560,18 +1024,15 @@ fn search_info_charset( ctx.skip_char(req, 1); } if ctx.at_end(req) { - return; + return false; } req.start = ctx.string_position; state.start = ctx.string_position; state.string_position = ctx.string_position; - state.context_stack.push(ctx); - state._match(req); - - if state.has_matched { - return; + if _match(req, state, ctx) { + return true; } ctx.skip_char(req, 1); @@ -579,9 +1040,10 @@ fn search_info_charset( } } +/* /* assert subpattern */ /* */ -fn op_assert(req: &Request, state: &mut State, ctx: &mut MatchContext) { +fn op_assert(req: &Request, state: &mut State, ctx: &mut MatchContext) { let back = ctx.peek_code(req, 2) as usize; if ctx.string_position < back { return ctx.failure(); @@ -601,7 +1063,7 @@ fn op_assert(req: &Request, state: &mut State, ctx: &mut Matc /* assert not subpattern */ /* */ -fn op_assert_not(req: &Request, state: &mut State, ctx: &mut MatchContext) { +fn op_assert_not(req: &Request, state: &mut State, ctx: &mut MatchContext) { let back = ctx.peek_code(req, 2) as usize; if ctx.string_position < back { @@ -622,17 +1084,13 @@ fn op_assert_not(req: &Request, state: &mut State, ctx: &mut // alternation // <0=skip> code ... -fn op_branch(req: &Request, state: &mut State, ctx: &mut MatchContext) { +fn op_branch(req: &Request, state: &mut State, ctx: &mut MatchContext) { state.marks.push(); ctx.count = 1; create_context(req, state, ctx); - fn create_context( - req: &Request, - state: &mut State, - ctx: &mut MatchContext, - ) { + fn create_context(req: &Request, state: &mut State, ctx: &mut MatchContext) { let branch_offset = ctx.count as usize; let next_length = ctx.peek_code(req, branch_offset) as isize; if next_length == 0 { @@ -646,7 +1104,7 @@ fn op_branch(req: &Request, state: &mut State, ctx: &mut Matc ctx.next_offset(branch_offset + 1, state, callback); } - fn callback(req: &Request, state: &mut State, ctx: &mut MatchContext) { + fn callback(req: &Request, state: &mut State, ctx: &mut MatchContext) { if state.popped_has_matched { return ctx.success(); } @@ -656,11 +1114,7 @@ fn op_branch(req: &Request, state: &mut State, ctx: &mut Matc } /* <1=min> <2=max> item tail */ -fn op_min_repeat_one( - req: &Request, - state: &mut State, - ctx: &mut MatchContext, -) { +fn op_min_repeat_one(req: &Request, state: &mut State, ctx: &mut MatchContext) { let min_count = ctx.peek_code(req, 2) as usize; if ctx.remaining_chars(req) < min_count { @@ -692,23 +1146,19 @@ fn op_min_repeat_one( state.marks.push(); create_context(req, state, ctx); - fn create_context( - req: &Request, - state: &mut State, - ctx: &mut MatchContext, - ) { + fn create_context(req: &Request, state: &mut State, ctx: &mut MatchContext) { let max_count = ctx.peek_code(req, 3) as usize; if max_count == MAXREPEAT || ctx.count as usize <= max_count { state.string_position = ctx.string_position; - ctx.next_from(1, req, state, callback); + ctx.next_peek_from(1, req, state, callback); } else { state.marks.pop_discard(); ctx.failure(); } } - fn callback(req: &Request, state: &mut State, ctx: &mut MatchContext) { + fn callback(req: &Request, state: &mut State, ctx: &mut MatchContext) { if state.popped_has_matched { return ctx.success(); } @@ -735,7 +1185,7 @@ exactly one character wide, and we're not already collecting backtracking points. for other cases, use the MAX_REPEAT operator */ /* <1=min> <2=max> item tail */ -fn op_repeat_one(req: &Request, state: &mut State, ctx: &mut MatchContext) { +fn op_repeat_one(req: &Request, state: &mut State, ctx: &mut MatchContext) { let min_count = ctx.peek_code(req, 2) as usize; let max_count = ctx.peek_code(req, 3) as usize; @@ -764,11 +1214,7 @@ fn op_repeat_one(req: &Request, state: &mut State, ctx: &mut ctx.count = count as isize; create_context(req, state, ctx); - fn create_context( - req: &Request, - state: &mut State, - ctx: &mut MatchContext, - ) { + fn create_context(req: &Request, state: &mut State, ctx: &mut MatchContext) { let min_count = ctx.peek_code(req, 2) as isize; let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); if next_code == SreOpcode::LITERAL as u32 { @@ -788,10 +1234,10 @@ fn op_repeat_one(req: &Request, state: &mut State, ctx: &mut state.string_position = ctx.string_position; // General case: backtracking - ctx.next_from(1, req, state, callback); + ctx.next_peek_from(1, req, state, callback); } - fn callback(req: &Request, state: &mut State, ctx: &mut MatchContext) { + fn callback(req: &Request, state: &mut State, ctx: &mut MatchContext) { if state.popped_has_matched { return ctx.success(); } @@ -810,6 +1256,7 @@ fn op_repeat_one(req: &Request, state: &mut State, ctx: &mut create_context(req, state, ctx); } } +*/ #[derive(Debug, Clone, Copy)] struct RepeatContext { @@ -821,10 +1268,11 @@ struct RepeatContext { prev_id: usize, } +/* /* create repeat context. all the hard work is done by the UNTIL operator (MAX_UNTIL, MIN_UNTIL) */ /* <1=min> <2=max> item tail */ -fn op_repeat(req: &Request, state: &mut State, ctx: &mut MatchContext) { +fn op_repeat(req: &Request, state: &mut State, ctx: &mut MatchContext) { let repeat_ctx = RepeatContext { count: -1, min_count: ctx.peek_code(req, 2) as usize, @@ -840,7 +1288,7 @@ fn op_repeat(req: &Request, state: &mut State, ctx: &mut Matc let repeat_ctx_id = state.repeat_stack.len() - 1; - let next_ctx = ctx.next_from(1, req, state, |_, state, ctx| { + let next_ctx = ctx.next_peek_from(1, req, state, |_, state, ctx| { ctx.has_matched = Some(state.popped_has_matched); state.repeat_stack.pop(); }); @@ -848,7 +1296,7 @@ fn op_repeat(req: &Request, state: &mut State, ctx: &mut Matc } /* minimizing repeat */ -fn op_min_until(state: &mut State, ctx: &mut MatchContext) { +fn op_min_until(state: &mut State, ctx: &mut MatchContext) { let repeat_ctx = state.repeat_stack.last_mut().unwrap(); state.string_position = ctx.string_position; @@ -915,7 +1363,7 @@ fn op_min_until(state: &mut State, ctx: &mut MatchContext) { } /* maximizing repeat */ -fn op_max_until(state: &mut State, ctx: &mut MatchContext) { +fn op_max_until(state: &mut State, ctx: &mut MatchContext) { let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; state.string_position = ctx.string_position; @@ -976,7 +1424,7 @@ fn op_max_until(state: &mut State, ctx: &mut MatchContext) { let next_ctx = ctx.next_offset(1, state, tail_callback); next_ctx.repeat_ctx_id = repeat_ctx_prev_id; - fn tail_callback(_: &Request, state: &mut State, ctx: &mut MatchContext) { + fn tail_callback(_: &Request, state: &mut State, ctx: &mut MatchContext) { if state.popped_has_matched { ctx.success(); } else { @@ -985,6 +1433,7 @@ fn op_max_until(state: &mut State, ctx: &mut MatchContext) { } } } +*/ pub trait StrDrive: Copy { fn offset(&self, offset: usize, skip: usize) -> usize; @@ -1061,67 +1510,49 @@ impl<'a> StrDrive for &'a [u8] { } } -type OpFunc = for<'a> fn(&Request<'a, S>, &mut State, &mut MatchContext); - #[derive(Clone, Copy)] -struct MatchContext { +struct MatchContext { string_position: usize, string_offset: usize, code_position: usize, - has_matched: Option, toplevel: bool, - handler: Option>, + jump: Jump, repeat_ctx_id: usize, count: isize, } -impl std::fmt::Debug for MatchContext { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MatchContext") - .field("string_position", &self.string_position) - .field("string_offset", &self.string_offset) - .field("code_position", &self.code_position) - .field("has_matched", &self.has_matched) - .field("toplevel", &self.toplevel) - .field("handler", &self.handler.map(|x| x as usize)) - .field("repeat_ctx_id", &self.repeat_ctx_id) - .field("count", &self.count) - .finish() - } -} - -impl MatchContext { - fn pattern<'a>(&self, req: &Request<'a, S>) -> &'a [u32] { +impl MatchContext { + fn pattern<'a, S>(&self, req: &Request<'a, S>) -> &'a [u32] { &req.pattern_codes[self.code_position..] } - fn remaining_codes(&self, req: &Request) -> usize { + fn remaining_codes(&self, req: &Request) -> usize { req.pattern_codes.len() - self.code_position } - fn remaining_chars(&self, req: &Request) -> usize { + fn remaining_chars(&self, req: &Request) -> usize { req.end - self.string_position } - fn peek_char(&self, req: &Request) -> u32 { + fn peek_char(&self, req: &Request) -> u32 { req.string.peek(self.string_offset) } - fn skip_char(&mut self, req: &Request, skip: usize) { + fn skip_char(&mut self, req: &Request, skip: usize) { self.string_position += skip; self.string_offset = req.string.offset(self.string_offset, skip); } - fn back_peek_char(&self, req: &Request) -> u32 { + fn back_peek_char(&self, req: &Request) -> u32 { req.string.back_peek(self.string_offset) } - fn back_skip_char(&mut self, req: &Request, skip: usize) { + fn back_skip_char(&mut self, req: &Request, skip: usize) { self.string_position -= skip; self.string_offset = req.string.back_offset(self.string_offset, skip); } - fn peek_code(&self, req: &Request, peek: usize) -> u32 { + fn peek_code(&self, req: &Request, peek: usize) -> u32 { req.pattern_codes[self.code_position + peek] } @@ -1129,7 +1560,7 @@ impl MatchContext { self.code_position += skip; } - fn skip_code_from(&mut self, req: &Request, peek: usize) { + fn skip_code_from(&mut self, req: &Request, peek: usize) { self.skip_code(self.peek_code(req, peek) as usize + 1); } @@ -1138,15 +1569,19 @@ impl MatchContext { self.string_position == 0 } - fn at_end(&self, req: &Request) -> bool { + fn at_end(&self, req: &Request) -> bool { self.string_position == req.end } - fn at_linebreak(&self, req: &Request) -> bool { + fn at_linebreak(&self, req: &Request) -> bool { !self.at_end(req) && is_linebreak(self.peek_char(req)) } - fn at_boundary bool>(&self, req: &Request, mut word_checker: F) -> bool { + fn at_boundary bool>( + &self, + req: &Request, + mut word_checker: F, + ) -> bool { if self.at_beginning() && self.at_end(req) { return false; } @@ -1155,7 +1590,7 @@ impl MatchContext { this != that } - fn at_non_boundary bool>( + fn at_non_boundary bool>( &self, req: &Request, mut word_checker: F, @@ -1168,7 +1603,7 @@ impl MatchContext { this == that } - fn can_success(&self, req: &Request) -> bool { + fn can_success(&self, req: &Request) -> bool { if !self.toplevel { return true; } @@ -1181,51 +1616,29 @@ impl MatchContext { true } - fn success(&mut self) { - self.has_matched = Some(true); + #[must_use] + fn next_peek_from(&mut self, peek: usize, req: &Request, jump: Jump) -> Self { + self.next_offset(self.peek_code(req, peek) as usize + 1, jump) } - fn failure(&mut self) { - self.has_matched = Some(false); + #[must_use] + fn next_offset(&mut self, offset: usize, jump: Jump) -> Self { + self.next_at(self.code_position + offset, jump) } - fn next_from<'b>( - &mut self, - peek: usize, - req: &Request, - state: &'b mut State, - f: OpFunc, - ) -> &'b mut Self { - self.next_offset(self.peek_code(req, peek) as usize + 1, state, f) - } - - fn next_offset<'b>( - &mut self, - offset: usize, - state: &'b mut State, - f: OpFunc, - ) -> &'b mut Self { - self.next_at(self.code_position + offset, state, f) - } - - fn next_at<'b>( - &mut self, - code_position: usize, - state: &'b mut State, - f: OpFunc, - ) -> &'b mut Self { - self.handler = Some(f); - state.next_context.insert(MatchContext { + #[must_use] + fn next_at(&mut self, code_position: usize, jump: Jump) -> Self { + self.jump = jump; + MatchContext { code_position, - has_matched: None, - handler: None, + jump: Jump::OpCode, count: -1, ..*self - }) + } } } -fn at(req: &Request, ctx: &MatchContext, atcode: SreAtCode) -> bool { +fn at(req: &Request, ctx: &MatchContext, atcode: SreAtCode) -> bool { match atcode { SreAtCode::BEGINNING | SreAtCode::BEGINNING_STRING => ctx.at_beginning(), SreAtCode::BEGINNING_LINE => ctx.at_beginning() || is_linebreak(ctx.back_peek_char(req)), @@ -1243,9 +1656,10 @@ fn at(req: &Request, ctx: &MatchContext, atcode: SreAtCode) - } } +/* fn general_op_literal bool>( req: &Request, - ctx: &mut MatchContext, + ctx: &mut MatchContext, f: F, ) { if ctx.at_end(req) || !f(ctx.peek_code(req, 1), ctx.peek_char(req)) { @@ -1258,7 +1672,7 @@ fn general_op_literal bool>( fn general_op_in bool>( req: &Request, - ctx: &mut MatchContext, + ctx: &mut MatchContext, f: F, ) { if ctx.at_end(req) || !f(&ctx.pattern(req)[2..], ctx.peek_char(req)) { @@ -1271,8 +1685,8 @@ fn general_op_in bool>( fn general_op_groupref u32>( req: &Request, - state: &State, - ctx: &mut MatchContext, + state: &State, + ctx: &mut MatchContext, mut f: F, ) { let (group_start, group_end) = state.marks.get(ctx.peek_code(req, 1) as usize); @@ -1301,6 +1715,7 @@ fn general_op_groupref u32>( ctx.skip_code(2); } +*/ fn char_loc_ignore(code: u32, c: u32) -> bool { code == c || code == lower_locate(c) || code == upper_locate(c) @@ -1433,8 +1848,8 @@ fn charset(set: &[u32], ch: u32) -> bool { fn _count( req: &Request, - state: &mut State, - mut ctx: MatchContext, + state: &mut State, + mut ctx: MatchContext, max_count: usize, ) -> usize { let max_count = std::cmp::min(max_count, ctx.remaining_chars(req)); @@ -1491,12 +1906,15 @@ fn _count( while count < max_count { ctx.code_position = reset_position; - let code = ctx.peek_code(req, 0); - let code = SreOpcode::try_from(code).unwrap(); - dispatch(req, state, &mut ctx, code); - if ctx.has_matched == Some(false) { + if !_match(req, state, ctx) { break; } + // let code = ctx.peek_code(req, 0); + // let code = SreOpcode::try_from(code).unwrap(); + // dispatch(req, state, &mut ctx, code); + // if ctx.has_matched == Some(false) { + // break; + // } count += 1; } return count; @@ -1509,7 +1927,7 @@ fn _count( fn general_count_literal bool>( req: &Request, - ctx: &mut MatchContext, + ctx: &mut MatchContext, end: usize, mut f: F, ) { diff --git a/tests/tests.rs b/tests/tests.rs index 5212226f4e..4e282a6f97 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -8,7 +8,7 @@ impl Pattern { fn state<'a, S: engine::StrDrive>( &self, string: S, - ) -> (engine::Request<'a, S>, engine::State) { + ) -> (engine::Request<'a, S>, engine::State) { let req = engine::Request::new(string, 0, usize::MAX, self.code, false); let state = engine::State::default(); (req, state) @@ -22,8 +22,7 @@ fn test_2427() { #[rustfmt::skip] let lookbehind = Pattern { code: &[15, 4, 0, 1, 1, 5, 5, 1, 17, 46, 1, 17, 120, 6, 10, 1] }; // END GENERATED let (req, mut state) = lookbehind.state("x"); - state.pymatch(req); - assert!(state.has_matched); + assert!(state.pymatch(&req)); } #[test] @@ -33,8 +32,7 @@ fn test_assert() { #[rustfmt::skip] let positive_lookbehind = Pattern { code: &[15, 4, 0, 3, 3, 4, 9, 3, 17, 97, 17, 98, 17, 99, 1, 17, 100, 17, 101, 17, 102, 1] }; // END GENERATED let (req, mut state) = positive_lookbehind.state("abcdef"); - state.search(req); - assert!(state.has_matched); + assert!(state.search(req)); } #[test] @@ -44,8 +42,7 @@ fn test_string_boundaries() { #[rustfmt::skip] let big_b = Pattern { code: &[15, 4, 0, 0, 0, 6, 11, 1] }; // END GENERATED let (req, mut state) = big_b.state(""); - state.search(req); - assert!(!state.has_matched); + assert!(!state.search(req)); } #[test] @@ -56,7 +53,7 @@ fn test_zerowidth() { // END GENERATED let (mut req, mut state) = p.state("a:"); req.must_advance = true; - state.search(req); + assert!(state.search(req)); assert_eq!(state.string_position, 1); } @@ -68,8 +65,8 @@ fn test_repeat_context_panic() { #[rustfmt::skip] let p = Pattern { code: &[15, 4, 0, 0, 4294967295, 24, 25, 0, 4294967295, 27, 6, 0, 4294967295, 17, 97, 1, 24, 11, 0, 1, 18, 0, 17, 120, 17, 120, 18, 1, 20, 17, 122, 19, 1] }; // END GENERATED let (req, mut state) = p.state("axxzaz"); - state.pymatch(req); - assert_eq!(*state.marks, vec![Optioned::some(1), Optioned::some(3)]); + assert!(state.pymatch(&req)); + assert_eq!(*state.marks.raw(), vec![Optioned::some(1), Optioned::some(3)]); } #[test] @@ -79,7 +76,7 @@ fn test_double_max_until() { #[rustfmt::skip] let p = Pattern { code: &[15, 4, 0, 0, 4294967295, 24, 18, 0, 4294967295, 18, 0, 24, 9, 0, 1, 18, 2, 17, 49, 18, 3, 19, 18, 1, 19, 1] }; // END GENERATED let (req, mut state) = p.state("1111"); - state.pymatch(req); + assert!(state.pymatch(&req)); assert_eq!(state.string_position, 4); } @@ -90,7 +87,7 @@ fn test_info_single() { #[rustfmt::skip] let p = Pattern { code: &[15, 8, 1, 1, 4294967295, 1, 1, 97, 0, 17, 97, 25, 6, 0, 4294967295, 17, 97, 1, 1] }; // END GENERATED let (req, mut state) = p.state("baaaa"); - state.search(req); + assert!(state.search(req)); assert_eq!(state.start, 1); assert_eq!(state.string_position, 5); } @@ -102,8 +99,7 @@ fn test_info_single2() { #[rustfmt::skip] let p = Pattern { code: &[15, 8, 1, 4, 6, 1, 1, 80, 0, 17, 80, 7, 13, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 11, 9, 17, 101, 17, 114, 17, 108, 16, 2, 0, 1] }; // END GENERATED let (req, mut state) = p.state("Perl"); - state.search(req); - assert!(state.has_matched); + assert!(state.search(req)); } #[test] @@ -113,8 +109,7 @@ fn test_info_literal() { #[rustfmt::skip] let p = Pattern { code: &[15, 14, 1, 5, 4294967295, 4, 4, 97, 98, 97, 98, 0, 0, 1, 2, 17, 97, 17, 98, 17, 97, 17, 98, 25, 6, 1, 4294967295, 17, 99, 1, 1] }; // END GENERATED let (req, mut state) = p.state("!ababc"); - state.search(req); - assert!(state.has_matched); + assert!(state.search(req)); } #[test] @@ -124,6 +119,5 @@ fn test_info_literal2() { #[rustfmt::skip] let p = Pattern { code: &[15, 18, 1, 12, 12, 6, 0, 112, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 18, 0, 17, 112, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 18, 1, 12, 0, 1] }; // END GENERATED let (req, mut state) = p.state("pythonpython"); - state.search(req); - assert!(state.has_matched); + assert!(state.search(req)); } From 39c0106e873645ff8f212bcfa21c79216e5f7d89 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Mon, 11 Dec 2023 20:17:51 +0200 Subject: [PATCH 229/893] update to cpython 3.12 op code --- benches/benches.rs | 22 +++++++------- src/constants.rs | 72 ++++++++++++++++++++++++---------------------- src/engine.rs | 4 ++- tests/tests.rs | 20 ++++++------- 4 files changed, 61 insertions(+), 57 deletions(-) diff --git a/benches/benches.rs b/benches/benches.rs index fe470d023c..f70138f920 100644 --- a/benches/benches.rs +++ b/benches/benches.rs @@ -30,47 +30,47 @@ fn benchmarks(b: &mut Bencher) { // # test common prefix // pattern p1 = re.compile('Python|Perl') # , 'Perl'), # Alternation // START GENERATED by generate_tests.py - #[rustfmt::skip] let p1 = Pattern { code: &[15, 8, 1, 4, 6, 1, 1, 80, 0, 17, 80, 7, 13, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 11, 9, 17, 101, 17, 114, 17, 108, 16, 2, 0, 1] }; + #[rustfmt::skip] let p1 = Pattern { code: &[14, 8, 1, 4, 6, 1, 1, 80, 0, 16, 80, 7, 13, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 15, 11, 9, 16, 101, 16, 114, 16, 108, 15, 2, 0, 1] }; // END GENERATED // pattern p2 = re.compile('(Python|Perl)') #, 'Perl'), # Grouped alternation // START GENERATED by generate_tests.py - #[rustfmt::skip] let p2 = Pattern { code: &[15, 8, 1, 4, 6, 1, 0, 80, 0, 18, 0, 17, 80, 7, 13, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 11, 9, 17, 101, 17, 114, 17, 108, 16, 2, 0, 18, 1, 1] }; + #[rustfmt::skip] let p2 = Pattern { code: &[14, 8, 1, 4, 6, 1, 0, 80, 0, 17, 0, 16, 80, 7, 13, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 15, 11, 9, 16, 101, 16, 114, 16, 108, 15, 2, 0, 17, 1, 1] }; // END GENERATED // pattern p3 = re.compile('Python|Perl|Tcl') #, 'Perl'), # Alternation // START GENERATED by generate_tests.py - #[rustfmt::skip] let p3 = Pattern { code: &[15, 9, 4, 3, 6, 17, 80, 17, 84, 0, 7, 15, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 22, 11, 17, 80, 17, 101, 17, 114, 17, 108, 16, 11, 9, 17, 84, 17, 99, 17, 108, 16, 2, 0, 1] }; + #[rustfmt::skip] let p3 = Pattern { code: &[14, 9, 4, 3, 6, 16, 80, 16, 84, 0, 7, 15, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 15, 22, 11, 16, 80, 16, 101, 16, 114, 16, 108, 15, 11, 9, 16, 84, 16, 99, 16, 108, 15, 2, 0, 1] }; // END GENERATED // pattern p4 = re.compile('(Python|Perl|Tcl)') #, 'Perl'), # Grouped alternation // START GENERATED by generate_tests.py - #[rustfmt::skip] let p4 = Pattern { code: &[15, 9, 4, 3, 6, 17, 80, 17, 84, 0, 18, 0, 7, 15, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 16, 22, 11, 17, 80, 17, 101, 17, 114, 17, 108, 16, 11, 9, 17, 84, 17, 99, 17, 108, 16, 2, 0, 18, 1, 1] }; + #[rustfmt::skip] let p4 = Pattern { code: &[14, 9, 4, 3, 6, 16, 80, 16, 84, 0, 17, 0, 7, 15, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 15, 22, 11, 16, 80, 16, 101, 16, 114, 16, 108, 15, 11, 9, 16, 84, 16, 99, 16, 108, 15, 2, 0, 17, 1, 1] }; // END GENERATED // pattern p5 = re.compile('(Python)\\1') #, 'PythonPython'), # Backreference // START GENERATED by generate_tests.py - #[rustfmt::skip] let p5 = Pattern { code: &[15, 18, 1, 12, 12, 6, 0, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 18, 0, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 18, 1, 12, 0, 1] }; + #[rustfmt::skip] let p5 = Pattern { code: &[14, 18, 1, 12, 12, 6, 0, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 17, 0, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 17, 1, 11, 0, 1] }; // END GENERATED // pattern p6 = re.compile('([0a-z][a-z0-9]*,)+') #, 'a5,b7,c9,'), # Disable the fastmap optimization // START GENERATED by generate_tests.py - #[rustfmt::skip] let p6 = Pattern { code: &[15, 4, 0, 2, 4294967295, 24, 31, 1, 4294967295, 18, 0, 14, 7, 17, 48, 23, 97, 122, 0, 25, 13, 0, 4294967295, 14, 8, 23, 97, 122, 23, 48, 57, 0, 1, 17, 44, 18, 1, 19, 1] }; + #[rustfmt::skip] let p6 = Pattern { code: &[14, 4, 0, 2, 4294967295, 23, 31, 1, 4294967295, 17, 0, 13, 7, 16, 48, 22, 97, 122, 0, 24, 13, 0, 4294967295, 13, 8, 22, 97, 122, 22, 48, 57, 0, 1, 16, 44, 17, 1, 18, 1] }; // END GENERATED // pattern p7 = re.compile('([a-z][a-z0-9]*,)+') #, 'a5,b7,c9,'), # A few sets // START GENERATED by generate_tests.py - #[rustfmt::skip] let p7 = Pattern { code: &[15, 4, 0, 2, 4294967295, 24, 29, 1, 4294967295, 18, 0, 14, 5, 23, 97, 122, 0, 25, 13, 0, 4294967295, 14, 8, 23, 97, 122, 23, 48, 57, 0, 1, 17, 44, 18, 1, 19, 1] }; + #[rustfmt::skip] let p7 = Pattern { code: &[14, 4, 0, 2, 4294967295, 23, 29, 1, 4294967295, 17, 0, 13, 5, 22, 97, 122, 0, 24, 13, 0, 4294967295, 13, 8, 22, 97, 122, 22, 48, 57, 0, 1, 16, 44, 17, 1, 18, 1] }; // END GENERATED // pattern p8 = re.compile('Python') #, 'Python'), # Simple text literal // START GENERATED by generate_tests.py - #[rustfmt::skip] let p8 = Pattern { code: &[15, 18, 3, 6, 6, 6, 6, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 1] }; + #[rustfmt::skip] let p8 = Pattern { code: &[14, 18, 3, 6, 6, 6, 6, 80, 121, 116, 104, 111, 110, 0, 0, 0, 0, 0, 0, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 1] }; // END GENERATED // pattern p9 = re.compile('.*Python') #, 'Python'), # Bad text literal // START GENERATED by generate_tests.py - #[rustfmt::skip] let p9 = Pattern { code: &[15, 4, 0, 6, 4294967295, 25, 5, 0, 4294967295, 2, 1, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 1] }; + #[rustfmt::skip] let p9 = Pattern { code: &[14, 4, 0, 6, 4294967295, 24, 5, 0, 4294967295, 2, 1, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 1] }; // END GENERATED // pattern p10 = re.compile('.*Python.*') #, 'Python'), # Worse text literal // START GENERATED by generate_tests.py - #[rustfmt::skip] let p10 = Pattern { code: &[15, 4, 0, 6, 4294967295, 25, 5, 0, 4294967295, 2, 1, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 25, 5, 0, 4294967295, 2, 1, 1] }; + #[rustfmt::skip] let p10 = Pattern { code: &[14, 4, 0, 6, 4294967295, 24, 5, 0, 4294967295, 2, 1, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 24, 5, 0, 4294967295, 2, 1, 1] }; // END GENERATED // pattern p11 = re.compile('.*(Python)') #, 'Python'), # Bad text literal with grouping // START GENERATED by generate_tests.py - #[rustfmt::skip] let p11 = Pattern { code: &[15, 4, 0, 6, 4294967295, 25, 5, 0, 4294967295, 2, 1, 18, 0, 17, 80, 17, 121, 17, 116, 17, 104, 17, 111, 17, 110, 18, 1, 1] }; + #[rustfmt::skip] let p11 = Pattern { code: &[14, 4, 0, 6, 4294967295, 24, 5, 0, 4294967295, 2, 1, 17, 0, 16, 80, 16, 121, 16, 116, 16, 104, 16, 111, 16, 110, 17, 1, 1] }; // END GENERATED let tests = [ diff --git a/src/constants.rs b/src/constants.rs index 0d5bb41939..dc61c33b2c 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -13,7 +13,7 @@ use bitflags::bitflags; -pub const SRE_MAGIC: usize = 20171005; +pub const SRE_MAGIC: usize = 20221023; #[derive(num_enum::TryFromPrimitive, Debug)] #[repr(u32)] #[allow(non_camel_case_types, clippy::upper_case_acronyms)] @@ -26,39 +26,41 @@ pub enum SreOpcode { ASSERT_NOT = 5, AT = 6, BRANCH = 7, - CALL = 8, - CATEGORY = 9, - CHARSET = 10, - BIGCHARSET = 11, - GROUPREF = 12, - GROUPREF_EXISTS = 13, - IN = 14, - INFO = 15, - JUMP = 16, - LITERAL = 17, - MARK = 18, - MAX_UNTIL = 19, - MIN_UNTIL = 20, - NOT_LITERAL = 21, - NEGATE = 22, - RANGE = 23, - REPEAT = 24, - REPEAT_ONE = 25, - SUBPATTERN = 26, - MIN_REPEAT_ONE = 27, - GROUPREF_IGNORE = 28, - IN_IGNORE = 29, - LITERAL_IGNORE = 30, - NOT_LITERAL_IGNORE = 31, - GROUPREF_LOC_IGNORE = 32, - IN_LOC_IGNORE = 33, - LITERAL_LOC_IGNORE = 34, - NOT_LITERAL_LOC_IGNORE = 35, - GROUPREF_UNI_IGNORE = 36, - IN_UNI_IGNORE = 37, - LITERAL_UNI_IGNORE = 38, - NOT_LITERAL_UNI_IGNORE = 39, - RANGE_UNI_IGNORE = 40, + CATEGORY = 8, + CHARSET = 9, + BIGCHARSET = 10, + GROUPREF = 11, + GROUPREF_EXISTS = 12, + IN = 13, + INFO = 14, + JUMP = 15, + LITERAL = 16, + MARK = 17, + MAX_UNTIL = 18, + MIN_UNTIL = 19, + NOT_LITERAL = 20, + NEGATE = 21, + RANGE = 22, + REPEAT = 23, + REPEAT_ONE = 24, + SUBPATTERN = 25, + MIN_REPEAT_ONE = 26, + ATOMIC_GROUP = 27, + POSSESSIVE_REPEAT = 28, + POSSESSIVE_REPEAT_ONE = 29, + GROUPREF_IGNORE = 30, + IN_IGNORE = 31, + LITERAL_IGNORE = 32, + NOT_LITERAL_IGNORE = 33, + GROUPREF_LOC_IGNORE = 34, + IN_LOC_IGNORE = 35, + LITERAL_LOC_IGNORE = 36, + NOT_LITERAL_LOC_IGNORE = 37, + GROUPREF_UNI_IGNORE = 38, + IN_UNI_IGNORE = 39, + LITERAL_UNI_IGNORE = 40, + NOT_LITERAL_UNI_IGNORE = 41, + RANGE_UNI_IGNORE = 42, } #[derive(num_enum::TryFromPrimitive, Debug)] #[repr(u32)] @@ -101,7 +103,7 @@ pub enum SreCatCode { UNI_NOT_LINEBREAK = 17, } bitflags! { - #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct SreFlag: u16 { const TEMPLATE = 1; const IGNORECASE = 2; diff --git a/src/engine.rs b/src/engine.rs index 7474f29013..e44a1f4a09 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -759,13 +759,15 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - ctx.skip_code_from(req, 2) } } - SreOpcode::CALL => todo!(), SreOpcode::CHARSET => todo!(), SreOpcode::BIGCHARSET => todo!(), SreOpcode::NEGATE => todo!(), SreOpcode::RANGE => todo!(), SreOpcode::RANGE_UNI_IGNORE => todo!(), SreOpcode::SUBPATTERN => todo!(), + SreOpcode::ATOMIC_GROUP => todo!(), + SreOpcode::POSSESSIVE_REPEAT => todo!(), + SreOpcode::POSSESSIVE_REPEAT_ONE => todo!(), } } }; diff --git a/tests/tests.rs b/tests/tests.rs index 4e282a6f97..0a1dc407fc 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -19,7 +19,7 @@ impl Pattern { fn test_2427() { // pattern lookbehind = re.compile(r'(? Date: Mon, 11 Dec 2023 22:28:36 +0200 Subject: [PATCH 230/893] fix _count general case --- src/engine.rs | 18 ++++++++---------- tests/tests.rs | 10 ++++++++++ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index e44a1f4a09..3a24fe812b 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1904,19 +1904,17 @@ fn _count( /* General case */ let mut count = 0; - let reset_position = ctx.code_position; - while count < max_count { - ctx.code_position = reset_position; - if !_match(req, state, ctx) { + let sub_ctx = MatchContext { + toplevel: true, + jump: Jump::OpCode, + repeat_ctx_id: usize::MAX, + count: -1, + ..ctx + }; + if !_match(req, state, sub_ctx) { break; } - // let code = ctx.peek_code(req, 0); - // let code = SreOpcode::try_from(code).unwrap(); - // dispatch(req, state, &mut ctx, code); - // if ctx.has_matched == Some(false) { - // break; - // } count += 1; } return count; diff --git a/tests/tests.rs b/tests/tests.rs index 0a1dc407fc..a452efb740 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -121,3 +121,13 @@ fn test_info_literal2() { let (req, mut state) = p.state("pythonpython"); assert!(state.search(req)); } + +#[test] +fn test_repeat_in_assertions() { + // pattern p = re.compile('^([ab]*?)(?=(b)?)c', re.IGNORECASE) + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { code: &[14, 4, 0, 1, 4294967295, 6, 0, 17, 0, 26, 10, 0, 4294967295, 39, 5, 22, 97, 98, 0, 1, 17, 1, 4, 14, 0, 23, 9, 0, 1, 17, 2, 40, 98, 17, 3, 18, 1, 40, 99, 1] }; + // END GENERATED + let (req, mut state) = p.state("abc"); + assert!(state.search(req)); +} \ No newline at end of file From 99ed744c57c7aa8e3a0d33bf7893b26d5be86434 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 13 Dec 2023 21:53:30 +0200 Subject: [PATCH 231/893] impl atomic group & possessive repeat --- src/engine.rs | 108 +++++++++++++++++++++++++++++++++++++++++++++++-- tests/tests.rs | 32 +++++++++++++++ 2 files changed, 136 insertions(+), 4 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 3a24fe812b..daed81212f 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -267,6 +267,11 @@ enum Jump { RepeatOne2, MinRepeatOne1, MinRepeatOne2, + AtomicGroup1, + PossessiveRepeat1, + PossessiveRepeat2, + PossessiveRepeat3, + PossessiveRepeat4, } fn _match(req: &Request, state: &mut State, ctx: MatchContext) -> bool { @@ -445,6 +450,73 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - ctx.jump = Jump::MinRepeatOne1; continue 'context; } + Jump::AtomicGroup1 => { + if popped_result { + ctx.skip_code_from(req, 1); + ctx.string_position = state.string_position; + ctx.string_offset = req.string.offset(0, state.string_position); + // dispatch opcode + } else { + state.string_position = ctx.string_position; + break 'result false; + } + } + Jump::PossessiveRepeat1 => { + let min_count = ctx.peek_code(req, 2) as isize; + if ctx.count < min_count { + break 'context ctx.next_offset(4, Jump::PossessiveRepeat2); + } + // zero match protection + ctx.string_position = usize::MAX; + ctx.jump = Jump::PossessiveRepeat3; + continue 'context; + } + Jump::PossessiveRepeat2 => { + if popped_result { + ctx.count += 1; + ctx.jump = Jump::PossessiveRepeat1; + continue 'context; + } else { + state.string_position = ctx.string_position; + break 'result false; + } + } + Jump::PossessiveRepeat3 => { + let max_count = ctx.peek_code(req, 3) as usize; + if ((ctx.count as usize) < max_count || max_count == MAXREPEAT) + && ctx.string_position != state.string_position + { + state.marks.push(); + ctx.string_position = state.string_position; + ctx.string_offset = req.string.offset(0, state.string_position); + break 'context ctx.next_offset(4, Jump::PossessiveRepeat4); + } + ctx.string_position = state.string_position; + ctx.string_offset = req.string.offset(0, state.string_position); + // popped_result = false; + // ctx.jump = Jump::PossessiveRepeat4; + // continue 'context; + ctx.skip_code_from(req, 1); + ctx.skip_code(1); + // if ctx.remaining_codes(req) > 1 && ctx.toplevel { + // ctx.skip_code(1); + // } + } + Jump::PossessiveRepeat4 => { + if popped_result { + state.marks.pop_discard(); + ctx.count += 1; + ctx.jump = Jump::PossessiveRepeat3; + continue 'context; + } + state.marks.pop(); + state.string_position = ctx.string_position; + ctx.skip_code_from(req, 1); + ctx.skip_code(1); + // if ctx.remaining_codes(req) > 1 && ctx.toplevel { + // ctx.skip_code(1); + // } + } } ctx.jump = Jump::OpCode; @@ -759,15 +831,43 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - ctx.skip_code_from(req, 2) } } + /* pattern tail */ + SreOpcode::ATOMIC_GROUP => { + state.string_position = ctx.string_position; + break 'context ctx.next_offset(2, Jump::AtomicGroup1); + } + /* <1=min> <2=max> pattern + tail */ + SreOpcode::POSSESSIVE_REPEAT => { + state.string_position = ctx.string_position; + ctx.count = 0; + ctx.jump = Jump::PossessiveRepeat1; + continue 'context; + } + /* <1=min> <2=max> item + tail */ + SreOpcode::POSSESSIVE_REPEAT_ONE => { + let min_count = ctx.peek_code(req, 2) as usize; + let max_count = ctx.peek_code(req, 3) as usize; + if ctx.remaining_chars(req) < min_count { + break 'result false; + } + state.string_position = ctx.string_position; + let mut count_ctx = ctx; + count_ctx.skip_code(4); + let count = _count(req, state, count_ctx, max_count); + if count < min_count { + break 'result false; + } + ctx.skip_char(req, count); + ctx.skip_code_from(req, 1); + } SreOpcode::CHARSET => todo!(), SreOpcode::BIGCHARSET => todo!(), SreOpcode::NEGATE => todo!(), SreOpcode::RANGE => todo!(), SreOpcode::RANGE_UNI_IGNORE => todo!(), SreOpcode::SUBPATTERN => todo!(), - SreOpcode::ATOMIC_GROUP => todo!(), - SreOpcode::POSSESSIVE_REPEAT => todo!(), - SreOpcode::POSSESSIVE_REPEAT_ONE => todo!(), } } }; @@ -1906,7 +2006,7 @@ fn _count( while count < max_count { let sub_ctx = MatchContext { - toplevel: true, + toplevel: false, jump: Jump::OpCode, repeat_ctx_id: usize::MAX, count: -1, diff --git a/tests/tests.rs b/tests/tests.rs index a452efb740..4c56c42fae 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -130,4 +130,36 @@ fn test_repeat_in_assertions() { // END GENERATED let (req, mut state) = p.state("abc"); assert!(state.search(req)); +} + +#[test] +fn test_possessive_quantifier() { + // pattern p = re.compile('e++a') + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { code: &[14, 4, 0, 2, 4294967295, 29, 6, 1, 4294967295, 16, 101, 1, 16, 97, 1] }; + // END GENERATED + let (req, mut state) = p.state("eeea"); + assert!(state.pymatch(&req)); +} + +#[test] +fn test_possessive_atomic_group() { + // pattern p = re.compile('(?>x)++x') + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { code: &[14, 4, 0, 2, 4294967295, 28, 8, 1, 4294967295, 27, 4, 16, 120, 1, 1, 16, 120, 1] }; + // END GENERATED + let (req, mut state) = p.state("xxx"); + assert!(!state.pymatch(&req)); +} + +#[test] +fn test_bug_20998() { + // pattern p = re.compile('[a-c]+', re.I) + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { code: &[14, 4, 0, 1, 4294967295, 24, 10, 1, 4294967295, 39, 5, 22, 97, 99, 0, 1, 1] }; + // END GENERATED + let (mut req, mut state) = p.state("ABC"); + req.match_all = true; + assert!(state.pymatch(&req)); + assert_eq!(state.string_position, 3); } \ No newline at end of file From 9378497346147500328d6b96e59386caecbfbeab Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 13 Dec 2023 21:59:38 +0200 Subject: [PATCH 232/893] clearup --- src/engine.rs | 592 ++------------------------------------------------ 1 file changed, 14 insertions(+), 578 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index daed81212f..3486c52d9f 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -493,14 +493,8 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - } ctx.string_position = state.string_position; ctx.string_offset = req.string.offset(0, state.string_position); - // popped_result = false; - // ctx.jump = Jump::PossessiveRepeat4; - // continue 'context; ctx.skip_code_from(req, 1); ctx.skip_code(1); - // if ctx.remaining_codes(req) > 1 && ctx.toplevel { - // ctx.skip_code(1); - // } } Jump::PossessiveRepeat4 => { if popped_result { @@ -513,9 +507,6 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - state.string_position = ctx.string_position; ctx.skip_code_from(req, 1); ctx.skip_code(1); - // if ctx.remaining_codes(req) > 1 && ctx.toplevel { - // ctx.skip_code(1); - // } } } ctx.jump = Jump::OpCode; @@ -603,6 +594,7 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - ctx.skip_code(1); ctx.skip_char(req, 1); } + /* */ SreOpcode::ASSERT => { let back = ctx.peek_code(req, 2) as usize; if ctx.string_position < back { @@ -615,6 +607,7 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - state.string_position = next_ctx.string_position; break 'context next_ctx; } + /* */ SreOpcode::ASSERT_NOT => { let back = ctx.peek_code(req, 2) as usize; if ctx.string_position < back { @@ -636,6 +629,7 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - break 'result false; } } + // <0=skip> code ... SreOpcode::BRANCH => { state.marks.push(); ctx.count = 1; @@ -672,6 +666,7 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - ctx.skip_code(2); } SreOpcode::JUMP => ctx.skip_code_from(req, 1), + /* <1=min> <2=max> item tail */ SreOpcode::REPEAT => { let repeat_ctx = RepeatContext { count: -1, @@ -736,6 +731,7 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - next_ctx.repeat_ctx_id = repeat_ctx.prev_id; break 'context next_ctx; } + /* <1=min> <2=max> item tail */ SreOpcode::REPEAT_ONE => { let min_count = ctx.peek_code(req, 2) as usize; let max_count = ctx.peek_code(req, 3) as usize; @@ -766,6 +762,7 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - ctx.jump = Jump::RepeatOne1; continue 'context; } + /* <1=min> <2=max> item tail */ SreOpcode::MIN_REPEAT_ONE => { let min_count = ctx.peek_code(req, 2) as usize; if ctx.remaining_chars(req) < min_count { @@ -862,12 +859,14 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - ctx.skip_char(req, count); ctx.skip_code_from(req, 1); } - SreOpcode::CHARSET => todo!(), - SreOpcode::BIGCHARSET => todo!(), - SreOpcode::NEGATE => todo!(), - SreOpcode::RANGE => todo!(), - SreOpcode::RANGE_UNI_IGNORE => todo!(), - SreOpcode::SUBPATTERN => todo!(), + SreOpcode::CHARSET + | SreOpcode::BIGCHARSET + | SreOpcode::NEGATE + | SreOpcode::RANGE + | SreOpcode::RANGE_UNI_IGNORE + | SreOpcode::SUBPATTERN => { + unreachable!("unexpected opcode on main dispatch") + } } } }; @@ -879,123 +878,6 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - popped_result } -/* -fn dispatch( - req: &Request, - state: &mut State, - ctx: &mut MatchContext, - opcode: SreOpcode, -) { - match opcode { - SreOpcode::FAILURE => { - ctx.failure(); - } - SreOpcode::SUCCESS => { - if ctx.can_success(req) { - state.string_position = ctx.string_position; - ctx.success(); - } else { - ctx.failure(); - } - } - SreOpcode::ANY => { - if ctx.at_end(req) || ctx.at_linebreak(req) { - ctx.failure(); - } else { - ctx.skip_code(1); - ctx.skip_char(req, 1); - } - } - SreOpcode::ANY_ALL => { - if ctx.at_end(req) { - ctx.failure(); - } else { - ctx.skip_code(1); - ctx.skip_char(req, 1); - } - } - SreOpcode::ASSERT => op_assert(req, state, ctx), - SreOpcode::ASSERT_NOT => op_assert_not(req, state, ctx), - SreOpcode::AT => { - let atcode = SreAtCode::try_from(ctx.peek_code(req, 1)).unwrap(); - if at(req, ctx, atcode) { - ctx.skip_code(2); - } else { - ctx.failure(); - } - } - SreOpcode::BRANCH => op_branch(req, state, ctx), - SreOpcode::CATEGORY => { - let catcode = SreCatCode::try_from(ctx.peek_code(req, 1)).unwrap(); - if ctx.at_end(req) || !category(catcode, ctx.peek_char(req)) { - ctx.failure(); - } else { - ctx.skip_code(2); - ctx.skip_char(req, 1); - } - } - SreOpcode::IN => general_op_in(req, ctx, charset), - SreOpcode::IN_IGNORE => general_op_in(req, ctx, |set, c| charset(set, lower_ascii(c))), - SreOpcode::IN_UNI_IGNORE => { - general_op_in(req, ctx, |set, c| charset(set, lower_unicode(c))) - } - SreOpcode::IN_LOC_IGNORE => general_op_in(req, ctx, charset_loc_ignore), - SreOpcode::INFO => { - let min = ctx.peek_code(req, 3) as usize; - if ctx.remaining_chars(req) < min { - ctx.failure(); - } else { - ctx.skip_code_from(req, 1); - } - } - SreOpcode::JUMP => ctx.skip_code_from(req, 1), - SreOpcode::LITERAL => general_op_literal(req, ctx, |code, c| code == c), - SreOpcode::NOT_LITERAL => general_op_literal(req, ctx, |code, c| code != c), - SreOpcode::LITERAL_IGNORE => general_op_literal(req, ctx, |code, c| code == lower_ascii(c)), - SreOpcode::NOT_LITERAL_IGNORE => { - general_op_literal(req, ctx, |code, c| code != lower_ascii(c)) - } - SreOpcode::LITERAL_UNI_IGNORE => { - general_op_literal(req, ctx, |code, c| code == lower_unicode(c)) - } - SreOpcode::NOT_LITERAL_UNI_IGNORE => { - general_op_literal(req, ctx, |code, c| code != lower_unicode(c)) - } - SreOpcode::LITERAL_LOC_IGNORE => general_op_literal(req, ctx, char_loc_ignore), - SreOpcode::NOT_LITERAL_LOC_IGNORE => { - general_op_literal(req, ctx, |code, c| !char_loc_ignore(code, c)) - } - SreOpcode::MARK => { - state - .marks - .set(ctx.peek_code(req, 1) as usize, ctx.string_position); - ctx.skip_code(2); - } - SreOpcode::MAX_UNTIL => op_max_until(state, ctx), - SreOpcode::MIN_UNTIL => op_min_until(state, ctx), - SreOpcode::REPEAT => op_repeat(req, state, ctx), - SreOpcode::REPEAT_ONE => op_repeat_one(req, state, ctx), - SreOpcode::MIN_REPEAT_ONE => op_min_repeat_one(req, state, ctx), - SreOpcode::GROUPREF => general_op_groupref(req, state, ctx, |x| x), - SreOpcode::GROUPREF_IGNORE => general_op_groupref(req, state, ctx, lower_ascii), - SreOpcode::GROUPREF_LOC_IGNORE => general_op_groupref(req, state, ctx, lower_locate), - SreOpcode::GROUPREF_UNI_IGNORE => general_op_groupref(req, state, ctx, lower_unicode), - SreOpcode::GROUPREF_EXISTS => { - let (group_start, group_end) = state.marks.get(ctx.peek_code(req, 1) as usize); - if group_start.is_some() - && group_end.is_some() - && group_start.unpack() <= group_end.unpack() - { - ctx.skip_code(3); - } else { - ctx.skip_code_from(req, 2) - } - } - _ => unreachable!("unexpected opcode"), - } -} -*/ - fn search_info_literal( req: &mut Request, state: &mut State, @@ -1142,224 +1024,6 @@ fn search_info_charset( } } -/* -/* assert subpattern */ -/* */ -fn op_assert(req: &Request, state: &mut State, ctx: &mut MatchContext) { - let back = ctx.peek_code(req, 2) as usize; - if ctx.string_position < back { - return ctx.failure(); - } - - let next_ctx = ctx.next_offset(3, state, |req, state, ctx| { - if state.popped_has_matched { - ctx.skip_code_from(req, 1); - } else { - ctx.failure(); - } - }); - next_ctx.toplevel = false; - next_ctx.back_skip_char(req, back); - state.string_position = next_ctx.string_position; -} - -/* assert not subpattern */ -/* */ -fn op_assert_not(req: &Request, state: &mut State, ctx: &mut MatchContext) { - let back = ctx.peek_code(req, 2) as usize; - - if ctx.string_position < back { - return ctx.skip_code_from(req, 1); - } - - let next_ctx = ctx.next_offset(3, state, |req, state, ctx| { - if state.popped_has_matched { - ctx.failure(); - } else { - ctx.skip_code_from(req, 1); - } - }); - next_ctx.toplevel = false; - next_ctx.back_skip_char(req, back); - state.string_position = next_ctx.string_position; -} - -// alternation -// <0=skip> code ... -fn op_branch(req: &Request, state: &mut State, ctx: &mut MatchContext) { - state.marks.push(); - - ctx.count = 1; - create_context(req, state, ctx); - - fn create_context(req: &Request, state: &mut State, ctx: &mut MatchContext) { - let branch_offset = ctx.count as usize; - let next_length = ctx.peek_code(req, branch_offset) as isize; - if next_length == 0 { - state.marks.pop_discard(); - return ctx.failure(); - } - - state.string_position = ctx.string_position; - - ctx.count += next_length; - ctx.next_offset(branch_offset + 1, state, callback); - } - - fn callback(req: &Request, state: &mut State, ctx: &mut MatchContext) { - if state.popped_has_matched { - return ctx.success(); - } - state.marks.pop_keep(); - create_context(req, state, ctx); - } -} - -/* <1=min> <2=max> item tail */ -fn op_min_repeat_one(req: &Request, state: &mut State, ctx: &mut MatchContext) { - let min_count = ctx.peek_code(req, 2) as usize; - - if ctx.remaining_chars(req) < min_count { - return ctx.failure(); - } - - state.string_position = ctx.string_position; - - ctx.count = if min_count == 0 { - 0 - } else { - let mut next_ctx = *ctx; - next_ctx.skip_code(4); - let count = _count(req, state, next_ctx, min_count); - if count < min_count { - return ctx.failure(); - } - ctx.skip_char(req, count); - count as isize - }; - - let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); - if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(req) { - // tail is empty. we're finished - state.string_position = ctx.string_position; - return ctx.success(); - } - - state.marks.push(); - create_context(req, state, ctx); - - fn create_context(req: &Request, state: &mut State, ctx: &mut MatchContext) { - let max_count = ctx.peek_code(req, 3) as usize; - - if max_count == MAXREPEAT || ctx.count as usize <= max_count { - state.string_position = ctx.string_position; - ctx.next_peek_from(1, req, state, callback); - } else { - state.marks.pop_discard(); - ctx.failure(); - } - } - - fn callback(req: &Request, state: &mut State, ctx: &mut MatchContext) { - if state.popped_has_matched { - return ctx.success(); - } - - state.string_position = ctx.string_position; - - let mut next_ctx = *ctx; - next_ctx.skip_code(4); - if _count(req, state, next_ctx, 1) == 0 { - state.marks.pop_discard(); - return ctx.failure(); - } - - ctx.skip_char(req, 1); - ctx.count += 1; - state.marks.pop_keep(); - create_context(req, state, ctx); - } -} - -/* match repeated sequence (maximizing regexp) */ -/* this operator only works if the repeated item is -exactly one character wide, and we're not already -collecting backtracking points. for other cases, -use the MAX_REPEAT operator */ -/* <1=min> <2=max> item tail */ -fn op_repeat_one(req: &Request, state: &mut State, ctx: &mut MatchContext) { - let min_count = ctx.peek_code(req, 2) as usize; - let max_count = ctx.peek_code(req, 3) as usize; - - if ctx.remaining_chars(req) < min_count { - return ctx.failure(); - } - - state.string_position = ctx.string_position; - - let mut next_ctx = *ctx; - next_ctx.skip_code(4); - let count = _count(req, state, next_ctx, max_count); - ctx.skip_char(req, count); - if count < min_count { - return ctx.failure(); - } - - let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); - if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(req) { - // tail is empty. we're finished - state.string_position = ctx.string_position; - return ctx.success(); - } - - state.marks.push(); - ctx.count = count as isize; - create_context(req, state, ctx); - - fn create_context(req: &Request, state: &mut State, ctx: &mut MatchContext) { - let min_count = ctx.peek_code(req, 2) as isize; - let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); - if next_code == SreOpcode::LITERAL as u32 { - // Special case: Tail starts with a literal. Skip positions where - // the rest of the pattern cannot possibly match. - let c = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 2); - while ctx.at_end(req) || ctx.peek_char(req) != c { - if ctx.count <= min_count { - state.marks.pop_discard(); - return ctx.failure(); - } - ctx.back_skip_char(req, 1); - ctx.count -= 1; - } - } - - state.string_position = ctx.string_position; - - // General case: backtracking - ctx.next_peek_from(1, req, state, callback); - } - - fn callback(req: &Request, state: &mut State, ctx: &mut MatchContext) { - if state.popped_has_matched { - return ctx.success(); - } - - let min_count = ctx.peek_code(req, 2) as isize; - - if ctx.count <= min_count { - state.marks.pop_discard(); - return ctx.failure(); - } - - ctx.back_skip_char(req, 1); - ctx.count -= 1; - - state.marks.pop_keep(); - create_context(req, state, ctx); - } -} -*/ - #[derive(Debug, Clone, Copy)] struct RepeatContext { count: isize, @@ -1370,173 +1034,6 @@ struct RepeatContext { prev_id: usize, } -/* -/* create repeat context. all the hard work is done -by the UNTIL operator (MAX_UNTIL, MIN_UNTIL) */ -/* <1=min> <2=max> item tail */ -fn op_repeat(req: &Request, state: &mut State, ctx: &mut MatchContext) { - let repeat_ctx = RepeatContext { - count: -1, - min_count: ctx.peek_code(req, 2) as usize, - max_count: ctx.peek_code(req, 3) as usize, - code_position: ctx.code_position, - last_position: std::usize::MAX, - prev_id: ctx.repeat_ctx_id, - }; - - state.repeat_stack.push(repeat_ctx); - - state.string_position = ctx.string_position; - - let repeat_ctx_id = state.repeat_stack.len() - 1; - - let next_ctx = ctx.next_peek_from(1, req, state, |_, state, ctx| { - ctx.has_matched = Some(state.popped_has_matched); - state.repeat_stack.pop(); - }); - next_ctx.repeat_ctx_id = repeat_ctx_id; -} - -/* minimizing repeat */ -fn op_min_until(state: &mut State, ctx: &mut MatchContext) { - let repeat_ctx = state.repeat_stack.last_mut().unwrap(); - - state.string_position = ctx.string_position; - - repeat_ctx.count += 1; - - if (repeat_ctx.count as usize) < repeat_ctx.min_count { - // not enough matches - ctx.next_at(repeat_ctx.code_position + 4, state, |_, state, ctx| { - if state.popped_has_matched { - ctx.success(); - } else { - state.repeat_stack[ctx.repeat_ctx_id].count -= 1; - state.string_position = ctx.string_position; - ctx.failure(); - } - }); - return; - } - - state.marks.push(); - - ctx.count = ctx.repeat_ctx_id as isize; - - let repeat_ctx_prev_id = repeat_ctx.prev_id; - - // see if the tail matches - let next_ctx = ctx.next_offset(1, state, |_, state, ctx| { - if state.popped_has_matched { - return ctx.success(); - } - - ctx.repeat_ctx_id = ctx.count as usize; - - let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; - - state.string_position = ctx.string_position; - - state.marks.pop(); - - // match more until tail matches - - if repeat_ctx.count as usize >= repeat_ctx.max_count && repeat_ctx.max_count != MAXREPEAT - || state.string_position == repeat_ctx.last_position - { - repeat_ctx.count -= 1; - return ctx.failure(); - } - - /* zero-width match protection */ - repeat_ctx.last_position = state.string_position; - - ctx.next_at(repeat_ctx.code_position + 4, state, |_, state, ctx| { - if state.popped_has_matched { - ctx.success(); - } else { - state.repeat_stack[ctx.repeat_ctx_id].count -= 1; - state.string_position = ctx.string_position; - ctx.failure(); - } - }); - }); - next_ctx.repeat_ctx_id = repeat_ctx_prev_id; -} - -/* maximizing repeat */ -fn op_max_until(state: &mut State, ctx: &mut MatchContext) { - let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; - - state.string_position = ctx.string_position; - - repeat_ctx.count += 1; - - if (repeat_ctx.count as usize) < repeat_ctx.min_count { - // not enough matches - ctx.next_at(repeat_ctx.code_position + 4, state, |_, state, ctx| { - if state.popped_has_matched { - ctx.success(); - } else { - state.repeat_stack[ctx.repeat_ctx_id].count -= 1; - state.string_position = ctx.string_position; - ctx.failure(); - } - }); - return; - } - - if ((repeat_ctx.count as usize) < repeat_ctx.max_count || repeat_ctx.max_count == MAXREPEAT) - && state.string_position != repeat_ctx.last_position - { - /* we may have enough matches, but if we can - match another item, do so */ - state.marks.push(); - - ctx.count = repeat_ctx.last_position as isize; - repeat_ctx.last_position = state.string_position; - - ctx.next_at(repeat_ctx.code_position + 4, state, |_, state, ctx| { - let save_last_position = ctx.count as usize; - let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; - repeat_ctx.last_position = save_last_position; - - if state.popped_has_matched { - state.marks.pop_discard(); - return ctx.success(); - } - - state.marks.pop(); - repeat_ctx.count -= 1; - - state.string_position = ctx.string_position; - - /* cannot match more repeated items here. make sure the - tail matches */ - let repeat_ctx_prev_id = repeat_ctx.prev_id; - let next_ctx = ctx.next_offset(1, state, tail_callback); - next_ctx.repeat_ctx_id = repeat_ctx_prev_id; - }); - return; - } - - /* cannot match more repeated items here. make sure the - tail matches */ - let repeat_ctx_prev_id = repeat_ctx.prev_id; - let next_ctx = ctx.next_offset(1, state, tail_callback); - next_ctx.repeat_ctx_id = repeat_ctx_prev_id; - - fn tail_callback(_: &Request, state: &mut State, ctx: &mut MatchContext) { - if state.popped_has_matched { - ctx.success(); - } else { - state.string_position = ctx.string_position; - ctx.failure(); - } - } -} -*/ - pub trait StrDrive: Copy { fn offset(&self, offset: usize, skip: usize) -> usize; fn count(&self) -> usize; @@ -1758,67 +1255,6 @@ fn at(req: &Request, ctx: &MatchContext, atcode: SreAtCode) -> b } } -/* -fn general_op_literal bool>( - req: &Request, - ctx: &mut MatchContext, - f: F, -) { - if ctx.at_end(req) || !f(ctx.peek_code(req, 1), ctx.peek_char(req)) { - ctx.failure(); - } else { - ctx.skip_code(2); - ctx.skip_char(req, 1); - } -} - -fn general_op_in bool>( - req: &Request, - ctx: &mut MatchContext, - f: F, -) { - if ctx.at_end(req) || !f(&ctx.pattern(req)[2..], ctx.peek_char(req)) { - ctx.failure(); - } else { - ctx.skip_code_from(req, 1); - ctx.skip_char(req, 1); - } -} - -fn general_op_groupref u32>( - req: &Request, - state: &State, - ctx: &mut MatchContext, - mut f: F, -) { - let (group_start, group_end) = state.marks.get(ctx.peek_code(req, 1) as usize); - let (group_start, group_end) = if group_start.is_some() - && group_end.is_some() - && group_start.unpack() <= group_end.unpack() - { - (group_start.unpack(), group_end.unpack()) - } else { - return ctx.failure(); - }; - - let mut gctx = MatchContext { - string_position: group_start, - string_offset: req.string.offset(0, group_start), - ..*ctx - }; - - for _ in group_start..group_end { - if ctx.at_end(req) || f(ctx.peek_char(req)) != f(gctx.peek_char(req)) { - return ctx.failure(); - } - ctx.skip_char(req, 1); - gctx.skip_char(req, 1); - } - - ctx.skip_code(2); -} -*/ - fn char_loc_ignore(code: u32, c: u32) -> bool { code == c || code == lower_locate(c) || code == upper_locate(c) } From 003c45dbffbfefe5a7a47899836d042bddfbb8a9 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 13 Dec 2023 22:00:35 +0200 Subject: [PATCH 233/893] remove unneccesary INFO logic on main dispatch --- src/engine.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 3486c52d9f..5cf79ea147 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -652,20 +652,13 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - general_op_in!(|set, c| charset(set, lower_unicode(c))) } SreOpcode::IN_LOC_IGNORE => general_op_in!(charset_loc_ignore), - SreOpcode::INFO => { - let min = ctx.peek_code(req, 3) as usize; - if ctx.remaining_chars(req) < min { - break 'result false; - } - ctx.skip_code_from(req, 1); - } SreOpcode::MARK => { state .marks .set(ctx.peek_code(req, 1) as usize, ctx.string_position); ctx.skip_code(2); } - SreOpcode::JUMP => ctx.skip_code_from(req, 1), + SreOpcode::INFO | SreOpcode::JUMP => ctx.skip_code_from(req, 1), /* <1=min> <2=max> item tail */ SreOpcode::REPEAT => { let repeat_ctx = RepeatContext { From 41bdcfe2212d08c4068595babe3c497cd2a035fb Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 13 Dec 2023 22:13:54 +0200 Subject: [PATCH 234/893] bump version to 0.5.0 --- Cargo.toml | 2 +- src/engine.rs | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index de1d68cf6d..b0ec8eab2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sre-engine" -version = "0.4.3" +version = "0.5.0" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" repository = "https://github.com/RustPython/sre-engine" diff --git a/src/engine.rs b/src/engine.rs index 5cf79ea147..5ee9af7d1b 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,8 +1,6 @@ // good luck to those that follow; here be dragons -use crate::constants::SreInfo; - -use super::constants::{SreAtCode, SreCatCode, SreOpcode}; +use super::constants::{SreAtCode, SreCatCode, SreOpcode, SreInfo}; use super::MAXREPEAT; use optional::Optioned; use std::convert::TryFrom; @@ -284,7 +282,7 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - }; popped_result = 'result: loop { - let yield_ = 'context: loop { + let yielded = 'context: loop { match ctx.jump { Jump::OpCode => {} Jump::Assert1 => { @@ -864,7 +862,7 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - } }; context_stack.push(ctx); - context_stack.push(yield_); + context_stack.push(yielded); continue 'coro; }; } From 169368b7f06fd502a5a77fd56cf5b7ec46e3a975 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 13 Dec 2023 22:43:17 +0200 Subject: [PATCH 235/893] fix some clippy --- src/engine.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 5ee9af7d1b..b6bf6a6fb6 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -142,7 +142,7 @@ impl State { repeat_ctx_id: usize::MAX, count: -1, }; - _match(&req, self, ctx) + _match(req, self, ctx) } pub fn search(&mut self, mut req: Request) -> bool { @@ -1400,16 +1400,16 @@ fn _count( } } SreOpcode::LITERAL => { - general_count_literal(req, &mut ctx, end, |code, c| code == c as u32); + general_count_literal(req, &mut ctx, end, |code, c| code == c); } SreOpcode::NOT_LITERAL => { - general_count_literal(req, &mut ctx, end, |code, c| code != c as u32); + general_count_literal(req, &mut ctx, end, |code, c| code != c); } SreOpcode::LITERAL_IGNORE => { - general_count_literal(req, &mut ctx, end, |code, c| code == lower_ascii(c) as u32); + general_count_literal(req, &mut ctx, end, |code, c| code == lower_ascii(c)); } SreOpcode::NOT_LITERAL_IGNORE => { - general_count_literal(req, &mut ctx, end, |code, c| code != lower_ascii(c) as u32); + general_count_literal(req, &mut ctx, end, |code, c| code != lower_ascii(c)); } SreOpcode::LITERAL_LOC_IGNORE => { general_count_literal(req, &mut ctx, end, char_loc_ignore); @@ -1419,12 +1419,12 @@ fn _count( } SreOpcode::LITERAL_UNI_IGNORE => { general_count_literal(req, &mut ctx, end, |code, c| { - code == lower_unicode(c) as u32 + code == lower_unicode(c) }); } SreOpcode::NOT_LITERAL_UNI_IGNORE => { general_count_literal(req, &mut ctx, end, |code, c| { - code != lower_unicode(c) as u32 + code != lower_unicode(c) }); } _ => { From 454aa4b6544cc53b3443328015ff923103f906c4 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 3 Jan 2024 15:34:01 +0200 Subject: [PATCH 236/893] fix count not advance --- src/engine.rs | 28 ++++++++++++++-------------- tests/tests.rs | 13 ++++++++++++- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index b6bf6a6fb6..86b8d20d8a 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,6 +1,6 @@ // good luck to those that follow; here be dragons -use super::constants::{SreAtCode, SreCatCode, SreOpcode, SreInfo}; +use super::constants::{SreAtCode, SreCatCode, SreInfo, SreOpcode}; use super::MAXREPEAT; use optional::Optioned; use std::convert::TryFrom; @@ -1418,31 +1418,31 @@ fn _count( general_count_literal(req, &mut ctx, end, |code, c| !char_loc_ignore(code, c)); } SreOpcode::LITERAL_UNI_IGNORE => { - general_count_literal(req, &mut ctx, end, |code, c| { - code == lower_unicode(c) - }); + general_count_literal(req, &mut ctx, end, |code, c| code == lower_unicode(c)); } SreOpcode::NOT_LITERAL_UNI_IGNORE => { - general_count_literal(req, &mut ctx, end, |code, c| { - code != lower_unicode(c) - }); + general_count_literal(req, &mut ctx, end, |code, c| code != lower_unicode(c)); } _ => { /* General case */ let mut count = 0; + let mut sub_ctx = MatchContext { + toplevel: false, + jump: Jump::OpCode, + repeat_ctx_id: usize::MAX, + count: -1, + ..ctx + }; + while count < max_count { - let sub_ctx = MatchContext { - toplevel: false, - jump: Jump::OpCode, - repeat_ctx_id: usize::MAX, - count: -1, - ..ctx - }; if !_match(req, state, sub_ctx) { break; } count += 1; + sub_ctx.skip_char(req, 1); + // ctx.string_position = state.string_position; + // ctx.string_offset = req.string.offset(0, state.string_position); } return count; } diff --git a/tests/tests.rs b/tests/tests.rs index 4c56c42fae..efeb2d2838 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -162,4 +162,15 @@ fn test_bug_20998() { req.match_all = true; assert!(state.pymatch(&req)); assert_eq!(state.string_position, 3); -} \ No newline at end of file +} + +#[test] +fn test_bigcharset() { + // pattern p = re.compile('[a-z]*', re.I) + // START GENERATED by generate_tests.py + #[rustfmt::skip] let p = Pattern { code: &[14, 4, 0, 0, 4294967295, 24, 97, 0, 4294967295, 39, 92, 10, 3, 33685760, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 33686018, 0, 0, 0, 134217726, 0, 0, 0, 0, 0, 131072, 0, 2147483648, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1] }; + // END GENERATED + let (req, mut state) = p.state("x "); + assert!(state.pymatch(&req)); + assert_eq!(state.string_position, 1); +} From 17e1152de63cd9b70e4ad3b061979baa7ee54a35 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 3 Jan 2024 17:17:27 +0200 Subject: [PATCH 237/893] fix assert not mark --- src/engine.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/engine.rs b/src/engine.rs index 86b8d20d8a..964775fd8f 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -296,6 +296,7 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - if popped_result { break 'result false; } + state.marks.pop(); ctx.skip_code_from(req, 1); } Jump::Branch1 => { @@ -612,6 +613,7 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - ctx.skip_code_from(req, 1); continue; } + state.marks.push(); let mut next_ctx = ctx.next_offset(3, Jump::AssertNot1); next_ctx.toplevel = false; From 2fe129252fd0d798dc0cdac78737bd641825d48b Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Thu, 4 Jan 2024 07:51:35 +0200 Subject: [PATCH 238/893] improve ctx in _match lazy create stack vec --- src/constants.rs | 4 ++-- src/engine.rs | 31 ++++++++++++++++++++++++------- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index dc61c33b2c..9fe792ce17 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -14,7 +14,7 @@ use bitflags::bitflags; pub const SRE_MAGIC: usize = 20221023; -#[derive(num_enum::TryFromPrimitive, Debug)] +#[derive(num_enum::TryFromPrimitive, Debug, PartialEq, Eq)] #[repr(u32)] #[allow(non_camel_case_types, clippy::upper_case_acronyms)] pub enum SreOpcode { @@ -62,7 +62,7 @@ pub enum SreOpcode { NOT_LITERAL_UNI_IGNORE = 41, RANGE_UNI_IGNORE = 42, } -#[derive(num_enum::TryFromPrimitive, Debug)] +#[derive(num_enum::TryFromPrimitive, Debug, PartialEq, Eq)] #[repr(u32)] #[allow(non_camel_case_types, clippy::upper_case_acronyms)] pub enum SreAtCode { diff --git a/src/engine.rs b/src/engine.rs index 964775fd8f..c23ffb7a2f 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -207,6 +207,15 @@ impl State { return true; } + if ctx.try_peek_code_as::(&req, 1).unwrap() == SreOpcode::AT + && (ctx.try_peek_code_as::(&req, 2).unwrap() == SreAtCode::BEGINNING + || ctx.try_peek_code_as::(&req, 2).unwrap() + == SreAtCode::BEGINNING_STRING) + { + self.reset(req.end); + return false; + } + req.must_advance = false; ctx.toplevel = false; while req.start < end { @@ -272,15 +281,11 @@ enum Jump { PossessiveRepeat4, } -fn _match(req: &Request, state: &mut State, ctx: MatchContext) -> bool { - let mut context_stack = vec![ctx]; +fn _match(req: &Request, state: &mut State, mut ctx: MatchContext) -> bool { + let mut context_stack = vec![]; let mut popped_result = false; 'coro: loop { - let Some(mut ctx) = context_stack.pop() else { - break; - }; - popped_result = 'result: loop { let yielded = 'context: loop { match ctx.jump { @@ -864,9 +869,14 @@ fn _match(req: &Request, state: &mut State, ctx: MatchContext) - } }; context_stack.push(ctx); - context_stack.push(yielded); + ctx = yielded; continue 'coro; }; + if let Some(popped_ctx) = context_stack.pop() { + ctx = popped_ctx; + } else { + break; + } } popped_result } @@ -1148,6 +1158,13 @@ impl MatchContext { req.pattern_codes[self.code_position + peek] } + fn try_peek_code_as(&self, req: &Request, peek: usize) -> Result + where + T: TryFrom, + { + self.peek_code(req, peek).try_into() + } + fn skip_code(&mut self, skip: usize) { self.code_position += skip; } From f9b2d10c710d45ebd0cc9788294cd5806cc9e8ac Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Fri, 5 Jan 2024 07:23:50 +0200 Subject: [PATCH 239/893] improve search at_beginning --- src/engine.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index c23ffb7a2f..dbb58025fe 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -201,15 +201,17 @@ impl State { return search_info_charset(&mut req, self, ctx); } // fallback to general search + // skip OP INFO + ctx.skip_code_from(&req, 1); } if _match(&req, self, ctx) { return true; } - if ctx.try_peek_code_as::(&req, 1).unwrap() == SreOpcode::AT - && (ctx.try_peek_code_as::(&req, 2).unwrap() == SreAtCode::BEGINNING - || ctx.try_peek_code_as::(&req, 2).unwrap() + if ctx.try_peek_code_as::(&req, 0).unwrap() == SreOpcode::AT + && (ctx.try_peek_code_as::(&req, 1).unwrap() == SreAtCode::BEGINNING + || ctx.try_peek_code_as::(&req, 1).unwrap() == SreAtCode::BEGINNING_STRING) { self.reset(req.end); From 118a00c012810900fe89277cebe6a5f09ff286d1 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 7 Jan 2024 08:42:18 +0200 Subject: [PATCH 240/893] refactor _count general case --- src/engine.rs | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index dbb58025fe..ca44f4994a 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1446,26 +1446,20 @@ fn _count( } _ => { /* General case */ - let mut count = 0; - - let mut sub_ctx = MatchContext { - toplevel: false, - jump: Jump::OpCode, - repeat_ctx_id: usize::MAX, - count: -1, - ..ctx + ctx.toplevel = false; + ctx.jump = Jump::OpCode; + ctx.repeat_ctx_id = usize::MAX; + ctx.count = -1; + + let mut sub_state = State { + marks: Marks::default(), + repeat_stack: vec![], + ..*state }; - while count < max_count { - if !_match(req, state, sub_ctx) { - break; - } - count += 1; - sub_ctx.skip_char(req, 1); - // ctx.string_position = state.string_position; - // ctx.string_offset = req.string.offset(0, state.string_position); + while ctx.string_position < end && _match(req, &mut sub_state, ctx) { + ctx.skip_char(req, 1); } - return count; } } From c93ea30b3b5849edc3ac888a1bfec69f08b537c1 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sat, 13 Jan 2024 16:03:38 +0200 Subject: [PATCH 241/893] improve use StringCursor replace index based position --- Cargo.toml | 2 +- benches/benches.rs | 12 +- src/engine.rs | 511 ++++++++++++++------------------------------- src/lib.rs | 5 + src/string.rs | 381 +++++++++++++++++++++++++++++++++ tests/tests.rs | 31 +-- 6 files changed, 562 insertions(+), 380 deletions(-) create mode 100644 src/string.rs diff --git a/Cargo.toml b/Cargo.toml index b0ec8eab2d..e54f124ac0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sre-engine" -version = "0.5.0" +version = "0.6.0" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" repository = "https://github.com/RustPython/sre-engine" diff --git a/benches/benches.rs b/benches/benches.rs index f70138f920..e89adab0dd 100644 --- a/benches/benches.rs +++ b/benches/benches.rs @@ -3,24 +3,24 @@ extern crate test; use test::Bencher; -use sre_engine::engine; +use sre_engine::{Request, State, StrDrive}; struct Pattern { code: &'static [u32], } impl Pattern { - fn state<'a, S: engine::StrDrive>(&self, string: S) -> (engine::Request<'a, S>, engine::State) { + fn state<'a, S: StrDrive>(&self, string: S) -> (Request<'a, S>, State) { self.state_range(string, 0..usize::MAX) } - fn state_range<'a, S: engine::StrDrive>( + fn state_range<'a, S: StrDrive>( &self, string: S, range: std::ops::Range, - ) -> (engine::Request<'a, S>, engine::State) { - let req = engine::Request::new(string, range.start, range.end, self.code, false); - let state = engine::State::default(); + ) -> (Request<'a, S>, State) { + let req = Request::new(string, range.start, range.end, self.code, false); + let state = State::default(); (req, state) } } diff --git a/src/engine.rs b/src/engine.rs index ca44f4994a..97489633d8 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -1,14 +1,14 @@ // good luck to those that follow; here be dragons -use super::constants::{SreAtCode, SreCatCode, SreInfo, SreOpcode}; -use super::MAXREPEAT; +use crate::string::{ + is_digit, is_linebreak, is_loc_word, is_space, is_uni_digit, is_uni_linebreak, is_uni_space, + is_uni_word, is_word, lower_ascii, lower_locate, lower_unicode, upper_locate, upper_unicode, +}; + +use super::{SreAtCode, SreCatCode, SreInfo, SreOpcode, StrDrive, StringCursor, MAXREPEAT}; use optional::Optioned; use std::convert::TryFrom; -const fn is_py_ascii_whitespace(b: u8) -> bool { - matches!(b, b'\t' | b'\n' | b'\x0C' | b'\r' | b' ' | b'\x0B') -} - #[derive(Debug, Clone, Copy)] pub struct Request<'a, S> { pub string: S, @@ -117,25 +117,29 @@ impl Marks { pub struct State { pub start: usize, pub marks: Marks, - pub string_position: usize, + pub cursor: StringCursor, repeat_stack: Vec, } impl State { - pub fn reset(&mut self, start: usize) { + pub fn reset(&mut self, req: &Request, start: usize) { self.marks.clear(); self.repeat_stack.clear(); self.start = start; - self.string_position = start; + if self.cursor.ptr.is_null() || self.cursor.position > self.start { + self.cursor = req.string.create_cursor(self.start); + } else if self.cursor.position < self.start { + let skip = self.start - self.cursor.position; + S::skip(&mut self.cursor, skip); + } } pub fn pymatch(&mut self, req: &Request) -> bool { self.start = req.start; - self.string_position = req.start; + self.cursor = req.string.create_cursor(self.start); let ctx = MatchContext { - string_position: req.start, - string_offset: req.string.offset(0, req.start), + cursor: self.cursor, code_position: 0, toplevel: true, jump: Jump::OpCode, @@ -147,7 +151,7 @@ impl State { pub fn search(&mut self, mut req: Request) -> bool { self.start = req.start; - self.string_position = req.start; + self.cursor = req.string.create_cursor(self.start); if req.start > req.end { return false; @@ -155,11 +159,8 @@ impl State { let mut end = req.end; - let mut start_offset = req.string.offset(0, req.start); - let mut ctx = MatchContext { - string_position: req.start, - string_offset: start_offset, + cursor: self.cursor, code_position: 0, toplevel: true, jump: Jump::OpCode, @@ -183,9 +184,9 @@ impl State { end -= min - 1; // adjust ctx position - if end < ctx.string_position { - ctx.string_position = end; - ctx.string_offset = req.string.offset(0, ctx.string_position); + if end < ctx.cursor.position { + let skip = end - self.cursor.position; + S::skip(&mut self.cursor, skip); } } @@ -214,7 +215,7 @@ impl State { || ctx.try_peek_code_as::(&req, 1).unwrap() == SreAtCode::BEGINNING_STRING) { - self.reset(req.end); + self.reset(&req, req.end); return false; } @@ -222,10 +223,8 @@ impl State { ctx.toplevel = false; while req.start < end { req.start += 1; - start_offset = req.string.offset(start_offset, 1); - self.reset(req.start); - ctx.string_position = req.start; - ctx.string_offset = start_offset; + self.reset(&req, req.start); + ctx.cursor = self.cursor; if _match(&req, self, ctx) { return true; @@ -248,13 +247,13 @@ impl<'a, S: StrDrive> Iterator for SearchIter<'a, S> { return None; } - self.state.reset(self.req.start); + self.state.reset(&self.req, self.req.start); if !self.state.search(self.req) { return None; } - self.req.must_advance = self.state.string_position == self.state.start; - self.req.start = self.state.string_position; + self.req.must_advance = self.state.cursor.position == self.state.start; + self.req.start = self.state.cursor.position; Some(()) } @@ -313,7 +312,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex state.marks.pop_discard(); break 'result false; } - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; let next_ctx = ctx.next_offset(branch_offset + 1, Jump::Branch2); ctx.count += next_length; break 'context next_ctx; @@ -333,7 +332,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex Jump::UntilBacktrace => { if !popped_result { state.repeat_stack[ctx.repeat_ctx_id].count -= 1; - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; } break 'result popped_result; } @@ -349,7 +348,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex state.marks.pop(); repeat_ctx.count -= 1; - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; /* cannot match more repeated items here. make sure the tail matches */ @@ -359,7 +358,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex } Jump::MaxUntil3 => { if !popped_result { - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; } break 'result popped_result; } @@ -369,20 +368,20 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex } ctx.repeat_ctx_id = ctx.count as usize; let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; state.marks.pop(); // match more until tail matches if repeat_ctx.count as usize >= repeat_ctx.max_count && repeat_ctx.max_count != MAXREPEAT - || state.string_position == repeat_ctx.last_position + || state.cursor.position == repeat_ctx.last_position { repeat_ctx.count -= 1; break 'result false; } /* zero-width match protection */ - repeat_ctx.last_position = state.string_position; + repeat_ctx.last_position = state.cursor.position; break 'context ctx .next_at(repeat_ctx.code_position + 4, Jump::UntilBacktrace); @@ -394,17 +393,17 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex // Special case: Tail starts with a literal. Skip positions where // the rest of the pattern cannot possibly match. let c = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 2); - while ctx.at_end(req) || ctx.peek_char(req) != c { + while ctx.at_end(req) || ctx.peek_char::() != c { if ctx.count <= min_count { state.marks.pop_discard(); break 'result false; } - ctx.back_skip_char(req, 1); + ctx.back_advance_char::(); ctx.count -= 1; } } - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; // General case: backtracking break 'context ctx.next_peek_from(1, req, Jump::RepeatOne2); } @@ -419,7 +418,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex break 'result false; } - ctx.back_skip_char(req, 1); + ctx.back_advance_char::(); ctx.count -= 1; state.marks.pop_keep(); @@ -429,7 +428,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex Jump::MinRepeatOne1 => { let max_count = ctx.peek_code(req, 3) as usize; if max_count == MAXREPEAT || ctx.count as usize <= max_count { - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; break 'context ctx.next_peek_from(1, req, Jump::MinRepeatOne2); } else { state.marks.pop_discard(); @@ -441,7 +440,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex break 'result true; } - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; let mut count_ctx = ctx; count_ctx.skip_code(4); @@ -450,7 +449,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex break 'result false; } - ctx.skip_char(req, 1); + ctx.advance_char::(); ctx.count += 1; state.marks.pop_keep(); ctx.jump = Jump::MinRepeatOne1; @@ -459,11 +458,10 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex Jump::AtomicGroup1 => { if popped_result { ctx.skip_code_from(req, 1); - ctx.string_position = state.string_position; - ctx.string_offset = req.string.offset(0, state.string_position); + ctx.cursor = state.cursor; // dispatch opcode } else { - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; break 'result false; } } @@ -473,7 +471,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex break 'context ctx.next_offset(4, Jump::PossessiveRepeat2); } // zero match protection - ctx.string_position = usize::MAX; + ctx.cursor.position = usize::MAX; ctx.jump = Jump::PossessiveRepeat3; continue 'context; } @@ -483,22 +481,20 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex ctx.jump = Jump::PossessiveRepeat1; continue 'context; } else { - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; break 'result false; } } Jump::PossessiveRepeat3 => { let max_count = ctx.peek_code(req, 3) as usize; if ((ctx.count as usize) < max_count || max_count == MAXREPEAT) - && ctx.string_position != state.string_position + && ctx.cursor.position != state.cursor.position { state.marks.push(); - ctx.string_position = state.string_position; - ctx.string_offset = req.string.offset(0, state.string_position); + ctx.cursor = state.cursor; break 'context ctx.next_offset(4, Jump::PossessiveRepeat4); } - ctx.string_position = state.string_position; - ctx.string_offset = req.string.offset(0, state.string_position); + ctx.cursor = state.cursor; ctx.skip_code_from(req, 1); ctx.skip_code(1); } @@ -510,7 +506,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex continue 'context; } state.marks.pop(); - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; ctx.skip_code_from(req, 1); ctx.skip_code(1); } @@ -520,21 +516,22 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex loop { macro_rules! general_op_literal { ($f:expr) => {{ - if ctx.at_end(req) || !$f(ctx.peek_code(req, 1), ctx.peek_char(req)) { + if ctx.at_end(req) || !$f(ctx.peek_code(req, 1), ctx.peek_char::()) { break 'result false; } ctx.skip_code(2); - ctx.skip_char(req, 1); + ctx.advance_char::(); }}; } macro_rules! general_op_in { ($f:expr) => {{ - if ctx.at_end(req) || !$f(&ctx.pattern(req)[2..], ctx.peek_char(req)) { + if ctx.at_end(req) || !$f(&ctx.pattern(req)[2..], ctx.peek_char::()) + { break 'result false; } ctx.skip_code_from(req, 1); - ctx.skip_char(req, 1); + ctx.advance_char::(); }}; } @@ -552,19 +549,18 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex }; let mut gctx = MatchContext { - string_position: group_start, - string_offset: req.string.offset(0, group_start), + cursor: req.string.create_cursor(group_start), ..ctx }; for _ in group_start..group_end { if ctx.at_end(req) - || $f(ctx.peek_char(req)) != $f(gctx.peek_char(req)) + || $f(ctx.peek_char::()) != $f(gctx.peek_char::()) { break 'result false; } - ctx.skip_char(req, 1); - gctx.skip_char(req, 1); + ctx.advance_char::(); + gctx.advance_char::(); } ctx.skip_code(2); @@ -581,7 +577,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex SreOpcode::FAILURE => break 'result false, SreOpcode::SUCCESS => { if ctx.can_success(req) { - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; break 'result true; } break 'result false; @@ -591,32 +587,32 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex break 'result false; } ctx.skip_code(1); - ctx.skip_char(req, 1); + ctx.advance_char::(); } SreOpcode::ANY_ALL => { if ctx.at_end(req) { break 'result false; } ctx.skip_code(1); - ctx.skip_char(req, 1); + ctx.advance_char::(); } /* */ SreOpcode::ASSERT => { let back = ctx.peek_code(req, 2) as usize; - if ctx.string_position < back { + if ctx.cursor.position < back { break 'result false; } let mut next_ctx = ctx.next_offset(3, Jump::Assert1); next_ctx.toplevel = false; - next_ctx.back_skip_char(req, back); - state.string_position = next_ctx.string_position; + next_ctx.back_skip_char::(back); + state.cursor = next_ctx.cursor; break 'context next_ctx; } /* */ SreOpcode::ASSERT_NOT => { let back = ctx.peek_code(req, 2) as usize; - if ctx.string_position < back { + if ctx.cursor.position < back { ctx.skip_code_from(req, 1); continue; } @@ -624,8 +620,8 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex let mut next_ctx = ctx.next_offset(3, Jump::AssertNot1); next_ctx.toplevel = false; - next_ctx.back_skip_char(req, back); - state.string_position = next_ctx.string_position; + next_ctx.back_skip_char::(back); + state.cursor = next_ctx.cursor; break 'context next_ctx; } SreOpcode::AT => { @@ -645,11 +641,11 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex } SreOpcode::CATEGORY => { let catcode = SreCatCode::try_from(ctx.peek_code(req, 1)).unwrap(); - if ctx.at_end(req) || !category(catcode, ctx.peek_char(req)) { + if ctx.at_end(req) || !category(catcode, ctx.peek_char::()) { break 'result false; } ctx.skip_code(2); - ctx.skip_char(req, 1); + ctx.advance_char::(); } SreOpcode::IN => general_op_in!(charset), SreOpcode::IN_IGNORE => { @@ -662,7 +658,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex SreOpcode::MARK => { state .marks - .set(ctx.peek_code(req, 1) as usize, ctx.string_position); + .set(ctx.peek_code(req, 1) as usize, ctx.cursor.position); ctx.skip_code(2); } SreOpcode::INFO | SreOpcode::JUMP => ctx.skip_code_from(req, 1), @@ -678,14 +674,14 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex }; state.repeat_stack.push(repeat_ctx); let repeat_ctx_id = state.repeat_stack.len() - 1; - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; let mut next_ctx = ctx.next_peek_from(1, req, Jump::Repeat1); next_ctx.repeat_ctx_id = repeat_ctx_id; break 'context next_ctx; } SreOpcode::MAX_UNTIL => { let repeat_ctx = &mut state.repeat_stack[ctx.repeat_ctx_id]; - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; repeat_ctx.count += 1; if (repeat_ctx.count as usize) < repeat_ctx.min_count { @@ -696,13 +692,13 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex if ((repeat_ctx.count as usize) < repeat_ctx.max_count || repeat_ctx.max_count == MAXREPEAT) - && state.string_position != repeat_ctx.last_position + && state.cursor.position != repeat_ctx.last_position { /* we may have enough matches, but if we can match another item, do so */ state.marks.push(); ctx.count = repeat_ctx.last_position as isize; - repeat_ctx.last_position = state.string_position; + repeat_ctx.last_position = state.cursor.position; break 'context ctx .next_at(repeat_ctx.code_position + 4, Jump::MaxUntil2); @@ -716,7 +712,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex } SreOpcode::MIN_UNTIL => { let repeat_ctx = state.repeat_stack.last_mut().unwrap(); - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; repeat_ctx.count += 1; if (repeat_ctx.count as usize) < repeat_ctx.min_count { @@ -740,12 +736,12 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex break 'result false; } - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; let mut next_ctx = ctx; next_ctx.skip_code(4); let count = _count(req, state, next_ctx, max_count); - ctx.skip_char(req, count); + ctx.skip_char::(count); if count < min_count { break 'result false; } @@ -753,7 +749,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(req) { // tail is empty. we're finished - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; break 'result true; } @@ -769,7 +765,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex break 'result false; } - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; ctx.count = if min_count == 0 { 0 } else { @@ -779,14 +775,14 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex if count < min_count { break 'result false; } - ctx.skip_char(req, count); + ctx.skip_char::(count); count as isize }; let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(req) { // tail is empty. we're finished - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; break 'result true; } @@ -830,13 +826,13 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex } /* pattern tail */ SreOpcode::ATOMIC_GROUP => { - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; break 'context ctx.next_offset(2, Jump::AtomicGroup1); } /* <1=min> <2=max> pattern tail */ SreOpcode::POSSESSIVE_REPEAT => { - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; ctx.count = 0; ctx.jump = Jump::PossessiveRepeat1; continue 'context; @@ -849,14 +845,14 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex if ctx.remaining_chars(req) < min_count { break 'result false; } - state.string_position = ctx.string_position; + state.cursor = ctx.cursor; let mut count_ctx = ctx; count_ctx.skip_code(4); let count = _count(req, state, count_ctx, max_count); if count < min_count { break 'result false; } - ctx.skip_char(req, count); + ctx.skip_char::(count); ctx.skip_code_from(req, 1); } SreOpcode::CHARSET @@ -907,16 +903,17 @@ fn search_info_literal( while !ctx.at_end(req) { // find the next matched literal - while ctx.peek_char(req) != c { - ctx.skip_char(req, 1); + while ctx.peek_char::() != c { + ctx.advance_char::(); if ctx.at_end(req) { return false; } } - req.start = ctx.string_position; - state.start = ctx.string_position; - state.string_position = ctx.string_position + skip; + req.start = ctx.cursor.position; + state.start = req.start; + state.cursor = ctx.cursor; + S::skip(&mut state.cursor, skip); // literal only if LITERAL { @@ -924,44 +921,46 @@ fn search_info_literal( } let mut next_ctx = ctx; - next_ctx.skip_char(req, skip); + next_ctx.skip_char::(skip); if _match(req, state, next_ctx) { return true; } - ctx.skip_char(req, 1); + ctx.advance_char::(); state.marks.clear(); } } else { while !ctx.at_end(req) { let c = prefix[0]; - while ctx.peek_char(req) != c { - ctx.skip_char(req, 1); + while ctx.peek_char::() != c { + ctx.advance_char::(); if ctx.at_end(req) { return false; } } - ctx.skip_char(req, 1); + ctx.advance_char::(); if ctx.at_end(req) { return false; } let mut i = 1; loop { - if ctx.peek_char(req) == prefix[i] { + if ctx.peek_char::() == prefix[i] { i += 1; if i != len { - ctx.skip_char(req, 1); + ctx.advance_char::(); if ctx.at_end(req) { return false; } continue; } - req.start = ctx.string_position - (len - 1); - state.start = req.start; - state.string_position = state.start + skip; + req.start = ctx.cursor.position - (len - 1); + state.reset(req, req.start); + S::skip(&mut state.cursor, skip); + // state.start = req.start; + // state.cursor = req.string.create_cursor(req.start + skip); // literal only if LITERAL { @@ -970,17 +969,16 @@ fn search_info_literal( let mut next_ctx = ctx; if skip != 0 { - next_ctx.skip_char(req, 1); + next_ctx.advance_char::(); } else { - next_ctx.string_position = state.string_position; - next_ctx.string_offset = req.string.offset(0, state.string_position); + next_ctx.cursor = state.cursor; } if _match(req, state, next_ctx) { return true; } - ctx.skip_char(req, 1); + ctx.advance_char::(); if ctx.at_end(req) { return false; } @@ -1009,22 +1007,22 @@ fn search_info_charset( req.must_advance = false; loop { - while !ctx.at_end(req) && !charset(set, ctx.peek_char(req)) { - ctx.skip_char(req, 1); + while !ctx.at_end(req) && !charset(set, ctx.peek_char::()) { + ctx.advance_char::(); } if ctx.at_end(req) { return false; } - req.start = ctx.string_position; - state.start = ctx.string_position; - state.string_position = ctx.string_position; + req.start = ctx.cursor.position; + state.start = ctx.cursor.position; + state.cursor = ctx.cursor; if _match(req, state, ctx) { return true; } - ctx.skip_char(req, 1); + ctx.advance_char::(); state.marks.clear(); } } @@ -1039,85 +1037,9 @@ struct RepeatContext { prev_id: usize, } -pub trait StrDrive: Copy { - fn offset(&self, offset: usize, skip: usize) -> usize; - fn count(&self) -> usize; - fn peek(&self, offset: usize) -> u32; - fn back_peek(&self, offset: usize) -> u32; - fn back_offset(&self, offset: usize, skip: usize) -> usize; -} - -impl StrDrive for &str { - fn offset(&self, offset: usize, skip: usize) -> usize { - self.get(offset..) - .and_then(|s| s.char_indices().nth(skip).map(|x| x.0 + offset)) - .unwrap_or(self.len()) - } - - fn count(&self) -> usize { - self.chars().count() - } - - fn peek(&self, offset: usize) -> u32 { - unsafe { self.get_unchecked(offset..) } - .chars() - .next() - .unwrap() as u32 - } - - fn back_peek(&self, offset: usize) -> u32 { - let bytes = self.as_bytes(); - let back_offset = utf8_back_peek_offset(bytes, offset); - match offset - back_offset { - 1 => u32::from_be_bytes([0, 0, 0, bytes[offset - 1]]), - 2 => u32::from_be_bytes([0, 0, bytes[offset - 2], bytes[offset - 1]]), - 3 => u32::from_be_bytes([0, bytes[offset - 3], bytes[offset - 2], bytes[offset - 1]]), - 4 => u32::from_be_bytes([ - bytes[offset - 4], - bytes[offset - 3], - bytes[offset - 2], - bytes[offset - 1], - ]), - _ => unreachable!(), - } - } - - fn back_offset(&self, offset: usize, skip: usize) -> usize { - let bytes = self.as_bytes(); - let mut back_offset = offset; - for _ in 0..skip { - back_offset = utf8_back_peek_offset(bytes, back_offset); - } - back_offset - } -} - -impl<'a> StrDrive for &'a [u8] { - fn offset(&self, offset: usize, skip: usize) -> usize { - offset + skip - } - - fn count(&self) -> usize { - self.len() - } - - fn peek(&self, offset: usize) -> u32 { - self[offset] as u32 - } - - fn back_peek(&self, offset: usize) -> u32 { - self[offset - 1] as u32 - } - - fn back_offset(&self, offset: usize, skip: usize) -> usize { - offset - skip - } -} - #[derive(Clone, Copy)] struct MatchContext { - string_position: usize, - string_offset: usize, + cursor: StringCursor, code_position: usize, toplevel: bool, jump: Jump, @@ -1135,25 +1057,31 @@ impl MatchContext { } fn remaining_chars(&self, req: &Request) -> usize { - req.end - self.string_position + req.end - self.cursor.position + } + + fn peek_char(&self) -> u32 { + S::peek(&self.cursor) + } + + fn skip_char(&mut self, skip: usize) { + S::skip(&mut self.cursor, skip); } - fn peek_char(&self, req: &Request) -> u32 { - req.string.peek(self.string_offset) + fn advance_char(&mut self) -> u32 { + S::advance(&mut self.cursor) } - fn skip_char(&mut self, req: &Request, skip: usize) { - self.string_position += skip; - self.string_offset = req.string.offset(self.string_offset, skip); + fn back_peek_char(&self) -> u32 { + S::back_peek(&self.cursor) } - fn back_peek_char(&self, req: &Request) -> u32 { - req.string.back_peek(self.string_offset) + fn back_skip_char(&mut self, skip: usize) { + S::back_skip(&mut self.cursor, skip); } - fn back_skip_char(&mut self, req: &Request, skip: usize) { - self.string_position -= skip; - self.string_offset = req.string.back_offset(self.string_offset, skip); + fn back_advance_char(&mut self) -> u32 { + S::back_advance(&mut self.cursor) } fn peek_code(&self, req: &Request, peek: usize) -> u32 { @@ -1177,15 +1105,15 @@ impl MatchContext { fn at_beginning(&self) -> bool { // self.ctx().string_position == self.state().start - self.string_position == 0 + self.cursor.position == 0 } fn at_end(&self, req: &Request) -> bool { - self.string_position == req.end + self.cursor.position == req.end } fn at_linebreak(&self, req: &Request) -> bool { - !self.at_end(req) && is_linebreak(self.peek_char(req)) + !self.at_end(req) && is_linebreak(self.peek_char::()) } fn at_boundary bool>( @@ -1196,8 +1124,8 @@ impl MatchContext { if self.at_beginning() && self.at_end(req) { return false; } - let that = !self.at_beginning() && word_checker(self.back_peek_char(req)); - let this = !self.at_end(req) && word_checker(self.peek_char(req)); + let that = !self.at_beginning() && word_checker(self.back_peek_char::()); + let this = !self.at_end(req) && word_checker(self.peek_char::()); this != that } @@ -1209,8 +1137,8 @@ impl MatchContext { if self.at_beginning() && self.at_end(req) { return false; } - let that = !self.at_beginning() && word_checker(self.back_peek_char(req)); - let this = !self.at_end(req) && word_checker(self.peek_char(req)); + let that = !self.at_beginning() && word_checker(self.back_peek_char::()); + let this = !self.at_end(req) && word_checker(self.peek_char::()); this == that } @@ -1221,7 +1149,7 @@ impl MatchContext { if req.match_all && !self.at_end(req) { return false; } - if req.must_advance && self.string_position == req.start { + if req.must_advance && self.cursor.position == req.start { return false; } true @@ -1252,7 +1180,7 @@ impl MatchContext { fn at(req: &Request, ctx: &MatchContext, atcode: SreAtCode) -> bool { match atcode { SreAtCode::BEGINNING | SreAtCode::BEGINNING_STRING => ctx.at_beginning(), - SreAtCode::BEGINNING_LINE => ctx.at_beginning() || is_linebreak(ctx.back_peek_char(req)), + SreAtCode::BEGINNING_LINE => ctx.at_beginning() || is_linebreak(ctx.back_peek_char::()), SreAtCode::BOUNDARY => ctx.at_boundary(req, is_word), SreAtCode::NON_BOUNDARY => ctx.at_non_boundary(req, is_word), SreAtCode::END => { @@ -1403,21 +1331,22 @@ fn _count( max_count: usize, ) -> usize { let max_count = std::cmp::min(max_count, ctx.remaining_chars(req)); - let end = ctx.string_position + max_count; + let end = ctx.cursor.position + max_count; let opcode = SreOpcode::try_from(ctx.peek_code(req, 0)).unwrap(); match opcode { SreOpcode::ANY => { - while ctx.string_position < end && !ctx.at_linebreak(req) { - ctx.skip_char(req, 1); + while ctx.cursor.position < end && !ctx.at_linebreak(req) { + ctx.advance_char::(); } } SreOpcode::ANY_ALL => { - ctx.skip_char(req, max_count); + ctx.skip_char::(max_count); } SreOpcode::IN => { - while ctx.string_position < end && charset(&ctx.pattern(req)[2..], ctx.peek_char(req)) { - ctx.skip_char(req, 1); + while ctx.cursor.position < end && charset(&ctx.pattern(req)[2..], ctx.peek_char::()) + { + ctx.advance_char::(); } } SreOpcode::LITERAL => { @@ -1457,14 +1386,14 @@ fn _count( ..*state }; - while ctx.string_position < end && _match(req, &mut sub_state, ctx) { - ctx.skip_char(req, 1); + while ctx.cursor.position < end && _match(req, &mut sub_state, ctx) { + ctx.advance_char::(); } } } // TODO: return offset - ctx.string_position - state.string_position + ctx.cursor.position - state.cursor.position } fn general_count_literal bool>( @@ -1474,145 +1403,7 @@ fn general_count_literal bool>( mut f: F, ) { let ch = ctx.peek_code(req, 1); - while ctx.string_position < end && f(ch, ctx.peek_char(req)) { - ctx.skip_char(req, 1); - } -} - -fn is_word(ch: u32) -> bool { - ch == '_' as u32 - || u8::try_from(ch) - .map(|x| x.is_ascii_alphanumeric()) - .unwrap_or(false) -} -fn is_space(ch: u32) -> bool { - u8::try_from(ch) - .map(is_py_ascii_whitespace) - .unwrap_or(false) -} -fn is_digit(ch: u32) -> bool { - u8::try_from(ch) - .map(|x| x.is_ascii_digit()) - .unwrap_or(false) -} -fn is_loc_alnum(ch: u32) -> bool { - // FIXME: Ignore the locales - u8::try_from(ch) - .map(|x| x.is_ascii_alphanumeric()) - .unwrap_or(false) -} -fn is_loc_word(ch: u32) -> bool { - ch == '_' as u32 || is_loc_alnum(ch) -} -fn is_linebreak(ch: u32) -> bool { - ch == '\n' as u32 -} -pub fn lower_ascii(ch: u32) -> u32 { - u8::try_from(ch) - .map(|x| x.to_ascii_lowercase() as u32) - .unwrap_or(ch) -} -fn lower_locate(ch: u32) -> u32 { - // FIXME: Ignore the locales - lower_ascii(ch) -} -fn upper_locate(ch: u32) -> u32 { - // FIXME: Ignore the locales - u8::try_from(ch) - .map(|x| x.to_ascii_uppercase() as u32) - .unwrap_or(ch) -} -fn is_uni_digit(ch: u32) -> bool { - // TODO: check with cpython - char::try_from(ch) - .map(|x| x.is_ascii_digit()) - .unwrap_or(false) -} -fn is_uni_space(ch: u32) -> bool { - // TODO: check with cpython - is_space(ch) - || matches!( - ch, - 0x0009 - | 0x000A - | 0x000B - | 0x000C - | 0x000D - | 0x001C - | 0x001D - | 0x001E - | 0x001F - | 0x0020 - | 0x0085 - | 0x00A0 - | 0x1680 - | 0x2000 - | 0x2001 - | 0x2002 - | 0x2003 - | 0x2004 - | 0x2005 - | 0x2006 - | 0x2007 - | 0x2008 - | 0x2009 - | 0x200A - | 0x2028 - | 0x2029 - | 0x202F - | 0x205F - | 0x3000 - ) -} -fn is_uni_linebreak(ch: u32) -> bool { - matches!( - ch, - 0x000A | 0x000B | 0x000C | 0x000D | 0x001C | 0x001D | 0x001E | 0x0085 | 0x2028 | 0x2029 - ) -} -fn is_uni_alnum(ch: u32) -> bool { - // TODO: check with cpython - char::try_from(ch) - .map(|x| x.is_alphanumeric()) - .unwrap_or(false) -} -fn is_uni_word(ch: u32) -> bool { - ch == '_' as u32 || is_uni_alnum(ch) -} -pub fn lower_unicode(ch: u32) -> u32 { - // TODO: check with cpython - char::try_from(ch) - .map(|x| x.to_lowercase().next().unwrap() as u32) - .unwrap_or(ch) -} -pub fn upper_unicode(ch: u32) -> u32 { - // TODO: check with cpython - char::try_from(ch) - .map(|x| x.to_uppercase().next().unwrap() as u32) - .unwrap_or(ch) -} - -fn is_utf8_first_byte(b: u8) -> bool { - // In UTF-8, there are three kinds of byte... - // 0xxxxxxx : ASCII - // 10xxxxxx : 2nd, 3rd or 4th byte of code - // 11xxxxxx : 1st byte of multibyte code - (b & 0b10000000 == 0) || (b & 0b11000000 == 0b11000000) -} - -fn utf8_back_peek_offset(bytes: &[u8], offset: usize) -> usize { - let mut offset = offset - 1; - if !is_utf8_first_byte(bytes[offset]) { - offset -= 1; - if !is_utf8_first_byte(bytes[offset]) { - offset -= 1; - if !is_utf8_first_byte(bytes[offset]) { - offset -= 1; - if !is_utf8_first_byte(bytes[offset]) { - panic!("not utf-8 code point"); - } - } - } + while ctx.cursor.position < end && f(ch, ctx.peek_char::()) { + ctx.advance_char::(); } - offset } diff --git a/src/lib.rs b/src/lib.rs index c23e807501..fd9f367dc6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,10 @@ pub mod constants; pub mod engine; +pub mod string; + +pub use constants::{SreAtCode, SreCatCode, SreFlag, SreInfo, SreOpcode, SRE_MAGIC}; +pub use engine::{Request, SearchIter, State}; +pub use string::{StrDrive, StringCursor}; pub const CODESIZE: usize = 4; diff --git a/src/string.rs b/src/string.rs new file mode 100644 index 0000000000..464901af83 --- /dev/null +++ b/src/string.rs @@ -0,0 +1,381 @@ +#[derive(Debug, Clone, Copy)] +pub struct StringCursor { + pub(crate) ptr: *const u8, + pub position: usize, +} + +impl Default for StringCursor { + fn default() -> Self { + Self { + ptr: std::ptr::null(), + position: 0, + } + } +} + +pub trait StrDrive: Copy { + fn count(&self) -> usize; + fn create_cursor(&self, n: usize) -> StringCursor; + fn advance(cursor: &mut StringCursor) -> u32; + fn peek(cursor: &StringCursor) -> u32; + fn skip(cursor: &mut StringCursor, n: usize); + fn back_advance(cursor: &mut StringCursor) -> u32; + fn back_peek(cursor: &StringCursor) -> u32; + fn back_skip(cursor: &mut StringCursor, n: usize); +} + +impl<'a> StrDrive for &'a [u8] { + #[inline] + fn count(&self) -> usize { + self.len() + } + + #[inline] + fn create_cursor(&self, n: usize) -> StringCursor { + StringCursor { + ptr: self[n..].as_ptr(), + position: n, + } + } + + #[inline] + fn advance(cursor: &mut StringCursor) -> u32 { + cursor.position += 1; + unsafe { cursor.ptr = cursor.ptr.add(1) }; + unsafe { *cursor.ptr as u32 } + } + + #[inline] + fn peek(cursor: &StringCursor) -> u32 { + unsafe { *cursor.ptr as u32 } + } + + #[inline] + fn skip(cursor: &mut StringCursor, n: usize) { + cursor.position += n; + unsafe { cursor.ptr = cursor.ptr.add(n) }; + } + + #[inline] + fn back_advance(cursor: &mut StringCursor) -> u32 { + cursor.position -= 1; + unsafe { cursor.ptr = cursor.ptr.sub(1) }; + unsafe { *cursor.ptr as u32 } + } + + #[inline] + fn back_peek(cursor: &StringCursor) -> u32 { + unsafe { *cursor.ptr.offset(-1) as u32 } + } + + #[inline] + fn back_skip(cursor: &mut StringCursor, n: usize) { + cursor.position -= n; + unsafe { cursor.ptr = cursor.ptr.sub(n) }; + } +} + +impl StrDrive for &str { + #[inline] + fn count(&self) -> usize { + self.chars().count() + } + + #[inline] + fn create_cursor(&self, n: usize) -> StringCursor { + let mut ptr = self.as_ptr(); + for _ in 0..n { + unsafe { next_code_point(&mut ptr) }; + } + StringCursor { ptr, position: n } + } + + #[inline] + fn advance(cursor: &mut StringCursor) -> u32 { + cursor.position += 1; + unsafe { next_code_point(&mut cursor.ptr) } + } + + #[inline] + fn peek(cursor: &StringCursor) -> u32 { + let mut ptr = cursor.ptr; + unsafe { next_code_point(&mut ptr) } + } + + #[inline] + fn skip(cursor: &mut StringCursor, n: usize) { + cursor.position += n; + for _ in 0..n { + unsafe { next_code_point(&mut cursor.ptr) }; + } + } + + #[inline] + fn back_advance(cursor: &mut StringCursor) -> u32 { + cursor.position -= 1; + unsafe { next_code_point_reverse(&mut cursor.ptr) } + } + + #[inline] + fn back_peek(cursor: &StringCursor) -> u32 { + let mut ptr = cursor.ptr; + unsafe { next_code_point_reverse(&mut ptr) } + } + + #[inline] + fn back_skip(cursor: &mut StringCursor, n: usize) { + cursor.position -= n; + for _ in 0..n { + unsafe { next_code_point_reverse(&mut cursor.ptr) }; + } + } +} + +/// Reads the next code point out of a byte iterator (assuming a +/// UTF-8-like encoding). +/// +/// # Safety +/// +/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string +#[inline] +unsafe fn next_code_point(ptr: &mut *const u8) -> u32 { + // Decode UTF-8 + let x = **ptr; + *ptr = ptr.offset(1); + + if x < 128 { + return x as u32; + } + + // Multibyte case follows + // Decode from a byte combination out of: [[[x y] z] w] + // NOTE: Performance is sensitive to the exact formulation here + let init = utf8_first_byte(x, 2); + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let y = **ptr; + *ptr = ptr.offset(1); + let mut ch = utf8_acc_cont_byte(init, y); + if x >= 0xE0 { + // [[x y z] w] case + // 5th bit in 0xE0 .. 0xEF is always clear, so `init` is still valid + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let z = **ptr; + *ptr = ptr.offset(1); + let y_z = utf8_acc_cont_byte((y & CONT_MASK) as u32, z); + ch = init << 12 | y_z; + if x >= 0xF0 { + // [x y z w] case + // use only the lower 3 bits of `init` + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let w = **ptr; + *ptr = ptr.offset(1); + ch = (init & 7) << 18 | utf8_acc_cont_byte(y_z, w); + } + } + + ch +} + +/// Reads the last code point out of a byte iterator (assuming a +/// UTF-8-like encoding). +/// +/// # Safety +/// +/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string +#[inline] +unsafe fn next_code_point_reverse(ptr: &mut *const u8) -> u32 { + // Decode UTF-8 + *ptr = ptr.offset(-1); + let w = match **ptr { + next_byte if next_byte < 128 => return next_byte as u32, + back_byte => back_byte, + }; + + // Multibyte case follows + // Decode from a byte combination out of: [x [y [z w]]] + let mut ch; + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + *ptr = ptr.offset(-1); + let z = **ptr; + ch = utf8_first_byte(z, 2); + if utf8_is_cont_byte(z) { + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + *ptr = ptr.offset(-1); + let y = **ptr; + ch = utf8_first_byte(y, 3); + if utf8_is_cont_byte(y) { + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + *ptr = ptr.offset(-1); + let x = **ptr; + ch = utf8_first_byte(x, 4); + ch = utf8_acc_cont_byte(ch, y); + } + ch = utf8_acc_cont_byte(ch, z); + } + ch = utf8_acc_cont_byte(ch, w); + + ch +} + +/// Returns the initial codepoint accumulator for the first byte. +/// The first byte is special, only want bottom 5 bits for width 2, 4 bits +/// for width 3, and 3 bits for width 4. +#[inline] +const fn utf8_first_byte(byte: u8, width: u32) -> u32 { + (byte & (0x7F >> width)) as u32 +} + +/// Returns the value of `ch` updated with continuation byte `byte`. +#[inline] +const fn utf8_acc_cont_byte(ch: u32, byte: u8) -> u32 { + (ch << 6) | (byte & CONT_MASK) as u32 +} + +/// Checks whether the byte is a UTF-8 continuation byte (i.e., starts with the +/// bits `10`). +#[inline] +const fn utf8_is_cont_byte(byte: u8) -> bool { + (byte as i8) < -64 +} + +/// Mask of the value bits of a continuation byte. +const CONT_MASK: u8 = 0b0011_1111; + +const fn is_py_ascii_whitespace(b: u8) -> bool { + matches!(b, b'\t' | b'\n' | b'\x0C' | b'\r' | b' ' | b'\x0B') +} + +#[inline] +pub(crate) fn is_word(ch: u32) -> bool { + ch == '_' as u32 + || u8::try_from(ch) + .map(|x| x.is_ascii_alphanumeric()) + .unwrap_or(false) +} +#[inline] +pub(crate) fn is_space(ch: u32) -> bool { + u8::try_from(ch) + .map(is_py_ascii_whitespace) + .unwrap_or(false) +} +#[inline] +pub(crate) fn is_digit(ch: u32) -> bool { + u8::try_from(ch) + .map(|x| x.is_ascii_digit()) + .unwrap_or(false) +} +#[inline] +pub(crate) fn is_loc_alnum(ch: u32) -> bool { + // FIXME: Ignore the locales + u8::try_from(ch) + .map(|x| x.is_ascii_alphanumeric()) + .unwrap_or(false) +} +#[inline] +pub(crate) fn is_loc_word(ch: u32) -> bool { + ch == '_' as u32 || is_loc_alnum(ch) +} +#[inline] +pub(crate) fn is_linebreak(ch: u32) -> bool { + ch == '\n' as u32 +} +#[inline] +pub fn lower_ascii(ch: u32) -> u32 { + u8::try_from(ch) + .map(|x| x.to_ascii_lowercase() as u32) + .unwrap_or(ch) +} +#[inline] +pub(crate) fn lower_locate(ch: u32) -> u32 { + // FIXME: Ignore the locales + lower_ascii(ch) +} +#[inline] +pub(crate) fn upper_locate(ch: u32) -> u32 { + // FIXME: Ignore the locales + u8::try_from(ch) + .map(|x| x.to_ascii_uppercase() as u32) + .unwrap_or(ch) +} +#[inline] +pub(crate) fn is_uni_digit(ch: u32) -> bool { + // TODO: check with cpython + char::try_from(ch) + .map(|x| x.is_ascii_digit()) + .unwrap_or(false) +} +#[inline] +pub(crate) fn is_uni_space(ch: u32) -> bool { + // TODO: check with cpython + is_space(ch) + || matches!( + ch, + 0x0009 + | 0x000A + | 0x000B + | 0x000C + | 0x000D + | 0x001C + | 0x001D + | 0x001E + | 0x001F + | 0x0020 + | 0x0085 + | 0x00A0 + | 0x1680 + | 0x2000 + | 0x2001 + | 0x2002 + | 0x2003 + | 0x2004 + | 0x2005 + | 0x2006 + | 0x2007 + | 0x2008 + | 0x2009 + | 0x200A + | 0x2028 + | 0x2029 + | 0x202F + | 0x205F + | 0x3000 + ) +} +#[inline] +pub(crate) fn is_uni_linebreak(ch: u32) -> bool { + matches!( + ch, + 0x000A | 0x000B | 0x000C | 0x000D | 0x001C | 0x001D | 0x001E | 0x0085 | 0x2028 | 0x2029 + ) +} +#[inline] +pub(crate) fn is_uni_alnum(ch: u32) -> bool { + // TODO: check with cpython + char::try_from(ch) + .map(|x| x.is_alphanumeric()) + .unwrap_or(false) +} +#[inline] +pub(crate) fn is_uni_word(ch: u32) -> bool { + ch == '_' as u32 || is_uni_alnum(ch) +} +#[inline] +pub fn lower_unicode(ch: u32) -> u32 { + // TODO: check with cpython + char::try_from(ch) + .map(|x| x.to_lowercase().next().unwrap() as u32) + .unwrap_or(ch) +} +#[inline] +pub fn upper_unicode(ch: u32) -> u32 { + // TODO: check with cpython + char::try_from(ch) + .map(|x| x.to_uppercase().next().unwrap() as u32) + .unwrap_or(ch) +} diff --git a/tests/tests.rs b/tests/tests.rs index efeb2d2838..f589c62e6e 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,16 +1,13 @@ -use sre_engine::engine; +use sre_engine::{Request, State, StrDrive}; struct Pattern { code: &'static [u32], } impl Pattern { - fn state<'a, S: engine::StrDrive>( - &self, - string: S, - ) -> (engine::Request<'a, S>, engine::State) { - let req = engine::Request::new(string, 0, usize::MAX, self.code, false); - let state = engine::State::default(); + fn state<'a, S: StrDrive>(&self, string: S) -> (Request<'a, S>, State) { + let req = Request::new(string, 0, usize::MAX, self.code, false); + let state = State::default(); (req, state) } } @@ -54,7 +51,7 @@ fn test_zerowidth() { let (mut req, mut state) = p.state("a:"); req.must_advance = true; assert!(state.search(req)); - assert_eq!(state.string_position, 1); + assert_eq!(state.cursor.position, 1); } #[test] @@ -66,7 +63,10 @@ fn test_repeat_context_panic() { // END GENERATED let (req, mut state) = p.state("axxzaz"); assert!(state.pymatch(&req)); - assert_eq!(*state.marks.raw(), vec![Optioned::some(1), Optioned::some(3)]); + assert_eq!( + *state.marks.raw(), + vec![Optioned::some(1), Optioned::some(3)] + ); } #[test] @@ -77,7 +77,7 @@ fn test_double_max_until() { // END GENERATED let (req, mut state) = p.state("1111"); assert!(state.pymatch(&req)); - assert_eq!(state.string_position, 4); + assert_eq!(state.cursor.position, 4); } #[test] @@ -89,7 +89,7 @@ fn test_info_single() { let (req, mut state) = p.state("baaaa"); assert!(state.search(req)); assert_eq!(state.start, 1); - assert_eq!(state.string_position, 5); + assert_eq!(state.cursor.position, 5); } #[test] @@ -161,7 +161,7 @@ fn test_bug_20998() { let (mut req, mut state) = p.state("ABC"); req.match_all = true; assert!(state.pymatch(&req)); - assert_eq!(state.string_position, 3); + assert_eq!(state.cursor.position, 3); } #[test] @@ -172,5 +172,10 @@ fn test_bigcharset() { // END GENERATED let (req, mut state) = p.state("x "); assert!(state.pymatch(&req)); - assert_eq!(state.string_position, 1); + assert_eq!(state.cursor.position, 1); +} + +#[test] +fn test_search_nonascii() { + // pattern p = re.compile('\xe0+') } From 10e51ba68909e9f09860c4a5d727c00a74cb0d7c Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 14 Jan 2024 10:03:01 +0200 Subject: [PATCH 242/893] improve: use adjust_cursor reduce double calc --- src/engine.rs | 17 +++++++---------- src/string.rs | 25 +++++++++++++++++++++---- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index 97489633d8..a854f8d898 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -7,7 +7,7 @@ use crate::string::{ use super::{SreAtCode, SreCatCode, SreInfo, SreOpcode, StrDrive, StringCursor, MAXREPEAT}; use optional::Optioned; -use std::convert::TryFrom; +use std::{convert::TryFrom, ptr::null}; #[derive(Debug, Clone, Copy)] pub struct Request<'a, S> { @@ -126,17 +126,12 @@ impl State { self.marks.clear(); self.repeat_stack.clear(); self.start = start; - if self.cursor.ptr.is_null() || self.cursor.position > self.start { - self.cursor = req.string.create_cursor(self.start); - } else if self.cursor.position < self.start { - let skip = self.start - self.cursor.position; - S::skip(&mut self.cursor, skip); - } + req.string.adjust_cursor(&mut self.cursor, start); } pub fn pymatch(&mut self, req: &Request) -> bool { self.start = req.start; - self.cursor = req.string.create_cursor(self.start); + req.string.adjust_cursor(&mut self.cursor, self.start); let ctx = MatchContext { cursor: self.cursor, @@ -151,7 +146,7 @@ impl State { pub fn search(&mut self, mut req: Request) -> bool { self.start = req.start; - self.cursor = req.string.create_cursor(self.start); + req.string.adjust_cursor(&mut self.cursor, self.start); if req.start > req.end { return false; @@ -215,7 +210,9 @@ impl State { || ctx.try_peek_code_as::(&req, 1).unwrap() == SreAtCode::BEGINNING_STRING) { - self.reset(&req, req.end); + self.cursor.position = req.end; + self.cursor.ptr = null(); + // self.reset(&req, req.end); return false; } diff --git a/src/string.rs b/src/string.rs index 464901af83..1340c37423 100644 --- a/src/string.rs +++ b/src/string.rs @@ -16,6 +16,7 @@ impl Default for StringCursor { pub trait StrDrive: Copy { fn count(&self) -> usize; fn create_cursor(&self, n: usize) -> StringCursor; + fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize); fn advance(cursor: &mut StringCursor) -> u32; fn peek(cursor: &StringCursor) -> u32; fn skip(cursor: &mut StringCursor, n: usize); @@ -38,6 +39,12 @@ impl<'a> StrDrive for &'a [u8] { } } + #[inline] + fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize) { + cursor.position = n; + cursor.ptr = self[n..].as_ptr(); + } + #[inline] fn advance(cursor: &mut StringCursor) -> u32 { cursor.position += 1; @@ -83,11 +90,21 @@ impl StrDrive for &str { #[inline] fn create_cursor(&self, n: usize) -> StringCursor { - let mut ptr = self.as_ptr(); - for _ in 0..n { - unsafe { next_code_point(&mut ptr) }; + let mut cursor = StringCursor { + ptr: self.as_ptr(), + position: 0, + }; + Self::skip(&mut cursor, n); + cursor + } + + #[inline] + fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize) { + if cursor.ptr.is_null() || cursor.position > n { + *cursor = Self::create_cursor(&self, n); + } else if cursor.position < n { + Self::skip(cursor, n - cursor.position); } - StringCursor { ptr, position: n } } #[inline] From 21fc2059b70ebd5bf4a7c524c40e7d4347e065dc Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sun, 14 Jan 2024 16:02:05 +0200 Subject: [PATCH 243/893] improve: fix double count on _count --- src/engine.rs | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/engine.rs b/src/engine.rs index a854f8d898..34f00234e5 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -441,7 +441,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex let mut count_ctx = ctx; count_ctx.skip_code(4); - if _count(req, state, count_ctx, 1) == 0 { + if _count(req, state, &mut count_ctx, 1) == 0 { state.marks.pop_discard(); break 'result false; } @@ -735,13 +735,13 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex state.cursor = ctx.cursor; - let mut next_ctx = ctx; - next_ctx.skip_code(4); - let count = _count(req, state, next_ctx, max_count); - ctx.skip_char::(count); + let mut count_ctx = ctx; + count_ctx.skip_code(4); + let count = _count(req, state, &mut count_ctx, max_count); if count < min_count { break 'result false; } + ctx.cursor = count_ctx.cursor; let next_code = ctx.peek_code(req, ctx.peek_code(req, 1) as usize + 1); if next_code == SreOpcode::SUCCESS as u32 && ctx.can_success(req) { @@ -768,11 +768,11 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex } else { let mut count_ctx = ctx; count_ctx.skip_code(4); - let count = _count(req, state, count_ctx, min_count); + let count = _count(req, state, &mut count_ctx, min_count); if count < min_count { break 'result false; } - ctx.skip_char::(count); + ctx.cursor = count_ctx.cursor; count as isize }; @@ -845,11 +845,11 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex state.cursor = ctx.cursor; let mut count_ctx = ctx; count_ctx.skip_code(4); - let count = _count(req, state, count_ctx, max_count); + let count = _count(req, state, &mut count_ctx, max_count); if count < min_count { break 'result false; } - ctx.skip_char::(count); + ctx.cursor = count_ctx.cursor; ctx.skip_code_from(req, 1); } SreOpcode::CHARSET @@ -1324,7 +1324,7 @@ fn charset(set: &[u32], ch: u32) -> bool { fn _count( req: &Request, state: &mut State, - mut ctx: MatchContext, + ctx: &mut MatchContext, max_count: usize, ) -> usize { let max_count = std::cmp::min(max_count, ctx.remaining_chars(req)); @@ -1347,28 +1347,28 @@ fn _count( } } SreOpcode::LITERAL => { - general_count_literal(req, &mut ctx, end, |code, c| code == c); + general_count_literal(req, ctx, end, |code, c| code == c); } SreOpcode::NOT_LITERAL => { - general_count_literal(req, &mut ctx, end, |code, c| code != c); + general_count_literal(req, ctx, end, |code, c| code != c); } SreOpcode::LITERAL_IGNORE => { - general_count_literal(req, &mut ctx, end, |code, c| code == lower_ascii(c)); + general_count_literal(req, ctx, end, |code, c| code == lower_ascii(c)); } SreOpcode::NOT_LITERAL_IGNORE => { - general_count_literal(req, &mut ctx, end, |code, c| code != lower_ascii(c)); + general_count_literal(req, ctx, end, |code, c| code != lower_ascii(c)); } SreOpcode::LITERAL_LOC_IGNORE => { - general_count_literal(req, &mut ctx, end, char_loc_ignore); + general_count_literal(req, ctx, end, char_loc_ignore); } SreOpcode::NOT_LITERAL_LOC_IGNORE => { - general_count_literal(req, &mut ctx, end, |code, c| !char_loc_ignore(code, c)); + general_count_literal(req, ctx, end, |code, c| !char_loc_ignore(code, c)); } SreOpcode::LITERAL_UNI_IGNORE => { - general_count_literal(req, &mut ctx, end, |code, c| code == lower_unicode(c)); + general_count_literal(req, ctx, end, |code, c| code == lower_unicode(c)); } SreOpcode::NOT_LITERAL_UNI_IGNORE => { - general_count_literal(req, &mut ctx, end, |code, c| code != lower_unicode(c)); + general_count_literal(req, ctx, end, |code, c| code != lower_unicode(c)); } _ => { /* General case */ @@ -1383,7 +1383,7 @@ fn _count( ..*state }; - while ctx.cursor.position < end && _match(req, &mut sub_state, ctx) { + while ctx.cursor.position < end && _match(req, &mut sub_state, *ctx) { ctx.advance_char::(); } } From 6917b4c2ca32058af641deee269a0df05ad2667f Mon Sep 17 00:00:00 2001 From: deantvv Date: Tue, 23 Jan 2024 19:09:22 +0800 Subject: [PATCH 244/893] os: ns_to_sec rounding (#5150) In cpython, they use `_PyTime_ROUND_FLOOR` to read time. But in RustPython, we use `[Duration::from_secs_f64](https://doc.rust-lang.org/std/time/struct.Duration.html#method.try_from_secs_f64)` to read time. Therefore, RustPython isn't affected by the rounding issue in the way that cpython does. We can safely ignore the `0.5*1e-9` bit in `ns_to_sec` function. --- Lib/test/test_os.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index a7f8cfe900..097124b7b5 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -815,10 +815,10 @@ def ns_to_sec(ns): # Convert a number of nanosecond (int) to a number of seconds (float). # Round towards infinity by adding 0.5 nanosecond to avoid rounding # issue, os.utime() rounds towards minus infinity. - return (ns * 1e-9) + 0.5e-9 + # XXX: RUSTCPYTHON os.utime() use `[Duration::from_secs_f64](https://doc.rust-lang.org/std/time/struct.Duration.html#method.try_from_secs_f64)` + # return (ns * 1e-9) + 0.5e-9 + return (ns * 1e-9) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_utime_by_indexed(self): # pass times as floating point seconds as the second indexed parameter def set_time(filename, ns): @@ -830,8 +830,6 @@ def set_time(filename, ns): os.utime(filename, (atime, mtime)) self._test_utime(set_time) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_utime_by_times(self): def set_time(filename, ns): atime_ns, mtime_ns = ns From 3eda1cf3b4a29eb35c30f2a3ed272771c642c065 Mon Sep 17 00:00:00 2001 From: Alin-Ioan Alexandru <118962201+alinioan@users.noreply.github.com> Date: Thu, 25 Jan 2024 07:54:06 +0200 Subject: [PATCH 245/893] Deprecation warning fix for __complex__ (#5152) --- Lib/test/test_complex.py | 2 -- vm/src/builtins/complex.rs | 34 ++++++++++++++++++++++++++-------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/Lib/test/test_complex.py b/Lib/test/test_complex.py index b26dd00d8f..6f93c29146 100644 --- a/Lib/test/test_complex.py +++ b/Lib/test/test_complex.py @@ -340,8 +340,6 @@ def test_boolcontext(self): def test_conjugate(self): self.assertClose(complex(5.3, 9.8).conjugate(), 5.3-9.8j) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_constructor(self): class NS: def __init__(self, value): self.value = value diff --git a/vm/src/builtins/complex.rs b/vm/src/builtins/complex.rs index 4a3125c138..284ee7e42d 100644 --- a/vm/src/builtins/complex.rs +++ b/vm/src/builtins/complex.rs @@ -9,6 +9,7 @@ use crate::{ }, identifier, protocol::PyNumberMethods, + stdlib::warnings, types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable}, AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; @@ -59,14 +60,31 @@ impl PyObjectRef { } if let Some(method) = vm.get_method(self.clone(), identifier!(vm, __complex__)) { let result = method?.call((), vm)?; - // TODO: returning strict subclasses of complex in __complex__ is deprecated - return match result.payload::() { - Some(complex_obj) => Ok(Some((complex_obj.value, true))), - None => Err(vm.new_type_error(format!( - "__complex__ returned non-complex (type '{}')", - result.class().name() - ))), - }; + + let ret_class = result.class().to_owned(); + if let Some(ret) = result.downcast_ref::() { + warnings::warn( + vm.ctx.exceptions.deprecation_warning, + format!( + "__complex__ returned non-complex (type {}). \ + The ability to return an instance of a strict subclass of complex \ + is deprecated, and may be removed in a future version of Python.", + ret_class + ), + 1, + vm, + )?; + + return Ok(Some((ret.value, true))); + } else { + return match result.payload::() { + Some(complex_obj) => Ok(Some((complex_obj.value, true))), + None => Err(vm.new_type_error(format!( + "__complex__ returned non-complex (type '{}')", + result.class().name() + ))), + }; + } } // `complex` does not have a `__complex__` by default, so subclasses might not either, // use the actual stored value in this case From 43d7c71a6836062b72cb1f2d8f24ef10c5bf17dd Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Wed, 7 Feb 2024 14:19:17 -0500 Subject: [PATCH 246/893] Fix integer rounding The [docs](https://docs.python.org/3/library/functions.html#round) specify that calling `round` with a negative precision removes significant digits, so that `round(12345, -2) == 12300`. The implementation was simply returning the original integer. Additionally, `round(a, b)` is implemented as `(a / 10^b) * 10^b`, using half-even rounding during the division. --- Lib/test/test_long.py | 2 -- vm/src/builtins/int.rs | 43 +++++++++++++++++++++++++++++++----------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/Lib/test/test_long.py b/Lib/test/test_long.py index a25eff5b06..6bbb44a695 100644 --- a/Lib/test/test_long.py +++ b/Lib/test/test_long.py @@ -1141,8 +1141,6 @@ def test_bit_count(self): self.assertEqual((a ^ 63).bit_count(), 7) self.assertEqual(((a - 1) ^ 510).bit_count(), exp - 8) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_round(self): # check round-half-even algorithm. For round to nearest ten; # rounding map is invariant under adding multiples of 20 diff --git a/vm/src/builtins/int.rs b/vm/src/builtins/int.rs index 9b25a504fc..9d5463a464 100644 --- a/vm/src/builtins/int.rs +++ b/vm/src/builtins/int.rs @@ -502,19 +502,40 @@ impl PyInt { #[pymethod(magic)] fn round( zelf: PyRef, - precision: OptionalArg, + ndigits: OptionalArg, vm: &VirtualMachine, ) -> PyResult> { - match precision { - OptionalArg::Missing => (), - OptionalArg::Present(ref value) => { - // Only accept int type ndigits - let _ndigits = value.payload_if_subclass::(vm).ok_or_else(|| { - vm.new_type_error(format!( - "'{}' object cannot be interpreted as an integer", - value.class().name() - )) - })?; + if let OptionalArg::Present(ndigits) = ndigits { + let ndigits = ndigits.as_bigint(); + // round(12345, -2) == 12300 + // If precision >= 0, then any integer is already rounded correctly + if let Some(ndigits) = ndigits.neg().to_u32() { + if ndigits > 0 { + // Work with positive integers and negate at the end if necessary + let sign = if zelf.value.is_negative() { + BigInt::from(-1) + } else { + BigInt::from(1) + }; + let value = zelf.value.abs(); + + // Divide and multiply by the power of 10 to get the approximate answer + let pow10 = BigInt::from(10).pow(ndigits); + let quotient = &value / &pow10; + let rounded = "ient * &pow10; + + // Malachite division uses floor rounding, Python uses half-even + let remainder = &value - &rounded; + let halfpow10 = &pow10 / BigInt::from(2); + let correction = + if remainder > halfpow10 || (remainder == halfpow10 && quotient.is_odd()) { + pow10 + } else { + BigInt::from(0) + }; + let rounded = (rounded + correction) * sign; + return Ok(vm.ctx.new_int(rounded)); + } } } Ok(zelf) From 36b9219e3239ea8707e3f018cf6ceac42e254d83 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Wed, 7 Feb 2024 18:40:28 -0500 Subject: [PATCH 247/893] Bump ahash from 0.8.3 to 0.8.7 When building the project for the Miri CI step, the latest nightly version apparently conflicts with the older `ahash` 0.8.3. Updating to 0.8.7 fixed the build for me locally. --- Cargo.lock | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 481ed44d18..9a017ab9c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -22,14 +22,15 @@ checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] name = "ahash" -version = "0.8.3" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" dependencies = [ "cfg-if", "getrandom", "once_cell", "version_check", + "zerocopy", ] [[package]] @@ -3423,3 +3424,23 @@ checksum = "56c1936c4cc7a1c9ab21a1ebb602eb942ba868cbd44a99cb7cdc5892335e1c85" dependencies = [ "linked-hash-map", ] + +[[package]] +name = "zerocopy" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.32", +] From bdf228eb426be1a842276277c32a91b02510d3bc Mon Sep 17 00:00:00 2001 From: kingiler <68145845+kingiler@users.noreply.github.com> Date: Fri, 9 Feb 2024 13:42:39 +0000 Subject: [PATCH 248/893] Fix bug in `binascii` `uu` encoding. Pass more related unit tests. (#5160) * Fix bug in binascii, passes more unit tests. * Pass more additional tests due to this PR. --- Lib/test/test_binascii.py | 6 ------ Lib/test/test_codecs.py | 6 ------ Lib/test/test_uu.py | 14 -------------- stdlib/src/binascii.rs | 16 ++++++++++------ 4 files changed, 10 insertions(+), 32 deletions(-) diff --git a/Lib/test/test_binascii.py b/Lib/test/test_binascii.py index fd52b9895c..4ae89837cc 100644 --- a/Lib/test/test_binascii.py +++ b/Lib/test/test_binascii.py @@ -38,8 +38,6 @@ def test_functions(self): self.assertTrue(hasattr(getattr(binascii, name), '__call__')) self.assertRaises(TypeError, getattr(binascii, name)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_returned_value(self): # Limit to the minimum of all limits (b2a_uu) MAX_ALL = 45 @@ -186,8 +184,6 @@ def assertInvalidLength(data): assertInvalidLength(b'a' * (4 * 87 + 1)) assertInvalidLength(b'A\tB\nC ??DE') # only 5 valid characters - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_uu(self): MAX_UU = 45 for backtick in (True, False): @@ -404,8 +400,6 @@ def test_unicode_b2a(self): # crc_hqx needs 2 arguments self.assertRaises(TypeError, binascii.crc_hqx, "test", 0) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_unicode_a2b(self): # Unicode strings are accepted by a2b_* functions. MAX_ALL = 45 diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py index 3a9c6d2741..0b972a58a5 100644 --- a/Lib/test/test_codecs.py +++ b/Lib/test/test_codecs.py @@ -2946,8 +2946,6 @@ def test_seek0(self): class TransformCodecTest(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_basics(self): binput = bytes(range(256)) for encoding in bytes_transform_encodings: @@ -2959,8 +2957,6 @@ def test_basics(self): self.assertEqual(size, len(o)) self.assertEqual(i, binput) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_read(self): for encoding in bytes_transform_encodings: with self.subTest(encoding=encoding): @@ -2969,8 +2965,6 @@ def test_read(self): sout = reader.read() self.assertEqual(sout, b"\x80") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_readline(self): for encoding in bytes_transform_encodings: with self.subTest(encoding=encoding): diff --git a/Lib/test/test_uu.py b/Lib/test/test_uu.py index 6b0b2f24f5..f71d877365 100644 --- a/Lib/test/test_uu.py +++ b/Lib/test/test_uu.py @@ -75,8 +75,6 @@ def test_encode(self): with self.assertRaises(TypeError): uu.encode(inp, out, "t1", 0o644, True) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_decode(self): for backtick in True, False: inp = io.BytesIO(encodedtextwrapped(0o666, "t1", backtick=backtick)) @@ -110,8 +108,6 @@ def test_missingbegin(self): except uu.Error as e: self.assertEqual(str(e), "No valid begin line found in input file") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_garbage_padding(self): # Issue #22406 encodedtext1 = ( @@ -163,8 +159,6 @@ def tearDown(self): sys.stdin = self.stdin sys.stdout = self.stdout - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encode(self): sys.stdin = FakeIO(plaintext.decode("ascii")) sys.stdout = FakeIO() @@ -172,8 +166,6 @@ def test_encode(self): self.assertEqual(sys.stdout.getvalue(), encodedtextwrapped(0o666, "t1").decode("ascii")) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_decode(self): sys.stdin = FakeIO(encodedtextwrapped(0o666, "t1").decode("ascii")) sys.stdout = FakeIO() @@ -192,8 +184,6 @@ def setUp(self): self.addCleanup(os_helper.unlink, self.tmpin) self.addCleanup(os_helper.unlink, self.tmpout) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_encode(self): with open(self.tmpin, 'wb') as fin: fin.write(plaintext) @@ -212,8 +202,6 @@ def test_encode(self): s = fout.read() self.assertEqual(s, encodedtextwrapped(0o644, self.tmpin)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_decode(self): with open(self.tmpin, 'wb') as f: f.write(encodedtextwrapped(0o644, self.tmpout)) @@ -226,8 +214,6 @@ def test_decode(self): self.assertEqual(s, plaintext) # XXX is there an xp way to verify the mode? - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_decode_filename(self): with open(self.tmpin, 'wb') as f: f.write(encodedtextwrapped(0o644, self.tmpout)) diff --git a/stdlib/src/binascii.rs b/stdlib/src/binascii.rs index d7467fc8a5..d348049f4c 100644 --- a/stdlib/src/binascii.rs +++ b/stdlib/src/binascii.rs @@ -289,7 +289,7 @@ mod decl { if [b'\r', b'\n'].contains(c) { return Ok(0); } - return Err(vm.new_value_error("Illegal char".to_string())); + return Err(super::new_binascii_error("Illegal char".to_owned(), vm)); } Ok((*c - b' ') & 0x3f) } @@ -645,7 +645,8 @@ mod decl { // Allocate the buffer let mut res = Vec::::with_capacity(length); - let trailing_garbage_error = || Err(vm.new_value_error("Trailing garbage".to_string())); + let trailing_garbage_error = + || Err(super::new_binascii_error("Trailing garbage".to_owned(), vm)); for chunk in b.get(1..).unwrap_or_default().chunks(4) { let (char_a, char_b, char_c, char_d) = { @@ -666,7 +667,7 @@ mod decl { } if res.len() < length { - res.push((char_b & 0xf) | char_c >> 2); + res.push((char_b & 0xf) << 4 | char_c >> 2); } else if char_c != 0 { return trailing_garbage_error(); } @@ -688,7 +689,7 @@ mod decl { #[derive(FromArgs)] struct BacktickArg { - #[pyarg(named, default = "true")] + #[pyarg(named, default = "false")] backtick: bool, } @@ -700,7 +701,7 @@ mod decl { ) -> PyResult> { #[inline] fn uu_b2a(num: u8, backtick: bool) -> u8 { - if backtick && num != 0 { + if backtick && num == 0 { 0x60 } else { b' ' + num @@ -710,7 +711,10 @@ mod decl { data.with_ref(|b| { let length = b.len(); if length > 45 { - return Err(vm.new_value_error("At most 45 bytes at once".to_string())); + return Err(super::new_binascii_error( + "At most 45 bytes at once".to_owned(), + vm, + )); } let mut res = Vec::::with_capacity(2 + ((length + 2) / 3) * 4); res.push(uu_b2a(length as u8, backtick)); From 9693ad9b11bbbe36f9007ab60831a5747a0bc0dd Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Fri, 9 Feb 2024 15:15:22 -0500 Subject: [PATCH 249/893] Replace rust-cpython with pyo3 in benchmarks The benchmarks have been broken since Python 3.10 deprecated the API they were using to parse and execute CPython baselines. Since then rust-cpython has been deprecated in favor of pyo3. - Replace `cpython` and `python3-sys` with `pyo3`. - Add `html_report` feature to `criterion`, it will be required in a future release. - Remove `anyhow`. It was unused and cargo cleaned it up automatically. --- Cargo.lock | 86 ++++++++++++++++++++++++++++++++++++++++++------------ Cargo.toml | 6 ++-- 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9a017ab9c1..c8d0342708 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -328,18 +328,6 @@ dependencies = [ "libc", ] -[[package]] -name = "cpython" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3052106c29da7390237bc2310c1928335733b286287754ea85e6093d2495280e" -dependencies = [ - "libc", - "num-traits", - "paste", - "python3-sys", -] - [[package]] name = "cranelift" version = "0.88.2" @@ -978,6 +966,12 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "indoc" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" + [[package]] name = "insta" version = "1.33.0" @@ -1686,13 +1680,64 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe7765e19fb2ba6fd4373b8d90399f5321683ea7c11b598c6bbaa3a72e9c83b8" [[package]] -name = "python3-sys" -version = "0.7.1" +name = "pyo3" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f8b50d72fb3015735aa403eebf19bbd72c093bfeeae24ee798be5f2f1aab52" +checksum = "9a89dc7a5850d0e983be1ec2a463a171d20990487c3cfcd68b5363f1ee3d6fe0" dependencies = [ + "cfg-if", + "indoc", "libc", - "regex", + "memoffset 0.9.0", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07426f0d8fe5a601f26293f300afd1a7b1ed5e78b2a705870c5f30893c5163be" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb7dec17e17766b46bca4f1a4215a85006b4c2ecde122076c562dd058da6cf1" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f738b4e40d50b5711957f142878cfa0f28e054aa0ebdfc3fd137a843f74ed3" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.32", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fc910d4851847827daf9d6cdd4a823fbdaab5b8818325c5e97a86da79e8881f" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.32", ] [[package]] @@ -1902,7 +1947,6 @@ dependencies = [ "atty", "cfg-if", "clap", - "cpython", "criterion", "dirs-next", "env_logger", @@ -1910,7 +1954,7 @@ dependencies = [ "flamescope", "libc", "log", - "python3-sys", + "pyo3", "rustpython-compiler", "rustpython-parser", "rustpython-pylib", @@ -2932,6 +2976,12 @@ dependencies = [ "time", ] +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + [[package]] name = "utf8parse" version = "0.2.0" diff --git a/Cargo.toml b/Cargo.toml index b7b3db3dc8..541a5aa3f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,6 @@ rustpython-format = { git = "https://github.com/RustPython/Parser.git", rev = "2 # rustpython-format = { path = "../RustPython-parser/format" } ahash = "0.8.3" -anyhow = "1.0.45" ascii = "1.0" atty = "0.2.14" bitflags = "2.4.0" @@ -118,9 +117,8 @@ libc = { workspace = true } rustyline = { workspace = true } [dev-dependencies] -cpython = "0.7.0" -criterion = "0.3.5" -python3-sys = "0.7.1" +criterion = { version = "0.3.5", features = ["html_reports"] } +pyo3 = { version = "0.20.2", features = ["auto-initialize"] } [[bench]] name = "execution" From ea1f72e92dc8399d2faea73ac59feda3bf6d7fc2 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Fri, 9 Feb 2024 16:12:47 -0500 Subject: [PATCH 250/893] Replace rust-cpython with pyo3 in benches Update the actual benchmark harnesses. Because the internal APIs previously used are no longer available, I opted to use `compile` and `exec` from within the CPython context to compile and execute code. There's probably more overhead to that than the internal API had, but that overhead should be consistent per benchmark. If anyone cares about hyperoptimizing benchmarks then they can optimize the harness as well. --- benches/execution.rs | 84 ++++++++-------------- benches/microbenchmarks.rs | 143 +++++++++++++++---------------------- 2 files changed, 87 insertions(+), 140 deletions(-) diff --git a/benches/execution.rs b/benches/execution.rs index 14fadfc2a5..de81cd0809 100644 --- a/benches/execution.rs +++ b/benches/execution.rs @@ -1,30 +1,32 @@ use criterion::measurement::WallTime; use criterion::{ - criterion_group, criterion_main, Bencher, BenchmarkGroup, BenchmarkId, Criterion, Throughput, + black_box, criterion_group, criterion_main, Bencher, BenchmarkGroup, BenchmarkId, Criterion, + Throughput, }; use rustpython_compiler::Mode; use rustpython_parser::ast; use rustpython_parser::Parse; -use rustpython_vm::{Interpreter, PyResult}; +use rustpython_vm::{Interpreter, PyResult, Settings}; use std::collections::HashMap; use std::path::Path; fn bench_cpython_code(b: &mut Bencher, source: &str) { - let gil = cpython::Python::acquire_gil(); - let python = gil.python(); - - b.iter(|| { - let res: cpython::PyResult<()> = python.run(source, None, None); - if let Err(e) = res { - e.print(python); - panic!("Error running source") - } - }); + pyo3::Python::with_gil(|py| { + b.iter(|| { + let module = + pyo3::types::PyModule::from_code(py, source, "", "").expect("Error running source"); + black_box(module); + }) + }) } fn bench_rustpy_code(b: &mut Bencher, name: &str, source: &str) { // NOTE: Take long time. - Interpreter::without_stdlib(Default::default()).enter(|vm| { + let mut settings = Settings::default(); + settings.path_list.push("Lib/".to_string()); + settings.dont_write_bytecode = true; + settings.no_user_site = true; + Interpreter::without_stdlib(settings).enter(|vm| { // Note: bench_cpython is both compiling and executing the code. // As such we compile the code in the benchmark loop as well. b.iter(|| { @@ -36,16 +38,12 @@ fn bench_rustpy_code(b: &mut Bencher, name: &str, source: &str) { }) } -pub fn benchmark_file_execution( - group: &mut BenchmarkGroup, - name: &str, - contents: &String, -) { +pub fn benchmark_file_execution(group: &mut BenchmarkGroup, name: &str, contents: &str) { group.bench_function(BenchmarkId::new(name, "cpython"), |b| { - bench_cpython_code(b, &contents) + bench_cpython_code(b, contents) }); group.bench_function(BenchmarkId::new(name, "rustpython"), |b| { - bench_rustpy_code(b, name, &contents) + bench_rustpy_code(b, name, contents) }); } @@ -55,44 +53,20 @@ pub fn benchmark_file_parsing(group: &mut BenchmarkGroup, name: &str, b.iter(|| ast::Suite::parse(contents, name).unwrap()) }); group.bench_function(BenchmarkId::new("cpython", name), |b| { - let gil = cpython::Python::acquire_gil(); - let py = gil.python(); - - let code = std::ffi::CString::new(contents).unwrap(); - let fname = cpython::PyString::new(py, name); - - b.iter(|| parse_program_cpython(py, &code, &fname)) + pyo3::Python::with_gil(|py| { + let builtins = + pyo3::types::PyModule::import(py, "builtins").expect("Failed to import builtins"); + let compile = builtins.getattr("compile").expect("no compile in builtins"); + b.iter(|| { + let x = compile + .call1((contents, name, "exec")) + .expect("Failed to parse code"); + black_box(x); + }) + }) }); } -fn parse_program_cpython( - py: cpython::Python<'_>, - code: &std::ffi::CStr, - fname: &cpython::PyString, -) { - extern "C" { - fn PyArena_New() -> *mut python3_sys::PyArena; - fn PyArena_Free(arena: *mut python3_sys::PyArena); - } - use cpython::PythonObject; - let fname = fname.as_object(); - unsafe { - let arena = PyArena_New(); - assert!(!arena.is_null()); - let ret = python3_sys::PyParser_ASTFromStringObject( - code.as_ptr() as _, - fname.as_ptr(), - python3_sys::Py_file_input, - std::ptr::null_mut(), - arena, - ); - if ret.is_null() { - cpython::PyErr::fetch(py).print(py); - } - PyArena_Free(arena); - } -} - pub fn benchmark_pystone(group: &mut BenchmarkGroup, contents: String) { // Default is 50_000. This takes a while, so reduce it to 30k. for idx in (10_000..=30_000).step_by(10_000) { diff --git a/benches/microbenchmarks.rs b/benches/microbenchmarks.rs index c30d86722a..befdc63fd1 100644 --- a/benches/microbenchmarks.rs +++ b/benches/microbenchmarks.rs @@ -5,7 +5,7 @@ use criterion::{ use rustpython_compiler::Mode; use rustpython_vm::{AsObject, Interpreter, PyResult, Settings}; use std::{ - ffi, fs, io, + fs, io, path::{Path, PathBuf}, }; @@ -36,95 +36,68 @@ pub struct MicroBenchmark { } fn bench_cpython_code(group: &mut BenchmarkGroup, bench: &MicroBenchmark) { - let gil = cpython::Python::acquire_gil(); - let py = gil.python(); - - let setup_code = ffi::CString::new(&*bench.setup).unwrap(); - let setup_name = ffi::CString::new(format!("{}_setup", bench.name)).unwrap(); - let setup_code = cpy_compile_code(py, &setup_code, &setup_name).unwrap(); - - let code = ffi::CString::new(&*bench.code).unwrap(); - let name = ffi::CString::new(&*bench.name).unwrap(); - let code = cpy_compile_code(py, &code, &name).unwrap(); - - let bench_func = |(globals, locals): &mut (cpython::PyDict, cpython::PyDict)| { - let res = cpy_run_code(py, &code, globals, locals); - if let Err(e) = res { - e.print(py); - panic!("Error running microbenchmark") - } - }; - - let bench_setup = |iterations| { - let globals = cpython::PyDict::new(py); - // setup the __builtins__ attribute - no other way to do this (other than manually) as far - // as I can tell - let _ = py.run("", Some(&globals), None); - let locals = cpython::PyDict::new(py); - if let Some(idx) = iterations { - globals.set_item(py, "ITERATIONS", idx).unwrap(); - } + pyo3::Python::with_gil(|py| { + let setup_name = format!("{}_setup", bench.name); + let setup_code = cpy_compile_code(py, &bench.setup, &setup_name).unwrap(); + + let code = cpy_compile_code(py, &bench.code, &bench.name).unwrap(); + + // Grab the exec function in advance so we don't have lookups in the hot code + let builtins = + pyo3::types::PyModule::import(py, "builtins").expect("Failed to import builtins"); + let exec = builtins.getattr("exec").expect("no exec in builtins"); + + let bench_func = |(globals, locals): &mut (&pyo3::types::PyDict, &pyo3::types::PyDict)| { + let res = exec.call((code, &*globals, &*locals), None); + if let Err(e) = res { + e.print(py); + panic!("Error running microbenchmark") + } + }; - let res = cpy_run_code(py, &setup_code, &globals, &locals); - if let Err(e) = res { - e.print(py); - panic!("Error running microbenchmark setup code") - } - (globals, locals) - }; - - if bench.iterate { - for idx in (100..=1_000).step_by(200) { - group.throughput(Throughput::Elements(idx as u64)); - group.bench_with_input(BenchmarkId::new("cpython", &bench.name), &idx, |b, idx| { - b.iter_batched_ref( - || bench_setup(Some(*idx)), - bench_func, - BatchSize::LargeInput, - ); - }); - } - } else { - group.bench_function(BenchmarkId::new("cpython", &bench.name), move |b| { - b.iter_batched_ref(|| bench_setup(None), bench_func, BatchSize::LargeInput); - }); - } -} + let bench_setup = |iterations| { + let globals = pyo3::types::PyDict::new(py); + let locals = pyo3::types::PyDict::new(py); + if let Some(idx) = iterations { + globals.set_item("ITERATIONS", idx).unwrap(); + } -unsafe fn cpy_res( - py: cpython::Python<'_>, - x: *mut python3_sys::PyObject, -) -> cpython::PyResult { - cpython::PyObject::from_owned_ptr_opt(py, x).ok_or_else(|| cpython::PyErr::fetch(py)) -} + let res = exec.call((setup_code, &globals, &locals), None); + if let Err(e) = res { + e.print(py); + panic!("Error running microbenchmark setup code") + } + (globals, locals) + }; -fn cpy_compile_code( - py: cpython::Python<'_>, - s: &ffi::CStr, - fname: &ffi::CStr, -) -> cpython::PyResult { - unsafe { - let res = - python3_sys::Py_CompileString(s.as_ptr(), fname.as_ptr(), python3_sys::Py_file_input); - cpy_res(py, res) - } + if bench.iterate { + for idx in (100..=1_000).step_by(200) { + group.throughput(Throughput::Elements(idx as u64)); + group.bench_with_input(BenchmarkId::new("cpython", &bench.name), &idx, |b, idx| { + b.iter_batched_ref( + || bench_setup(Some(*idx)), + bench_func, + BatchSize::LargeInput, + ); + }); + } + } else { + group.bench_function(BenchmarkId::new("cpython", &bench.name), move |b| { + b.iter_batched_ref(|| bench_setup(None), bench_func, BatchSize::LargeInput); + }); + } + }) } -fn cpy_run_code( - py: cpython::Python<'_>, - code: &cpython::PyObject, - locals: &cpython::PyDict, - globals: &cpython::PyDict, -) -> cpython::PyResult { - use cpython::PythonObject; - unsafe { - let res = python3_sys::PyEval_EvalCode( - code.as_ptr(), - locals.as_object().as_ptr(), - globals.as_object().as_ptr(), - ); - cpy_res(py, res) - } +fn cpy_compile_code<'a>( + py: pyo3::Python<'a>, + code: &str, + name: &str, +) -> pyo3::PyResult<&'a pyo3::types::PyCode> { + let builtins = + pyo3::types::PyModule::import(py, "builtins").expect("Failed to import builtins"); + let compile = builtins.getattr("compile").expect("no compile in builtins"); + compile.call1((code, name, "exec"))?.extract() } fn bench_rustpy_code(group: &mut BenchmarkGroup, bench: &MicroBenchmark) { From ffc52ef87c8a6e7bac8295478c988b1380f2d872 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Fri, 9 Feb 2024 16:33:36 -0500 Subject: [PATCH 251/893] Removed unused __bool__ methods Python does not define `list().__bool__`, `dict().__bool__`, and `str().__bool__`, and some tests were failing because they were defined. I could not find any references to them and deleting them doesn't seem to break anything. --- Lib/test/test_collections.py | 6 ------ vm/src/builtins/dict.rs | 5 ----- vm/src/builtins/list.rs | 5 ----- vm/src/builtins/str.rs | 5 ----- vm/src/dictdatatype.rs | 4 ---- 5 files changed, 25 deletions(-) diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 68ca288fb1..81167d4d79 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -52,18 +52,12 @@ def _copy_test(self, obj): self.assertEqual(obj.data, obj_copy.data) self.assertIs(obj.test, obj_copy.test) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_str_protocol(self): self._superset_test(UserString, str) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_list_protocol(self): self._superset_test(UserList, list) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_dict_protocol(self): self._superset_test(UserDict, dict) diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index 1a323b4c47..f5a1abdee2 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -242,11 +242,6 @@ impl PyDict { } } - #[pymethod(magic)] - fn bool(&self) -> bool { - !self.entries.is_empty() - } - #[pymethod(magic)] pub fn len(&self) -> usize { self.entries.len() diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index 03503a0cea..29be635de2 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -164,11 +164,6 @@ impl PyList { Ok(zelf) } - #[pymethod(magic)] - fn bool(&self) -> bool { - !self.borrow_vec().is_empty() - } - #[pymethod] fn clear(&self) { let _removed = std::mem::take(self.borrow_vec_mut().deref_mut()); diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index b6015cbe26..03207aa53e 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -411,11 +411,6 @@ impl PyStr { } } - #[pymethod(magic)] - fn bool(&self) -> bool { - !self.bytes.is_empty() - } - fn _contains(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult { if let Some(needle) = needle.payload::() { Ok(self.as_str().contains(needle.as_str())) diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index f8aeb3f6da..d5e5f1e8f9 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -486,10 +486,6 @@ impl Dict { self.read().used } - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - pub fn size(&self) -> DictSize { self.read().size() } From bf461cdebc8cfb00a6e7e931978ab293911988ac Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Fri, 9 Feb 2024 21:02:40 -0500 Subject: [PATCH 252/893] Use CPython hash algorithm for frozenset The original hash algorithm just XOR'd all the hashes of the elements of the set, which is problematic. The CPython algorithm is required to pass the tests. - Replace `PyFrozenSet::hash` with CPython's algorithm - Remove unused `hash_iter_unorded` functions - Add `frozenset` benchmark - Enable tests - Lower performance expectations on effectiveness test - Adjust `slot::hash_wrapper` so that it doesn't rehash the computed hash value in the process of converting PyInt to PyHash. --- Lib/test/test_collections.py | 2 -- Lib/test/test_set.py | 11 ++++++++--- benches/microbenchmarks/frozenset.py | 5 +++++ common/src/hash.rs | 14 -------------- vm/src/builtins/set.rs | 23 ++++++++++++++++++++++- vm/src/types/slot.rs | 7 ++++++- vm/src/utils.rs | 7 ------- 7 files changed, 41 insertions(+), 28 deletions(-) create mode 100644 benches/microbenchmarks/frozenset.py diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 68ca288fb1..3408b1b10b 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -1838,8 +1838,6 @@ def __repr__(self): self.assertTrue(f1 != l1) self.assertTrue(f1 != l2) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_Set_hash_matches_frozenset(self): sets = [ {}, {1}, {None}, {-1}, {0.0}, {"abc"}, {1, 2, 3}, diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index 4284393ca5..a1d421211a 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -707,8 +707,6 @@ def test_hash_caching(self): f = self.thetype('abcdcda') self.assertEqual(hash(f), hash(f)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_hash_effectiveness(self): n = 13 hashvalues = set() @@ -730,7 +728,14 @@ def powerset(s): for i in range(len(s)+1): yield from map(frozenset, itertools.combinations(s, i)) - for n in range(18): + # TODO the original test has: + # for n in range(18): + # Due to general performance overhead, hashing a frozenset takes + # about 50 times longer than in CPython. This test amplifies that + # exponentially, so the best we can do here reasonably is 13. + # Even if the internal hash function did nothing, it would still be + # about 40 times slower than CPython. + for n in range(13): t = 2 ** n mask = t - 1 for nums in (range, zf_range): diff --git a/benches/microbenchmarks/frozenset.py b/benches/microbenchmarks/frozenset.py new file mode 100644 index 0000000000..74bfb9ddb4 --- /dev/null +++ b/benches/microbenchmarks/frozenset.py @@ -0,0 +1,5 @@ +fs = frozenset(range(0, ITERATIONS)) + +# --- + +hash(fs) diff --git a/common/src/hash.rs b/common/src/hash.rs index f514dac326..558b0fe15f 100644 --- a/common/src/hash.rs +++ b/common/src/hash.rs @@ -130,20 +130,6 @@ pub fn hash_float(value: f64) -> Option { Some(fix_sentinel(x as PyHash * value.signum() as PyHash)) } -pub fn hash_iter_unordered<'a, T: 'a, I, F, E>(iter: I, hashf: F) -> Result -where - I: IntoIterator, - F: Fn(&'a T) -> Result, -{ - let mut hash: PyHash = 0; - for element in iter { - let item_hash = hashf(element)?; - // xor is commutative and hash should be independent of order - hash ^= item_hash; - } - Ok(fix_sentinel(mod_int(hash))) -} - pub fn hash_bigint(value: &BigInt) -> PyHash { let ret = match value.to_i64() { Some(i) => mod_int(i), diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index 29ead02915..9672b467ec 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -434,7 +434,28 @@ impl PySetInner { } fn hash(&self, vm: &VirtualMachine) -> PyResult { - crate::utils::hash_iter_unordered(self.elements().iter(), vm) + // Work to increase the bit dispersion for closely spaced hash values. + // This is important because some use cases have many combinations of a + // small number of elements with nearby hashes so that many distinct + // combinations collapse to only a handful of distinct hash values. + fn _shuffle_bits(h: u64) -> u64 { + ((h ^ 89869747) ^ (h.wrapping_shl(16))).wrapping_mul(3644798167) + } + // Factor in the number of active entries + let mut hash: u64 = (self.elements().len() as u64 + 1).wrapping_mul(1927868237); + // Xor-in shuffled bits from every entry's hash field because xor is + // commutative and a frozenset hash should be independent of order. + for element in self.elements().iter() { + hash ^= _shuffle_bits(element.hash(vm)? as u64); + } + // Disperse patterns arising in nested frozensets + hash ^= (hash >> 11) ^ (hash >> 25); + hash = hash.wrapping_mul(69069).wrapping_add(907133923); + // -1 is reserved as an error code + if hash == u64::MAX { + hash = 590923713; + } + Ok(hash as PyHash) } // Run operation, on failure, if item is a set/set subclass, convert it diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 132a0f68c4..578d13917d 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -15,6 +15,7 @@ use crate::{ AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; +use malachite_bigint::BigInt; use num_traits::{Signed, ToPrimitive}; use std::{borrow::Borrow, cmp::Ordering, ops::Deref}; @@ -254,7 +255,11 @@ fn hash_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { let py_int = hash_obj .payload_if_subclass::(vm) .ok_or_else(|| vm.new_type_error("__hash__ method should return an integer".to_owned()))?; - Ok(rustpython_common::hash::hash_bigint(py_int.as_bigint())) + let big_int = py_int.as_bigint(); + let hash: PyHash = big_int + .to_i64() + .unwrap_or_else(|| (big_int % BigInt::from(u64::MAX)).to_i64().unwrap()); + Ok(hash) } /// Marks a type as unhashable. Similar to PyObject_HashNotImplemented in CPython diff --git a/vm/src/utils.rs b/vm/src/utils.rs index ab45343f85..2c5ff79d3f 100644 --- a/vm/src/utils.rs +++ b/vm/src/utils.rs @@ -11,13 +11,6 @@ pub fn hash_iter<'a, I: IntoIterator>( vm.state.hash_secret.hash_iter(iter, |obj| obj.hash(vm)) } -pub fn hash_iter_unordered<'a, I: IntoIterator>( - iter: I, - vm: &VirtualMachine, -) -> PyResult { - rustpython_common::hash::hash_iter_unordered(iter, |obj| obj.hash(vm)) -} - impl ToPyObject for std::convert::Infallible { fn to_pyobject(self, _vm: &VirtualMachine) -> PyObjectRef { match self {} From 074d228a7a4cb71fba3ddb3c9103e963a3355ca9 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Fri, 9 Feb 2024 21:41:45 -0500 Subject: [PATCH 253/893] Use correct TODO syntax Co-authored-by: fanninpm --- Lib/test/test_set.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index a1d421211a..523c39ce68 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -728,7 +728,8 @@ def powerset(s): for i in range(len(s)+1): yield from map(frozenset, itertools.combinations(s, i)) - # TODO the original test has: + # TODO: RUSTPYTHON + # The original test has: # for n in range(18): # Due to general performance overhead, hashing a frozenset takes # about 50 times longer than in CPython. This test amplifies that From 61e40de32b722fe50dc81a2ca89e86d3483454ce Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Sat, 10 Feb 2024 11:43:40 +0900 Subject: [PATCH 254/893] Fix typo in slot.rs (#5162) overriden -> overridden --- vm/src/types/slot.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 132a0f68c4..35a7cb1b11 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -1290,7 +1290,7 @@ where #[cold] fn slot_iter(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { let repr = zelf.repr(vm)?; - unreachable!("slot must be overriden for {}", repr.as_str()); + unreachable!("slot must be overridden for {}", repr.as_str()); } fn __iter__(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult { From 69e8e4be438f242a8a8efefa82c7a3cc85ac02d6 Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Sun, 18 Feb 2024 15:33:01 +0800 Subject: [PATCH 255/893] Update calendar.py from CPython v3.12.0 --- Lib/calendar.py | 76 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 23 deletions(-) diff --git a/Lib/calendar.py b/Lib/calendar.py index 657396439c..baab52a157 100644 --- a/Lib/calendar.py +++ b/Lib/calendar.py @@ -7,8 +7,10 @@ import sys import datetime +from enum import IntEnum, global_enum import locale as _locale from itertools import repeat +import warnings __all__ = ["IllegalMonthError", "IllegalWeekdayError", "setfirstweekday", "firstweekday", "isleap", "leapdays", "weekday", "monthrange", @@ -16,6 +18,9 @@ "timegm", "month_name", "month_abbr", "day_name", "day_abbr", "Calendar", "TextCalendar", "HTMLCalendar", "LocaleTextCalendar", "LocaleHTMLCalendar", "weekheader", + "Day", "Month", "JANUARY", "FEBRUARY", "MARCH", + "APRIL", "MAY", "JUNE", "JULY", + "AUGUST", "SEPTEMBER", "OCTOBER", "NOVEMBER", "DECEMBER", "MONDAY", "TUESDAY", "WEDNESDAY", "THURSDAY", "FRIDAY", "SATURDAY", "SUNDAY"] @@ -37,9 +42,46 @@ def __str__(self): return "bad weekday number %r; must be 0 (Monday) to 6 (Sunday)" % self.weekday -# Constants for months referenced later -January = 1 -February = 2 +def __getattr__(name): + if name in ('January', 'February'): + warnings.warn(f"The '{name}' attribute is deprecated, use '{name.upper()}' instead", + DeprecationWarning, stacklevel=2) + if name == 'January': + return 1 + else: + return 2 + + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + +# Constants for months +@global_enum +class Month(IntEnum): + JANUARY = 1 + FEBRUARY = 2 + MARCH = 3 + APRIL = 4 + MAY = 5 + JUNE = 6 + JULY = 7 + AUGUST = 8 + SEPTEMBER = 9 + OCTOBER = 10 + NOVEMBER = 11 + DECEMBER = 12 + + +# Constants for days +@global_enum +class Day(IntEnum): + MONDAY = 0 + TUESDAY = 1 + WEDNESDAY = 2 + THURSDAY = 3 + FRIDAY = 4 + SATURDAY = 5 + SUNDAY = 6 + # Number of days per month (except for February in leap years) mdays = [0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31] @@ -95,9 +137,6 @@ def __len__(self): month_name = _localized_month('%B') month_abbr = _localized_month('%b') -# Constants for weekdays -(MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, SATURDAY, SUNDAY) = range(7) - def isleap(year): """Return True for leap years, False for non-leap years.""" @@ -116,7 +155,7 @@ def weekday(year, month, day): """Return weekday (0-6 ~ Mon-Sun) for year, month (1-12), day (1-31).""" if not datetime.MINYEAR <= year <= datetime.MAXYEAR: year = 2000 + year % 400 - return datetime.date(year, month, day).weekday() + return Day(datetime.date(year, month, day).weekday()) def monthrange(year, month): @@ -125,12 +164,12 @@ def monthrange(year, month): if not 1 <= month <= 12: raise IllegalMonthError(month) day1 = weekday(year, month, 1) - ndays = mdays[month] + (month == February and isleap(year)) + ndays = mdays[month] + (month == FEBRUARY and isleap(year)) return day1, ndays def _monthlen(year, month): - return mdays[month] + (month == February and isleap(year)) + return mdays[month] + (month == FEBRUARY and isleap(year)) def _prevmonth(year, month): @@ -260,10 +299,7 @@ def yeardatescalendar(self, year, width=3): Each month contains between 4 and 6 weeks and each week contains 1-7 days. Days are datetime.date objects. """ - months = [ - self.monthdatescalendar(year, i) - for i in range(January, January+12) - ] + months = [self.monthdatescalendar(year, m) for m in Month] return [months[i:i+width] for i in range(0, len(months), width) ] def yeardays2calendar(self, year, width=3): @@ -273,10 +309,7 @@ def yeardays2calendar(self, year, width=3): (day number, weekday number) tuples. Day numbers outside this month are zero. """ - months = [ - self.monthdays2calendar(year, i) - for i in range(January, January+12) - ] + months = [self.monthdays2calendar(year, m) for m in Month] return [months[i:i+width] for i in range(0, len(months), width) ] def yeardayscalendar(self, year, width=3): @@ -285,10 +318,7 @@ def yeardayscalendar(self, year, width=3): yeardatescalendar()). Entries in the week lists are day numbers. Day numbers outside this month are zero. """ - months = [ - self.monthdayscalendar(year, i) - for i in range(January, January+12) - ] + months = [self.monthdayscalendar(year, m) for m in Month] return [months[i:i+width] for i in range(0, len(months), width) ] @@ -509,7 +539,7 @@ def formatyear(self, theyear, width=3): a('\n') a('%s' % ( width, self.cssclass_year_head, theyear)) - for i in range(January, January+12, width): + for i in range(JANUARY, JANUARY+12, width): # months in this row months = range(i, min(i+width, 13)) a('') @@ -693,7 +723,7 @@ def main(args): parser.add_argument( "-L", "--locale", default=None, - help="locale to be used from month and weekday names" + help="locale to use for month and weekday names" ) parser.add_argument( "-e", "--encoding", From 258342e1cabdceee944b43e423117139fa3d792e Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Sun, 18 Feb 2024 15:36:45 +0800 Subject: [PATCH 256/893] Update test_calendar.py from CPython v3.12.0 --- Lib/test/test_calendar.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_calendar.py b/Lib/test/test_calendar.py index 42490c8366..a53766d455 100644 --- a/Lib/test/test_calendar.py +++ b/Lib/test/test_calendar.py @@ -8,6 +8,7 @@ import sys import datetime import os +import warnings # From https://en.wikipedia.org/wiki/Leap_year_starting_on_Saturday result_0_02_text = """\ @@ -490,6 +491,14 @@ def test_format(self): self.assertEqual(out.getvalue().strip(), "1 2 3") class CalendarTestCase(unittest.TestCase): + + def test_deprecation_warning(self): + with self.assertWarnsRegex( + DeprecationWarning, + "The 'January' attribute is deprecated, use 'JANUARY' instead" + ): + calendar.January + def test_isleap(self): # Make sure that the return is right for a few years, and # ensure that the return values are 1 or 0, not just true or @@ -568,11 +577,15 @@ def test_locale_calendar_formatweekday(self): try: # formatweekday uses different day names based on the available width. cal = calendar.LocaleTextCalendar(locale='en_US') + # For really short widths, the abbreviated name is truncated. + self.assertEqual(cal.formatweekday(0, 1), "M") + self.assertEqual(cal.formatweekday(0, 2), "Mo") # For short widths, a centered, abbreviated name is used. + self.assertEqual(cal.formatweekday(0, 3), "Mon") self.assertEqual(cal.formatweekday(0, 5), " Mon ") - # For really short widths, even the abbreviated name is truncated. - self.assertEqual(cal.formatweekday(0, 2), "Mo") + self.assertEqual(cal.formatweekday(0, 8), " Mon ") # For long widths, the full day name is used. + self.assertEqual(cal.formatweekday(0, 9), " Monday ") self.assertEqual(cal.formatweekday(0, 10), " Monday ") except locale.Error: raise unittest.SkipTest('cannot set the en_US locale') From a88c2fe00050255e0859ca6432034608c4f848ba Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Mon, 19 Feb 2024 00:39:47 +0900 Subject: [PATCH 257/893] Fix miri test failure (#5170) --- vm/src/class.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vm/src/class.rs b/vm/src/class.rs index a2eea21213..6ce73d32ce 100644 --- a/vm/src/class.rs +++ b/vm/src/class.rs @@ -127,7 +127,8 @@ pub trait PyClassImpl: PyClassDef { Self::extend_class(ctx, unsafe { // typ will be saved in static_cell let r: &Py = &typ; - &*(r as *const _) + let r: &'static Py = std::mem::transmute(r); + r }); typ })) From 97a0705d2e30a740f2cdfbfea00d92ae314ec682 Mon Sep 17 00:00:00 2001 From: Dmitry Erlikh Date: Sun, 18 Feb 2024 18:36:18 +0100 Subject: [PATCH 258/893] Fix Windows CI (#5168) * pin openssl version for windows CI * use cargo vcpkg * install openssl with vcpkg * use Swatinem/rust-cache right after dtolnay/rust-toolchain * cargo install --target-dir=target cargo-vcpkg --------- Co-authored-by: Dmitry Erlikh --- .github/workflows/ci.yaml | 22 +++++++--------------- Cargo.toml | 10 ++++++++++ 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4f31ba74c1..e147497805 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -123,22 +123,18 @@ jobs: - uses: dtolnay/rust-toolchain@stable with: components: clippy + - uses: Swatinem/rust-cache@v2 + - name: Set up the Windows environment shell: bash run: | - choco install llvm openssl --no-progress - if [[ -d "C:\Program Files\OpenSSL-Win64" ]]; then - echo "OPENSSL_DIR=C:\Program Files\OpenSSL-Win64" >> $GITHUB_ENV - else - echo "OPENSSL_DIR=C:\Program Files\OpenSSL" >> $GITHUB_ENV - fi + cargo install --target-dir=target -v cargo-vcpkg + cargo vcpkg -v build if: runner.os == 'Windows' - name: Set up the Mac environment run: brew install autoconf automake libtool if: runner.os == 'macOS' - - uses: Swatinem/rust-cache@v2 - - name: run clippy run: cargo clippy ${{ env.CARGO_ARGS }} --workspace --exclude rustpython_wasm -- -Dwarnings @@ -249,24 +245,20 @@ jobs: steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 - uses: actions/setup-python@v4 with: python-version: ${{ env.PYTHON_VERSION }} - name: Set up the Windows environment shell: bash run: | - choco install llvm openssl --no-progress - if [[ -d "C:\Program Files\OpenSSL-Win64" ]]; then - echo "OPENSSL_DIR=C:\Program Files\OpenSSL-Win64" >> $GITHUB_ENV - else - echo "OPENSSL_DIR=C:\Program Files\OpenSSL" >> $GITHUB_ENV - fi + cargo install cargo-vcpkg + cargo vcpkg build if: runner.os == 'Windows' - name: Set up the Mac environment run: brew install autoconf automake libtool openssl@3 if: runner.os == 'macOS' - - uses: Swatinem/rust-cache@v2 - name: build rustpython run: cargo build --release --verbose --features=threading ${{ env.CARGO_ARGS }} - uses: actions/setup-python@v4 diff --git a/Cargo.toml b/Cargo.toml index 541a5aa3f3..bfc882fdc5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -151,3 +151,13 @@ lto = "thin" [patch.crates-io] # REDOX START, Uncomment when you want to compile/check with redoxer # REDOX END + +# Used only on Windows to build the vcpkg dependencies +[package.metadata.vcpkg] +git = "https://github.com/microsoft/vcpkg" +# The revision of the vcpkg repository to use +# https://github.com/microsoft/vcpkg/tags +rev = "2024.02.14" + +[package.metadata.vcpkg.target] +x86_64-pc-windows-msvc = { triplet = "x64-windows-static-md", dev-dependencies = ["openssl" ] } From defc3ac7f14d385a801e1a037ed90fe584ada288 Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Mon, 19 Feb 2024 19:22:38 +0800 Subject: [PATCH 259/893] Update configparser.py from CPython v3.12.0 --- Lib/configparser.py | 63 +++++---------------------------------------- 1 file changed, 6 insertions(+), 57 deletions(-) diff --git a/Lib/configparser.py b/Lib/configparser.py index df2d7e335d..e8aae21794 100644 --- a/Lib/configparser.py +++ b/Lib/configparser.py @@ -59,7 +59,7 @@ instance. It will be used as the handler for option value pre-processing when using getters. RawConfigParser objects don't do any sort of interpolation, whereas ConfigParser uses an instance of - BasicInterpolation. The library also provides a ``zc.buildbot`` + BasicInterpolation. The library also provides a ``zc.buildout`` inspired ExtendedInterpolation implementation. When `converters` is given, it should be a dictionary where each key @@ -149,14 +149,14 @@ import sys import warnings -__all__ = ["NoSectionError", "DuplicateOptionError", "DuplicateSectionError", +__all__ = ("NoSectionError", "DuplicateOptionError", "DuplicateSectionError", "NoOptionError", "InterpolationError", "InterpolationDepthError", "InterpolationMissingOptionError", "InterpolationSyntaxError", "ParsingError", "MissingSectionHeaderError", - "ConfigParser", "SafeConfigParser", "RawConfigParser", + "ConfigParser", "RawConfigParser", "Interpolation", "BasicInterpolation", "ExtendedInterpolation", "LegacyInterpolation", "SectionProxy", "ConverterMapping", - "DEFAULTSECT", "MAX_INTERPOLATION_DEPTH"] + "DEFAULTSECT", "MAX_INTERPOLATION_DEPTH") _default_dict = dict DEFAULTSECT = "DEFAULT" @@ -298,41 +298,12 @@ def __init__(self, option, section, rawval): class ParsingError(Error): """Raised when a configuration file does not follow legal syntax.""" - def __init__(self, source=None, filename=None): - # Exactly one of `source'/`filename' arguments has to be given. - # `filename' kept for compatibility. - if filename and source: - raise ValueError("Cannot specify both `filename' and `source'. " - "Use `source'.") - elif not filename and not source: - raise ValueError("Required argument `source' not given.") - elif filename: - source = filename - Error.__init__(self, 'Source contains parsing errors: %r' % source) + def __init__(self, source): + super().__init__(f'Source contains parsing errors: {source!r}') self.source = source self.errors = [] self.args = (source, ) - @property - def filename(self): - """Deprecated, use `source'.""" - warnings.warn( - "The 'filename' attribute will be removed in Python 3.12. " - "Use 'source' instead.", - DeprecationWarning, stacklevel=2 - ) - return self.source - - @filename.setter - def filename(self, value): - """Deprecated, user `source'.""" - warnings.warn( - "The 'filename' attribute will be removed in Python 3.12. " - "Use 'source' instead.", - DeprecationWarning, stacklevel=2 - ) - self.source = value - def append(self, lineno, line): self.errors.append((lineno, line)) self.message += '\n\t[line %2d]: %s' % (lineno, line) @@ -769,15 +740,6 @@ def read_dict(self, dictionary, source=''): elements_added.add((section, key)) self.set(section, key, value) - def readfp(self, fp, filename=None): - """Deprecated, use read_file instead.""" - warnings.warn( - "This method will be removed in Python 3.12. " - "Use 'parser.read_file()' instead.", - DeprecationWarning, stacklevel=2 - ) - self.read_file(fp, source=filename) - def get(self, section, option, *, raw=False, vars=None, fallback=_UNSET): """Get an option value for a given section. @@ -1240,19 +1202,6 @@ def _read_defaults(self, defaults): self._interpolation = hold_interpolation -class SafeConfigParser(ConfigParser): - """ConfigParser alias for backwards compatibility purposes.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - warnings.warn( - "The SafeConfigParser class has been renamed to ConfigParser " - "in Python 3.2. This alias will be removed in Python 3.12." - " Use ConfigParser directly instead.", - DeprecationWarning, stacklevel=2 - ) - - class SectionProxy(MutableMapping): """A proxy for a single section from a parser.""" From d061837467702e3bb2e012c5c576afaa34330f80 Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Mon, 19 Feb 2024 19:25:17 +0800 Subject: [PATCH 260/893] Update test_configparser.py from CPython v3.12.0 --- Lib/test/test_configparser.py | 46 ++++++----------------------------- 1 file changed, 7 insertions(+), 39 deletions(-) diff --git a/Lib/test/test_configparser.py b/Lib/test/test_configparser.py index 6d5c74ff48..01e8e6c675 100644 --- a/Lib/test/test_configparser.py +++ b/Lib/test/test_configparser.py @@ -114,7 +114,7 @@ def basic_test(self, cf): # The use of spaces in the section names serves as a # regression test for SourceForge bug #583248: - # http://www.python.org/sf/583248 + # https://bugs.python.org/issue583248 # API access eq(cf.get('Foo Bar', 'foo'), 'bar1') @@ -934,7 +934,7 @@ def test_items(self): ('name', 'value')]) def test_safe_interpolation(self): - # See http://www.python.org/sf/511737 + # See https://bugs.python.org/issue511737 cf = self.fromstring("[section]\n" "option1{eq}xxx\n" "option2{eq}%(option1)s/xxx\n" @@ -1614,23 +1614,12 @@ def test_interpolation_depth_error(self): self.assertEqual(error.section, 'section') def test_parsing_error(self): - with self.assertRaises(ValueError) as cm: + with self.assertRaises(TypeError) as cm: configparser.ParsingError() - self.assertEqual(str(cm.exception), "Required argument `source' not " - "given.") - with self.assertRaises(ValueError) as cm: - configparser.ParsingError(source='source', filename='filename') - self.assertEqual(str(cm.exception), "Cannot specify both `filename' " - "and `source'. Use `source'.") - error = configparser.ParsingError(filename='source') + error = configparser.ParsingError(source='source') + self.assertEqual(error.source, 'source') + error = configparser.ParsingError('source') self.assertEqual(error.source, 'source') - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always", DeprecationWarning) - self.assertEqual(error.filename, 'source') - error.filename = 'filename' - self.assertEqual(error.source, 'filename') - for warning in w: - self.assertTrue(warning.category is DeprecationWarning) def test_interpolation_validation(self): parser = configparser.ConfigParser() @@ -1649,27 +1638,6 @@ def test_interpolation_validation(self): self.assertEqual(str(cm.exception), "bad interpolation variable " "reference '%(()'") - def test_readfp_deprecation(self): - sio = io.StringIO(""" - [section] - option = value - """) - parser = configparser.ConfigParser() - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always", DeprecationWarning) - parser.readfp(sio, filename='StringIO') - for warning in w: - self.assertTrue(warning.category is DeprecationWarning) - self.assertEqual(len(parser), 2) - self.assertEqual(parser['section']['option'], 'value') - - def test_safeconfigparser_deprecation(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always", DeprecationWarning) - parser = configparser.SafeConfigParser() - for warning in w: - self.assertTrue(warning.category is DeprecationWarning) - def test_legacyinterpolation_deprecation(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always", DeprecationWarning) @@ -1843,7 +1811,7 @@ def test_parsingerror(self): self.assertEqual(e1.source, e2.source) self.assertEqual(e1.errors, e2.errors) self.assertEqual(repr(e1), repr(e2)) - e1 = configparser.ParsingError(filename='filename') + e1 = configparser.ParsingError('filename') e1.append(1, 'line1') e1.append(2, 'line2') e1.append(3, 'line3') From 078cd0d88c29bb70e631e9bdf8ac9112d50bab45 Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Mon, 19 Feb 2024 20:04:09 +0800 Subject: [PATCH 261/893] Update code.py from CPython v3.12.0 --- Lib/code.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Lib/code.py b/Lib/code.py index 76000f8c8b..2bd5fa3e79 100644 --- a/Lib/code.py +++ b/Lib/code.py @@ -106,6 +106,7 @@ def showsyntaxerror(self, filename=None): """ type, value, tb = sys.exc_info() + sys.last_exc = value sys.last_type = type sys.last_value = value sys.last_traceback = tb @@ -119,7 +120,7 @@ def showsyntaxerror(self, filename=None): else: # Stuff in the right filename value = SyntaxError(msg, (filename, lineno, offset, line)) - sys.last_value = value + sys.last_exc = sys.last_value = value if sys.excepthook is sys.__excepthook__: lines = traceback.format_exception_only(type, value) self.write(''.join(lines)) @@ -138,6 +139,7 @@ def showtraceback(self): """ sys.last_type, sys.last_value, last_tb = ei = sys.exc_info() sys.last_traceback = last_tb + sys.last_exc = ei[1] try: lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next) if sys.excepthook is sys.__excepthook__: From def5661728a0efb7c55743654ed9ce6efdd78865 Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Mon, 19 Feb 2024 20:13:59 +0800 Subject: [PATCH 262/893] Update test_code.py from CPython v3.12.0 --- Lib/test/test_code.py | 60 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py index 2661cbaa1f..e64b2d2f1a 100644 --- a/Lib/test/test_code.py +++ b/Lib/test/test_code.py @@ -150,7 +150,7 @@ gc_collect) from test.support.script_helper import assert_python_ok from test.support import threading_helper -from opcode import opmap +from opcode import opmap, opname COPY_FREE_VARS = opmap['COPY_FREE_VARS'] @@ -357,6 +357,28 @@ def func(): new_code = code = func.__code__.replace(co_linetable=b'') self.assertEqual(list(new_code.co_lines()), []) + def test_co_lnotab_is_deprecated(self): # TODO: remove in 3.14 + def func(): + pass + + with self.assertWarns(DeprecationWarning): + func.__code__.co_lnotab + + def test_invalid_bytecode(self): + def foo(): + pass + + # assert that opcode 229 is invalid + self.assertEqual(opname[229], '<229>') + + # change first opcode to 0xeb (=229) + foo.__code__ = foo.__code__.replace( + co_code=b'\xe5' + foo.__code__.co_code[1:]) + + msg = "unknown opcode 229" + with self.assertRaisesRegex(SystemError, msg): + foo() + # TODO: RUSTPYTHON @unittest.expectedFailure # @requires_debug_ranges() @@ -479,6 +501,32 @@ def f(): self.assertNotEqual(code_b, code_d) self.assertNotEqual(code_c, code_d) + def test_code_hash_uses_firstlineno(self): + c1 = (lambda: 1).__code__ + c2 = (lambda: 1).__code__ + self.assertNotEqual(c1, c2) + self.assertNotEqual(hash(c1), hash(c2)) + c3 = c1.replace(co_firstlineno=17) + self.assertNotEqual(c1, c3) + self.assertNotEqual(hash(c1), hash(c3)) + + def test_code_hash_uses_order(self): + # Swapping posonlyargcount and kwonlyargcount should change the hash. + c = (lambda x, y, *, z=1, w=1: 1).__code__ + self.assertEqual(c.co_argcount, 2) + self.assertEqual(c.co_posonlyargcount, 0) + self.assertEqual(c.co_kwonlyargcount, 2) + swapped = c.replace(co_posonlyargcount=2, co_kwonlyargcount=0) + self.assertNotEqual(c, swapped) + self.assertNotEqual(hash(c), hash(swapped)) + + def test_code_hash_uses_bytecode(self): + c = (lambda x, y: x + y).__code__ + d = (lambda x, y: x * y).__code__ + c1 = c.replace(co_code=d.co_code) + self.assertNotEqual(c, c1) + self.assertNotEqual(hash(c), hash(c1)) + def isinterned(s): return s is sys.intern(('_' + s + '_')[1:-1]) @@ -692,7 +740,8 @@ def test_positions(self): def check_lines(self, func): co = func.__code__ - lines1 = list(dedup(l for (_, _, l) in co.co_lines())) + lines1 = [line for _, _, line in co.co_lines()] + self.assertEqual(lines1, list(dedup(lines1))) lines2 = list(lines_from_postions(positions_from_location_table(co))) for l1, l2 in zip(lines1, lines2): self.assertEqual(l1, l2) @@ -714,6 +763,7 @@ def f(): pass PY_CODE_LOCATION_INFO_NO_COLUMNS = 13 f.__code__ = f.__code__.replace( + co_stacksize=1, co_firstlineno=42, co_code=bytes( [ @@ -742,15 +792,15 @@ def f(): py = ctypes.pythonapi freefunc = ctypes.CFUNCTYPE(None,ctypes.c_voidp) - RequestCodeExtraIndex = py._PyEval_RequestCodeExtraIndex + RequestCodeExtraIndex = py.PyUnstable_Eval_RequestCodeExtraIndex RequestCodeExtraIndex.argtypes = (freefunc,) RequestCodeExtraIndex.restype = ctypes.c_ssize_t - SetExtra = py._PyCode_SetExtra + SetExtra = py.PyUnstable_Code_SetExtra SetExtra.argtypes = (ctypes.py_object, ctypes.c_ssize_t, ctypes.c_voidp) SetExtra.restype = ctypes.c_int - GetExtra = py._PyCode_GetExtra + GetExtra = py.PyUnstable_Code_GetExtra GetExtra.argtypes = (ctypes.py_object, ctypes.c_ssize_t, ctypes.POINTER(ctypes.c_voidp)) GetExtra.restype = ctypes.c_int From 572053df82a6ad1b09e00be3d4cb130eb8eba762 Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Mon, 19 Feb 2024 20:30:06 +0800 Subject: [PATCH 263/893] est_co_lnotab_is_deprecated, 'test_invalid_bytecode' and est_code_hash_uses_bytecode test exceptions, add TODO --- Lib/test/test_code.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py index e64b2d2f1a..1aceff4efc 100644 --- a/Lib/test/test_code.py +++ b/Lib/test/test_code.py @@ -357,6 +357,8 @@ def func(): new_code = code = func.__code__.replace(co_linetable=b'') self.assertEqual(list(new_code.co_lines()), []) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_co_lnotab_is_deprecated(self): # TODO: remove in 3.14 def func(): pass @@ -364,6 +366,8 @@ def func(): with self.assertWarns(DeprecationWarning): func.__code__.co_lnotab + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_invalid_bytecode(self): def foo(): pass @@ -520,6 +524,8 @@ def test_code_hash_uses_order(self): self.assertNotEqual(c, swapped) self.assertNotEqual(hash(c), hash(swapped)) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_code_hash_uses_bytecode(self): c = (lambda x, y: x + y).__code__ d = (lambda x, y: x * y).__code__ From 36a36b200d519c3262524ab17604c20df4824df8 Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Tue, 20 Feb 2024 19:37:42 +0800 Subject: [PATCH 264/893] Update copy.py from CPython v3.12.0 --- Lib/copy.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/Lib/copy.py b/Lib/copy.py index 1b276afe08..da2908ef62 100644 --- a/Lib/copy.py +++ b/Lib/copy.py @@ -56,11 +56,6 @@ class Error(Exception): pass error = Error # backward compatibility -try: - from org.python.core import PyStringMap -except ImportError: - PyStringMap = None - __all__ = ["Error", "copy", "deepcopy"] def copy(x): @@ -106,13 +101,11 @@ def copy(x): def _copy_immutable(x): return x -for t in (type(None), int, float, bool, complex, str, tuple, +for t in (types.NoneType, int, float, bool, complex, str, tuple, bytes, frozenset, type, range, slice, property, - types.BuiltinFunctionType, type(Ellipsis), type(NotImplemented), - types.FunctionType, weakref.ref): - d[t] = _copy_immutable -t = getattr(types, "CodeType", None) -if t is not None: + types.BuiltinFunctionType, types.EllipsisType, + types.NotImplementedType, types.FunctionType, types.CodeType, + weakref.ref): d[t] = _copy_immutable d[list] = list.copy @@ -120,9 +113,6 @@ def _copy_immutable(x): d[set] = set.copy d[bytearray] = bytearray.copy -if PyStringMap is not None: - d[PyStringMap] = PyStringMap.copy - del d, t def deepcopy(x, memo=None, _nil=[]): @@ -181,9 +171,9 @@ def deepcopy(x, memo=None, _nil=[]): def _deepcopy_atomic(x, memo): return x -d[type(None)] = _deepcopy_atomic -d[type(Ellipsis)] = _deepcopy_atomic -d[type(NotImplemented)] = _deepcopy_atomic +d[types.NoneType] = _deepcopy_atomic +d[types.EllipsisType] = _deepcopy_atomic +d[types.NotImplementedType] = _deepcopy_atomic d[int] = _deepcopy_atomic d[float] = _deepcopy_atomic d[bool] = _deepcopy_atomic @@ -231,8 +221,6 @@ def _deepcopy_dict(x, memo, deepcopy=deepcopy): y[deepcopy(key, memo)] = deepcopy(value, memo) return y d[dict] = _deepcopy_dict -if PyStringMap is not None: - d[PyStringMap] = _deepcopy_dict def _deepcopy_method(x, memo): # Copy instance methods return type(x)(x.__func__, deepcopy(x.__self__, memo)) @@ -301,4 +289,4 @@ def _reconstruct(x, memo, func, args, y[key] = value return y -del types, weakref, PyStringMap +del types, weakref From 83b8c3a3fc806a3fe966f5540269abcf7d743520 Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Tue, 20 Feb 2024 19:38:58 +0800 Subject: [PATCH 265/893] Update test_copy.py from CPython v3.12.0 --- Lib/test/test_copy.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py index 913cf3c8bc..cf3dc57930 100644 --- a/Lib/test/test_copy.py +++ b/Lib/test/test_copy.py @@ -91,9 +91,7 @@ def __getattribute__(self, name): # Type-specific _copy_xxx() methods def test_copy_atomic(self): - class Classic: - pass - class NewStyle(object): + class NewStyle: pass def f(): pass @@ -103,7 +101,7 @@ class WithMetaclass(metaclass=abc.ABCMeta): 42, 2**100, 3.14, True, False, 1j, "hello", "hello\u1234", f.__code__, b"world", bytes(range(256)), range(10), slice(1, 10, 2), - NewStyle, Classic, max, WithMetaclass, property()] + NewStyle, max, WithMetaclass, property()] for x in tests: self.assertIs(copy.copy(x), x) @@ -358,15 +356,13 @@ def __getattribute__(self, name): # Type-specific _deepcopy_xxx() methods def test_deepcopy_atomic(self): - class Classic: - pass - class NewStyle(object): + class NewStyle: pass def f(): pass tests = [None, ..., NotImplemented, 42, 2**100, 3.14, True, False, 1j, b"bytes", "hello", "hello\u1234", f.__code__, - NewStyle, range(10), Classic, max, property()] + NewStyle, range(10), max, property()] for x in tests: self.assertIs(copy.deepcopy(x), x) From 5dc6d5582abf8f6a32e0a8664484471ea6484261 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Fri, 23 Feb 2024 20:13:41 -0500 Subject: [PATCH 266/893] Update _collections_abc.py to 3.12.2 --- Lib/_collections_abc.py | 74 +++++++++++++++++++++++++++++++++++------ 1 file changed, 63 insertions(+), 11 deletions(-) diff --git a/Lib/_collections_abc.py b/Lib/_collections_abc.py index e96e4c3535..601107d2d8 100644 --- a/Lib/_collections_abc.py +++ b/Lib/_collections_abc.py @@ -6,6 +6,32 @@ Unit tests are in test_collections. """ +############ Maintenance notes ######################################### +# +# ABCs are different from other standard library modules in that they +# specify compliance tests. In general, once an ABC has been published, +# new methods (either abstract or concrete) cannot be added. +# +# Though classes that inherit from an ABC would automatically receive a +# new mixin method, registered classes would become non-compliant and +# violate the contract promised by ``isinstance(someobj, SomeABC)``. +# +# Though irritating, the correct procedure for adding new abstract or +# mixin methods is to create a new ABC as a subclass of the previous +# ABC. For example, union(), intersection(), and difference() cannot +# be added to Set but could go into a new ABC that extends Set. +# +# Because they are so hard to change, new ABCs should have their APIs +# carefully thought through prior to publication. +# +# Since ABCMeta only checks for the presence of methods, it is possible +# to alter the signature of a method by adding optional arguments +# or changing parameters names. This is still a bit dubious but at +# least it won't cause isinstance() to return an incorrect result. +# +# +####################################################################### + from abc import ABCMeta, abstractmethod import sys @@ -23,7 +49,7 @@ def _f(): pass "Mapping", "MutableMapping", "MappingView", "KeysView", "ItemsView", "ValuesView", "Sequence", "MutableSequence", - "ByteString", + "ByteString", "Buffer", ] # This module has been renamed from collections.abc to _collections_abc to @@ -413,6 +439,21 @@ def __subclasshook__(cls, C): return NotImplemented +class Buffer(metaclass=ABCMeta): + + __slots__ = () + + @abstractmethod + def __buffer__(self, flags: int, /) -> memoryview: + raise NotImplementedError + + @classmethod + def __subclasshook__(cls, C): + if cls is Buffer: + return _check_methods(C, "__buffer__") + return NotImplemented + + class _CallableGenericAlias(GenericAlias): """ Represent `Callable[argtypes, resulttype]`. @@ -455,15 +496,8 @@ def __getitem__(self, item): # rather than the default types.GenericAlias object. Most of the # code is copied from typing's _GenericAlias and the builtin # types.GenericAlias. - if not isinstance(item, tuple): item = (item,) - # A special case in PEP 612 where if X = Callable[P, int], - # then X[int, str] == X[[int, str]]. - if (len(self.__parameters__) == 1 - and _is_param_expr(self.__parameters__[0]) - and item and not _is_param_expr(item[0])): - item = (item,) new_args = super().__getitem__(item).__args__ @@ -491,9 +525,8 @@ def _type_repr(obj): Copied from :mod:`typing` since collections.abc shouldn't depend on that module. + (Keep this roughly in sync with the typing version.) """ - if isinstance(obj, GenericAlias): - return repr(obj) if isinstance(obj, type): if obj.__module__ == 'builtins': return obj.__qualname__ @@ -1038,8 +1071,27 @@ def count(self, value): Sequence.register(range) Sequence.register(memoryview) +class _DeprecateByteStringMeta(ABCMeta): + def __new__(cls, name, bases, namespace, **kwargs): + if name != "ByteString": + import warnings + + warnings._deprecated( + "collections.abc.ByteString", + remove=(3, 14), + ) + return super().__new__(cls, name, bases, namespace, **kwargs) + + def __instancecheck__(cls, instance): + import warnings + + warnings._deprecated( + "collections.abc.ByteString", + remove=(3, 14), + ) + return super().__instancecheck__(instance) -class ByteString(Sequence): +class ByteString(Sequence, metaclass=_DeprecateByteStringMeta): """This unifies bytes and bytearray. XXX Should add all their methods. From bcb35919a4772dc93f87ac9306524c3afb43dbc1 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Fri, 23 Feb 2024 20:14:41 -0500 Subject: [PATCH 267/893] Update collections/__init__.py to 3.12.2 --- Lib/collections/__init__.py | 45 +++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py index 59a2d520fe..f7348ee918 100644 --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -45,6 +45,11 @@ else: _collections_abc.MutableSequence.register(deque) +try: + from _collections import _deque_iterator +except ImportError: + pass + try: from _collections import defaultdict except ImportError: @@ -94,17 +99,19 @@ class OrderedDict(dict): # Individual links are kept alive by the hard reference in self.__map. # Those hard references disappear when a key is deleted from an OrderedDict. + def __new__(cls, /, *args, **kwds): + "Create the ordered dict object and set up the underlying structures." + self = dict.__new__(cls) + self.__hardroot = _Link() + self.__root = root = _proxy(self.__hardroot) + root.prev = root.next = root + self.__map = {} + return self + def __init__(self, other=(), /, **kwds): '''Initialize an ordered dictionary. The signature is the same as regular dictionaries. Keyword argument order is preserved. ''' - try: - self.__root - except AttributeError: - self.__hardroot = _Link() - self.__root = root = _proxy(self.__hardroot) - root.prev = root.next = root - self.__map = {} self.__update(other, **kwds) def __setitem__(self, key, value, @@ -271,7 +278,7 @@ def __repr__(self): 'od.__repr__() <==> repr(od)' if not self: return '%s()' % (self.__class__.__name__,) - return '%s(%r)' % (self.__class__.__name__, list(self.items())) + return '%s(%r)' % (self.__class__.__name__, dict(self.items())) def __reduce__(self): 'Return state information for pickling' @@ -511,9 +518,12 @@ def __getnewargs__(self): # specified a particular module. if module is None: try: - module = _sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - pass + module = _sys._getframemodulename(1) or '__main__' + except AttributeError: + try: + module = _sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + pass if module is not None: result.__module__ = module @@ -1015,8 +1025,8 @@ def __len__(self): def __iter__(self): d = {} - for mapping in reversed(self.maps): - d.update(dict.fromkeys(mapping)) # reuses stored hash values if possible + for mapping in map(dict.fromkeys, reversed(self.maps)): + d |= mapping # reuses stored hash values if possible return iter(d) def __contains__(self, key): @@ -1136,10 +1146,17 @@ def __delitem__(self, key): def __iter__(self): return iter(self.data) - # Modify __contains__ to work correctly when __missing__ is present + # Modify __contains__ and get() to work like dict + # does when __missing__ is present. def __contains__(self, key): return key in self.data + def get(self, key, default=None): + if key in self: + return self[key] + return default + + # Now, add the methods in dicts but not in MutableMapping def __repr__(self): return repr(self.data) From 0bd8c2504c51d64e8351a55153266e0059dd8137 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Fri, 23 Feb 2024 20:16:16 -0500 Subject: [PATCH 268/893] Update test_collections.py to 3.12.2 --- Lib/test/test_collections.py | 41 ++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 50e7282d17..d70490f153 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -25,7 +25,7 @@ from collections.abc import Set, MutableSet from collections.abc import Mapping, MutableMapping, KeysView, ItemsView, ValuesView from collections.abc import Sequence, MutableSequence -from collections.abc import ByteString +from collections.abc import ByteString, Buffer class TestUserObjects(unittest.TestCase): @@ -71,6 +71,14 @@ def test_dict_copy(self): obj[123] = "abc" self._copy_test(obj) + def test_dict_missing(self): + class A(UserDict): + def __missing__(self, key): + return 456 + self.assertEqual(A()[123], 456) + # get() ignores __missing__ on dict + self.assertIs(A().get(123), None) + ################################################################################ ### ChainMap (helper class for configparser and the string module) @@ -539,7 +547,7 @@ def test_odd_sizes(self): self.assertEqual(Dot(1)._replace(d=999), (999,)) self.assertEqual(Dot(1)._fields, ('d',)) - n = 5000 + n = support.EXCEEDS_RECURSION_LIMIT names = list(set(''.join([choice(string.ascii_letters) for j in range(10)]) for i in range(n))) n = len(names) @@ -1629,7 +1637,7 @@ def test_Set_from_iterable(self): class SetUsingInstanceFromIterable(MutableSet): def __init__(self, values, created_by): if not created_by: - raise ValueError(f'created_by must be specified') + raise ValueError('created_by must be specified') self.created_by = created_by self._values = set(values) @@ -1949,13 +1957,34 @@ def assert_index_same(seq1, seq2, index_args): def test_ByteString(self): for sample in [bytes, bytearray]: - self.assertIsInstance(sample(), ByteString) + with self.assertWarns(DeprecationWarning): + self.assertIsInstance(sample(), ByteString) self.assertTrue(issubclass(sample, ByteString)) for sample in [str, list, tuple]: - self.assertNotIsInstance(sample(), ByteString) + with self.assertWarns(DeprecationWarning): + self.assertNotIsInstance(sample(), ByteString) self.assertFalse(issubclass(sample, ByteString)) - self.assertNotIsInstance(memoryview(b""), ByteString) + with self.assertWarns(DeprecationWarning): + self.assertNotIsInstance(memoryview(b""), ByteString) self.assertFalse(issubclass(memoryview, ByteString)) + with self.assertWarns(DeprecationWarning): + self.validate_abstract_methods(ByteString, '__getitem__', '__len__') + + with self.assertWarns(DeprecationWarning): + class X(ByteString): pass + + with self.assertWarns(DeprecationWarning): + # No metaclass conflict + class Z(ByteString, Awaitable): pass + + def test_Buffer(self): + for sample in [bytes, bytearray, memoryview]: + self.assertIsInstance(sample(b"x"), Buffer) + self.assertTrue(issubclass(sample, Buffer)) + for sample in [str, list, tuple]: + self.assertNotIsInstance(sample(), Buffer) + self.assertFalse(issubclass(sample, Buffer)) + self.validate_abstract_methods(Buffer, '__buffer__') # TODO: RUSTPYTHON @unittest.expectedFailure From 9481df23e1219f7df851770fcddb1a1f08f85f38 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Sat, 24 Feb 2024 11:02:41 -0500 Subject: [PATCH 269/893] Disable test_Buffer This test will not work until the `__buffer__` and `__release_buffer__` methods are implemented on the appropriate builtin types, which is outside the current scope. --- Lib/test/test_collections.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index d70490f153..ecd574ab83 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -1977,6 +1977,10 @@ class X(ByteString): pass # No metaclass conflict class Z(ByteString, Awaitable): pass + # TODO: RUSTPYTHON + # Need to implement __buffer__ and __release_buffer__ + # https://docs.python.org/3.13/reference/datamodel.html#emulating-buffer-types + @unittest.expectedFailure def test_Buffer(self): for sample in [bytes, bytearray, memoryview]: self.assertIsInstance(sample(b"x"), Buffer) From c9ec4507ad2dd1e309753a38f31cfe75ae8d89d9 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Sat, 24 Feb 2024 11:58:23 -0500 Subject: [PATCH 270/893] Disable broken test_repr in test_typing.py This should be resolved when `typing.py` and `test_typing.py` are updated to 3.12. --- Lib/test/test_typing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index abcc03ce2d..3140ff9028 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -513,6 +513,8 @@ def f(): with self.assertRaises(TypeError): self.assertNotIsInstance(None, Callable[[], Any]) + # TODO: RUSTPYTHON update typing to 3.12 + @unittest.expectedFailure def test_repr(self): Callable = self.Callable fullname = f'{Callable.__module__}.Callable' From 6cd9e5442765340ec64a1280dd75756d0a309498 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Sat, 24 Feb 2024 12:00:07 -0500 Subject: [PATCH 271/893] Update test_ordered_dict.py to 3.12.2 --- Lib/test/test_ordered_dict.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_ordered_dict.py b/Lib/test/test_ordered_dict.py index cfb87d7829..942748bd91 100644 --- a/Lib/test/test_ordered_dict.py +++ b/Lib/test/test_ordered_dict.py @@ -122,6 +122,17 @@ def items(self): self.OrderedDict(Spam()) self.assertEqual(calls, ['keys']) + def test_overridden_init(self): + # Sync-up pure Python OD class with C class where + # a consistent internal state is created in __new__ + # rather than __init__. + OrderedDict = self.OrderedDict + class ODNI(OrderedDict): + def __init__(*args, **kwargs): + pass + od = ODNI() + od['a'] = 1 # This used to fail because __init__ was bypassed + def test_fromkeys(self): OrderedDict = self.OrderedDict od = OrderedDict.fromkeys('abc') @@ -370,7 +381,7 @@ def test_repr(self): OrderedDict = self.OrderedDict od = OrderedDict([('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]) self.assertEqual(repr(od), - "OrderedDict([('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)])") + "OrderedDict({'c': 1, 'b': 2, 'a': 3, 'd': 4, 'e': 5, 'f': 6})") self.assertEqual(eval(repr(od)), od) self.assertEqual(repr(OrderedDict()), "OrderedDict()") @@ -380,7 +391,7 @@ def test_repr_recursive(self): od = OrderedDict.fromkeys('abc') od['x'] = od self.assertEqual(repr(od), - "OrderedDict([('a', None), ('b', None), ('c', None), ('x', ...)])") + "OrderedDict({'a': None, 'b': None, 'c': None, 'x': ...})") def test_repr_recursive_values(self): OrderedDict = self.OrderedDict From 407f251866f437987bb86419bbd672b9327db4ea Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Sat, 24 Feb 2024 13:33:17 -0500 Subject: [PATCH 272/893] Un-skip passing typing test I missed that the typing test I disabled was on a base test class. Moving the expected failure to the subclass allows the passing test to pass. --- Lib/test/test_typing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 3140ff9028..95fd3748e6 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -513,8 +513,6 @@ def f(): with self.assertRaises(TypeError): self.assertNotIsInstance(None, Callable[[], Any]) - # TODO: RUSTPYTHON update typing to 3.12 - @unittest.expectedFailure def test_repr(self): Callable = self.Callable fullname = f'{Callable.__module__}.Callable' @@ -707,6 +705,11 @@ def test_paramspec(self): # TODO: RUSTPYTHON, remove when this passes def test_concatenate(self): # TODO: RUSTPYTHON, remove when this passes super().test_concatenate() # TODO: RUSTPYTHON, remove when this passes + # TODO: RUSTPYTHON might be fixed by updating typing to 3.12 + @unittest.expectedFailure + def test_repr(self): # TODO: RUSTPYTHON, remove when this passes + super().test_repr() # TODO: RUSTPYTHON, remove when this passes + class LiteralTests(BaseTestCase): def test_basics(self): From d26766b7bca4f6a11ac4f0f87e80ff7925d6cffe Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Thu, 29 Feb 2024 21:30:31 +0900 Subject: [PATCH 273/893] better symboltable error message --- compiler/codegen/src/compile.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index e5976047a3..0b0f2877c7 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -478,8 +478,8 @@ impl Compiler { self.check_forbidden_name(&name, usage)?; let symbol_table = self.symbol_table_stack.last().unwrap(); - let symbol = symbol_table.lookup(name.as_ref()).expect( - "The symbol must be present in the symbol table, even when it is undefined in python.", + let symbol = symbol_table.lookup(name.as_ref()).unwrap_or_else(|| + panic!("The symbol '{name}' must be present in the symbol table, even when it is undefined in python."), ); let info = self.code_stack.last_mut().unwrap(); let mut cache = &mut info.name_cache; From 88ee64d5858823dc98df9564bc3619bcd543ba61 Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Tue, 20 Feb 2024 19:59:02 +0800 Subject: [PATCH 274/893] Update csv.py from CPython v3.12.0 --- Lib/csv.py | 48 ++++++++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/Lib/csv.py b/Lib/csv.py index 2f38bb1a19..77f30c8d2b 100644 --- a/Lib/csv.py +++ b/Lib/csv.py @@ -4,17 +4,22 @@ """ import re -from _csv import Error, writer, reader, \ +import types +from _csv import Error, __version__, writer, reader, register_dialect, \ + unregister_dialect, get_dialect, list_dialects, \ + field_size_limit, \ QUOTE_MINIMAL, QUOTE_ALL, QUOTE_NONNUMERIC, QUOTE_NONE, \ + QUOTE_STRINGS, QUOTE_NOTNULL, \ __doc__ +from _csv import Dialect as _Dialect -from collections import OrderedDict from io import StringIO __all__ = ["QUOTE_MINIMAL", "QUOTE_ALL", "QUOTE_NONNUMERIC", "QUOTE_NONE", + "QUOTE_STRINGS", "QUOTE_NOTNULL", "Error", "Dialect", "__doc__", "excel", "excel_tab", "field_size_limit", "reader", "writer", - "Sniffer", + "register_dialect", "get_dialect", "list_dialects", "Sniffer", "unregister_dialect", "__version__", "DictReader", "DictWriter", "unix_dialect"] @@ -57,10 +62,12 @@ class excel(Dialect): skipinitialspace = False lineterminator = '\r\n' quoting = QUOTE_MINIMAL +register_dialect("excel", excel) class excel_tab(excel): """Describe the usual properties of Excel-generated TAB-delimited files.""" delimiter = '\t' +register_dialect("excel-tab", excel_tab) class unix_dialect(Dialect): """Describe the usual properties of Unix-generated CSV files.""" @@ -70,11 +77,14 @@ class unix_dialect(Dialect): skipinitialspace = False lineterminator = '\n' quoting = QUOTE_ALL +register_dialect("unix", unix_dialect) class DictReader: def __init__(self, f, fieldnames=None, restkey=None, restval=None, dialect="excel", *args, **kwds): + if fieldnames is not None and iter(fieldnames) is fieldnames: + fieldnames = list(fieldnames) self._fieldnames = fieldnames # list of keys for the dict self.restkey = restkey # key to catch long rows self.restval = restval # default value for short rows @@ -111,7 +121,7 @@ def __next__(self): # values while row == []: row = next(self.reader) - d = OrderedDict(zip(self.fieldnames, row)) + d = dict(zip(self.fieldnames, row)) lf = len(self.fieldnames) lr = len(row) if lf < lr: @@ -121,13 +131,18 @@ def __next__(self): d[key] = self.restval return d + __class_getitem__ = classmethod(types.GenericAlias) + class DictWriter: def __init__(self, f, fieldnames, restval="", extrasaction="raise", dialect="excel", *args, **kwds): + if fieldnames is not None and iter(fieldnames) is fieldnames: + fieldnames = list(fieldnames) self.fieldnames = fieldnames # list of keys for the dict self.restval = restval # for writing short dicts - if extrasaction.lower() not in ("raise", "ignore"): + extrasaction = extrasaction.lower() + if extrasaction not in ("raise", "ignore"): raise ValueError("extrasaction (%s) must be 'raise' or 'ignore'" % extrasaction) self.extrasaction = extrasaction @@ -135,7 +150,7 @@ def __init__(self, f, fieldnames, restval="", extrasaction="raise", def writeheader(self): header = dict(zip(self.fieldnames, self.fieldnames)) - self.writerow(header) + return self.writerow(header) def _dict_to_list(self, rowdict): if self.extrasaction == "raise": @@ -151,11 +166,8 @@ def writerow(self, rowdict): def writerows(self, rowdicts): return self.writer.writerows(map(self._dict_to_list, rowdicts)) -# Guard Sniffer's type checking against builds that exclude complex() -try: - complex -except NameError: - complex = float + __class_getitem__ = classmethod(types.GenericAlias) + class Sniffer: ''' @@ -404,14 +416,10 @@ def has_header(self, sample): continue # skip rows that have irregular number of columns for col in list(columnTypes.keys()): - - for thisType in [int, float, complex]: - try: - thisType(row[col]) - break - except (ValueError, OverflowError): - pass - else: + thisType = complex + try: + thisType(row[col]) + except (ValueError, OverflowError): # fallback to length of string thisType = len(row[col]) @@ -427,7 +435,7 @@ def has_header(self, sample): # on whether it's a header hasHeader = 0 for col, colType in columnTypes.items(): - if type(colType) == type(0): # it's a length + if isinstance(colType, int): # it's a length if len(header[col]) != colType: hasHeader += 1 else: From d2bf69e354467ef20e24d3dccff77664983dc48a Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Tue, 20 Feb 2024 20:19:32 +0800 Subject: [PATCH 275/893] Update test_csv.py from CPython v3.12.0 --- Lib/test/test_csv.py | 1441 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1441 insertions(+) create mode 100644 Lib/test/test_csv.py diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py new file mode 100644 index 0000000000..bc9961e083 --- /dev/null +++ b/Lib/test/test_csv.py @@ -0,0 +1,1441 @@ +# Copyright (C) 2001,2002 Python Software Foundation +# csv package unit tests + +import copy +import sys +import unittest +from io import StringIO +from tempfile import TemporaryFile +import csv +import gc +import pickle +from test import support +from test.support import warnings_helper, import_helper, check_disallow_instantiation +from itertools import permutations +from textwrap import dedent +from collections import OrderedDict + + +class BadIterable: + def __iter__(self): + raise OSError + + +class Test_Csv(unittest.TestCase): + """ + Test the underlying C csv parser in ways that are not appropriate + from the high level interface. Further tests of this nature are done + in TestDialectRegistry. + """ + def _test_arg_valid(self, ctor, arg): + self.assertRaises(TypeError, ctor) + self.assertRaises(TypeError, ctor, None) + self.assertRaises(TypeError, ctor, arg, bad_attr = 0) + self.assertRaises(TypeError, ctor, arg, delimiter = 0) + self.assertRaises(TypeError, ctor, arg, delimiter = 'XX') + self.assertRaises(csv.Error, ctor, arg, 'foo') + self.assertRaises(TypeError, ctor, arg, delimiter=None) + self.assertRaises(TypeError, ctor, arg, delimiter=1) + self.assertRaises(TypeError, ctor, arg, quotechar=1) + self.assertRaises(TypeError, ctor, arg, lineterminator=None) + self.assertRaises(TypeError, ctor, arg, lineterminator=1) + self.assertRaises(TypeError, ctor, arg, quoting=None) + self.assertRaises(TypeError, ctor, arg, + quoting=csv.QUOTE_ALL, quotechar='') + self.assertRaises(TypeError, ctor, arg, + quoting=csv.QUOTE_ALL, quotechar=None) + self.assertRaises(TypeError, ctor, arg, + quoting=csv.QUOTE_NONE, quotechar='') + + def test_reader_arg_valid(self): + self._test_arg_valid(csv.reader, []) + self.assertRaises(OSError, csv.reader, BadIterable()) + + def test_writer_arg_valid(self): + self._test_arg_valid(csv.writer, StringIO()) + class BadWriter: + @property + def write(self): + raise OSError + self.assertRaises(OSError, csv.writer, BadWriter()) + + def _test_default_attrs(self, ctor, *args): + obj = ctor(*args) + # Check defaults + self.assertEqual(obj.dialect.delimiter, ',') + self.assertIs(obj.dialect.doublequote, True) + self.assertEqual(obj.dialect.escapechar, None) + self.assertEqual(obj.dialect.lineterminator, "\r\n") + self.assertEqual(obj.dialect.quotechar, '"') + self.assertEqual(obj.dialect.quoting, csv.QUOTE_MINIMAL) + self.assertIs(obj.dialect.skipinitialspace, False) + self.assertIs(obj.dialect.strict, False) + # Try deleting or changing attributes (they are read-only) + self.assertRaises(AttributeError, delattr, obj.dialect, 'delimiter') + self.assertRaises(AttributeError, setattr, obj.dialect, 'delimiter', ':') + self.assertRaises(AttributeError, delattr, obj.dialect, 'quoting') + self.assertRaises(AttributeError, setattr, obj.dialect, + 'quoting', None) + + def test_reader_attrs(self): + self._test_default_attrs(csv.reader, []) + + def test_writer_attrs(self): + self._test_default_attrs(csv.writer, StringIO()) + + def _test_kw_attrs(self, ctor, *args): + # Now try with alternate options + kwargs = dict(delimiter=':', doublequote=False, escapechar='\\', + lineterminator='\r', quotechar='*', + quoting=csv.QUOTE_NONE, skipinitialspace=True, + strict=True) + obj = ctor(*args, **kwargs) + self.assertEqual(obj.dialect.delimiter, ':') + self.assertIs(obj.dialect.doublequote, False) + self.assertEqual(obj.dialect.escapechar, '\\') + self.assertEqual(obj.dialect.lineterminator, "\r") + self.assertEqual(obj.dialect.quotechar, '*') + self.assertEqual(obj.dialect.quoting, csv.QUOTE_NONE) + self.assertIs(obj.dialect.skipinitialspace, True) + self.assertIs(obj.dialect.strict, True) + + def test_reader_kw_attrs(self): + self._test_kw_attrs(csv.reader, []) + + def test_writer_kw_attrs(self): + self._test_kw_attrs(csv.writer, StringIO()) + + def _test_dialect_attrs(self, ctor, *args): + # Now try with dialect-derived options + class dialect: + delimiter='-' + doublequote=False + escapechar='^' + lineterminator='$' + quotechar='#' + quoting=csv.QUOTE_ALL + skipinitialspace=True + strict=False + args = args + (dialect,) + obj = ctor(*args) + self.assertEqual(obj.dialect.delimiter, '-') + self.assertIs(obj.dialect.doublequote, False) + self.assertEqual(obj.dialect.escapechar, '^') + self.assertEqual(obj.dialect.lineterminator, "$") + self.assertEqual(obj.dialect.quotechar, '#') + self.assertEqual(obj.dialect.quoting, csv.QUOTE_ALL) + self.assertIs(obj.dialect.skipinitialspace, True) + self.assertIs(obj.dialect.strict, False) + + def test_reader_dialect_attrs(self): + self._test_dialect_attrs(csv.reader, []) + + def test_writer_dialect_attrs(self): + self._test_dialect_attrs(csv.writer, StringIO()) + + + def _write_test(self, fields, expect, **kwargs): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, **kwargs) + writer.writerow(fields) + fileobj.seek(0) + self.assertEqual(fileobj.read(), + expect + writer.dialect.lineterminator) + + def _write_error_test(self, exc, fields, **kwargs): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, **kwargs) + with self.assertRaises(exc): + writer.writerow(fields) + fileobj.seek(0) + self.assertEqual(fileobj.read(), '') + + def test_write_arg_valid(self): + self._write_error_test(csv.Error, None) + self._write_test((), '') + self._write_test([None], '""') + self._write_error_test(csv.Error, [None], quoting = csv.QUOTE_NONE) + # Check that exceptions are passed up the chain + self._write_error_test(OSError, BadIterable()) + class BadList: + def __len__(self): + return 10 + def __getitem__(self, i): + if i > 2: + raise OSError + self._write_error_test(OSError, BadList()) + class BadItem: + def __str__(self): + raise OSError + self._write_error_test(OSError, [BadItem()]) + + def test_write_bigfield(self): + # This exercises the buffer realloc functionality + bigstring = 'X' * 50000 + self._write_test([bigstring,bigstring], '%s,%s' % \ + (bigstring, bigstring)) + + def test_write_quoting(self): + self._write_test(['a',1,'p,q'], 'a,1,"p,q"') + self._write_error_test(csv.Error, ['a',1,'p,q'], + quoting = csv.QUOTE_NONE) + self._write_test(['a',1,'p,q'], 'a,1,"p,q"', + quoting = csv.QUOTE_MINIMAL) + self._write_test(['a',1,'p,q'], '"a",1,"p,q"', + quoting = csv.QUOTE_NONNUMERIC) + self._write_test(['a',1,'p,q'], '"a","1","p,q"', + quoting = csv.QUOTE_ALL) + self._write_test(['a\nb',1], '"a\nb","1"', + quoting = csv.QUOTE_ALL) + self._write_test(['a','',None,1], '"a","",,1', + quoting = csv.QUOTE_STRINGS) + self._write_test(['a','',None,1], '"a","",,"1"', + quoting = csv.QUOTE_NOTNULL) + + def test_write_escape(self): + self._write_test(['a',1,'p,q'], 'a,1,"p,q"', + escapechar='\\') + self._write_error_test(csv.Error, ['a',1,'p,"q"'], + escapechar=None, doublequote=False) + self._write_test(['a',1,'p,"q"'], 'a,1,"p,\\"q\\""', + escapechar='\\', doublequote = False) + self._write_test(['"'], '""""', + escapechar='\\', quoting = csv.QUOTE_MINIMAL) + self._write_test(['"'], '\\"', + escapechar='\\', quoting = csv.QUOTE_MINIMAL, + doublequote = False) + self._write_test(['"'], '\\"', + escapechar='\\', quoting = csv.QUOTE_NONE) + self._write_test(['a',1,'p,q'], 'a,1,p\\,q', + escapechar='\\', quoting = csv.QUOTE_NONE) + self._write_test(['\\', 'a'], '\\\\,a', + escapechar='\\', quoting=csv.QUOTE_NONE) + self._write_test(['\\', 'a'], '\\\\,a', + escapechar='\\', quoting=csv.QUOTE_MINIMAL) + self._write_test(['\\', 'a'], '"\\\\","a"', + escapechar='\\', quoting=csv.QUOTE_ALL) + self._write_test(['\\ ', 'a'], '\\\\ ,a', + escapechar='\\', quoting=csv.QUOTE_MINIMAL) + self._write_test(['\\,', 'a'], '\\\\\\,,a', + escapechar='\\', quoting=csv.QUOTE_NONE) + self._write_test([',\\', 'a'], '",\\\\",a', + escapechar='\\', quoting=csv.QUOTE_MINIMAL) + self._write_test(['C\\', '6', '7', 'X"'], 'C\\\\,6,7,"X"""', + escapechar='\\', quoting=csv.QUOTE_MINIMAL) + + def test_write_lineterminator(self): + for lineterminator in '\r\n', '\n', '\r', '!@#', '\0': + with self.subTest(lineterminator=lineterminator): + with StringIO() as sio: + writer = csv.writer(sio, lineterminator=lineterminator) + writer.writerow(['a', 'b']) + writer.writerow([1, 2]) + self.assertEqual(sio.getvalue(), + f'a,b{lineterminator}' + f'1,2{lineterminator}') + + def test_write_iterable(self): + self._write_test(iter(['a', 1, 'p,q']), 'a,1,"p,q"') + self._write_test(iter(['a', 1, None]), 'a,1,') + self._write_test(iter([]), '') + self._write_test(iter([None]), '""') + self._write_error_test(csv.Error, iter([None]), quoting=csv.QUOTE_NONE) + self._write_test(iter([None, None]), ',') + + def test_writerows(self): + class BrokenFile: + def write(self, buf): + raise OSError + writer = csv.writer(BrokenFile()) + self.assertRaises(OSError, writer.writerows, [['a']]) + + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + self.assertRaises(TypeError, writer.writerows, None) + writer.writerows([['a', 'b'], ['c', 'd']]) + fileobj.seek(0) + self.assertEqual(fileobj.read(), "a,b\r\nc,d\r\n") + + def test_writerows_with_none(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + writer.writerows([['a', None], [None, 'd']]) + fileobj.seek(0) + self.assertEqual(fileobj.read(), "a,\r\n,d\r\n") + + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + writer.writerows([[None], ['a']]) + fileobj.seek(0) + self.assertEqual(fileobj.read(), '""\r\na\r\n') + + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + writer.writerows([['a'], [None]]) + fileobj.seek(0) + self.assertEqual(fileobj.read(), 'a\r\n""\r\n') + + def test_writerows_errors(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + self.assertRaises(TypeError, writer.writerows, None) + self.assertRaises(OSError, writer.writerows, BadIterable()) + + @support.cpython_only + @support.requires_legacy_unicode_capi() + @warnings_helper.ignore_warnings(category=DeprecationWarning) + def test_writerows_legacy_strings(self): + import _testcapi + c = _testcapi.unicode_legacy_string('a') + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + writer.writerows([[c]]) + fileobj.seek(0) + self.assertEqual(fileobj.read(), "a\r\n") + + def _read_test(self, input, expect, **kwargs): + reader = csv.reader(input, **kwargs) + result = list(reader) + self.assertEqual(result, expect) + + def test_read_oddinputs(self): + self._read_test([], []) + self._read_test([''], [[]]) + self.assertRaises(csv.Error, self._read_test, + ['"ab"c'], None, strict = 1) + self._read_test(['"ab"c'], [['abc']], doublequote = 0) + + self.assertRaises(csv.Error, self._read_test, + [b'abc'], None) + + def test_read_eol(self): + self._read_test(['a,b'], [['a','b']]) + self._read_test(['a,b\n'], [['a','b']]) + self._read_test(['a,b\r\n'], [['a','b']]) + self._read_test(['a,b\r'], [['a','b']]) + self.assertRaises(csv.Error, self._read_test, ['a,b\rc,d'], []) + self.assertRaises(csv.Error, self._read_test, ['a,b\nc,d'], []) + self.assertRaises(csv.Error, self._read_test, ['a,b\r\nc,d'], []) + + def test_read_eof(self): + self._read_test(['a,"'], [['a', '']]) + self._read_test(['"a'], [['a']]) + self._read_test(['^'], [['\n']], escapechar='^') + self.assertRaises(csv.Error, self._read_test, ['a,"'], [], strict=True) + self.assertRaises(csv.Error, self._read_test, ['"a'], [], strict=True) + self.assertRaises(csv.Error, self._read_test, + ['^'], [], escapechar='^', strict=True) + + def test_read_nul(self): + self._read_test(['\0'], [['\0']]) + self._read_test(['a,\0b,c'], [['a', '\0b', 'c']]) + self._read_test(['a,b\0,c'], [['a', 'b\0', 'c']]) + self._read_test(['a,b\\\0,c'], [['a', 'b\0', 'c']], escapechar='\\') + self._read_test(['a,"\0b",c'], [['a', '\0b', 'c']]) + + def test_read_delimiter(self): + self._read_test(['a,b,c'], [['a', 'b', 'c']]) + self._read_test(['a;b;c'], [['a', 'b', 'c']], delimiter=';') + self._read_test(['a\0b\0c'], [['a', 'b', 'c']], delimiter='\0') + + def test_read_escape(self): + self._read_test(['a,\\b,c'], [['a', 'b', 'c']], escapechar='\\') + self._read_test(['a,b\\,c'], [['a', 'b,c']], escapechar='\\') + self._read_test(['a,"b\\,c"'], [['a', 'b,c']], escapechar='\\') + self._read_test(['a,"b,\\c"'], [['a', 'b,c']], escapechar='\\') + self._read_test(['a,"b,c\\""'], [['a', 'b,c"']], escapechar='\\') + self._read_test(['a,"b,c"\\'], [['a', 'b,c\\']], escapechar='\\') + self._read_test(['a,^b,c'], [['a', 'b', 'c']], escapechar='^') + self._read_test(['a,\0b,c'], [['a', 'b', 'c']], escapechar='\0') + self._read_test(['a,\\b,c'], [['a', '\\b', 'c']], escapechar=None) + self._read_test(['a,\\b,c'], [['a', '\\b', 'c']]) + + def test_read_quoting(self): + self._read_test(['1,",3,",5'], [['1', ',3,', '5']]) + self._read_test(['1,",3,",5'], [['1', '"', '3', '"', '5']], + quotechar=None, escapechar='\\') + self._read_test(['1,",3,",5'], [['1', '"', '3', '"', '5']], + quoting=csv.QUOTE_NONE, escapechar='\\') + # will this fail where locale uses comma for decimals? + self._read_test([',3,"5",7.3, 9'], [['', 3, '5', 7.3, 9]], + quoting=csv.QUOTE_NONNUMERIC) + self._read_test(['"a\nb", 7'], [['a\nb', ' 7']]) + self.assertRaises(ValueError, self._read_test, + ['abc,3'], [[]], + quoting=csv.QUOTE_NONNUMERIC) + self._read_test(['1,@,3,@,5'], [['1', ',3,', '5']], quotechar='@') + self._read_test(['1,\0,3,\0,5'], [['1', ',3,', '5']], quotechar='\0') + + def test_read_skipinitialspace(self): + self._read_test(['no space, space, spaces,\ttab'], + [['no space', 'space', 'spaces', '\ttab']], + skipinitialspace=True) + + def test_read_bigfield(self): + # This exercises the buffer realloc functionality and field size + # limits. + limit = csv.field_size_limit() + try: + size = 50000 + bigstring = 'X' * size + bigline = '%s,%s' % (bigstring, bigstring) + self._read_test([bigline], [[bigstring, bigstring]]) + csv.field_size_limit(size) + self._read_test([bigline], [[bigstring, bigstring]]) + self.assertEqual(csv.field_size_limit(), size) + csv.field_size_limit(size-1) + self.assertRaises(csv.Error, self._read_test, [bigline], []) + self.assertRaises(TypeError, csv.field_size_limit, None) + self.assertRaises(TypeError, csv.field_size_limit, 1, None) + finally: + csv.field_size_limit(limit) + + def test_read_linenum(self): + r = csv.reader(['line,1', 'line,2', 'line,3']) + self.assertEqual(r.line_num, 0) + next(r) + self.assertEqual(r.line_num, 1) + next(r) + self.assertEqual(r.line_num, 2) + next(r) + self.assertEqual(r.line_num, 3) + self.assertRaises(StopIteration, next, r) + self.assertEqual(r.line_num, 3) + + def test_roundtrip_quoteed_newlines(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj) + rows = [['a\nb','b'],['c','x\r\nd']] + writer.writerows(rows) + fileobj.seek(0) + for i, row in enumerate(csv.reader(fileobj)): + self.assertEqual(row, rows[i]) + + def test_roundtrip_escaped_unquoted_newlines(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\") + rows = [['a\nb','b'],['c','x\r\nd']] + writer.writerows(rows) + fileobj.seek(0) + for i, row in enumerate(csv.reader(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\")): + self.assertEqual(row,rows[i]) + +class TestDialectRegistry(unittest.TestCase): + def test_registry_badargs(self): + self.assertRaises(TypeError, csv.list_dialects, None) + self.assertRaises(TypeError, csv.get_dialect) + self.assertRaises(csv.Error, csv.get_dialect, None) + self.assertRaises(csv.Error, csv.get_dialect, "nonesuch") + self.assertRaises(TypeError, csv.unregister_dialect) + self.assertRaises(csv.Error, csv.unregister_dialect, None) + self.assertRaises(csv.Error, csv.unregister_dialect, "nonesuch") + self.assertRaises(TypeError, csv.register_dialect, None) + self.assertRaises(TypeError, csv.register_dialect, None, None) + self.assertRaises(TypeError, csv.register_dialect, "nonesuch", 0, 0) + self.assertRaises(TypeError, csv.register_dialect, "nonesuch", + badargument=None) + self.assertRaises(TypeError, csv.register_dialect, "nonesuch", + quoting=None) + self.assertRaises(TypeError, csv.register_dialect, []) + + def test_registry(self): + class myexceltsv(csv.excel): + delimiter = "\t" + name = "myexceltsv" + expected_dialects = csv.list_dialects() + [name] + expected_dialects.sort() + csv.register_dialect(name, myexceltsv) + self.addCleanup(csv.unregister_dialect, name) + self.assertEqual(csv.get_dialect(name).delimiter, '\t') + got_dialects = sorted(csv.list_dialects()) + self.assertEqual(expected_dialects, got_dialects) + + def test_register_kwargs(self): + name = 'fedcba' + csv.register_dialect(name, delimiter=';') + self.addCleanup(csv.unregister_dialect, name) + self.assertEqual(csv.get_dialect(name).delimiter, ';') + self.assertEqual([['X', 'Y', 'Z']], list(csv.reader(['X;Y;Z'], name))) + + def test_register_kwargs_override(self): + class mydialect(csv.Dialect): + delimiter = "\t" + quotechar = '"' + doublequote = True + skipinitialspace = False + lineterminator = '\r\n' + quoting = csv.QUOTE_MINIMAL + + name = 'test_dialect' + csv.register_dialect(name, mydialect, + delimiter=';', + quotechar="'", + doublequote=False, + skipinitialspace=True, + lineterminator='\n', + quoting=csv.QUOTE_ALL) + self.addCleanup(csv.unregister_dialect, name) + + # Ensure that kwargs do override attributes of a dialect class: + dialect = csv.get_dialect(name) + self.assertEqual(dialect.delimiter, ';') + self.assertEqual(dialect.quotechar, "'") + self.assertEqual(dialect.doublequote, False) + self.assertEqual(dialect.skipinitialspace, True) + self.assertEqual(dialect.lineterminator, '\n') + self.assertEqual(dialect.quoting, csv.QUOTE_ALL) + + def test_incomplete_dialect(self): + class myexceltsv(csv.Dialect): + delimiter = "\t" + self.assertRaises(csv.Error, myexceltsv) + + def test_space_dialect(self): + class space(csv.excel): + delimiter = " " + quoting = csv.QUOTE_NONE + escapechar = "\\" + + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("abc def\nc1ccccc1 benzene\n") + fileobj.seek(0) + reader = csv.reader(fileobj, dialect=space()) + self.assertEqual(next(reader), ["abc", "def"]) + self.assertEqual(next(reader), ["c1ccccc1", "benzene"]) + + def compare_dialect_123(self, expected, *writeargs, **kwwriteargs): + + with TemporaryFile("w+", newline='', encoding="utf-8") as fileobj: + + writer = csv.writer(fileobj, *writeargs, **kwwriteargs) + writer.writerow([1,2,3]) + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected) + + def test_dialect_apply(self): + class testA(csv.excel): + delimiter = "\t" + class testB(csv.excel): + delimiter = ":" + class testC(csv.excel): + delimiter = "|" + class testUni(csv.excel): + delimiter = "\u039B" + + class unspecified(): + # A class to pass as dialect but with no dialect attributes. + pass + + csv.register_dialect('testC', testC) + try: + self.compare_dialect_123("1,2,3\r\n") + self.compare_dialect_123("1,2,3\r\n", dialect=None) + self.compare_dialect_123("1,2,3\r\n", dialect=unspecified) + self.compare_dialect_123("1\t2\t3\r\n", testA) + self.compare_dialect_123("1:2:3\r\n", dialect=testB()) + self.compare_dialect_123("1|2|3\r\n", dialect='testC') + self.compare_dialect_123("1;2;3\r\n", dialect=testA, + delimiter=';') + self.compare_dialect_123("1\u039B2\u039B3\r\n", + dialect=testUni) + + finally: + csv.unregister_dialect('testC') + + def test_bad_dialect(self): + # Unknown parameter + self.assertRaises(TypeError, csv.reader, [], bad_attr = 0) + # Bad values + self.assertRaises(TypeError, csv.reader, [], delimiter = None) + self.assertRaises(TypeError, csv.reader, [], quoting = -1) + self.assertRaises(TypeError, csv.reader, [], quoting = 100) + + def test_copy(self): + for name in csv.list_dialects(): + dialect = csv.get_dialect(name) + self.assertRaises(TypeError, copy.copy, dialect) + + def test_pickle(self): + for name in csv.list_dialects(): + dialect = csv.get_dialect(name) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + self.assertRaises(TypeError, pickle.dumps, dialect, proto) + +class TestCsvBase(unittest.TestCase): + def readerAssertEqual(self, input, expected_result): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + fileobj.write(input) + fileobj.seek(0) + reader = csv.reader(fileobj, dialect = self.dialect) + fields = list(reader) + self.assertEqual(fields, expected_result) + + def writerAssertEqual(self, input, expected_result): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, dialect = self.dialect) + writer.writerows(input) + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected_result) + +class TestDialectExcel(TestCsvBase): + dialect = 'excel' + + def test_single(self): + self.readerAssertEqual('abc', [['abc']]) + + def test_simple(self): + self.readerAssertEqual('1,2,3,4,5', [['1','2','3','4','5']]) + + def test_blankline(self): + self.readerAssertEqual('', []) + + def test_empty_fields(self): + self.readerAssertEqual(',', [['', '']]) + + def test_singlequoted(self): + self.readerAssertEqual('""', [['']]) + + def test_singlequoted_left_empty(self): + self.readerAssertEqual('"",', [['','']]) + + def test_singlequoted_right_empty(self): + self.readerAssertEqual(',""', [['','']]) + + def test_single_quoted_quote(self): + self.readerAssertEqual('""""', [['"']]) + + def test_quoted_quotes(self): + self.readerAssertEqual('""""""', [['""']]) + + def test_inline_quote(self): + self.readerAssertEqual('a""b', [['a""b']]) + + def test_inline_quotes(self): + self.readerAssertEqual('a"b"c', [['a"b"c']]) + + def test_quotes_and_more(self): + # Excel would never write a field containing '"a"b', but when + # reading one, it will return 'ab'. + self.readerAssertEqual('"a"b', [['ab']]) + + def test_lone_quote(self): + self.readerAssertEqual('a"b', [['a"b']]) + + def test_quote_and_quote(self): + # Excel would never write a field containing '"a" "b"', but when + # reading one, it will return 'a "b"'. + self.readerAssertEqual('"a" "b"', [['a "b"']]) + + def test_space_and_quote(self): + self.readerAssertEqual(' "a"', [[' "a"']]) + + def test_quoted(self): + self.readerAssertEqual('1,2,3,"I think, therefore I am",5,6', + [['1', '2', '3', + 'I think, therefore I am', + '5', '6']]) + + def test_quoted_quote(self): + self.readerAssertEqual('1,2,3,"""I see,"" said the blind man","as he picked up his hammer and saw"', + [['1', '2', '3', + '"I see," said the blind man', + 'as he picked up his hammer and saw']]) + + def test_quoted_nl(self): + input = '''\ +1,2,3,"""I see,"" +said the blind man","as he picked up his +hammer and saw" +9,8,7,6''' + self.readerAssertEqual(input, + [['1', '2', '3', + '"I see,"\nsaid the blind man', + 'as he picked up his\nhammer and saw'], + ['9','8','7','6']]) + + def test_dubious_quote(self): + self.readerAssertEqual('12,12,1",', [['12', '12', '1"', '']]) + + def test_null(self): + self.writerAssertEqual([], '') + + def test_single_writer(self): + self.writerAssertEqual([['abc']], 'abc\r\n') + + def test_simple_writer(self): + self.writerAssertEqual([[1, 2, 'abc', 3, 4]], '1,2,abc,3,4\r\n') + + def test_quotes(self): + self.writerAssertEqual([[1, 2, 'a"bc"', 3, 4]], '1,2,"a""bc""",3,4\r\n') + + def test_quote_fieldsep(self): + self.writerAssertEqual([['abc,def']], '"abc,def"\r\n') + + def test_newlines(self): + self.writerAssertEqual([[1, 2, 'a\nbc', 3, 4]], '1,2,"a\nbc",3,4\r\n') + +class EscapedExcel(csv.excel): + quoting = csv.QUOTE_NONE + escapechar = '\\' + +class TestEscapedExcel(TestCsvBase): + dialect = EscapedExcel() + + def test_escape_fieldsep(self): + self.writerAssertEqual([['abc,def']], 'abc\\,def\r\n') + + def test_read_escape_fieldsep(self): + self.readerAssertEqual('abc\\,def\r\n', [['abc,def']]) + +class TestDialectUnix(TestCsvBase): + dialect = 'unix' + + def test_simple_writer(self): + self.writerAssertEqual([[1, 'abc def', 'abc']], '"1","abc def","abc"\n') + + def test_simple_reader(self): + self.readerAssertEqual('"1","abc def","abc"\n', [['1', 'abc def', 'abc']]) + +class QuotedEscapedExcel(csv.excel): + quoting = csv.QUOTE_NONNUMERIC + escapechar = '\\' + +class TestQuotedEscapedExcel(TestCsvBase): + dialect = QuotedEscapedExcel() + + def test_write_escape_fieldsep(self): + self.writerAssertEqual([['abc,def']], '"abc,def"\r\n') + + def test_read_escape_fieldsep(self): + self.readerAssertEqual('"abc\\,def"\r\n', [['abc,def']]) + +class TestDictFields(unittest.TestCase): + ### "long" means the row is longer than the number of fieldnames + ### "short" means there are fewer elements in the row than fieldnames + def test_writeheader_return_value(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.DictWriter(fileobj, fieldnames = ["f1", "f2", "f3"]) + writeheader_return_value = writer.writeheader() + self.assertEqual(writeheader_return_value, 10) + + def test_write_simple_dict(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.DictWriter(fileobj, fieldnames = ["f1", "f2", "f3"]) + writer.writeheader() + fileobj.seek(0) + self.assertEqual(fileobj.readline(), "f1,f2,f3\r\n") + writer.writerow({"f1": 10, "f3": "abc"}) + fileobj.seek(0) + fileobj.readline() # header + self.assertEqual(fileobj.read(), "10,,abc\r\n") + + def test_write_multiple_dict_rows(self): + fileobj = StringIO() + writer = csv.DictWriter(fileobj, fieldnames=["f1", "f2", "f3"]) + writer.writeheader() + self.assertEqual(fileobj.getvalue(), "f1,f2,f3\r\n") + writer.writerows([{"f1": 1, "f2": "abc", "f3": "f"}, + {"f1": 2, "f2": 5, "f3": "xyz"}]) + self.assertEqual(fileobj.getvalue(), + "f1,f2,f3\r\n1,abc,f\r\n2,5,xyz\r\n") + + def test_write_no_fields(self): + fileobj = StringIO() + self.assertRaises(TypeError, csv.DictWriter, fileobj) + + def test_write_fields_not_in_fieldnames(self): + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.DictWriter(fileobj, fieldnames = ["f1", "f2", "f3"]) + # Of special note is the non-string key (issue 19449) + with self.assertRaises(ValueError) as cx: + writer.writerow({"f4": 10, "f2": "spam", 1: "abc"}) + exception = str(cx.exception) + self.assertIn("fieldnames", exception) + self.assertIn("'f4'", exception) + self.assertNotIn("'f2'", exception) + self.assertIn("1", exception) + + def test_typo_in_extrasaction_raises_error(self): + fileobj = StringIO() + self.assertRaises(ValueError, csv.DictWriter, fileobj, ['f1', 'f2'], + extrasaction="raised") + + def test_write_field_not_in_field_names_raise(self): + fileobj = StringIO() + writer = csv.DictWriter(fileobj, ['f1', 'f2'], extrasaction="raise") + dictrow = {'f0': 0, 'f1': 1, 'f2': 2, 'f3': 3} + self.assertRaises(ValueError, csv.DictWriter.writerow, writer, dictrow) + + # see bpo-44512 (differently cased 'raise' should not result in 'ignore') + writer = csv.DictWriter(fileobj, ['f1', 'f2'], extrasaction="RAISE") + self.assertRaises(ValueError, csv.DictWriter.writerow, writer, dictrow) + + def test_write_field_not_in_field_names_ignore(self): + fileobj = StringIO() + writer = csv.DictWriter(fileobj, ['f1', 'f2'], extrasaction="ignore") + dictrow = {'f0': 0, 'f1': 1, 'f2': 2, 'f3': 3} + csv.DictWriter.writerow(writer, dictrow) + self.assertEqual(fileobj.getvalue(), "1,2\r\n") + + # bpo-44512 + writer = csv.DictWriter(fileobj, ['f1', 'f2'], extrasaction="IGNORE") + csv.DictWriter.writerow(writer, dictrow) + + def test_dict_reader_fieldnames_accepts_iter(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + reader = csv.DictReader(f, iter(fieldnames)) + self.assertEqual(reader.fieldnames, fieldnames) + + def test_dict_reader_fieldnames_accepts_list(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + reader = csv.DictReader(f, fieldnames) + self.assertEqual(reader.fieldnames, fieldnames) + + def test_dict_writer_fieldnames_rejects_iter(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + writer = csv.DictWriter(f, iter(fieldnames)) + self.assertEqual(writer.fieldnames, fieldnames) + + def test_dict_writer_fieldnames_accepts_list(self): + fieldnames = ["a", "b", "c"] + f = StringIO() + writer = csv.DictWriter(f, fieldnames) + self.assertEqual(writer.fieldnames, fieldnames) + + def test_dict_reader_fieldnames_is_optional(self): + f = StringIO() + reader = csv.DictReader(f, fieldnames=None) + + def test_read_dict_fields(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("1,2,abc\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, + fieldnames=["f1", "f2", "f3"]) + self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'}) + + def test_read_dict_no_fieldnames(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("f1,f2,f3\r\n1,2,abc\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj) + self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'}) + self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) + + # Two test cases to make sure existing ways of implicitly setting + # fieldnames continue to work. Both arise from discussion in issue3436. + def test_read_dict_fieldnames_from_file(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("f1,f2,f3\r\n1,2,abc\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, + fieldnames=next(csv.reader(fileobj))) + self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) + self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'}) + + def test_read_dict_fieldnames_chain(self): + import itertools + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("f1,f2,f3\r\n1,2,abc\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj) + first = next(reader) + for row in itertools.chain([first], reader): + self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) + self.assertEqual(row, {"f1": '1', "f2": '2', "f3": 'abc'}) + + def test_read_long(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("1,2,abc,4,5,6\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, + fieldnames=["f1", "f2"]) + self.assertEqual(next(reader), {"f1": '1', "f2": '2', + None: ["abc", "4", "5", "6"]}) + + def test_read_long_with_rest(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("1,2,abc,4,5,6\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, + fieldnames=["f1", "f2"], restkey="_rest") + self.assertEqual(next(reader), {"f1": '1', "f2": '2', + "_rest": ["abc", "4", "5", "6"]}) + + def test_read_long_with_rest_no_fieldnames(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("f1,f2\r\n1,2,abc,4,5,6\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, restkey="_rest") + self.assertEqual(reader.fieldnames, ["f1", "f2"]) + self.assertEqual(next(reader), {"f1": '1', "f2": '2', + "_rest": ["abc", "4", "5", "6"]}) + + def test_read_short(self): + with TemporaryFile("w+", encoding="utf-8") as fileobj: + fileobj.write("1,2,abc,4,5,6\r\n1,2,abc\r\n") + fileobj.seek(0) + reader = csv.DictReader(fileobj, + fieldnames="1 2 3 4 5 6".split(), + restval="DEFAULT") + self.assertEqual(next(reader), {"1": '1', "2": '2', "3": 'abc', + "4": '4', "5": '5', "6": '6'}) + self.assertEqual(next(reader), {"1": '1', "2": '2', "3": 'abc', + "4": 'DEFAULT', "5": 'DEFAULT', + "6": 'DEFAULT'}) + + def test_read_multi(self): + sample = [ + '2147483648,43.0e12,17,abc,def\r\n', + '147483648,43.0e2,17,abc,def\r\n', + '47483648,43.0,170,abc,def\r\n' + ] + + reader = csv.DictReader(sample, + fieldnames="i1 float i2 s1 s2".split()) + self.assertEqual(next(reader), {"i1": '2147483648', + "float": '43.0e12', + "i2": '17', + "s1": 'abc', + "s2": 'def'}) + + def test_read_with_blanks(self): + reader = csv.DictReader(["1,2,abc,4,5,6\r\n","\r\n", + "1,2,abc,4,5,6\r\n"], + fieldnames="1 2 3 4 5 6".split()) + self.assertEqual(next(reader), {"1": '1', "2": '2', "3": 'abc', + "4": '4', "5": '5', "6": '6'}) + self.assertEqual(next(reader), {"1": '1', "2": '2', "3": 'abc', + "4": '4', "5": '5', "6": '6'}) + + def test_read_semi_sep(self): + reader = csv.DictReader(["1;2;abc;4;5;6\r\n"], + fieldnames="1 2 3 4 5 6".split(), + delimiter=';') + self.assertEqual(next(reader), {"1": '1', "2": '2', "3": 'abc', + "4": '4', "5": '5', "6": '6'}) + +class TestArrayWrites(unittest.TestCase): + def test_int_write(self): + import array + contents = [(20-i) for i in range(20)] + a = array.array('i', contents) + + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, dialect="excel") + writer.writerow(a) + expected = ",".join([str(i) for i in a])+"\r\n" + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected) + + def test_double_write(self): + import array + contents = [(20-i)*0.1 for i in range(20)] + a = array.array('d', contents) + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, dialect="excel") + writer.writerow(a) + expected = ",".join([str(i) for i in a])+"\r\n" + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected) + + def test_float_write(self): + import array + contents = [(20-i)*0.1 for i in range(20)] + a = array.array('f', contents) + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, dialect="excel") + writer.writerow(a) + expected = ",".join([str(i) for i in a])+"\r\n" + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected) + + def test_char_write(self): + import array, string + a = array.array('u', string.ascii_letters) + + with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: + writer = csv.writer(fileobj, dialect="excel") + writer.writerow(a) + expected = ",".join(a)+"\r\n" + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected) + +class TestDialectValidity(unittest.TestCase): + def test_quoting(self): + class mydialect(csv.Dialect): + delimiter = ";" + escapechar = '\\' + doublequote = False + skipinitialspace = True + lineterminator = '\r\n' + quoting = csv.QUOTE_NONE + d = mydialect() + self.assertEqual(d.quoting, csv.QUOTE_NONE) + + mydialect.quoting = None + self.assertRaises(csv.Error, mydialect) + + mydialect.doublequote = True + mydialect.quoting = csv.QUOTE_ALL + mydialect.quotechar = '"' + d = mydialect() + self.assertEqual(d.quoting, csv.QUOTE_ALL) + self.assertEqual(d.quotechar, '"') + self.assertTrue(d.doublequote) + + mydialect.quotechar = "" + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"quotechar" must be a 1-character string') + + mydialect.quotechar = "''" + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"quotechar" must be a 1-character string') + + mydialect.quotechar = 4 + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"quotechar" must be string or None, not int') + + def test_delimiter(self): + class mydialect(csv.Dialect): + delimiter = ";" + escapechar = '\\' + doublequote = False + skipinitialspace = True + lineterminator = '\r\n' + quoting = csv.QUOTE_NONE + d = mydialect() + self.assertEqual(d.delimiter, ";") + + mydialect.delimiter = ":::" + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"delimiter" must be a 1-character string') + + mydialect.delimiter = "" + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"delimiter" must be a 1-character string') + + mydialect.delimiter = b"," + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"delimiter" must be string, not bytes') + + mydialect.delimiter = 4 + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"delimiter" must be string, not int') + + mydialect.delimiter = None + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"delimiter" must be string, not NoneType') + + def test_escapechar(self): + class mydialect(csv.Dialect): + delimiter = ";" + escapechar = '\\' + doublequote = False + skipinitialspace = True + lineterminator = '\r\n' + quoting = csv.QUOTE_NONE + d = mydialect() + self.assertEqual(d.escapechar, "\\") + + mydialect.escapechar = "" + with self.assertRaisesRegex(csv.Error, '"escapechar" must be a 1-character string'): + mydialect() + + mydialect.escapechar = "**" + with self.assertRaisesRegex(csv.Error, '"escapechar" must be a 1-character string'): + mydialect() + + mydialect.escapechar = b"*" + with self.assertRaisesRegex(csv.Error, '"escapechar" must be string or None, not bytes'): + mydialect() + + mydialect.escapechar = 4 + with self.assertRaisesRegex(csv.Error, '"escapechar" must be string or None, not int'): + mydialect() + + def test_lineterminator(self): + class mydialect(csv.Dialect): + delimiter = ";" + escapechar = '\\' + doublequote = False + skipinitialspace = True + lineterminator = '\r\n' + quoting = csv.QUOTE_NONE + d = mydialect() + self.assertEqual(d.lineterminator, '\r\n') + + mydialect.lineterminator = ":::" + d = mydialect() + self.assertEqual(d.lineterminator, ":::") + + mydialect.lineterminator = 4 + with self.assertRaises(csv.Error) as cm: + mydialect() + self.assertEqual(str(cm.exception), + '"lineterminator" must be a string') + + def test_invalid_chars(self): + def create_invalid(field_name, value): + class mydialect(csv.Dialect): + pass + setattr(mydialect, field_name, value) + d = mydialect() + + for field_name in ("delimiter", "escapechar", "quotechar"): + with self.subTest(field_name=field_name): + self.assertRaises(csv.Error, create_invalid, field_name, "") + self.assertRaises(csv.Error, create_invalid, field_name, "abc") + self.assertRaises(csv.Error, create_invalid, field_name, b'x') + self.assertRaises(csv.Error, create_invalid, field_name, 5) + + +class TestSniffer(unittest.TestCase): + sample1 = """\ +Harry's, Arlington Heights, IL, 2/1/03, Kimi Hayes +Shark City, Glendale Heights, IL, 12/28/02, Prezence +Tommy's Place, Blue Island, IL, 12/28/02, Blue Sunday/White Crow +Stonecutters Seafood and Chop House, Lemont, IL, 12/19/02, Week Back +""" + sample2 = """\ +'Harry''s':'Arlington Heights':'IL':'2/1/03':'Kimi Hayes' +'Shark City':'Glendale Heights':'IL':'12/28/02':'Prezence' +'Tommy''s Place':'Blue Island':'IL':'12/28/02':'Blue Sunday/White Crow' +'Stonecutters ''Seafood'' and Chop House':'Lemont':'IL':'12/19/02':'Week Back' +""" + header1 = '''\ +"venue","city","state","date","performers" +''' + sample3 = '''\ +05/05/03?05/05/03?05/05/03?05/05/03?05/05/03?05/05/03 +05/05/03?05/05/03?05/05/03?05/05/03?05/05/03?05/05/03 +05/05/03?05/05/03?05/05/03?05/05/03?05/05/03?05/05/03 +''' + + sample4 = '''\ +2147483648;43.0e12;17;abc;def +147483648;43.0e2;17;abc;def +47483648;43.0;170;abc;def +''' + + sample5 = "aaa\tbbb\r\nAAA\t\r\nBBB\t\r\n" + sample6 = "a|b|c\r\nd|e|f\r\n" + sample7 = "'a'|'b'|'c'\r\n'd'|e|f\r\n" + +# Issue 18155: Use a delimiter that is a special char to regex: + + header2 = '''\ +"venue"+"city"+"state"+"date"+"performers" +''' + sample8 = """\ +Harry's+ Arlington Heights+ IL+ 2/1/03+ Kimi Hayes +Shark City+ Glendale Heights+ IL+ 12/28/02+ Prezence +Tommy's Place+ Blue Island+ IL+ 12/28/02+ Blue Sunday/White Crow +Stonecutters Seafood and Chop House+ Lemont+ IL+ 12/19/02+ Week Back +""" + sample9 = """\ +'Harry''s'+ Arlington Heights'+ 'IL'+ '2/1/03'+ 'Kimi Hayes' +'Shark City'+ Glendale Heights'+' IL'+ '12/28/02'+ 'Prezence' +'Tommy''s Place'+ Blue Island'+ 'IL'+ '12/28/02'+ 'Blue Sunday/White Crow' +'Stonecutters ''Seafood'' and Chop House'+ 'Lemont'+ 'IL'+ '12/19/02'+ 'Week Back' +""" + + sample10 = dedent(""" + abc,def + ghijkl,mno + ghi,jkl + """) + + sample11 = dedent(""" + abc,def + ghijkl,mnop + ghi,jkl + """) + + sample12 = dedent(""""time","forces" + 1,1.5 + 0.5,5+0j + 0,0 + 1+1j,6 + """) + + sample13 = dedent(""""time","forces" + 0,0 + 1,2 + a,b + """) + + sample14 = """\ +abc\0def +ghijkl\0mno +ghi\0jkl +""" + + def test_issue43625(self): + sniffer = csv.Sniffer() + self.assertTrue(sniffer.has_header(self.sample12)) + self.assertFalse(sniffer.has_header(self.sample13)) + + def test_has_header_strings(self): + "More to document existing (unexpected?) behavior than anything else." + sniffer = csv.Sniffer() + self.assertFalse(sniffer.has_header(self.sample10)) + self.assertFalse(sniffer.has_header(self.sample11)) + + def test_has_header(self): + sniffer = csv.Sniffer() + self.assertIs(sniffer.has_header(self.sample1), False) + self.assertIs(sniffer.has_header(self.header1 + self.sample1), True) + + def test_has_header_regex_special_delimiter(self): + sniffer = csv.Sniffer() + self.assertIs(sniffer.has_header(self.sample8), False) + self.assertIs(sniffer.has_header(self.header2 + self.sample8), True) + + def test_guess_quote_and_delimiter(self): + sniffer = csv.Sniffer() + for header in (";'123;4';", "'123;4';", ";'123;4'", "'123;4'"): + with self.subTest(header): + dialect = sniffer.sniff(header, ",;") + self.assertEqual(dialect.delimiter, ';') + self.assertEqual(dialect.quotechar, "'") + self.assertIs(dialect.doublequote, False) + self.assertIs(dialect.skipinitialspace, False) + + def test_sniff(self): + sniffer = csv.Sniffer() + dialect = sniffer.sniff(self.sample1) + self.assertEqual(dialect.delimiter, ",") + self.assertEqual(dialect.quotechar, '"') + self.assertIs(dialect.skipinitialspace, True) + + dialect = sniffer.sniff(self.sample2) + self.assertEqual(dialect.delimiter, ":") + self.assertEqual(dialect.quotechar, "'") + self.assertIs(dialect.skipinitialspace, False) + + def test_delimiters(self): + sniffer = csv.Sniffer() + dialect = sniffer.sniff(self.sample3) + # given that all three lines in sample3 are equal, + # I think that any character could have been 'guessed' as the + # delimiter, depending on dictionary order + self.assertIn(dialect.delimiter, self.sample3) + dialect = sniffer.sniff(self.sample3, delimiters="?,") + self.assertEqual(dialect.delimiter, "?") + dialect = sniffer.sniff(self.sample3, delimiters="/,") + self.assertEqual(dialect.delimiter, "/") + dialect = sniffer.sniff(self.sample4) + self.assertEqual(dialect.delimiter, ";") + dialect = sniffer.sniff(self.sample5) + self.assertEqual(dialect.delimiter, "\t") + dialect = sniffer.sniff(self.sample6) + self.assertEqual(dialect.delimiter, "|") + dialect = sniffer.sniff(self.sample7) + self.assertEqual(dialect.delimiter, "|") + self.assertEqual(dialect.quotechar, "'") + dialect = sniffer.sniff(self.sample8) + self.assertEqual(dialect.delimiter, '+') + dialect = sniffer.sniff(self.sample9) + self.assertEqual(dialect.delimiter, '+') + self.assertEqual(dialect.quotechar, "'") + dialect = sniffer.sniff(self.sample14) + self.assertEqual(dialect.delimiter, '\0') + + def test_doublequote(self): + sniffer = csv.Sniffer() + dialect = sniffer.sniff(self.header1) + self.assertFalse(dialect.doublequote) + dialect = sniffer.sniff(self.header2) + self.assertFalse(dialect.doublequote) + dialect = sniffer.sniff(self.sample2) + self.assertTrue(dialect.doublequote) + dialect = sniffer.sniff(self.sample8) + self.assertFalse(dialect.doublequote) + dialect = sniffer.sniff(self.sample9) + self.assertTrue(dialect.doublequote) + +class NUL: + def write(s, *args): + pass + writelines = write + +@unittest.skipUnless(hasattr(sys, "gettotalrefcount"), + 'requires sys.gettotalrefcount()') +class TestLeaks(unittest.TestCase): + def test_create_read(self): + delta = 0 + lastrc = sys.gettotalrefcount() + for i in range(20): + gc.collect() + self.assertEqual(gc.garbage, []) + rc = sys.gettotalrefcount() + csv.reader(["a,b,c\r\n"]) + csv.reader(["a,b,c\r\n"]) + csv.reader(["a,b,c\r\n"]) + delta = rc-lastrc + lastrc = rc + # if csv.reader() leaks, last delta should be 3 or more + self.assertLess(delta, 3) + + def test_create_write(self): + delta = 0 + lastrc = sys.gettotalrefcount() + s = NUL() + for i in range(20): + gc.collect() + self.assertEqual(gc.garbage, []) + rc = sys.gettotalrefcount() + csv.writer(s) + csv.writer(s) + csv.writer(s) + delta = rc-lastrc + lastrc = rc + # if csv.writer() leaks, last delta should be 3 or more + self.assertLess(delta, 3) + + def test_read(self): + delta = 0 + rows = ["a,b,c\r\n"]*5 + lastrc = sys.gettotalrefcount() + for i in range(20): + gc.collect() + self.assertEqual(gc.garbage, []) + rc = sys.gettotalrefcount() + rdr = csv.reader(rows) + for row in rdr: + pass + delta = rc-lastrc + lastrc = rc + # if reader leaks during read, delta should be 5 or more + self.assertLess(delta, 5) + + def test_write(self): + delta = 0 + rows = [[1,2,3]]*5 + s = NUL() + lastrc = sys.gettotalrefcount() + for i in range(20): + gc.collect() + self.assertEqual(gc.garbage, []) + rc = sys.gettotalrefcount() + writer = csv.writer(s) + for row in rows: + writer.writerow(row) + delta = rc-lastrc + lastrc = rc + # if writer leaks during write, last delta should be 5 or more + self.assertLess(delta, 5) + +class TestUnicode(unittest.TestCase): + + names = ["Martin von Löwis", + "Marc André Lemburg", + "Guido van Rossum", + "François Pinard"] + + def test_unicode_read(self): + with TemporaryFile("w+", newline='', encoding="utf-8") as fileobj: + fileobj.write(",".join(self.names) + "\r\n") + fileobj.seek(0) + reader = csv.reader(fileobj) + self.assertEqual(list(reader), [self.names]) + + + def test_unicode_write(self): + with TemporaryFile("w+", newline='', encoding="utf-8") as fileobj: + writer = csv.writer(fileobj) + writer.writerow(self.names) + expected = ",".join(self.names)+"\r\n" + fileobj.seek(0) + self.assertEqual(fileobj.read(), expected) + +class KeyOrderingTest(unittest.TestCase): + + def test_ordering_for_the_dict_reader_and_writer(self): + resultset = set() + for keys in permutations("abcde"): + with TemporaryFile('w+', newline='', encoding="utf-8") as fileobject: + dw = csv.DictWriter(fileobject, keys) + dw.writeheader() + fileobject.seek(0) + dr = csv.DictReader(fileobject) + kt = tuple(dr.fieldnames) + self.assertEqual(keys, kt) + resultset.add(kt) + # Final sanity check: were all permutations unique? + self.assertEqual(len(resultset), 120, "Key ordering: some key permutations not collected (expected 120)") + + def test_ordered_dict_reader(self): + data = dedent('''\ + FirstName,LastName + Eric,Idle + Graham,Chapman,Over1,Over2 + + Under1 + John,Cleese + ''').splitlines() + + self.assertEqual(list(csv.DictReader(data)), + [OrderedDict([('FirstName', 'Eric'), ('LastName', 'Idle')]), + OrderedDict([('FirstName', 'Graham'), ('LastName', 'Chapman'), + (None, ['Over1', 'Over2'])]), + OrderedDict([('FirstName', 'Under1'), ('LastName', None)]), + OrderedDict([('FirstName', 'John'), ('LastName', 'Cleese')]), + ]) + + self.assertEqual(list(csv.DictReader(data, restkey='OtherInfo')), + [OrderedDict([('FirstName', 'Eric'), ('LastName', 'Idle')]), + OrderedDict([('FirstName', 'Graham'), ('LastName', 'Chapman'), + ('OtherInfo', ['Over1', 'Over2'])]), + OrderedDict([('FirstName', 'Under1'), ('LastName', None)]), + OrderedDict([('FirstName', 'John'), ('LastName', 'Cleese')]), + ]) + + del data[0] # Remove the header row + self.assertEqual(list(csv.DictReader(data, fieldnames=['fname', 'lname'])), + [OrderedDict([('fname', 'Eric'), ('lname', 'Idle')]), + OrderedDict([('fname', 'Graham'), ('lname', 'Chapman'), + (None, ['Over1', 'Over2'])]), + OrderedDict([('fname', 'Under1'), ('lname', None)]), + OrderedDict([('fname', 'John'), ('lname', 'Cleese')]), + ]) + + +class MiscTestCase(unittest.TestCase): + def test__all__(self): + extra = {'__doc__', '__version__'} + support.check__all__(self, csv, ('csv', '_csv'), extra=extra) + + def test_subclassable(self): + # issue 44089 + class Foo(csv.Error): ... + + @support.cpython_only + def test_disallow_instantiation(self): + _csv = import_helper.import_module("_csv") + for tp in _csv.Reader, _csv.Writer: + with self.subTest(tp=tp): + check_disallow_instantiation(self, tp) + +if __name__ == '__main__': + unittest.main() From e4be47a08b2f1a2592e527083fa481462a679a4e Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Tue, 5 Mar 2024 15:09:24 +0900 Subject: [PATCH 276/893] Mark failing tests as expectedFailure --- Lib/test/test_csv.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py index bc9961e083..2646be086c 100644 --- a/Lib/test/test_csv.py +++ b/Lib/test/test_csv.py @@ -150,6 +150,8 @@ def _write_error_test(self, exc, fields, **kwargs): fileobj.seek(0) self.assertEqual(fileobj.read(), '') + # TODO: RUSTPYTHON ''\r\n to ""\r\n unsupported + @unittest.expectedFailure def test_write_arg_valid(self): self._write_error_test(csv.Error, None) self._write_test((), '') @@ -175,6 +177,8 @@ def test_write_bigfield(self): self._write_test([bigstring,bigstring], '%s,%s' % \ (bigstring, bigstring)) + # TODO: RUSTPYTHON quoting style check is unsupported + @unittest.expectedFailure def test_write_quoting(self): self._write_test(['a',1,'p,q'], 'a,1,"p,q"') self._write_error_test(csv.Error, ['a',1,'p,q'], @@ -192,6 +196,8 @@ def test_write_quoting(self): self._write_test(['a','',None,1], '"a","",,"1"', quoting = csv.QUOTE_NOTNULL) + # TODO: RUSTPYTHON doublequote check is unsupported + @unittest.expectedFailure def test_write_escape(self): self._write_test(['a',1,'p,q'], 'a,1,"p,q"', escapechar='\\') @@ -223,6 +229,8 @@ def test_write_escape(self): self._write_test(['C\\', '6', '7', 'X"'], 'C\\\\,6,7,"X"""', escapechar='\\', quoting=csv.QUOTE_MINIMAL) + # TODO: RUSTPYTHON lineterminator double char unsupported + @unittest.expectedFailure def test_write_lineterminator(self): for lineterminator in '\r\n', '\n', '\r', '!@#', '\0': with self.subTest(lineterminator=lineterminator): @@ -234,6 +242,8 @@ def test_write_lineterminator(self): f'a,b{lineterminator}' f'1,2{lineterminator}') + # TODO: RUSTPYTHON ''\r\n to ""\r\n unspported + @unittest.expectedFailure def test_write_iterable(self): self._write_test(iter(['a', 1, 'p,q']), 'a,1,"p,q"') self._write_test(iter(['a', 1, None]), 'a,1,') @@ -298,6 +308,8 @@ def _read_test(self, input, expect, **kwargs): result = list(reader) self.assertEqual(result, expect) + # TODO RUSTPYTHON strict mode is unsupported + @unittest.expectedFailure def test_read_oddinputs(self): self._read_test([], []) self._read_test([''], [[]]) @@ -317,6 +329,8 @@ def test_read_eol(self): self.assertRaises(csv.Error, self._read_test, ['a,b\nc,d'], []) self.assertRaises(csv.Error, self._read_test, ['a,b\r\nc,d'], []) + # TODO RUSTPYTHON double quote umimplement + @unittest.expectedFailure def test_read_eof(self): self._read_test(['a,"'], [['a', '']]) self._read_test(['"a'], [['a']]) @@ -326,6 +340,8 @@ def test_read_eof(self): self.assertRaises(csv.Error, self._read_test, ['^'], [], escapechar='^', strict=True) + # TODO RUSTPYTHON + @unittest.expectedFailure def test_read_nul(self): self._read_test(['\0'], [['\0']]) self._read_test(['a,\0b,c'], [['a', '\0b', 'c']]) @@ -338,6 +354,8 @@ def test_read_delimiter(self): self._read_test(['a;b;c'], [['a', 'b', 'c']], delimiter=';') self._read_test(['a\0b\0c'], [['a', 'b', 'c']], delimiter='\0') + # TODO RUSTPYTHON + @unittest.expectedFailure def test_read_escape(self): self._read_test(['a,\\b,c'], [['a', 'b', 'c']], escapechar='\\') self._read_test(['a,b\\,c'], [['a', 'b,c']], escapechar='\\') @@ -350,6 +368,8 @@ def test_read_escape(self): self._read_test(['a,\\b,c'], [['a', '\\b', 'c']], escapechar=None) self._read_test(['a,\\b,c'], [['a', '\\b', 'c']]) + # TODO RUSTPYTHON escapechar unsupported + @unittest.expectedFailure def test_read_quoting(self): self._read_test(['1,",3,",5'], [['1', ',3,', '5']]) self._read_test(['1,",3,",5'], [['1', '"', '3', '"', '5']], @@ -402,6 +422,8 @@ def test_read_linenum(self): self.assertRaises(StopIteration, next, r) self.assertEqual(r.line_num, 3) + # TODO: RUSTPYTHON only '\r\n' unsupported + @unittest.expectedFailure def test_roundtrip_quoteed_newlines(self): with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: writer = csv.writer(fileobj) @@ -411,6 +433,8 @@ def test_roundtrip_quoteed_newlines(self): for i, row in enumerate(csv.reader(fileobj)): self.assertEqual(row, rows[i]) + # TODO: RUSTPYTHON only '\r\n' unsupported + @unittest.expectedFailure def test_roundtrip_escaped_unquoted_newlines(self): with TemporaryFile("w+", encoding="utf-8", newline='') as fileobj: writer = csv.writer(fileobj,quoting=csv.QUOTE_NONE,escapechar="\\") @@ -512,6 +536,8 @@ def compare_dialect_123(self, expected, *writeargs, **kwwriteargs): fileobj.seek(0) self.assertEqual(fileobj.read(), expected) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_dialect_apply(self): class testA(csv.excel): delimiter = "\t" @@ -555,6 +581,8 @@ def test_copy(self): dialect = csv.get_dialect(name) self.assertRaises(TypeError, copy.copy, dialect) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_pickle(self): for name in csv.list_dialects(): dialect = csv.get_dialect(name) @@ -641,6 +669,8 @@ def test_quoted_quote(self): '"I see," said the blind man', 'as he picked up his hammer and saw']]) + # Rustpython TODO + @unittest.expectedFailure def test_quoted_nl(self): input = '''\ 1,2,3,"""I see,"" @@ -681,15 +711,21 @@ class EscapedExcel(csv.excel): class TestEscapedExcel(TestCsvBase): dialect = EscapedExcel() + # TODO RUSTPYTHON + @unittest.expectedFailure def test_escape_fieldsep(self): self.writerAssertEqual([['abc,def']], 'abc\\,def\r\n') + # TODO RUSTPYTHON + @unittest.expectedFailure def test_read_escape_fieldsep(self): self.readerAssertEqual('abc\\,def\r\n', [['abc,def']]) class TestDialectUnix(TestCsvBase): dialect = 'unix' + # TODO RUSTPYTHON + @unittest.expectedFailure def test_simple_writer(self): self.writerAssertEqual([[1, 'abc def', 'abc']], '"1","abc def","abc"\n') @@ -706,6 +742,8 @@ class TestQuotedEscapedExcel(TestCsvBase): def test_write_escape_fieldsep(self): self.writerAssertEqual([['abc,def']], '"abc,def"\r\n') + # TODO RUSTPYTHON + @unittest.expectedFailure def test_read_escape_fieldsep(self): self.readerAssertEqual('"abc\\,def"\r\n', [['abc,def']]) @@ -902,6 +940,8 @@ def test_read_multi(self): "s1": 'abc', "s2": 'def'}) + # TODO RUSTPYTHON + @unittest.expectedFailure def test_read_with_blanks(self): reader = csv.DictReader(["1,2,abc,4,5,6\r\n","\r\n", "1,2,abc,4,5,6\r\n"], From 54247df801dff727a8a674f2b996cdcf8f0ad572 Mon Sep 17 00:00:00 2001 From: Blues-star <469830014@qq.com> Date: Tue, 5 Mar 2024 15:09:47 +0900 Subject: [PATCH 277/893] implement more csv features Co-authored-by: Jeong, YunWon --- stdlib/src/csv.rs | 946 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 886 insertions(+), 60 deletions(-) diff --git a/stdlib/src/csv.rs b/stdlib/src/csv.rs index bee3fd5faa..96aa1c1fe0 100644 --- a/stdlib/src/csv.rs +++ b/stdlib/src/csv.rs @@ -4,15 +4,18 @@ pub(crate) use _csv::make_module; mod _csv { use crate::common::lock::PyMutex; use crate::vm::{ - builtins::{PyStr, PyTypeRef}, - function::{ArgIterable, ArgumentError, FromArgs, FuncArgs}, - match_class, + builtins::{PyBaseExceptionRef, PyInt, PyNone, PyStr, PyType, PyTypeError, PyTypeRef}, + function::{ArgIterable, ArgumentError, FromArgs, FuncArgs, OptionalArg}, protocol::{PyIter, PyIterReturn}, - types::{IterNext, Iterable, SelfIter}, - AsObject, Py, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, + types::{Constructor, IterNext, Iterable, SelfIter}, + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, }; + use csv_core::Terminator; use itertools::{self, Itertools}; - use std::fmt; + use once_cell::sync::Lazy; + use parking_lot::Mutex; + use rustpython_vm::match_class; + use std::{collections::HashMap, fmt}; #[pyattr] const QUOTE_MINIMAL: i32 = QuoteStyle::Minimal as i32; @@ -22,6 +25,12 @@ mod _csv { const QUOTE_NONNUMERIC: i32 = QuoteStyle::Nonnumeric as i32; #[pyattr] const QUOTE_NONE: i32 = QuoteStyle::None as i32; + #[pyattr] + const QUOTE_STRINGS: i32 = QuoteStyle::Strings as i32; + #[pyattr] + const QUOTE_NOTNULL: i32 = QuoteStyle::Notnull as i32; + #[pyattr(name = "__version__")] + const __VERSION__: &str = "1.0"; #[pyattr(name = "Error", once)] fn error(vm: &VirtualMachine) -> PyTypeRef { @@ -32,13 +41,334 @@ mod _csv { ) } + static GLOBAL_HASHMAP: Lazy>> = Lazy::new(|| { + let m = HashMap::new(); + Mutex::new(m) + }); + static GLOBAL_FIELD_LIMIT: Lazy> = Lazy::new(|| Mutex::new(131072)); + + fn new_csv_error(vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef { + vm.new_exception_msg(super::_csv::error(vm), msg) + } + + #[pyattr] + #[pyclass(module = "csv", name = "Dialect")] + #[derive(Debug, PyPayload, Clone, Copy)] + struct PyDialect { + delimiter: u8, + quotechar: Option, + escapechar: Option, + doublequote: bool, + skipinitialspace: bool, + lineterminator: csv_core::Terminator, + quoting: QuoteStyle, + strict: bool, + } + impl Constructor for PyDialect { + type Args = PyObjectRef; + + fn py_new(cls: PyTypeRef, ctx: Self::Args, vm: &VirtualMachine) -> PyResult { + PyDialect::try_from_object(vm, ctx)? + .into_ref_with_type(vm, cls) + .map(Into::into) + } + } + #[pyclass(with(Constructor))] + impl PyDialect { + #[pygetset] + fn delimiter(&self, vm: &VirtualMachine) -> PyRef { + vm.ctx.new_str(format!("{}", self.delimiter as char)) + } + #[pygetset] + fn quotechar(&self, vm: &VirtualMachine) -> Option> { + Some(vm.ctx.new_str(format!("{}", self.quotechar? as char))) + } + #[pygetset] + fn doublequote(&self) -> bool { + self.doublequote + } + #[pygetset] + fn skipinitialspace(&self) -> bool { + self.skipinitialspace + } + #[pygetset] + fn lineterminator(&self, vm: &VirtualMachine) -> PyRef { + match self.lineterminator { + Terminator::CRLF => vm.ctx.new_str("\r\n".to_string()).to_owned(), + Terminator::Any(t) => vm.ctx.new_str(format!("{}", t as char)).to_owned(), + _ => unreachable!(), + } + } + #[pygetset] + fn quoting(&self) -> isize { + self.quoting.into() + } + #[pygetset] + fn escapechar(&self, vm: &VirtualMachine) -> Option> { + Some(vm.ctx.new_str(format!("{}", self.escapechar? as char))) + } + #[pygetset(name = "strict")] + fn get_strict(&self) -> bool { + self.strict + } + } + /// Parses the delimiter from a Python object and returns its ASCII value. + /// + /// This function attempts to extract the 'delimiter' attribute from the given Python object and ensures that the attribute is a single-character string. If successful, it returns the ASCII value of the character. If the attribute is not a single-character string, an error is returned. + /// + /// # Arguments + /// + /// * `vm` - A reference to the VirtualMachine, used for executing Python code and manipulating Python objects. + /// * `obj` - A reference to the PyObjectRef from which the 'delimiter' attribute is to be parsed. + /// + /// # Returns + /// + /// If successful, returns a `PyResult` representing the ASCII value of the 'delimiter' attribute. If unsuccessful, returns a `PyResult` containing an error message. + /// + /// # Errors + /// + /// This function can return the following errors: + /// + /// * If the 'delimiter' attribute is not a single-character string, a type error is returned. + /// * If the 'obj' is not of string type and does not have a 'delimiter' attribute, a type error is returned. + fn parse_delimiter_from_obj(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult { + if let Ok(attr) = obj.get_attr("delimiter", vm) { + parse_delimiter_from_obj(vm, &attr) + } else { + match_class!(match obj.clone() { + s @ PyStr => { + Ok(s.as_str().bytes().exactly_one().map_err(|_| { + let msg = r#""delimiter" must be a 1-character string"#; + vm.new_type_error(msg.to_owned()) + })?) + } + attr => { + let msg = format!("\"delimiter\" must be string, not {}", attr.class()); + Err(vm.new_type_error(msg)) + } + }) + } + } + fn parse_quotechar_from_obj(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult> { + match_class!(match obj.get_attr("quotechar", vm)? { + s @ PyStr => { + Ok(Some(s.as_str().bytes().exactly_one().map_err(|_| { + vm.new_exception_msg( + super::_csv::error(vm), + r#""quotechar" must be a 1-character string"#.to_owned(), + ) + })?)) + } + _n @ PyNone => { + Ok(None) + } + _ => { + Err(vm.new_exception_msg( + super::_csv::error(vm), + r#""quotechar" must be string or None, not int"#.to_owned(), + )) + } + }) + } + fn parse_escapechar_from_obj(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult> { + match_class!(match obj.get_attr("escapechar", vm)? { + s @ PyStr => { + Ok(Some(s.as_str().bytes().exactly_one().map_err(|_| { + vm.new_exception_msg( + super::_csv::error(vm), + r#""escapechar" must be a 1-character string"#.to_owned(), + ) + })?)) + } + _n @ PyNone => { + Ok(None) + } + attr => { + let msg = format!( + "\"escapechar\" must be string or None, not {}", + attr.class() + ); + Err(vm.new_type_error(msg.to_owned())) + } + }) + } + fn prase_lineterminator_from_obj( + vm: &VirtualMachine, + obj: &PyObjectRef, + ) -> PyResult { + match_class!(match obj.get_attr("lineterminator", vm)? { + s @ PyStr => { + Ok(if s.as_str().as_bytes().eq(b"\r\n") { + csv_core::Terminator::CRLF + } else if let Some(t) = s.as_str().as_bytes().first() { + // Due to limitations in the current implementation within csv_core + // the support for multiple characters in lineterminator is not complete. + // only capture the first character + csv_core::Terminator::Any(*t) + } else { + return Err(vm.new_exception_msg( + super::_csv::error(vm), + r#""lineterminator" must be a string"#.to_owned(), + )); + }) + } + _ => { + let msg = "\"lineterminator\" must be a string".to_string(); + Err(vm.new_type_error(msg.to_owned())) + } + }) + } + fn prase_quoting_from_obj(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult { + match_class!(match obj.get_attr("quoting", vm)? { + i @ PyInt => { + Ok(i.try_to_primitive::(vm)?.try_into().map_err(|_| { + let msg = r#"bad "quoting" value"#; + vm.new_type_error(msg.to_owned()) + })?) + } + attr => { + let msg = format!("\"quoting\" must be string or None, not {}", attr.class()); + Err(vm.new_type_error(msg.to_owned())) + } + }) + } + impl TryFromObject for PyDialect { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + let delimiter = parse_delimiter_from_obj(vm, &obj)?; + let quotechar = parse_quotechar_from_obj(vm, &obj)?; + let escapechar = parse_escapechar_from_obj(vm, &obj)?; + let doublequote = obj.get_attr("doublequote", vm)?.try_to_bool(vm)?; + let skipinitialspace = obj.get_attr("skipinitialspace", vm)?.try_to_bool(vm)?; + let lineterminator = prase_lineterminator_from_obj(vm, &obj)?; + let quoting = prase_quoting_from_obj(vm, &obj)?; + let strict = if let Ok(t) = obj.get_attr("strict", vm) { + t.try_to_bool(vm).unwrap_or(false) + } else { + false + }; + + Ok(Self { + delimiter, + quotechar, + escapechar, + doublequote, + skipinitialspace, + lineterminator, + quoting, + strict, + }) + } + } + + #[pyfunction] + fn register_dialect( + name: PyObjectRef, + dialect: OptionalArg, + opts: FormatOptions, + // TODO: handle quote style, etc + mut _rest: FuncArgs, + vm: &VirtualMachine, + ) -> PyResult<()> { + let Some(name) = name.payload_if_subclass::(vm) else { + return Err(vm.new_type_error("argument 0 must be a string".to_string())); + }; + let mut dialect = match dialect { + OptionalArg::Present(d) => PyDialect::try_from_object(vm, d) + .map_err(|_| vm.new_type_error("argument 1 must be a dialect object".to_owned()))?, + OptionalArg::Missing => opts.result(vm)?, + }; + opts.update_pydialect(&mut dialect); + GLOBAL_HASHMAP + .lock() + .insert(name.as_str().to_owned(), dialect); + Ok(()) + } + + #[pyfunction] + fn get_dialect( + name: PyObjectRef, + mut _rest: FuncArgs, + vm: &VirtualMachine, + ) -> PyResult { + let Some(name) = name.payload_if_subclass::(vm) else { + return Err(vm.new_exception_msg( + super::_csv::error(vm), + format!("argument 0 must be a string, not '{}'", name.class()), + )); + }; + let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name.as_str()) { + return Ok(*dialect); + } + Err(vm.new_exception_msg(super::_csv::error(vm), "unknown dialect".to_string())) + } + + #[pyfunction] + fn unregister_dialect( + name: PyObjectRef, + mut _rest: FuncArgs, + vm: &VirtualMachine, + ) -> PyResult<()> { + let Some(name) = name.payload_if_subclass::(vm) else { + return Err(vm.new_exception_msg( + super::_csv::error(vm), + format!("argument 0 must be a string, not '{}'", name.class()), + )); + }; + let mut g = GLOBAL_HASHMAP.lock(); + if let Some(_removed) = g.remove(name.as_str()) { + return Ok(()); + } + Err(vm.new_exception_msg(super::_csv::error(vm), "unknown dialect".to_string())) + } + + #[pyfunction] + fn list_dialects( + rest: FuncArgs, + vm: &VirtualMachine, + ) -> PyResult { + if !rest.args.is_empty() || !rest.kwargs.is_empty() { + return Err(vm.new_type_error("too many argument".to_string())); + } + let g = GLOBAL_HASHMAP.lock(); + let t = g + .keys() + .cloned() + .map(|x| vm.ctx.new_str(x).into()) + .collect_vec(); + // .iter().map(|x| vm.ctx.new_str(x.clone()).into_pyobject(vm)).collect_vec(); + Ok(vm.ctx.new_list(t)) + } + + #[pyfunction] + fn field_size_limit(rest: FuncArgs, vm: &VirtualMachine) -> PyResult { + let old_size = GLOBAL_FIELD_LIMIT.lock().to_owned(); + if !rest.args.is_empty() { + let arg_len = rest.args.len(); + if arg_len != 1 { + return Err(vm.new_type_error( + format!( + "field_size_limit() takes at most 1 argument ({} given)", + arg_len + ) + .to_string(), + )); + } + let Ok(new_size) = rest.args.first().unwrap().try_int(vm) else { + return Err(vm.new_type_error("limit must be an integer".to_string())); + }; + *GLOBAL_FIELD_LIMIT.lock() = new_size.try_to_primitive::(vm)?; + } + Ok(old_size) + } + #[pyfunction] fn reader( iter: PyIter, options: FormatOptions, // TODO: handle quote style, etc _rest: FuncArgs, - _vm: &VirtualMachine, + vm: &VirtualMachine, ) -> PyResult { Ok(Reader { iter, @@ -46,7 +376,11 @@ mod _csv { buffer: vec![0; 1024], output_ends: vec![0; 16], reader: options.to_reader(), + skipinitialspace: options.get_skipinitialspace(), + delimiter: options.get_delimiter(), + line_num: 0, }), + dialect: options.result(vm)?, }) } @@ -72,6 +406,7 @@ mod _csv { buffer: vec![0; 1024], writer: options.to_writer(), }), + dialect: options.result(vm)?, }) } @@ -82,67 +417,482 @@ mod _csv { } #[repr(i32)] + #[derive(Debug, Clone, Copy)] pub enum QuoteStyle { Minimal = 0, All = 1, Nonnumeric = 2, None = 3, + Strings = 4, + Notnull = 5, + } + impl From for csv_core::QuoteStyle { + fn from(val: QuoteStyle) -> Self { + match val { + QuoteStyle::Minimal => csv_core::QuoteStyle::Always, + QuoteStyle::All => csv_core::QuoteStyle::Always, + QuoteStyle::Nonnumeric => csv_core::QuoteStyle::NonNumeric, + QuoteStyle::None => csv_core::QuoteStyle::Never, + QuoteStyle::Strings => todo!(), + QuoteStyle::Notnull => todo!(), + } + } + } + impl TryFromObject for QuoteStyle { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + let num = obj.try_int(vm)?.try_to_primitive::(vm)?; + num.try_into().map_err(|_| { + vm.new_value_error( + "can not convert to QuoteStyle enum from input argument".to_string(), + ) + }) + } + } + impl TryFrom for QuoteStyle { + type Error = PyTypeError; + fn try_from(num: isize) -> Result { + match num { + 0 => Ok(QuoteStyle::Minimal), + 1 => Ok(QuoteStyle::All), + 2 => Ok(QuoteStyle::Nonnumeric), + 3 => Ok(QuoteStyle::None), + 4 => Ok(QuoteStyle::Strings), + 5 => Ok(QuoteStyle::Notnull), + _ => Err(PyTypeError {}), + } + } + } + impl From for isize { + fn from(val: QuoteStyle) -> Self { + match val { + QuoteStyle::Minimal => 0, + QuoteStyle::All => 1, + QuoteStyle::Nonnumeric => 2, + QuoteStyle::None => 3, + QuoteStyle::Strings => 4, + QuoteStyle::Notnull => 5, + } + } + } + + enum DialectItem { + Str(String), + Obj(PyDialect), + None, } struct FormatOptions { - delimiter: u8, - quotechar: u8, + dialect: DialectItem, + delimiter: Option, + quotechar: Option>, + escapechar: Option, + doublequote: Option, + skipinitialspace: Option, + lineterminator: Option, + quoting: Option, + strict: Option, + } + impl Default for FormatOptions { + fn default() -> Self { + FormatOptions { + dialect: DialectItem::None, + delimiter: None, + quotechar: None, + escapechar: None, + doublequote: None, + skipinitialspace: None, + lineterminator: None, + quoting: None, + strict: None, + } + } + } + /// prase a dialect item from a Python argument and returns a `DialectItem` or an `ArgumentError`. + /// + /// This function takes a reference to the VirtualMachine and a PyObjectRef as input and attempts to parse a dialect item from the provided Python argument. It returns a `DialectItem` if successful, or an `ArgumentError` if unsuccessful. + /// + /// # Arguments + /// + /// * `vm` - A reference to the VirtualMachine, used for executing Python code and manipulating Python objects. + /// * `obj` - The PyObjectRef from which the dialect item is to be parsed. + /// + /// # Returns + /// + /// If successful, returns a `Result` representing the parsed dialect item. If unsuccessful, returns an `ArgumentError`. + /// + /// # Errors + /// + /// This function can return the following errors: + /// + /// * If the provided object is a PyStr, it returns a `DialectItem::Str` containing the string value. + /// * If the provided object is PyNone, it returns an `ArgumentError` with the message "InvalidKeywordArgument('dialect')". + /// * If the provided object is a PyType, it attempts to create a PyDialect from the object and returns a `DialectItem::Obj` containing the PyDialect if successful. If unsuccessful, it returns an `ArgumentError` with the message "InvalidKeywordArgument('dialect')". + /// * If the provided object is none of the above types, it attempts to create a PyDialect from the object and returns a `DialectItem::Obj` containing the PyDialect if successful. If unsuccessful, it returns an `ArgumentError` with the message "InvalidKeywordArgument('dialect')". + fn prase_dialect_item_from_arg( + vm: &VirtualMachine, + obj: PyObjectRef, + ) -> Result { + match_class!(match obj { + s @ PyStr => { + Ok(DialectItem::Str(s.as_str().to_string())) + } + PyNone => { + Err(ArgumentError::InvalidKeywordArgument("dialect".to_string())) + } + t @ PyType => { + let temp = t + .as_object() + .call(vec![], vm) + .map_err(|_e| ArgumentError::InvalidKeywordArgument("dialect".to_string()))?; + Ok(DialectItem::Obj( + PyDialect::try_from_object(vm, temp).map_err(|_| { + ArgumentError::InvalidKeywordArgument("dialect".to_string()) + })?, + )) + } + obj => { + if let Ok(cur_dialect_item) = PyDialect::try_from_object(vm, obj) { + Ok(DialectItem::Obj(cur_dialect_item)) + } else { + let msg = "dialect".to_string(); + Err(ArgumentError::InvalidKeywordArgument(msg)) + } + } + }) } impl FromArgs for FormatOptions { fn from_args(vm: &VirtualMachine, args: &mut FuncArgs) -> Result { - let delimiter = if let Some(delimiter) = args.kwargs.remove("delimiter") { - delimiter - .try_to_value::<&str>(vm)? - .bytes() - .exactly_one() - .map_err(|_| { - let msg = r#""delimiter" must be a 1-character string"#; - vm.new_type_error(msg.to_owned()) - })? + let mut res = FormatOptions::default(); + if let Some(dialect) = args.kwargs.remove("dialect") { + res.dialect = prase_dialect_item_from_arg(vm, dialect)?; + } else if let Some(dialect) = args.args.first() { + res.dialect = prase_dialect_item_from_arg(vm, dialect.clone())?; } else { - b',' + res.dialect = DialectItem::None; }; - let quotechar = if let Some(quotechar) = args.kwargs.remove("quotechar") { - quotechar - .try_to_value::<&str>(vm)? - .bytes() - .exactly_one() - .map_err(|_| { + if let Some(delimiter) = args.kwargs.remove("delimiter") { + res.delimiter = Some(parse_delimiter_from_obj(vm, &delimiter)?); + } + + if let Some(escapechar) = args.kwargs.remove("escapechar") { + res.escapechar = match_class!(match escapechar { + s @ PyStr => Some(s.as_str().bytes().exactly_one().map_err(|_| { + let msg = r#""escapechar" must be a 1-character string"#; + vm.new_type_error(msg.to_owned()) + })?), + _ => None, + }) + }; + if let Some(lineterminator) = args.kwargs.remove("lineterminator") { + res.lineterminator = Some(csv_core::Terminator::Any( + lineterminator + .try_to_value::<&str>(vm)? + .bytes() + .exactly_one() + .map_err(|_| { + let msg = r#""lineterminator" must be a 1-character string"#; + vm.new_type_error(msg.to_owned()) + })?, + )) + }; + if let Some(doublequote) = args.kwargs.remove("doublequote") { + res.doublequote = Some(doublequote.try_to_bool(vm).map_err(|_| { + let msg = r#""doublequote" must be a bool"#; + vm.new_type_error(msg.to_owned()) + })?) + }; + if let Some(skipinitialspace) = args.kwargs.remove("skipinitialspace") { + res.skipinitialspace = Some(skipinitialspace.try_to_bool(vm).map_err(|_| { + let msg = r#""skipinitialspace" must be a bool"#; + vm.new_type_error(msg.to_owned()) + })?) + }; + if let Some(quoting) = args.kwargs.remove("quoting") { + res.quoting = match_class!(match quoting { + i @ PyInt => + Some(i.try_to_primitive::(vm)?.try_into().map_err(|_e| { + ArgumentError::InvalidKeywordArgument("quoting".to_string()) + })?), + _ => { + // let msg = r#""quoting" must be a int enum"#; + return Err(ArgumentError::InvalidKeywordArgument("quoting".to_string())); + } + }); + }; + if let Some(quotechar) = args.kwargs.remove("quotechar") { + res.quotechar = match_class!(match quotechar { + s @ PyStr => Some(Some(s.as_str().bytes().exactly_one().map_err(|_| { let msg = r#""quotechar" must be a 1-character string"#; vm.new_type_error(msg.to_owned()) - })? - } else { - b'"' + })?)), + PyNone => { + if let Some(QuoteStyle::All) = res.quoting { + let msg = "quotechar must be set if quoting enabled"; + return Err(ArgumentError::Exception( + vm.new_type_error(msg.to_owned()), + )); + } + Some(None) + } + _o => { + let msg = r#"quotechar"#; + return Err( + rustpython_vm::function::ArgumentError::InvalidKeywordArgument( + msg.to_string(), + ), + ); + } + }) + }; + if let Some(strict) = args.kwargs.remove("strict") { + res.strict = Some(strict.try_to_bool(vm).map_err(|_| { + let msg = r#""strict" must be a int enum"#; + vm.new_type_error(msg.to_owned()) + })?) }; - Ok(FormatOptions { - delimiter, - quotechar, - }) + if let Some(last_arg) = args.kwargs.pop() { + let msg = format!( + r#"'{}' is an invalid keyword argument for this function"#, + last_arg.0 + ); + return Err(rustpython_vm::function::ArgumentError::InvalidKeywordArgument(msg)); + } + Ok(res) } } impl FormatOptions { + fn update_pydialect<'b>(&self, res: &'b mut PyDialect) -> &'b mut PyDialect { + macro_rules! check_and_fill { + ($res:ident, $e:ident) => {{ + if let Some(t) = self.$e { + $res.$e = t; + } + }}; + } + check_and_fill!(res, delimiter); + // check_and_fill!(res, quotechar); + check_and_fill!(res, delimiter); + check_and_fill!(res, doublequote); + check_and_fill!(res, skipinitialspace); + if let Some(t) = self.escapechar { + res.escapechar = Some(t); + }; + if let Some(t) = self.quotechar { + if let Some(u) = t { + res.quotechar = Some(u); + } else { + res.quotechar = None; + } + }; + check_and_fill!(res, quoting); + check_and_fill!(res, lineterminator); + check_and_fill!(res, strict); + res + } + + fn result(&self, vm: &VirtualMachine) -> PyResult { + match &self.dialect { + DialectItem::Str(name) => { + let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name) { + let mut dialect = *dialect; + self.update_pydialect(&mut dialect); + Ok(dialect) + } else { + Err(new_csv_error(vm, format!("{} is not registed.", name))) + } + // TODO + // Maybe need to update the obj from HashMap + } + DialectItem::Obj(mut o) => { + self.update_pydialect(&mut o); + Ok(o) + } + DialectItem::None => { + let g = GLOBAL_HASHMAP.lock(); + let mut res = *g.get("excel").unwrap(); + self.update_pydialect(&mut res); + Ok(res) + } + } + } + fn get_skipinitialspace(&self) -> bool { + let mut skipinitialspace = match &self.dialect { + DialectItem::Str(name) => { + let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name) { + dialect.skipinitialspace + // RustPython todo + // todo! Perfecting the remaining attributes. + } else { + false + } + } + DialectItem::Obj(obj) => obj.skipinitialspace, + _ => false, + }; + if let Some(attr) = self.skipinitialspace { + skipinitialspace = attr + } + skipinitialspace + } + fn get_delimiter(&self) -> u8 { + let mut delimiter = match &self.dialect { + DialectItem::Str(name) => { + let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name) { + dialect.delimiter + // RustPython todo + // todo! Perfecting the remaining attributes. + } else { + b',' + } + } + DialectItem::Obj(obj) => obj.delimiter, + _ => b',', + }; + if let Some(attr) = self.delimiter { + delimiter = attr + } + delimiter + } fn to_reader(&self) -> csv_core::Reader { - csv_core::ReaderBuilder::new() - .delimiter(self.delimiter) - .quote(self.quotechar) - .terminator(csv_core::Terminator::CRLF) - .build() + let mut builder = csv_core::ReaderBuilder::new(); + let mut reader = match &self.dialect { + DialectItem::Str(name) => { + let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name) { + let mut builder = builder + .delimiter(dialect.delimiter) + .double_quote(dialect.doublequote); + if let Some(t) = dialect.quotechar { + builder = builder.quote(t); + } + builder + // RustPython todo + // todo! Perfecting the remaining attributes. + } else { + &mut builder + } + } + DialectItem::Obj(obj) => { + let mut builder = builder + .delimiter(obj.delimiter) + .double_quote(obj.doublequote); + if let Some(t) = obj.quotechar { + builder = builder.quote(t); + } + builder + } + _ => { + let name = "excel"; + let g = GLOBAL_HASHMAP.lock(); + let dialect = g.get(name).unwrap(); + let mut builder = builder + .delimiter(dialect.delimiter) + .double_quote(dialect.doublequote); + if let Some(quotechar) = dialect.quotechar { + builder = builder.quote(quotechar); + } + builder + } + }; + + if let Some(t) = self.delimiter { + reader = reader.delimiter(t); + } + if let Some(t) = self.quotechar { + if let Some(u) = t { + reader = reader.quote(u); + } else { + reader = reader.quoting(false); + } + } else { + match self.quoting { + Some(QuoteStyle::None) => { + reader = reader.quoting(false); + } + // None => reader = reader.quoting(true), + _ => reader = reader.quoting(true), + } + } + + if let Some(t) = self.lineterminator { + reader = reader.terminator(t); + } + if let Some(t) = self.doublequote { + reader = reader.double_quote(t); + } + if self.escapechar.is_some() { + reader = reader.escape(self.escapechar); + } + reader = match self.lineterminator { + Some(u) => reader.terminator(u), + None => reader.terminator(Terminator::CRLF), + }; + reader.build() } fn to_writer(&self) -> csv_core::Writer { - csv_core::WriterBuilder::new() - .delimiter(self.delimiter) - .quote(self.quotechar) - .terminator(csv_core::Terminator::CRLF) - .build() + let mut builder = csv_core::WriterBuilder::new(); + let mut writer = match &self.dialect { + DialectItem::Str(name) => { + let g = GLOBAL_HASHMAP.lock(); + if let Some(dialect) = g.get(name) { + let mut builder = builder + .delimiter(dialect.delimiter) + .double_quote(dialect.doublequote) + .terminator(dialect.lineterminator); + if let Some(t) = dialect.quotechar { + builder = builder.quote(t); + } + builder + + // RustPython todo + // todo! Perfecting the remaining attributes. + } else { + &mut builder + } + } + DialectItem::Obj(obj) => { + let mut builder = builder + .delimiter(obj.delimiter) + .double_quote(obj.doublequote) + .terminator(obj.lineterminator); + if let Some(t) = obj.quotechar { + builder = builder.quote(t); + } + builder + } + _ => &mut builder, + }; + if let Some(t) = self.delimiter { + writer = writer.delimiter(t); + } + if let Some(t) = self.quotechar { + if let Some(u) = t { + writer = writer.quote(u); + } else { + todo!() + } + } + if let Some(t) = self.doublequote { + writer = writer.double_quote(t); + } + writer = match self.lineterminator { + Some(u) => writer.terminator(u), + None => writer.terminator(Terminator::CRLF), + }; + if let Some(e) = self.escapechar { + writer = writer.escape(e); + } + if let Some(e) = self.quoting { + writer = writer.quote_style(e.into()); + } + writer.build() } } @@ -150,6 +900,9 @@ mod _csv { buffer: Vec, output_ends: Vec, reader: csv_core::Reader, + skipinitialspace: bool, + delimiter: u8, + line_num: u64, } #[pyclass(no_attr, module = "_csv", name = "reader", traverse)] @@ -158,6 +911,8 @@ mod _csv { iter: PyIter, #[pytraverse(skip)] state: PyMutex, + #[pytraverse(skip)] + dialect: PyDialect, } impl fmt::Debug for Reader { @@ -167,7 +922,16 @@ mod _csv { } #[pyclass(with(IterNext, Iterable))] - impl Reader {} + impl Reader { + #[pygetset] + fn line_num(&self) -> u64 { + self.state.lock().line_num + } + #[pygetset] + fn dialect(&self, _vm: &VirtualMachine) -> PyDialect { + self.dialect + } + } impl SelfIter for Reader {} impl IterNext for Reader { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { @@ -176,27 +940,55 @@ mod _csv { PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), }; let string = string.downcast::().map_err(|obj| { - vm.new_type_error(format!( + new_csv_error( + vm, + format!( "iterator should return strings, not {} (the file should be opened in text mode)", obj.class().name() - )) + ), + ) })?; let input = string.as_str().as_bytes(); - + if input.is_empty() || input.starts_with(b"\n") { + return Ok(PyIterReturn::Return(vm.ctx.new_list(vec![]).into())); + } let mut state = zelf.state.lock(); let ReadState { buffer, output_ends, reader, + skipinitialspace, + delimiter, + line_num, } = &mut *state; let mut input_offset = 0; let mut output_offset = 0; let mut output_ends_offset = 0; - + let field_limit = GLOBAL_FIELD_LIMIT.lock().to_owned(); + #[inline] + fn trim_spaces(input: &[u8]) -> &[u8] { + let trimmed_start = input.iter().position(|&x| x != b' ').unwrap_or(input.len()); + let trimmed_end = input + .iter() + .rposition(|&x| x != b' ') + .map(|i| i + 1) + .unwrap_or(0); + &input[trimmed_start..trimmed_end] + } + let input = if *skipinitialspace { + let t = input.split(|x| x == delimiter); + t.map(|x| { + let trimmed = trim_spaces(x); + String::from_utf8(trimmed.to_vec()).unwrap() + }) + .join(format!("{}", *delimiter as char).as_str()) + } else { + String::from_utf8(input.to_vec()).unwrap() + }; loop { let (res, nread, nwritten, nends) = reader.read_record( - &input[input_offset..], + input[input_offset..].as_bytes(), &mut buffer[output_offset..], &mut output_ends[output_ends_offset..], ); @@ -213,9 +1005,10 @@ mod _csv { } } } - let rest = &input[input_offset..]; + let rest = input[input_offset..].as_bytes(); if !rest.iter().all(|&c| matches!(c, b'\r' | b'\n')) { - return Err(vm.new_value_error( + return Err(new_csv_error( + vm, "new-line character seen in unquoted field - \ do you need to open the file in universal-newline mode?" .to_owned(), @@ -223,17 +1016,40 @@ mod _csv { } let mut prev_end = 0; - let out = output_ends[..output_ends_offset] + let out: Vec = output_ends[..output_ends_offset] .iter() .map(|&end| { let range = prev_end..end; + if range.len() > field_limit as usize { + return Err(new_csv_error(vm, "filed too long to read".to_string())); + } prev_end = end; - let s = std::str::from_utf8(&buffer[range]) + let s = std::str::from_utf8(&buffer[range.clone()]) // not sure if this is possible - the input was all strings .map_err(|_e| vm.new_unicode_decode_error("csv not utf8".to_owned()))?; - Ok(vm.ctx.new_str(s).into()) + // Rustpython TODO! + // Incomplete implementation + if let QuoteStyle::Nonnumeric = zelf.dialect.quoting { + if let Ok(t) = + String::from_utf8(trim_spaces(&buffer[range.clone()]).to_vec()) + .unwrap() + .parse::() + { + Ok(vm.ctx.new_int(t).into()) + } else { + Ok(vm.ctx.new_str(s).into()) + } + } else { + Ok(vm.ctx.new_str(s).into()) + } }) .collect::>()?; + // Removes the last null item before the line terminator, if there is a separator before the line terminator, + // todo! + // if out.last().unwrap().length(vm).unwrap() == 0 { + // out.pop(); + // } + *line_num += 1; Ok(PyIterReturn::Return(vm.ctx.new_list(out).into())) } } @@ -249,6 +1065,8 @@ mod _csv { write: PyObjectRef, #[pytraverse(skip)] state: PyMutex, + #[pytraverse(skip)] + dialect: PyDialect, } impl fmt::Debug for Writer { @@ -259,6 +1077,10 @@ mod _csv { #[pyclass] impl Writer { + #[pygetset(name = "dialect")] + fn get_dialect(&self, _vm: &VirtualMachine) -> PyDialect { + self.dialect + } #[pymethod] fn writerow(&self, row: PyObjectRef, vm: &VirtualMachine) -> PyResult { let mut state = self.state.lock(); @@ -277,7 +1099,10 @@ mod _csv { }}; } - let row = ArgIterable::try_from_object(vm, row)?; + let row = ArgIterable::try_from_object(vm, row.clone()).map_err(|_e| { + new_csv_error(vm, format!("\'{}\' object is not iterable", row.class())) + })?; + let mut first_flag = true; for field in row.iter(vm)? { let field: PyObjectRef = field?; let stringified; @@ -289,8 +1114,14 @@ mod _csv { stringified.as_str().as_bytes() } }); - let mut input_offset = 0; + if first_flag { + first_flag = false; + } else { + loop { + handle_res!(writer.delimiter(&mut buffer[buffer_offset..])); + } + } loop { let (res, nread, nwritten) = @@ -298,16 +1129,11 @@ mod _csv { input_offset += nread; handle_res!((res, nwritten)); } - - loop { - handle_res!(writer.delimiter(&mut buffer[buffer_offset..])); - } } loop { handle_res!(writer.terminator(&mut buffer[buffer_offset..])); } - let s = std::str::from_utf8(&buffer[..buffer_offset]) .map_err(|_| vm.new_unicode_decode_error("csv not utf8".to_owned()))?; From 2f8e5189d3c0abc70251fd6d928d0176e7b7f3ab Mon Sep 17 00:00:00 2001 From: rrupy <147432801+rrupy@users.noreply.github.com> Date: Tue, 5 Mar 2024 16:37:42 +0300 Subject: [PATCH 278/893] Use ast::Suite::parse instead of deprecated parse_program. (#5186) --- compiler/codegen/src/compile.rs | 5 +++-- examples/parse_folder.rs | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index 0b0f2877c7..bf0fbe1dec 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -2941,7 +2941,8 @@ impl ToU32 for usize { #[cfg(test)] mod tests { use super::*; - use rustpython_parser as parser; + use rustpython_parser::ast::Suite; + use rustpython_parser::Parse; use rustpython_parser_core::source_code::LinearLocator; fn compile_exec(source: &str) -> CodeObject { @@ -2952,7 +2953,7 @@ mod tests { "source_path".to_owned(), "".to_owned(), ); - let ast = parser::parse_program(source, "").unwrap(); + let ast = Suite::parse(source, "").unwrap(); let ast = locator.fold(ast).unwrap(); let symbol_scope = SymbolTable::scan_program(&ast).unwrap(); compiler.compile_program(&ast, symbol_scope).unwrap(); diff --git a/examples/parse_folder.rs b/examples/parse_folder.rs index 7774b8afbf..7055a6f831 100644 --- a/examples/parse_folder.rs +++ b/examples/parse_folder.rs @@ -12,7 +12,7 @@ extern crate env_logger; extern crate log; use clap::{App, Arg}; -use rustpython_parser::{self as parser, ast}; +use rustpython_parser::{self as parser, ast, Parse}; use std::{ path::Path, time::{Duration, Instant}, @@ -85,8 +85,8 @@ fn parse_python_file(filename: &Path) -> ParsedFile { }, Ok(source) => { let num_lines = source.lines().count(); - let result = parser::parse_program(&source, &filename.to_string_lossy()) - .map_err(|e| e.to_string()); + let result = + ast::Suite::parse(&source, &filename.to_string_lossy()).map_err(|e| e.to_string()); ParsedFile { // filename: Box::new(filename.to_path_buf()), // code: source.to_string(), From 35229721ead1e96a9135e26d74ae9b0d36f7efa4 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Sat, 24 Feb 2024 12:32:16 -0500 Subject: [PATCH 279/893] Fix test_cmd_line.py The failing test was unsetting `PYTHONPATH`, but neglecting to unset `RUSTPYTHONPATH`, which obviously was not significant for the original CPython test. Including `RUSTPYTHONPATH` in the test fixes it. --- Lib/test/test_cmd_line.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py index 02f060ba2c..88ff71726f 100644 --- a/Lib/test/test_cmd_line.py +++ b/Lib/test/test_cmd_line.py @@ -411,7 +411,8 @@ def test_empty_PYTHONPATH_issue16309(self): path = ":".join(sys.path) path = path.encode("ascii", "backslashreplace") sys.stdout.buffer.write(path)""" - rc1, out1, err1 = assert_python_ok('-c', code, PYTHONPATH="") + # TODO: RUSTPYTHON we must unset RUSTPYTHONPATH as well + rc1, out1, err1 = assert_python_ok('-c', code, PYTHONPATH="", RUSTPYTHONPATH="") rc2, out2, err2 = assert_python_ok('-c', code, __isolated=False) # regarding to Posix specification, outputs should be equal # for empty and unset PYTHONPATH From ead42beff6265d95c4dc18016735229daef1a56d Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Sat, 24 Feb 2024 13:25:44 -0500 Subject: [PATCH 280/893] Disable test_locale in test_format.py See https://github.com/RustPython/RustPython/issues/5181 --- Lib/test/test_format.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Lib/test/test_format.py b/Lib/test/test_format.py index f6c11a4aad..66e2b077bf 100644 --- a/Lib/test/test_format.py +++ b/Lib/test/test_format.py @@ -420,6 +420,9 @@ def test_non_ascii(self): self.assertEqual(format(1+2j, "\u2007^8"), "\u2007(1+2j)\u2007") self.assertEqual(format(0j, "\u2007^4"), "\u20070j\u2007") + # TODO: RUSTPYTHON formatting does not support locales + # See https://github.com/RustPython/RustPython/issues/5181 + @unittest.expectedFailure def test_locale(self): try: oldloc = locale.setlocale(locale.LC_ALL) From de7e4e49dab823358e37823ae547cd66b8e7bfaf Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Sat, 24 Feb 2024 13:46:22 -0500 Subject: [PATCH 281/893] Disable broken test_socket.py tests There are a substantial number of socket tests that are disabled due to `bind(): bad family` errors. It seems like RustPython only supports a small subset of the required connection families, so the failing tests are broken for the same reasons. --- Lib/test/test_socket.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 35f94a4e22..17e9dae8c0 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -2161,12 +2161,16 @@ def testCreateISOTPSocket(self): with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s: pass + # TODO: RUSTPYTHON, OSError: bind(): bad family + @unittest.expectedFailure def testTooLongInterfaceName(self): # most systems limit IFNAMSIZ to 16, take 1024 to be sure with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s: with self.assertRaisesRegex(OSError, 'interface name too long'): s.bind(('x' * 1024, 1, 2)) + # TODO: RUSTPYTHON, OSError: bind(): bad family + @unittest.expectedFailure def testBind(self): try: with socket.socket(socket.PF_CAN, socket.SOCK_DGRAM, socket.CAN_ISOTP) as s: From 9b974bda0d8792dfafcc25af5a09d6bd6386704a Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Sun, 10 Mar 2024 22:23:46 -0400 Subject: [PATCH 282/893] Re-enable test_format.test_locale Technically speaking, my system was misconfigured, leading me to disable the test in the first place. `test_locale` calls `locale.setlocale(locale.LC_ALL, '')`, which reads the value of the `LANG` environment variable and uses that to look up and reset all the locale settings. My system has `LANG=en_US.UTF-8`, which is apparently not what this test was expecting. If `LANG` is unset or set to `C`, the test passes, as it does in CI. --- Lib/test/test_format.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/Lib/test/test_format.py b/Lib/test/test_format.py index 66e2b077bf..f6c11a4aad 100644 --- a/Lib/test/test_format.py +++ b/Lib/test/test_format.py @@ -420,9 +420,6 @@ def test_non_ascii(self): self.assertEqual(format(1+2j, "\u2007^8"), "\u2007(1+2j)\u2007") self.assertEqual(format(0j, "\u2007^4"), "\u20070j\u2007") - # TODO: RUSTPYTHON formatting does not support locales - # See https://github.com/RustPython/RustPython/issues/5181 - @unittest.expectedFailure def test_locale(self): try: oldloc = locale.setlocale(locale.LC_ALL) From 23ebbd021b7f65883e05c6430aa80108d5cb3461 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Sun, 10 Mar 2024 22:53:39 -0400 Subject: [PATCH 283/893] Skip test_format.test_locale I had previously `test_locale` as expected to fail, as it did indeed fail on my system due to unimplemented functionality. As it happens, it passes in CI because the locale settings used there (`C`, I believe) just happen to format integers the same with "%d" as "%n". I mistakenly un-marked it because I thought I misunderstood the problem. --- Lib/test/test_format.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Lib/test/test_format.py b/Lib/test/test_format.py index f6c11a4aad..187270d5b6 100644 --- a/Lib/test/test_format.py +++ b/Lib/test/test_format.py @@ -420,6 +420,9 @@ def test_non_ascii(self): self.assertEqual(format(1+2j, "\u2007^8"), "\u2007(1+2j)\u2007") self.assertEqual(format(0j, "\u2007^4"), "\u20070j\u2007") + # TODO: RUSTPYTHON formatting does not support locales + # See https://github.com/RustPython/RustPython/issues/5181 + @unittest.skip("formatting does not support locales") def test_locale(self): try: oldloc = locale.setlocale(locale.LC_ALL) From 2fde8e91e5dbc070b06060c5b356a60f7e85fb3d Mon Sep 17 00:00:00 2001 From: wellweek <148746285+wellweek@users.noreply.github.com> Date: Mon, 11 Mar 2024 14:01:37 +0800 Subject: [PATCH 284/893] fix some typos (#5187) Signed-off-by: wellweek --- Lib/test/test_unpack.py | 2 +- architecture/architecture.md | 2 +- benches/benchmarks/pystone.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_unpack.py b/Lib/test/test_unpack.py index f5ca1d455b..515ec128a0 100644 --- a/Lib/test/test_unpack.py +++ b/Lib/test/test_unpack.py @@ -162,7 +162,7 @@ def test_extended_oparg_not_ignored(self): ns = {} exec(code, ns) unpack_400 = ns["unpack_400"] - # Warm up the the function for quickening (PEP 659) + # Warm up the function for quickening (PEP 659) for _ in range(30): y = unpack_400(range(400)) self.assertEqual(y, 399) diff --git a/architecture/architecture.md b/architecture/architecture.md index 5b1ae9cc68..a59b6498bf 100644 --- a/architecture/architecture.md +++ b/architecture/architecture.md @@ -101,7 +101,7 @@ Part of the Python standard library that's implemented in Rust. The modules that ### Lib -Python side of the standard libary, copied over (with care) from CPython sourcecode. +Python side of the standard library, copied over (with care) from CPython sourcecode. #### Lib/test diff --git a/benches/benchmarks/pystone.py b/benches/benchmarks/pystone.py index 3faf675ae7..755b4ba85c 100644 --- a/benches/benchmarks/pystone.py +++ b/benches/benchmarks/pystone.py @@ -16,7 +16,7 @@ Version History: - Inofficial version 1.1.1 by Chris Arndt: + Unofficial version 1.1.1 by Chris Arndt: - Make it run under Python 2 and 3 by using "from __future__ import print_function". From 7f02324dcec46bf97ccf0e9cc0b94a8ad5057abb Mon Sep 17 00:00:00 2001 From: Kirill Podoprigora Date: Mon, 11 Mar 2024 15:04:35 +0200 Subject: [PATCH 285/893] Update Lib/test/test_hmac.py to 3.12 version (#5188) --- Lib/test/test_hmac.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py index bc2e02528d..8e1a4a204c 100644 --- a/Lib/test/test_hmac.py +++ b/Lib/test/test_hmac.py @@ -389,6 +389,18 @@ def test_with_digestmod_no_default(self): with self.assertRaisesRegex(TypeError, r'required.*digestmod'): hmac.HMAC(key, msg=data, digestmod='') + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_with_fallback(self): + cache = getattr(hashlib, '__builtin_constructor_cache') + try: + cache['foo'] = hashlib.sha256 + hexdigest = hmac.digest(b'key', b'message', 'foo').hex() + expected = '6e9ef29b75fffc5b7abae527d58fdadb2fe42e7219011976917343065f58ed4a' + self.assertEqual(hexdigest, expected) + finally: + cache.pop('foo') + class ConstructorTestCase(unittest.TestCase): From 83d1ad8a2cfaa6cb96d7d1646923ed58c030b8df Mon Sep 17 00:00:00 2001 From: Kirill Podoprigora Date: Tue, 12 Mar 2024 15:35:21 +0200 Subject: [PATCH 286/893] Update test_operator.py to 3.12 (#5194) --- Lib/test/test_operator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py index b7e38c2334..1db738d228 100644 --- a/Lib/test/test_operator.py +++ b/Lib/test/test_operator.py @@ -208,6 +208,9 @@ def test_indexOf(self): nan = float("nan") self.assertEqual(operator.indexOf([nan, nan, 21], nan), 0) self.assertEqual(operator.indexOf([{}, 1, {}, 2], {}), 0) + it = iter('leave the iterator at exactly the position after the match') + self.assertEqual(operator.indexOf(it, 'a'), 2) + self.assertEqual(next(it), 'v') def test_invert(self): operator = self.module From 4e7b3bc8f247bf8a31ba1bd791844e83307dde09 Mon Sep 17 00:00:00 2001 From: Kirill Podoprigora Date: Tue, 12 Mar 2024 15:36:10 +0200 Subject: [PATCH 287/893] Update pprint.py and test_pprint.py to 3.12 (#5195) --- Lib/pprint.py | 16 ---------------- Lib/test/test_pprint.py | 2 +- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/Lib/pprint.py b/Lib/pprint.py index 575688d8eb..34ed12637e 100644 --- a/Lib/pprint.py +++ b/Lib/pprint.py @@ -637,19 +637,6 @@ def _recursion(object): % (type(object).__name__, id(object))) -def _perfcheck(object=None): - import time - if object is None: - object = [("string", (1, 2), [3, 4], {5: 6, 7: 8})] * 100000 - p = PrettyPrinter() - t1 = time.perf_counter() - p._safe_repr(object, {}, None, 0, True) - t2 = time.perf_counter() - p.pformat(object) - t3 = time.perf_counter() - print("_safe_repr:", t2 - t1) - print("pformat:", t3 - t2) - def _wrap_bytes_repr(object, width, allowance): current = b'' last = len(object) // 4 * 4 @@ -666,6 +653,3 @@ def _wrap_bytes_repr(object, width, allowance): current = candidate if current: yield repr(current) - -if __name__ == "__main__": - _perfcheck() diff --git a/Lib/test/test_pprint.py b/Lib/test/test_pprint.py index c7b9893943..6ea7e7db2c 100644 --- a/Lib/test/test_pprint.py +++ b/Lib/test/test_pprint.py @@ -203,7 +203,7 @@ def test_knotted(self): def test_unreadable(self): # Not recursive but not readable anyway pp = pprint.PrettyPrinter() - for unreadable in type(3), pprint, pprint.isrecursive: + for unreadable in object(), int, pprint, pprint.isrecursive: # module-level convenience functions self.assertFalse(pprint.isrecursive(unreadable), "expected not isrecursive for %r" % (unreadable,)) From 855fa1411fc20de7c7cfd76c807bcb435e6873f7 Mon Sep 17 00:00:00 2001 From: Kirill Podoprigora Date: Wed, 13 Mar 2024 00:35:16 +0200 Subject: [PATCH 288/893] Update ftplib and test_ftplib to 3.12 (#5196) --- Lib/ftplib.py | 36 +++++++----------------------------- Lib/test/test_ftplib.py | 22 ++++++++-------------- 2 files changed, 15 insertions(+), 43 deletions(-) diff --git a/Lib/ftplib.py b/Lib/ftplib.py index 7c5a50715f..a56e0c3085 100644 --- a/Lib/ftplib.py +++ b/Lib/ftplib.py @@ -434,10 +434,7 @@ def retrbinary(self, cmd, callback, blocksize=8192, rest=None): """ self.voidcmd('TYPE I') with self.transfercmd(cmd, rest) as conn: - while 1: - data = conn.recv(blocksize) - if not data: - break + while data := conn.recv(blocksize): callback(data) # shutdown ssl layer if _SSLSocket is not None and isinstance(conn, _SSLSocket): @@ -496,10 +493,7 @@ def storbinary(self, cmd, fp, blocksize=8192, callback=None, rest=None): """ self.voidcmd('TYPE I') with self.transfercmd(cmd, rest) as conn: - while 1: - buf = fp.read(blocksize) - if not buf: - break + while buf := fp.read(blocksize): conn.sendall(buf) if callback: callback(buf) @@ -561,7 +555,7 @@ def dir(self, *args): LIST command. (This *should* only be used for a pathname.)''' cmd = 'LIST' func = None - if args[-1:] and type(args[-1]) != type(''): + if args[-1:] and not isinstance(args[-1], str): args, func = args[:-1], args[-1] for arg in args: if arg: @@ -713,28 +707,12 @@ class FTP_TLS(FTP): '221 Goodbye.' >>> ''' - ssl_version = ssl.PROTOCOL_TLS_CLIENT def __init__(self, host='', user='', passwd='', acct='', - keyfile=None, certfile=None, context=None, - timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None, *, - encoding='utf-8'): - if context is not None and keyfile is not None: - raise ValueError("context and keyfile arguments are mutually " - "exclusive") - if context is not None and certfile is not None: - raise ValueError("context and certfile arguments are mutually " - "exclusive") - if keyfile is not None or certfile is not None: - import warnings - warnings.warn("keyfile and certfile are deprecated, use a " - "custom context instead", DeprecationWarning, 2) - self.keyfile = keyfile - self.certfile = certfile + *, context=None, timeout=_GLOBAL_DEFAULT_TIMEOUT, + source_address=None, encoding='utf-8'): if context is None: - context = ssl._create_stdlib_context(self.ssl_version, - certfile=certfile, - keyfile=keyfile) + context = ssl._create_stdlib_context() self.context = context self._prot_p = False super().__init__(host, user, passwd, acct, @@ -749,7 +727,7 @@ def auth(self): '''Set up secure control connection by using TLS/SSL.''' if isinstance(self.sock, ssl.SSLSocket): raise ValueError("Already using TLS") - if self.ssl_version >= ssl.PROTOCOL_TLS: + if self.context.protocol >= ssl.PROTOCOL_TLS: resp = self.voidcmd('AUTH TLS') else: resp = self.voidcmd('AUTH SSL') diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index e8c126ddc4..7e632efa4c 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -21,6 +21,8 @@ from test.support import threading_helper from test.support import socket_helper from test.support import warnings_helper +from test.support import asynchat +from test.support import asyncore from test.support.socket_helper import HOST, HOSTv6 import sys @@ -992,11 +994,11 @@ def test_context(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE - self.assertRaises(ValueError, ftplib.FTP_TLS, keyfile=CERTFILE, + self.assertRaises(TypeError, ftplib.FTP_TLS, keyfile=CERTFILE, context=ctx) - self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE, + self.assertRaises(TypeError, ftplib.FTP_TLS, certfile=CERTFILE, context=ctx) - self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE, + self.assertRaises(TypeError, ftplib.FTP_TLS, certfile=CERTFILE, keyfile=CERTFILE, context=ctx) self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT) @@ -1160,18 +1162,10 @@ def test__all__(self): support.check__all__(self, ftplib, not_exported=not_exported) -def test_main(): - tests = [TestFTPClass, TestTimeouts, - TestIPv6Environment, - TestTLS_FTPClassMixin, TestTLS_FTPClass, - MiscTestCase] - +def setUpModule(): thread_info = threading_helper.threading_setup() - try: - support.run_unittest(*tests) - finally: - threading_helper.threading_cleanup(*thread_info) + unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info) if __name__ == '__main__': - test_main() + unittest.main() From d8f2bd04ace5c0f07e3855f633c4bfdb0564e713 Mon Sep 17 00:00:00 2001 From: Kirill Podoprigora Date: Wed, 13 Mar 2024 08:22:24 +0200 Subject: [PATCH 289/893] Update cgitb.py to 3.12 (#5197) --- Lib/cgitb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/cgitb.py b/Lib/cgitb.py index 8ce0e833a9..f6b97f25c5 100644 --- a/Lib/cgitb.py +++ b/Lib/cgitb.py @@ -74,7 +74,7 @@ def lookup(name, frame, locals): return 'global', frame.f_globals[name] if '__builtins__' in frame.f_globals: builtins = frame.f_globals['__builtins__'] - if type(builtins) is type({}): + if isinstance(builtins, dict): if name in builtins: return 'builtin', builtins[name] else: From 92c8b371ae5db0d95bd8199bc42b08af115bb88a Mon Sep 17 00:00:00 2001 From: Kirill Podoprigora Date: Wed, 13 Mar 2024 08:22:57 +0200 Subject: [PATCH 290/893] Update colorsys.py and test_colorsys.py to 3.12 (#5198) --- Lib/colorsys.py | 2 +- Lib/test/test_colorsys.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/Lib/colorsys.py b/Lib/colorsys.py index 9bdc83e377..bc897bd0f9 100644 --- a/Lib/colorsys.py +++ b/Lib/colorsys.py @@ -83,7 +83,7 @@ def rgb_to_hls(r, g, b): if l <= 0.5: s = rangec / sumc else: - s = rangec / (2.0-sumc) + s = rangec / (2.0-maxc-minc) # Not always 2.0-sumc: gh-106498. rc = (maxc-r) / rangec gc = (maxc-g) / rangec bc = (maxc-b) / rangec diff --git a/Lib/test/test_colorsys.py b/Lib/test/test_colorsys.py index a24e3adcb4..74d76294b0 100644 --- a/Lib/test/test_colorsys.py +++ b/Lib/test/test_colorsys.py @@ -69,6 +69,16 @@ def test_hls_values(self): self.assertTripleEqual(hls, colorsys.rgb_to_hls(*rgb)) self.assertTripleEqual(rgb, colorsys.hls_to_rgb(*hls)) + def test_hls_nearwhite(self): # gh-106498 + values = ( + # rgb, hls: these do not work in reverse + ((0.9999999999999999, 1, 1), (0.5, 1.0, 1.0)), + ((1, 0.9999999999999999, 0.9999999999999999), (0.0, 1.0, 1.0)), + ) + for rgb, hls in values: + self.assertTripleEqual(hls, colorsys.rgb_to_hls(*rgb)) + self.assertTripleEqual((1.0, 1.0, 1.0), colorsys.hls_to_rgb(*hls)) + def test_yiq_roundtrip(self): for r in frange(0.0, 1.0, 0.2): for g in frange(0.0, 1.0, 0.2): From 426e582ba039492b789b114d26b29f3dfa86c56e Mon Sep 17 00:00:00 2001 From: Nikita Sobolev Date: Fri, 15 Mar 2024 16:15:45 +0300 Subject: [PATCH 291/893] Remove incorrect `@expectedFailure`s from `test_cmd_line` (#5201) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After you suggestion in https://github.com/python/cpython/issues/116504#issuecomment-1999239012 I went to take a look at `test_cmd_line` in RustPython (it was so long ago I contributed to this amazing project, so may thing had changed!), and I've noticed this. This is a problem, here' the simplest demo: ```python import unittest class TestMe(unittest.TestCase): @unittest.expectedFailure def test_me(self): def run(): raise ValueError with self.subTest(run=run): run() if __name__ == '__main__': unittest.main() ``` This works as expected: ``` » ./python.exe ex.py x ---------------------------------------------------------------------- Ran 1 test in 0.001s OK (expected failures=1) ``` This does not: ```python import unittest class TestMe(unittest.TestCase): def test_me(self): @unittest.expectedFailure def run(): raise ValueError with self.subTest(run=run): run() if __name__ == '__main__': unittest.main() ``` Produces: ``` » ./python.exe ex.py E ====================================================================== ERROR: test_me (__main__.TestMe.test_me) (run=.run at 0x1057a2150>) ---------------------------------------------------------------------- Traceback (most recent call last): File "/Users/sobolev/Desktop/cpython2/ex.py", line 10, in test_me run() ~~~^^ File "/Users/sobolev/Desktop/cpython2/ex.py", line 7, in run raise ValueError ValueError ---------------------------------------------------------------------- Ran 1 test in 0.001s FAILED (errors=1) ``` So, I propose to remove these decorators, let's only keep `TODO` comments to indicate separate failures. --- Lib/test/test_cmd_line.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py index 88ff71726f..6644a3cd5c 100644 --- a/Lib/test/test_cmd_line.py +++ b/Lib/test/test_cmd_line.py @@ -278,13 +278,11 @@ def test_invalid_utf8_arg(self): code = 'import sys, os; s=os.fsencode(sys.argv[1]); print(ascii(s))' # TODO: RUSTPYTHON - @unittest.expectedFailure def run_default(arg): cmd = [sys.executable, '-c', code, arg] return subprocess.run(cmd, stdout=subprocess.PIPE, text=True) # TODO: RUSTPYTHON - @unittest.expectedFailure def run_c_locale(arg): cmd = [sys.executable, '-c', code, arg] env = dict(os.environ) @@ -293,7 +291,6 @@ def run_c_locale(arg): text=True, env=env) # TODO: RUSTPYTHON - @unittest.expectedFailure def run_utf8_mode(arg): cmd = [sys.executable, '-X', 'utf8', '-c', code, arg] return subprocess.run(cmd, stdout=subprocess.PIPE, text=True) From 12601d0b44c719183d559fa73d76ab6561255ed9 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Mon, 18 Mar 2024 16:57:28 +0900 Subject: [PATCH 292/893] integrate sre_engine crate to workspace --- Cargo.lock | 38 ++++++++++++++++++++++++++++++++---- Cargo.toml | 5 +++-- vm/sre_engine/Cargo.toml | 8 ++++---- vm/sre_engine/tests/tests.rs | 2 +- 4 files changed, 42 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c8d0342708..52afbb053f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1428,7 +1428,16 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d829733185c1ca374f17e52b762f24f535ec625d2cc1f070e34c8a9068f341b" dependencies = [ - "num_enum_derive", + "num_enum_derive 0.5.9", +] + +[[package]] +name = "num_enum" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02339744ee7253741199f897151b38e72257d13802d4ee837285cc2990a90845" +dependencies = [ + "num_enum_derive 0.7.2", ] [[package]] @@ -1443,6 +1452,18 @@ dependencies = [ "syn 1.0.107", ] +[[package]] +name = "num_enum_derive" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "681030a937600a36906c185595136d26abfebb4aa9c65701cefcaf8578bb982b" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.32", +] + [[package]] name = "once_cell" version = "1.18.0" @@ -2165,6 +2186,15 @@ dependencies = [ "rustpython-derive", ] +[[package]] +name = "rustpython-sre_engine" +version = "0.6.0" +dependencies = [ + "bitflags 2.4.0", + "num_enum 0.7.2", + "optional", +] + [[package]] name = "rustpython-stdlib" version = "0.3.0" @@ -2200,7 +2230,7 @@ dependencies = [ "num-complex", "num-integer", "num-traits", - "num_enum", + "num_enum 0.7.2", "once_cell", "openssl", "openssl-probe", @@ -2270,7 +2300,7 @@ dependencies = [ "num-integer", "num-traits", "num_cpus", - "num_enum", + "num_enum 0.7.2", "once_cell", "optional", "parking_lot", @@ -2527,7 +2557,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a490c5c46c35dba9a6f5e7ee8e4d67e775eb2d2da0f115750b8d10e1c1ac2d28" dependencies = [ "bitflags 1.3.2", - "num_enum", + "num_enum 0.5.9", "optional", ] diff --git a/Cargo.toml b/Cargo.toml index bfc882fdc5..0f4fb49dc3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ include = ["LICENSE", "Cargo.toml", "src/**/*.rs"] resolver = "2" members = [ "compiler", "compiler/core", "compiler/codegen", - ".", "common", "derive", "jit", "vm", "pylib", "stdlib", "wasm/lib", "derive-impl", + ".", "common", "derive", "jit", "vm", "vm/sre_engine", "pylib", "stdlib", "wasm/lib", "derive-impl", ] [workspace.dependencies] @@ -27,6 +27,7 @@ rustpython-jit = { path = "jit", version = "0.3.0" } rustpython-vm = { path = "vm", default-features = false, version = "0.3.0" } rustpython-pylib = { path = "pylib", version = "0.3.0" } rustpython-stdlib = { path = "stdlib", default-features = false, version = "0.3.0" } +rustpython-sre_engine = { path = "vm/sre_engine", version = "0.6.0" } rustpython-doc = { git = "https://github.com/RustPython/__doc__", tag = "0.3.0", version = "0.3.0" } rustpython-literal = { git = "https://github.com/RustPython/Parser.git", rev = "29c4728dbedc7e69cc2560b9b34058bbba9b1303" } @@ -64,7 +65,7 @@ malachite-base = "0.4.4" num-complex = "0.4.0" num-integer = "0.1.44" num-traits = "0.2" -num_enum = "0.5.7" +num_enum = "0.7" once_cell = "1.18" parking_lot = "0.12.1" paste = "1.0.7" diff --git a/vm/sre_engine/Cargo.toml b/vm/sre_engine/Cargo.toml index e54f124ac0..2caa8b73e5 100644 --- a/vm/sre_engine/Cargo.toml +++ b/vm/sre_engine/Cargo.toml @@ -1,15 +1,15 @@ [package] -name = "sre-engine" +name = "rustpython-sre_engine" version = "0.6.0" authors = ["Kangzhi Shi ", "RustPython Team"] description = "A low-level implementation of Python's SRE regex engine" -repository = "https://github.com/RustPython/sre-engine" +repository = "https://github.com/RustPython/RustPython" license = "MIT" edition = "2021" keywords = ["regex"] include = ["LICENSE", "src/**/*.rs"] [dependencies] -num_enum = "0.7" -bitflags = "2" +num_enum = { workspace = true } +bitflags = { workspace = true } optional = "0.5" diff --git a/vm/sre_engine/tests/tests.rs b/vm/sre_engine/tests/tests.rs index f589c62e6e..53494c5e3d 100644 --- a/vm/sre_engine/tests/tests.rs +++ b/vm/sre_engine/tests/tests.rs @@ -1,4 +1,4 @@ -use sre_engine::{Request, State, StrDrive}; +use rustpython_sre_engine::{Request, State, StrDrive}; struct Pattern { code: &'static [u32], From ac7851704487ba7f37479946ef8957ead2e68097 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Thu, 21 Mar 2024 01:44:03 -0400 Subject: [PATCH 293/893] Skip TestScander.test_uninstantiable (#5204) This test was marked as an expected failure. Because the garbage collector is missing, that meant that the `os.scandir` object went unclosed. This object was squatting on the file descriptors of all the files contained in the test directory, which was breaking test_zipfile. --- Lib/test/test_os.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index 097124b7b5..c880a9b902 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -4249,7 +4249,8 @@ def assert_stat_equal(self, stat1, stat2, skip_fields): self.assertEqual(stat1, stat2) # TODO: RUSTPPYTHON (AssertionError: TypeError not raised by ScandirIter) - @unittest.expectedFailure + # TODO: See https://github.com/RustPython/RustPython/issues/5190 for skip rationale + @unittest.skip("skipping to avoid the unclosed scandir from squatting on file descriptors") def test_uninstantiable(self): scandir_iter = os.scandir(self.path) self.assertRaises(TypeError, type(scandir_iter)) From 3737f2a0918d94312b9e40b974f3939ce0f44013 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" <69878+youknowone@users.noreply.github.com> Date: Thu, 21 Mar 2024 21:48:29 +0900 Subject: [PATCH 294/893] make adding a single module simpler for interpreter users (#4792) --- src/interpreter.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/interpreter.rs b/src/interpreter.rs index b84f167ae4..6b0d2b5bde 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -1,4 +1,4 @@ -use rustpython_vm::{Interpreter, Settings, VirtualMachine}; +use rustpython_vm::{builtins::PyModule, Interpreter, PyRef, Settings, VirtualMachine}; pub type InitHook = Box; @@ -63,6 +63,15 @@ impl InterpreterConfig { self.init_hooks.push(hook); self } + pub fn add_native_module( + self, + name: String, + make_module: fn(&VirtualMachine) -> PyRef, + ) -> Self { + self.init_hook(Box::new(move |vm| { + vm.add_native_module(name, Box::new(make_module)) + })) + } #[cfg(feature = "stdlib")] pub fn init_stdlib(self) -> Self { self.init_hook(Box::new(init_stdlib)) From 5ee5531f327d7505739276686607c3b32eaeff63 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Thu, 21 Mar 2024 08:51:57 -0400 Subject: [PATCH 295/893] Properly unload modules between tests (#5192) There seems to have been a bug in the libregrtest code which unloaded modules between tests. The previous state was calculated using `sys.modules.keys()`, which is actually a mutable object that is updated as the underlying `sys.modules` is updated. The result was that modules were not unloaded between tests, which is the root cause for `test_unittest` failing when run after `test_import` and `test_importlib`. This code is copied from 3.12. Ideally all of `libregrtest` should probably be updated as it seems wildly out of date, but that's a lot more work. --- Lib/test/libregrtest/main.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/Lib/test/libregrtest/main.py b/Lib/test/libregrtest/main.py index fba24e4f32..e1d19e1e4a 100644 --- a/Lib/test/libregrtest/main.py +++ b/Lib/test/libregrtest/main.py @@ -373,7 +373,7 @@ def run_tests_sequential(self): import trace self.tracer = trace.Trace(trace=False, count=True) - save_modules = sys.modules.keys() + save_modules = set(sys.modules) print("Run tests sequentially") @@ -409,10 +409,18 @@ def run_tests_sequential(self): # be quiet: say nothing if the test passed shortly previous_test = None - # Unload the newly imported modules (best effort finalization) - for module in sys.modules.keys(): - if module not in save_modules and module.startswith("test."): - import_helper.unload(module) + # Unload the newly imported test modules (best effort finalization) + new_modules = [module for module in sys.modules + if module not in save_modules and + module.startswith(("test.", "test_"))] + for module in new_modules: + sys.modules.pop(module, None) + # Remove the attribute of the parent module. + parent, _, name = module.rpartition('.') + try: + delattr(sys.modules[parent], name) + except (KeyError, AttributeError): + pass if previous_test: print(previous_test) From 85c427b8423bc8620d9b39bf22743da4ea05cb03 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Thu, 21 Mar 2024 11:12:01 -0400 Subject: [PATCH 296/893] Reset exception in WithCleanupFinish (#5203) Context managers have an `__exit__` function that returns a boolean-like object. If the object is truthy, then exceptions are suppressed. If an exception was thrown while resolving that boolean, it would leak and live on in the error stack, getting tacked on to all future exceptions. This caused several mysterious test failures which would only trigger after this very specific event was tested in `test_with`. The solution is to move a call to `vm.set_exception()` before attempting the `try_to_bool()` which threw the error. Minimal example to reproduce the bug: ```py import sys import traceback class cm(object): def __init__(self): pass def __enter__(self): return 3 def __exit__(self, a, b, c): class Bool: def __bool__(self): 1 // 0 return Bool() try: with cm(): raise Exception("Should NOT see this") except ZeroDivisionError: print("exception caught, as expected") print("There should now be no exception") traceback.print_exc() print(sys.exc_info()) ``` --- vm/src/frame.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vm/src/frame.rs b/vm/src/frame.rs index a0af23336f..56e0f433da 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -939,10 +939,10 @@ impl ExecutingFrame<'_> { _ => self.fatal("WithCleanupFinish expects a FinallyHandler block on stack"), }; - let suppress_exception = self.pop_value().try_to_bool(vm)?; - vm.set_exception(prev_exc); + let suppress_exception = self.pop_value().try_to_bool(vm)?; + if suppress_exception { Ok(None) } else if let Some(reason) = reason { From e3150776300045da5ddf74afbd97abe89d7481bb Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Thu, 21 Mar 2024 11:28:33 -0400 Subject: [PATCH 297/893] Add TODO: RUSTPYTHON to skip reason --- Lib/test/test_os.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index c880a9b902..21ce9bc329 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -4250,7 +4250,7 @@ def assert_stat_equal(self, stat1, stat2, skip_fields): # TODO: RUSTPPYTHON (AssertionError: TypeError not raised by ScandirIter) # TODO: See https://github.com/RustPython/RustPython/issues/5190 for skip rationale - @unittest.skip("skipping to avoid the unclosed scandir from squatting on file descriptors") + @unittest.skip("TODO: RUSTPYTHON, avoid the unclosed scandir from squatting on file descriptors") def test_uninstantiable(self): scandir_iter = os.scandir(self.path) self.assertRaises(TypeError, type(scandir_iter)) From e6c73883eadb1c3feb6c8422826107a98efa0ecb Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Thu, 21 Mar 2024 13:36:28 -0400 Subject: [PATCH 298/893] Revert test skip --- Lib/test/test_os.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index 21ce9bc329..e9ab681357 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -4248,9 +4248,8 @@ def assert_stat_equal(self, stat1, stat2, skip_fields): else: self.assertEqual(stat1, stat2) - # TODO: RUSTPPYTHON (AssertionError: TypeError not raised by ScandirIter) - # TODO: See https://github.com/RustPython/RustPython/issues/5190 for skip rationale - @unittest.skip("TODO: RUSTPYTHON, avoid the unclosed scandir from squatting on file descriptors") + # TODO: RUSTPYTHON (AssertionError: TypeError not raised by ScandirIter) + @unittest.expectedFailure def test_uninstantiable(self): scandir_iter = os.scandir(self.path) self.assertRaises(TypeError, type(scandir_iter)) From 0a24e106baa25db64783ff25534f4e53a29005f9 Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Thu, 21 Mar 2024 13:37:46 -0400 Subject: [PATCH 299/893] Increase threshold for zipfile test_many_opens It turns out that there are many other tests that can impact test_many_opens by leaving unclosed file handles. Rather than fix them all, it is easier to simply increase the threshold for the problematic test. --- Lib/test/test_zipfile.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/Lib/test/test_zipfile.py b/Lib/test/test_zipfile.py index fd4a3918e6..43178ca26b 100644 --- a/Lib/test/test_zipfile.py +++ b/Lib/test/test_zipfile.py @@ -2561,17 +2561,22 @@ def test_write_after_read(self): self.assertEqual(data1, self.data1) self.assertEqual(data2, self.data2) + # TODO: RUSTPYTHON other tests can impact the file descriptor incrementor + # by leaving file handles unclosed. If there are more than 100 files in + # TESTFN and references to them are left unclosed and ungarbage collected + # in another test, then fileno() will always be too high for this test to + # pass. The solution is to increase the number of files from 100 to 200 def test_many_opens(self): # Verify that read() and open() promptly close the file descriptor, # and don't rely on the garbage collector to free resources. self.make_test_archive(TESTFN2) with zipfile.ZipFile(TESTFN2, mode="r") as zipf: - for x in range(100): + for x in range(200): zipf.read('ones') with zipf.open('ones') as zopen1: pass with open(os.devnull, "rb") as f: - self.assertLess(f.fileno(), 100) + self.assertLess(f.fileno(), 200) def test_write_while_reading(self): with zipfile.ZipFile(TESTFN2, 'w', zipfile.ZIP_DEFLATED) as zipf: From 90724b32ec1283c941f992132baf398fa87a1a2b Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Thu, 21 Mar 2024 21:25:53 -0400 Subject: [PATCH 300/893] Implement new clippy lints (#5208) * Implement new clippy lints clippy was just updated and has a few minor issues with the code base. * Forgotten lint hidden behind feature --- common/src/static_cell.rs | 2 +- common/src/str.rs | 2 +- vm/src/builtins/memory.rs | 2 +- vm/src/stdlib/os.rs | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/common/src/static_cell.rs b/common/src/static_cell.rs index 01a54db29c..7f16dad399 100644 --- a/common/src/static_cell.rs +++ b/common/src/static_cell.rs @@ -46,7 +46,7 @@ mod non_threading { F: FnOnce() -> Result, { self.inner - .with(|x| x.get_or_try_init(|| f().map(leak)).map(|&x| x)) + .with(|x| x.get_or_try_init(|| f().map(leak)).copied()) } } diff --git a/common/src/str.rs b/common/src/str.rs index cdee03f14f..48fdb0f95a 100644 --- a/common/src/str.rs +++ b/common/src/str.rs @@ -250,7 +250,7 @@ pub mod levenshtein { pub fn levenshtein_distance(a: &str, b: &str, max_cost: usize) -> usize { thread_local! { - static BUFFER: RefCell<[usize; MAX_STRING_SIZE]> = RefCell::new([0usize; MAX_STRING_SIZE]); + static BUFFER: RefCell<[usize; MAX_STRING_SIZE]> = const { RefCell::new([0usize; MAX_STRING_SIZE]) }; } if a == b { diff --git a/vm/src/builtins/memory.rs b/vm/src/builtins/memory.rs index 2c436ca316..aca2114bf0 100644 --- a/vm/src/builtins/memory.rs +++ b/vm/src/builtins/memory.rs @@ -1047,7 +1047,7 @@ impl Hashable for PyMemoryView { } Ok(zelf.contiguous_or_collect(|bytes| vm.state.hash_secret.hash_bytes(bytes))) }) - .map(|&x| x) + .copied() } } diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 376c18fb3a..bd76c8ed95 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -698,7 +698,7 @@ pub(super) mod _os { if self.is_symlink(vm)? { do_stat(true) } else { - lstat().map(Clone::clone) + lstat().cloned() } })? } else { From df363c0ba7613e54588f1f4b4b60981eb49d516f Mon Sep 17 00:00:00 2001 From: Daniel Chiquito Date: Thu, 21 Mar 2024 21:26:40 -0400 Subject: [PATCH 301/893] Skip typing test which causes other failures (#5207) --- Lib/test/test_typing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 95fd3748e6..b6a167f998 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -1466,6 +1466,10 @@ def __new__(cls, *args): with self.assertRaises(TypeError): C[int](a=42) + # TODO: RUSTPYTHON the last line breaks any tests that use unittest.mock + # See https://github.com/RustPython/RustPython/issues/5190#issuecomment-2010535802 + # It's possible that updating typing to 3.12 will resolve this + @unittest.skip("TODO: RUSTPYTHON this test breaks other tests that use unittest.mock") def test_protocols_bad_subscripts(self): T = TypeVar('T') S = TypeVar('S') From 1dd9a2fbe45f7f34f5c505d5e9774603bd34d6bb Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Fri, 22 Mar 2024 11:28:49 +0900 Subject: [PATCH 302/893] suppress clippy warnings --- vm/sre_engine/src/engine.rs | 5 +++++ vm/sre_engine/src/string.rs | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/vm/sre_engine/src/engine.rs b/vm/sre_engine/src/engine.rs index 34f00234e5..fb7d766e29 100644 --- a/vm/sre_engine/src/engine.rs +++ b/vm/sre_engine/src/engine.rs @@ -283,6 +283,8 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex let mut context_stack = vec![]; let mut popped_result = false; + // NOTE: 'result loop is not an actual loop but break label + #[allow(clippy::never_loop)] 'coro: loop { popped_result = 'result: loop { let yielded = 'context: loop { @@ -513,6 +515,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex loop { macro_rules! general_op_literal { ($f:expr) => {{ + #[allow(clippy::redundant_closure_call)] if ctx.at_end(req) || !$f(ctx.peek_code(req, 1), ctx.peek_char::()) { break 'result false; } @@ -523,6 +526,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex macro_rules! general_op_in { ($f:expr) => {{ + #[allow(clippy::redundant_closure_call)] if ctx.at_end(req) || !$f(&ctx.pattern(req)[2..], ctx.peek_char::()) { break 'result false; @@ -551,6 +555,7 @@ fn _match(req: &Request, state: &mut State, mut ctx: MatchContex }; for _ in group_start..group_end { + #[allow(clippy::redundant_closure_call)] if ctx.at_end(req) || $f(ctx.peek_char::()) != $f(gctx.peek_char::()) { diff --git a/vm/sre_engine/src/string.rs b/vm/sre_engine/src/string.rs index 1340c37423..e3f14ef019 100644 --- a/vm/sre_engine/src/string.rs +++ b/vm/sre_engine/src/string.rs @@ -101,7 +101,7 @@ impl StrDrive for &str { #[inline] fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize) { if cursor.ptr.is_null() || cursor.position > n { - *cursor = Self::create_cursor(&self, n); + *cursor = Self::create_cursor(self, n); } else if cursor.position < n { Self::skip(cursor, n - cursor.position); } From 280337a305b66d7bb9da13efa48f33dba9980766 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sat, 18 Nov 2023 15:29:58 +0200 Subject: [PATCH 303/893] Add Lib/re/* from CPython 3.12 --- Lib/{re.py => re/__init__.py} | 221 ++++--- Lib/re/_casefix.py | 106 ++++ Lib/re/_compiler.py | 763 +++++++++++++++++++++++ Lib/re/_constants.py | 219 +++++++ Lib/re/_parser.py | 1081 +++++++++++++++++++++++++++++++++ 5 files changed, 2301 insertions(+), 89 deletions(-) rename Lib/{re.py => re/__init__.py} (69%) create mode 100644 Lib/re/_casefix.py create mode 100644 Lib/re/_compiler.py create mode 100644 Lib/re/_constants.py create mode 100644 Lib/re/_parser.py diff --git a/Lib/re.py b/Lib/re/__init__.py similarity index 69% rename from Lib/re.py rename to Lib/re/__init__.py index bfb7b1ccd9..428d1b0d5f 100644 --- a/Lib/re.py +++ b/Lib/re/__init__.py @@ -122,65 +122,40 @@ """ import enum -import sre_compile -import sre_parse +from . import _compiler, _parser import functools -try: - import _locale -except ImportError: - _locale = None +import _sre # public symbols __all__ = [ "match", "fullmatch", "search", "sub", "subn", "split", - "findall", "finditer", "compile", "purge", "template", "escape", + "findall", "finditer", "compile", "purge", "escape", "error", "Pattern", "Match", "A", "I", "L", "M", "S", "X", "U", "ASCII", "IGNORECASE", "LOCALE", "MULTILINE", "DOTALL", "VERBOSE", - "UNICODE", + "UNICODE", "NOFLAG", "RegexFlag", ] __version__ = "2.2.1" -class RegexFlag(enum.IntFlag): - ASCII = A = sre_compile.SRE_FLAG_ASCII # assume ascii "locale" - IGNORECASE = I = sre_compile.SRE_FLAG_IGNORECASE # ignore case - LOCALE = L = sre_compile.SRE_FLAG_LOCALE # assume current 8-bit locale - UNICODE = U = sre_compile.SRE_FLAG_UNICODE # assume unicode "locale" - MULTILINE = M = sre_compile.SRE_FLAG_MULTILINE # make anchors look for newline - DOTALL = S = sre_compile.SRE_FLAG_DOTALL # make dot match newline - VERBOSE = X = sre_compile.SRE_FLAG_VERBOSE # ignore whitespace and comments +@enum.global_enum +@enum._simple_enum(enum.IntFlag, boundary=enum.KEEP) +class RegexFlag: + NOFLAG = 0 + ASCII = A = _compiler.SRE_FLAG_ASCII # assume ascii "locale" + IGNORECASE = I = _compiler.SRE_FLAG_IGNORECASE # ignore case + LOCALE = L = _compiler.SRE_FLAG_LOCALE # assume current 8-bit locale + UNICODE = U = _compiler.SRE_FLAG_UNICODE # assume unicode "locale" + MULTILINE = M = _compiler.SRE_FLAG_MULTILINE # make anchors look for newline + DOTALL = S = _compiler.SRE_FLAG_DOTALL # make dot match newline + VERBOSE = X = _compiler.SRE_FLAG_VERBOSE # ignore whitespace and comments # sre extensions (experimental, don't rely on these) - TEMPLATE = T = sre_compile.SRE_FLAG_TEMPLATE # disable backtracking - DEBUG = sre_compile.SRE_FLAG_DEBUG # dump pattern after compilation - - def __repr__(self): - if self._name_ is not None: - return f're.{self._name_}' - value = self._value_ - members = [] - negative = value < 0 - if negative: - value = ~value - for m in self.__class__: - if value & m._value_: - value &= ~m._value_ - members.append(f're.{m._name_}') - if value: - members.append(hex(value)) - res = '|'.join(members) - if negative: - if len(members) > 1: - res = f'~({res})' - else: - res = f'~{res}' - return res + DEBUG = _compiler.SRE_FLAG_DEBUG # dump pattern after compilation __str__ = object.__str__ - -globals().update(RegexFlag.__members__) + _numeric_repr_ = hex # sre exception -error = sre_compile.error +error = _compiler.error # -------------------------------------------------------------------- # public interface @@ -200,16 +175,39 @@ def search(pattern, string, flags=0): a Match object, or None if no match was found.""" return _compile(pattern, flags).search(string) -def sub(pattern, repl, string, count=0, flags=0): +class _ZeroSentinel(int): + pass +_zero_sentinel = _ZeroSentinel() + +def sub(pattern, repl, string, *args, count=_zero_sentinel, flags=_zero_sentinel): """Return the string obtained by replacing the leftmost non-overlapping occurrences of the pattern in string by the replacement repl. repl can be either a string or a callable; if a string, backslash escapes in it are processed. If it is a callable, it's passed the Match object and must return a replacement string to be used.""" + if args: + if count is not _zero_sentinel: + raise TypeError("sub() got multiple values for argument 'count'") + count, *args = args + if args: + if flags is not _zero_sentinel: + raise TypeError("sub() got multiple values for argument 'flags'") + flags, *args = args + if args: + raise TypeError("sub() takes from 3 to 5 positional arguments " + "but %d were given" % (5 + len(args))) + + import warnings + warnings.warn( + "'count' is passed as positional argument", + DeprecationWarning, stacklevel=2 + ) + return _compile(pattern, flags).sub(repl, string, count) +sub.__text_signature__ = '(pattern, repl, string, count=0, flags=0)' -def subn(pattern, repl, string, count=0, flags=0): +def subn(pattern, repl, string, *args, count=_zero_sentinel, flags=_zero_sentinel): """Return a 2-tuple containing (new_string, number). new_string is the string obtained by replacing the leftmost non-overlapping occurrences of the pattern in the source @@ -218,9 +216,28 @@ def subn(pattern, repl, string, count=0, flags=0): callable; if a string, backslash escapes in it are processed. If it is a callable, it's passed the Match object and must return a replacement string to be used.""" + if args: + if count is not _zero_sentinel: + raise TypeError("subn() got multiple values for argument 'count'") + count, *args = args + if args: + if flags is not _zero_sentinel: + raise TypeError("subn() got multiple values for argument 'flags'") + flags, *args = args + if args: + raise TypeError("subn() takes from 3 to 5 positional arguments " + "but %d were given" % (5 + len(args))) + + import warnings + warnings.warn( + "'count' is passed as positional argument", + DeprecationWarning, stacklevel=2 + ) + return _compile(pattern, flags).subn(repl, string, count) +subn.__text_signature__ = '(pattern, repl, string, count=0, flags=0)' -def split(pattern, string, maxsplit=0, flags=0): +def split(pattern, string, *args, maxsplit=_zero_sentinel, flags=_zero_sentinel): """Split the source string by the occurrences of the pattern, returning a list containing the resulting substrings. If capturing parentheses are used in pattern, then the text of all @@ -228,7 +245,26 @@ def split(pattern, string, maxsplit=0, flags=0): list. If maxsplit is nonzero, at most maxsplit splits occur, and the remainder of the string is returned as the final element of the list.""" + if args: + if maxsplit is not _zero_sentinel: + raise TypeError("split() got multiple values for argument 'maxsplit'") + maxsplit, *args = args + if args: + if flags is not _zero_sentinel: + raise TypeError("split() got multiple values for argument 'flags'") + flags, *args = args + if args: + raise TypeError("split() takes from 2 to 4 positional arguments " + "but %d were given" % (4 + len(args))) + + import warnings + warnings.warn( + "'maxsplit' is passed as positional argument", + DeprecationWarning, stacklevel=2 + ) + return _compile(pattern, flags).split(string, maxsplit) +split.__text_signature__ = '(pattern, string, maxsplit=0, flags=0)' def findall(pattern, string, flags=0): """Return a list of all non-overlapping matches in the string. @@ -254,11 +290,9 @@ def compile(pattern, flags=0): def purge(): "Clear the regular expression caches" _cache.clear() - _compile_repl.cache_clear() + _cache2.clear() + _compile_template.cache_clear() -def template(pattern, flags=0): - "Compile a template pattern, returning a Pattern object" - return _compile(pattern, flags|T) # SPECIAL_CHARS # closing ')', '}' and ']' @@ -277,60 +311,69 @@ def escape(pattern): pattern = str(pattern, 'latin1') return pattern.translate(_special_chars_map).encode('latin1') -Pattern = type(sre_compile.compile('', 0)) -Match = type(sre_compile.compile('', 0).match('')) +Pattern = type(_compiler.compile('', 0)) +Match = type(_compiler.compile('', 0).match('')) # -------------------------------------------------------------------- # internals -_cache = {} # ordered! - +# Use the fact that dict keeps the insertion order. +# _cache2 uses the simple FIFO policy which has better latency. +# _cache uses the LRU policy which has better hit rate. +_cache = {} # LRU +_cache2 = {} # FIFO _MAXCACHE = 512 +_MAXCACHE2 = 256 +assert _MAXCACHE2 < _MAXCACHE + def _compile(pattern, flags): # internal: compile pattern if isinstance(flags, RegexFlag): flags = flags.value try: - return _cache[type(pattern), pattern, flags] + return _cache2[type(pattern), pattern, flags] except KeyError: pass - if isinstance(pattern, Pattern): - if flags: - raise ValueError( - "cannot process flags argument with a compiled pattern") - return pattern - if not sre_compile.isstring(pattern): - raise TypeError("first argument must be string or compiled pattern") - p = sre_compile.compile(pattern, flags) - if not (flags & DEBUG): + + key = (type(pattern), pattern, flags) + # Item in _cache should be moved to the end if found. + p = _cache.pop(key, None) + if p is None: + if isinstance(pattern, Pattern): + if flags: + raise ValueError( + "cannot process flags argument with a compiled pattern") + return pattern + if not _compiler.isstring(pattern): + raise TypeError("first argument must be string or compiled pattern") + p = _compiler.compile(pattern, flags) + if flags & DEBUG: + return p if len(_cache) >= _MAXCACHE: - # Drop the oldest item + # Drop the least recently used item. + # next(iter(_cache)) is known to have linear amortized time, + # but it is used here to avoid a dependency from using OrderedDict. + # For the small _MAXCACHE value it doesn't make much of a difference. try: del _cache[next(iter(_cache))] except (StopIteration, RuntimeError, KeyError): pass - _cache[type(pattern), pattern, flags] = p + # Append to the end. + _cache[key] = p + + if len(_cache2) >= _MAXCACHE2: + # Drop the oldest item. + try: + del _cache2[next(iter(_cache2))] + except (StopIteration, RuntimeError, KeyError): + pass + _cache2[key] = p return p @functools.lru_cache(_MAXCACHE) -def _compile_repl(repl, pattern): +def _compile_template(pattern, repl): # internal: compile replacement pattern - return sre_parse.parse_template(repl, pattern) - -def _expand(pattern, match, template): - # internal: Match.expand implementation hook - template = sre_parse.parse_template(template, pattern) - return sre_parse.expand_template(template, match) - -def _subx(pattern, template): - # internal: Pattern.sub/subn implementation helper - template = _compile_repl(template, pattern) - if not template[0] and len(template[1]) == 1: - # literal replacement - return template[1][0] - def filter(match, template=template): - return sre_parse.expand_template(template, match) - return filter + return _sre.template(pattern, _parser.parse_template(repl, pattern)) # register myself for pickling @@ -346,22 +389,22 @@ def _pickle(p): class Scanner: def __init__(self, lexicon, flags=0): - from sre_constants import BRANCH, SUBPATTERN + from ._constants import BRANCH, SUBPATTERN if isinstance(flags, RegexFlag): flags = flags.value self.lexicon = lexicon # combine phrases into a compound pattern p = [] - s = sre_parse.State() + s = _parser.State() s.flags = flags for phrase, action in lexicon: gid = s.opengroup() - p.append(sre_parse.SubPattern(s, [ - (SUBPATTERN, (gid, 0, 0, sre_parse.parse(phrase, flags))), + p.append(_parser.SubPattern(s, [ + (SUBPATTERN, (gid, 0, 0, _parser.parse(phrase, flags))), ])) s.closegroup(gid, p[-1]) - p = sre_parse.SubPattern(s, [(BRANCH, (None, p))]) - self.scanner = sre_compile.compile(p) + p = _parser.SubPattern(s, [(BRANCH, (None, p))]) + self.scanner = _compiler.compile(p) def scan(self, string): result = [] append = result.append diff --git a/Lib/re/_casefix.py b/Lib/re/_casefix.py new file mode 100644 index 0000000000..06507d08be --- /dev/null +++ b/Lib/re/_casefix.py @@ -0,0 +1,106 @@ +# Auto-generated by Tools/scripts/generate_re_casefix.py. + +# Maps the code of lowercased character to codes of different lowercased +# characters which have the same uppercase. +_EXTRA_CASES = { + # LATIN SMALL LETTER I: LATIN SMALL LETTER DOTLESS I + 0x0069: (0x0131,), # 'i': 'ı' + # LATIN SMALL LETTER S: LATIN SMALL LETTER LONG S + 0x0073: (0x017f,), # 's': 'ſ' + # MICRO SIGN: GREEK SMALL LETTER MU + 0x00b5: (0x03bc,), # 'µ': 'μ' + # LATIN SMALL LETTER DOTLESS I: LATIN SMALL LETTER I + 0x0131: (0x0069,), # 'ı': 'i' + # LATIN SMALL LETTER LONG S: LATIN SMALL LETTER S + 0x017f: (0x0073,), # 'ſ': 's' + # COMBINING GREEK YPOGEGRAMMENI: GREEK SMALL LETTER IOTA, GREEK PROSGEGRAMMENI + 0x0345: (0x03b9, 0x1fbe), # '\u0345': 'ιι' + # GREEK SMALL LETTER IOTA WITH DIALYTIKA AND TONOS: GREEK SMALL LETTER IOTA WITH DIALYTIKA AND OXIA + 0x0390: (0x1fd3,), # 'ΐ': 'ΐ' + # GREEK SMALL LETTER UPSILON WITH DIALYTIKA AND TONOS: GREEK SMALL LETTER UPSILON WITH DIALYTIKA AND OXIA + 0x03b0: (0x1fe3,), # 'ΰ': 'ΰ' + # GREEK SMALL LETTER BETA: GREEK BETA SYMBOL + 0x03b2: (0x03d0,), # 'β': 'ϐ' + # GREEK SMALL LETTER EPSILON: GREEK LUNATE EPSILON SYMBOL + 0x03b5: (0x03f5,), # 'ε': 'ϵ' + # GREEK SMALL LETTER THETA: GREEK THETA SYMBOL + 0x03b8: (0x03d1,), # 'θ': 'ϑ' + # GREEK SMALL LETTER IOTA: COMBINING GREEK YPOGEGRAMMENI, GREEK PROSGEGRAMMENI + 0x03b9: (0x0345, 0x1fbe), # 'ι': '\u0345ι' + # GREEK SMALL LETTER KAPPA: GREEK KAPPA SYMBOL + 0x03ba: (0x03f0,), # 'κ': 'ϰ' + # GREEK SMALL LETTER MU: MICRO SIGN + 0x03bc: (0x00b5,), # 'μ': 'µ' + # GREEK SMALL LETTER PI: GREEK PI SYMBOL + 0x03c0: (0x03d6,), # 'π': 'ϖ' + # GREEK SMALL LETTER RHO: GREEK RHO SYMBOL + 0x03c1: (0x03f1,), # 'ρ': 'ϱ' + # GREEK SMALL LETTER FINAL SIGMA: GREEK SMALL LETTER SIGMA + 0x03c2: (0x03c3,), # 'ς': 'σ' + # GREEK SMALL LETTER SIGMA: GREEK SMALL LETTER FINAL SIGMA + 0x03c3: (0x03c2,), # 'σ': 'ς' + # GREEK SMALL LETTER PHI: GREEK PHI SYMBOL + 0x03c6: (0x03d5,), # 'φ': 'ϕ' + # GREEK BETA SYMBOL: GREEK SMALL LETTER BETA + 0x03d0: (0x03b2,), # 'ϐ': 'β' + # GREEK THETA SYMBOL: GREEK SMALL LETTER THETA + 0x03d1: (0x03b8,), # 'ϑ': 'θ' + # GREEK PHI SYMBOL: GREEK SMALL LETTER PHI + 0x03d5: (0x03c6,), # 'ϕ': 'φ' + # GREEK PI SYMBOL: GREEK SMALL LETTER PI + 0x03d6: (0x03c0,), # 'ϖ': 'π' + # GREEK KAPPA SYMBOL: GREEK SMALL LETTER KAPPA + 0x03f0: (0x03ba,), # 'ϰ': 'κ' + # GREEK RHO SYMBOL: GREEK SMALL LETTER RHO + 0x03f1: (0x03c1,), # 'ϱ': 'ρ' + # GREEK LUNATE EPSILON SYMBOL: GREEK SMALL LETTER EPSILON + 0x03f5: (0x03b5,), # 'ϵ': 'ε' + # CYRILLIC SMALL LETTER VE: CYRILLIC SMALL LETTER ROUNDED VE + 0x0432: (0x1c80,), # 'в': 'ᲀ' + # CYRILLIC SMALL LETTER DE: CYRILLIC SMALL LETTER LONG-LEGGED DE + 0x0434: (0x1c81,), # 'д': 'ᲁ' + # CYRILLIC SMALL LETTER O: CYRILLIC SMALL LETTER NARROW O + 0x043e: (0x1c82,), # 'о': 'ᲂ' + # CYRILLIC SMALL LETTER ES: CYRILLIC SMALL LETTER WIDE ES + 0x0441: (0x1c83,), # 'с': 'ᲃ' + # CYRILLIC SMALL LETTER TE: CYRILLIC SMALL LETTER TALL TE, CYRILLIC SMALL LETTER THREE-LEGGED TE + 0x0442: (0x1c84, 0x1c85), # 'т': 'ᲄᲅ' + # CYRILLIC SMALL LETTER HARD SIGN: CYRILLIC SMALL LETTER TALL HARD SIGN + 0x044a: (0x1c86,), # 'ъ': 'ᲆ' + # CYRILLIC SMALL LETTER YAT: CYRILLIC SMALL LETTER TALL YAT + 0x0463: (0x1c87,), # 'ѣ': 'ᲇ' + # CYRILLIC SMALL LETTER ROUNDED VE: CYRILLIC SMALL LETTER VE + 0x1c80: (0x0432,), # 'ᲀ': 'в' + # CYRILLIC SMALL LETTER LONG-LEGGED DE: CYRILLIC SMALL LETTER DE + 0x1c81: (0x0434,), # 'ᲁ': 'д' + # CYRILLIC SMALL LETTER NARROW O: CYRILLIC SMALL LETTER O + 0x1c82: (0x043e,), # 'ᲂ': 'о' + # CYRILLIC SMALL LETTER WIDE ES: CYRILLIC SMALL LETTER ES + 0x1c83: (0x0441,), # 'ᲃ': 'с' + # CYRILLIC SMALL LETTER TALL TE: CYRILLIC SMALL LETTER TE, CYRILLIC SMALL LETTER THREE-LEGGED TE + 0x1c84: (0x0442, 0x1c85), # 'ᲄ': 'тᲅ' + # CYRILLIC SMALL LETTER THREE-LEGGED TE: CYRILLIC SMALL LETTER TE, CYRILLIC SMALL LETTER TALL TE + 0x1c85: (0x0442, 0x1c84), # 'ᲅ': 'тᲄ' + # CYRILLIC SMALL LETTER TALL HARD SIGN: CYRILLIC SMALL LETTER HARD SIGN + 0x1c86: (0x044a,), # 'ᲆ': 'ъ' + # CYRILLIC SMALL LETTER TALL YAT: CYRILLIC SMALL LETTER YAT + 0x1c87: (0x0463,), # 'ᲇ': 'ѣ' + # CYRILLIC SMALL LETTER UNBLENDED UK: CYRILLIC SMALL LETTER MONOGRAPH UK + 0x1c88: (0xa64b,), # 'ᲈ': 'ꙋ' + # LATIN SMALL LETTER S WITH DOT ABOVE: LATIN SMALL LETTER LONG S WITH DOT ABOVE + 0x1e61: (0x1e9b,), # 'ṡ': 'ẛ' + # LATIN SMALL LETTER LONG S WITH DOT ABOVE: LATIN SMALL LETTER S WITH DOT ABOVE + 0x1e9b: (0x1e61,), # 'ẛ': 'ṡ' + # GREEK PROSGEGRAMMENI: COMBINING GREEK YPOGEGRAMMENI, GREEK SMALL LETTER IOTA + 0x1fbe: (0x0345, 0x03b9), # 'ι': '\u0345ι' + # GREEK SMALL LETTER IOTA WITH DIALYTIKA AND OXIA: GREEK SMALL LETTER IOTA WITH DIALYTIKA AND TONOS + 0x1fd3: (0x0390,), # 'ΐ': 'ΐ' + # GREEK SMALL LETTER UPSILON WITH DIALYTIKA AND OXIA: GREEK SMALL LETTER UPSILON WITH DIALYTIKA AND TONOS + 0x1fe3: (0x03b0,), # 'ΰ': 'ΰ' + # CYRILLIC SMALL LETTER MONOGRAPH UK: CYRILLIC SMALL LETTER UNBLENDED UK + 0xa64b: (0x1c88,), # 'ꙋ': 'ᲈ' + # LATIN SMALL LIGATURE LONG S T: LATIN SMALL LIGATURE ST + 0xfb05: (0xfb06,), # 'ſt': 'st' + # LATIN SMALL LIGATURE ST: LATIN SMALL LIGATURE LONG S T + 0xfb06: (0xfb05,), # 'st': 'ſt' +} diff --git a/Lib/re/_compiler.py b/Lib/re/_compiler.py new file mode 100644 index 0000000000..f87712d6d6 --- /dev/null +++ b/Lib/re/_compiler.py @@ -0,0 +1,763 @@ +# +# Secret Labs' Regular Expression Engine +# +# convert template to internal format +# +# Copyright (c) 1997-2001 by Secret Labs AB. All rights reserved. +# +# See the __init__.py file for information on usage and redistribution. +# + +"""Internal support module for sre""" + +import _sre +from . import _parser +from ._constants import * +from ._casefix import _EXTRA_CASES + +assert _sre.MAGIC == MAGIC, "SRE module mismatch" + +_LITERAL_CODES = {LITERAL, NOT_LITERAL} +_SUCCESS_CODES = {SUCCESS, FAILURE} +_ASSERT_CODES = {ASSERT, ASSERT_NOT} +_UNIT_CODES = _LITERAL_CODES | {ANY, IN} + +_REPEATING_CODES = { + MIN_REPEAT: (REPEAT, MIN_UNTIL, MIN_REPEAT_ONE), + MAX_REPEAT: (REPEAT, MAX_UNTIL, REPEAT_ONE), + POSSESSIVE_REPEAT: (POSSESSIVE_REPEAT, SUCCESS, POSSESSIVE_REPEAT_ONE), +} + +def _combine_flags(flags, add_flags, del_flags, + TYPE_FLAGS=_parser.TYPE_FLAGS): + if add_flags & TYPE_FLAGS: + flags &= ~TYPE_FLAGS + return (flags | add_flags) & ~del_flags + +def _compile(code, pattern, flags): + # internal: compile a (sub)pattern + emit = code.append + _len = len + LITERAL_CODES = _LITERAL_CODES + REPEATING_CODES = _REPEATING_CODES + SUCCESS_CODES = _SUCCESS_CODES + ASSERT_CODES = _ASSERT_CODES + iscased = None + tolower = None + fixes = None + if flags & SRE_FLAG_IGNORECASE and not flags & SRE_FLAG_LOCALE: + if flags & SRE_FLAG_UNICODE: + iscased = _sre.unicode_iscased + tolower = _sre.unicode_tolower + fixes = _EXTRA_CASES + else: + iscased = _sre.ascii_iscased + tolower = _sre.ascii_tolower + for op, av in pattern: + if op in LITERAL_CODES: + if not flags & SRE_FLAG_IGNORECASE: + emit(op) + emit(av) + elif flags & SRE_FLAG_LOCALE: + emit(OP_LOCALE_IGNORE[op]) + emit(av) + elif not iscased(av): + emit(op) + emit(av) + else: + lo = tolower(av) + if not fixes: # ascii + emit(OP_IGNORE[op]) + emit(lo) + elif lo not in fixes: + emit(OP_UNICODE_IGNORE[op]) + emit(lo) + else: + emit(IN_UNI_IGNORE) + skip = _len(code); emit(0) + if op is NOT_LITERAL: + emit(NEGATE) + for k in (lo,) + fixes[lo]: + emit(LITERAL) + emit(k) + emit(FAILURE) + code[skip] = _len(code) - skip + elif op is IN: + charset, hascased = _optimize_charset(av, iscased, tolower, fixes) + if flags & SRE_FLAG_IGNORECASE and flags & SRE_FLAG_LOCALE: + emit(IN_LOC_IGNORE) + elif not hascased: + emit(IN) + elif not fixes: # ascii + emit(IN_IGNORE) + else: + emit(IN_UNI_IGNORE) + skip = _len(code); emit(0) + _compile_charset(charset, flags, code) + code[skip] = _len(code) - skip + elif op is ANY: + if flags & SRE_FLAG_DOTALL: + emit(ANY_ALL) + else: + emit(ANY) + elif op in REPEATING_CODES: + if _simple(av[2]): + emit(REPEATING_CODES[op][2]) + skip = _len(code); emit(0) + emit(av[0]) + emit(av[1]) + _compile(code, av[2], flags) + emit(SUCCESS) + code[skip] = _len(code) - skip + else: + emit(REPEATING_CODES[op][0]) + skip = _len(code); emit(0) + emit(av[0]) + emit(av[1]) + _compile(code, av[2], flags) + code[skip] = _len(code) - skip + emit(REPEATING_CODES[op][1]) + elif op is SUBPATTERN: + group, add_flags, del_flags, p = av + if group: + emit(MARK) + emit((group-1)*2) + # _compile_info(code, p, _combine_flags(flags, add_flags, del_flags)) + _compile(code, p, _combine_flags(flags, add_flags, del_flags)) + if group: + emit(MARK) + emit((group-1)*2+1) + elif op is ATOMIC_GROUP: + # Atomic Groups are handled by starting with an Atomic + # Group op code, then putting in the atomic group pattern + # and finally a success op code to tell any repeat + # operations within the Atomic Group to stop eating and + # pop their stack if they reach it + emit(ATOMIC_GROUP) + skip = _len(code); emit(0) + _compile(code, av, flags) + emit(SUCCESS) + code[skip] = _len(code) - skip + elif op in SUCCESS_CODES: + emit(op) + elif op in ASSERT_CODES: + emit(op) + skip = _len(code); emit(0) + if av[0] >= 0: + emit(0) # look ahead + else: + lo, hi = av[1].getwidth() + if lo > MAXCODE: + raise error("looks too much behind") + if lo != hi: + raise error("look-behind requires fixed-width pattern") + emit(lo) # look behind + _compile(code, av[1], flags) + emit(SUCCESS) + code[skip] = _len(code) - skip + elif op is AT: + emit(op) + if flags & SRE_FLAG_MULTILINE: + av = AT_MULTILINE.get(av, av) + if flags & SRE_FLAG_LOCALE: + av = AT_LOCALE.get(av, av) + elif flags & SRE_FLAG_UNICODE: + av = AT_UNICODE.get(av, av) + emit(av) + elif op is BRANCH: + emit(op) + tail = [] + tailappend = tail.append + for av in av[1]: + skip = _len(code); emit(0) + # _compile_info(code, av, flags) + _compile(code, av, flags) + emit(JUMP) + tailappend(_len(code)); emit(0) + code[skip] = _len(code) - skip + emit(FAILURE) # end of branch + for tail in tail: + code[tail] = _len(code) - tail + elif op is CATEGORY: + emit(op) + if flags & SRE_FLAG_LOCALE: + av = CH_LOCALE[av] + elif flags & SRE_FLAG_UNICODE: + av = CH_UNICODE[av] + emit(av) + elif op is GROUPREF: + if not flags & SRE_FLAG_IGNORECASE: + emit(op) + elif flags & SRE_FLAG_LOCALE: + emit(GROUPREF_LOC_IGNORE) + elif not fixes: # ascii + emit(GROUPREF_IGNORE) + else: + emit(GROUPREF_UNI_IGNORE) + emit(av-1) + elif op is GROUPREF_EXISTS: + emit(op) + emit(av[0]-1) + skipyes = _len(code); emit(0) + _compile(code, av[1], flags) + if av[2]: + emit(JUMP) + skipno = _len(code); emit(0) + code[skipyes] = _len(code) - skipyes + 1 + _compile(code, av[2], flags) + code[skipno] = _len(code) - skipno + else: + code[skipyes] = _len(code) - skipyes + 1 + else: + raise error("internal: unsupported operand type %r" % (op,)) + +def _compile_charset(charset, flags, code): + # compile charset subprogram + emit = code.append + for op, av in charset: + emit(op) + if op is NEGATE: + pass + elif op is LITERAL: + emit(av) + elif op is RANGE or op is RANGE_UNI_IGNORE: + emit(av[0]) + emit(av[1]) + elif op is CHARSET: + code.extend(av) + elif op is BIGCHARSET: + code.extend(av) + elif op is CATEGORY: + if flags & SRE_FLAG_LOCALE: + emit(CH_LOCALE[av]) + elif flags & SRE_FLAG_UNICODE: + emit(CH_UNICODE[av]) + else: + emit(av) + else: + raise error("internal: unsupported set operator %r" % (op,)) + emit(FAILURE) + +def _optimize_charset(charset, iscased=None, fixup=None, fixes=None): + # internal: optimize character set + out = [] + tail = [] + charmap = bytearray(256) + hascased = False + for op, av in charset: + while True: + try: + if op is LITERAL: + if fixup: + lo = fixup(av) + charmap[lo] = 1 + if fixes and lo in fixes: + for k in fixes[lo]: + charmap[k] = 1 + if not hascased and iscased(av): + hascased = True + else: + charmap[av] = 1 + elif op is RANGE: + r = range(av[0], av[1]+1) + if fixup: + if fixes: + for i in map(fixup, r): + charmap[i] = 1 + if i in fixes: + for k in fixes[i]: + charmap[k] = 1 + else: + for i in map(fixup, r): + charmap[i] = 1 + if not hascased: + hascased = any(map(iscased, r)) + else: + for i in r: + charmap[i] = 1 + elif op is NEGATE: + out.append((op, av)) + else: + tail.append((op, av)) + except IndexError: + if len(charmap) == 256: + # character set contains non-UCS1 character codes + charmap += b'\0' * 0xff00 + continue + # Character set contains non-BMP character codes. + # For range, all BMP characters in the range are already + # proceeded. + if fixup: + hascased = True + # For now, IN_UNI_IGNORE+LITERAL and + # IN_UNI_IGNORE+RANGE_UNI_IGNORE work for all non-BMP + # characters, because two characters (at least one of + # which is not in the BMP) match case-insensitively + # if and only if: + # 1) c1.lower() == c2.lower() + # 2) c1.lower() == c2 or c1.lower().upper() == c2 + # Also, both c.lower() and c.lower().upper() are single + # characters for every non-BMP character. + if op is RANGE: + op = RANGE_UNI_IGNORE + tail.append((op, av)) + break + + # compress character map + runs = [] + q = 0 + while True: + p = charmap.find(1, q) + if p < 0: + break + if len(runs) >= 2: + runs = None + break + q = charmap.find(0, p) + if q < 0: + runs.append((p, len(charmap))) + break + runs.append((p, q)) + if runs is not None: + # use literal/range + for p, q in runs: + if q - p == 1: + out.append((LITERAL, p)) + else: + out.append((RANGE, (p, q - 1))) + out += tail + # if the case was changed or new representation is more compact + if hascased or len(out) < len(charset): + return out, hascased + # else original character set is good enough + return charset, hascased + + # use bitmap + if len(charmap) == 256: + data = _mk_bitmap(charmap) + out.append((CHARSET, data)) + out += tail + return out, hascased + + # To represent a big charset, first a bitmap of all characters in the + # set is constructed. Then, this bitmap is sliced into chunks of 256 + # characters, duplicate chunks are eliminated, and each chunk is + # given a number. In the compiled expression, the charset is + # represented by a 32-bit word sequence, consisting of one word for + # the number of different chunks, a sequence of 256 bytes (64 words) + # of chunk numbers indexed by their original chunk position, and a + # sequence of 256-bit chunks (8 words each). + + # Compression is normally good: in a typical charset, large ranges of + # Unicode will be either completely excluded (e.g. if only cyrillic + # letters are to be matched), or completely included (e.g. if large + # subranges of Kanji match). These ranges will be represented by + # chunks of all one-bits or all zero-bits. + + # Matching can be also done efficiently: the more significant byte of + # the Unicode character is an index into the chunk number, and the + # less significant byte is a bit index in the chunk (just like the + # CHARSET matching). + + charmap = bytes(charmap) # should be hashable + comps = {} + mapping = bytearray(256) + block = 0 + data = bytearray() + for i in range(0, 65536, 256): + chunk = charmap[i: i + 256] + if chunk in comps: + mapping[i // 256] = comps[chunk] + else: + mapping[i // 256] = comps[chunk] = block + block += 1 + data += chunk + data = _mk_bitmap(data) + data[0:0] = [block] + _bytes_to_codes(mapping) + out.append((BIGCHARSET, data)) + out += tail + return out, hascased + +_CODEBITS = _sre.CODESIZE * 8 +MAXCODE = (1 << _CODEBITS) - 1 +_BITS_TRANS = b'0' + b'1' * 255 +def _mk_bitmap(bits, _CODEBITS=_CODEBITS, _int=int): + s = bits.translate(_BITS_TRANS)[::-1] + return [_int(s[i - _CODEBITS: i], 2) + for i in range(len(s), 0, -_CODEBITS)] + +def _bytes_to_codes(b): + # Convert block indices to word array + a = memoryview(b).cast('I') + assert a.itemsize == _sre.CODESIZE + assert len(a) * a.itemsize == len(b) + return a.tolist() + +def _simple(p): + # check if this subpattern is a "simple" operator + if len(p) != 1: + return False + op, av = p[0] + if op is SUBPATTERN: + return av[0] is None and _simple(av[-1]) + return op in _UNIT_CODES + +def _generate_overlap_table(prefix): + """ + Generate an overlap table for the following prefix. + An overlap table is a table of the same size as the prefix which + informs about the potential self-overlap for each index in the prefix: + - if overlap[i] == 0, prefix[i:] can't overlap prefix[0:...] + - if overlap[i] == k with 0 < k <= i, prefix[i-k+1:i+1] overlaps with + prefix[0:k] + """ + table = [0] * len(prefix) + for i in range(1, len(prefix)): + idx = table[i - 1] + while prefix[i] != prefix[idx]: + if idx == 0: + table[i] = 0 + break + idx = table[idx - 1] + else: + table[i] = idx + 1 + return table + +def _get_iscased(flags): + if not flags & SRE_FLAG_IGNORECASE: + return None + elif flags & SRE_FLAG_UNICODE: + return _sre.unicode_iscased + else: + return _sre.ascii_iscased + +def _get_literal_prefix(pattern, flags): + # look for literal prefix + prefix = [] + prefixappend = prefix.append + prefix_skip = None + iscased = _get_iscased(flags) + for op, av in pattern.data: + if op is LITERAL: + if iscased and iscased(av): + break + prefixappend(av) + elif op is SUBPATTERN: + group, add_flags, del_flags, p = av + flags1 = _combine_flags(flags, add_flags, del_flags) + if flags1 & SRE_FLAG_IGNORECASE and flags1 & SRE_FLAG_LOCALE: + break + prefix1, prefix_skip1, got_all = _get_literal_prefix(p, flags1) + if prefix_skip is None: + if group is not None: + prefix_skip = len(prefix) + elif prefix_skip1 is not None: + prefix_skip = len(prefix) + prefix_skip1 + prefix.extend(prefix1) + if not got_all: + break + else: + break + else: + return prefix, prefix_skip, True + return prefix, prefix_skip, False + +def _get_charset_prefix(pattern, flags): + while True: + if not pattern.data: + return None + op, av = pattern.data[0] + if op is not SUBPATTERN: + break + group, add_flags, del_flags, pattern = av + flags = _combine_flags(flags, add_flags, del_flags) + if flags & SRE_FLAG_IGNORECASE and flags & SRE_FLAG_LOCALE: + return None + + iscased = _get_iscased(flags) + if op is LITERAL: + if iscased and iscased(av): + return None + return [(op, av)] + elif op is BRANCH: + charset = [] + charsetappend = charset.append + for p in av[1]: + if not p: + return None + op, av = p[0] + if op is LITERAL and not (iscased and iscased(av)): + charsetappend((op, av)) + else: + return None + return charset + elif op is IN: + charset = av + if iscased: + for op, av in charset: + if op is LITERAL: + if iscased(av): + return None + elif op is RANGE: + if av[1] > 0xffff: + return None + if any(map(iscased, range(av[0], av[1]+1))): + return None + return charset + return None + +def _compile_info(code, pattern, flags): + # internal: compile an info block. in the current version, + # this contains min/max pattern width, and an optional literal + # prefix or a character map + lo, hi = pattern.getwidth() + if hi > MAXCODE: + hi = MAXCODE + if lo == 0: + code.extend([INFO, 4, 0, lo, hi]) + return + # look for a literal prefix + prefix = [] + prefix_skip = 0 + charset = [] # not used + if not (flags & SRE_FLAG_IGNORECASE and flags & SRE_FLAG_LOCALE): + # look for literal prefix + prefix, prefix_skip, got_all = _get_literal_prefix(pattern, flags) + # if no prefix, look for charset prefix + if not prefix: + charset = _get_charset_prefix(pattern, flags) +## if prefix: +## print("*** PREFIX", prefix, prefix_skip) +## if charset: +## print("*** CHARSET", charset) + # add an info block + emit = code.append + emit(INFO) + skip = len(code); emit(0) + # literal flag + mask = 0 + if prefix: + mask = SRE_INFO_PREFIX + if prefix_skip is None and got_all: + mask = mask | SRE_INFO_LITERAL + elif charset: + mask = mask | SRE_INFO_CHARSET + emit(mask) + # pattern length + if lo < MAXCODE: + emit(lo) + else: + emit(MAXCODE) + prefix = prefix[:MAXCODE] + emit(hi) + # add literal prefix + if prefix: + emit(len(prefix)) # length + if prefix_skip is None: + prefix_skip = len(prefix) + emit(prefix_skip) # skip + code.extend(prefix) + # generate overlap table + code.extend(_generate_overlap_table(prefix)) + elif charset: + charset, hascased = _optimize_charset(charset) + assert not hascased + _compile_charset(charset, flags, code) + code[skip] = len(code) - skip + +def isstring(obj): + return isinstance(obj, (str, bytes)) + +def _code(p, flags): + + flags = p.state.flags | flags + code = [] + + # compile info block + _compile_info(code, p, flags) + + # compile the pattern + _compile(code, p.data, flags) + + code.append(SUCCESS) + + return code + +def _hex_code(code): + return '[%s]' % ', '.join('%#0*x' % (_sre.CODESIZE*2+2, x) for x in code) + +def dis(code): + import sys + + labels = set() + level = 0 + offset_width = len(str(len(code) - 1)) + + def dis_(start, end): + def print_(*args, to=None): + if to is not None: + labels.add(to) + args += ('(to %d)' % (to,),) + print('%*d%s ' % (offset_width, start, ':' if start in labels else '.'), + end=' '*(level-1)) + print(*args) + + def print_2(*args): + print(end=' '*(offset_width + 2*level)) + print(*args) + + nonlocal level + level += 1 + i = start + while i < end: + start = i + op = code[i] + i += 1 + op = OPCODES[op] + if op in (SUCCESS, FAILURE, ANY, ANY_ALL, + MAX_UNTIL, MIN_UNTIL, NEGATE): + print_(op) + elif op in (LITERAL, NOT_LITERAL, + LITERAL_IGNORE, NOT_LITERAL_IGNORE, + LITERAL_UNI_IGNORE, NOT_LITERAL_UNI_IGNORE, + LITERAL_LOC_IGNORE, NOT_LITERAL_LOC_IGNORE): + arg = code[i] + i += 1 + print_(op, '%#02x (%r)' % (arg, chr(arg))) + elif op is AT: + arg = code[i] + i += 1 + arg = str(ATCODES[arg]) + assert arg[:3] == 'AT_' + print_(op, arg[3:]) + elif op is CATEGORY: + arg = code[i] + i += 1 + arg = str(CHCODES[arg]) + assert arg[:9] == 'CATEGORY_' + print_(op, arg[9:]) + elif op in (IN, IN_IGNORE, IN_UNI_IGNORE, IN_LOC_IGNORE): + skip = code[i] + print_(op, skip, to=i+skip) + dis_(i+1, i+skip) + i += skip + elif op in (RANGE, RANGE_UNI_IGNORE): + lo, hi = code[i: i+2] + i += 2 + print_(op, '%#02x %#02x (%r-%r)' % (lo, hi, chr(lo), chr(hi))) + elif op is CHARSET: + print_(op, _hex_code(code[i: i + 256//_CODEBITS])) + i += 256//_CODEBITS + elif op is BIGCHARSET: + arg = code[i] + i += 1 + mapping = list(b''.join(x.to_bytes(_sre.CODESIZE, sys.byteorder) + for x in code[i: i + 256//_sre.CODESIZE])) + print_(op, arg, mapping) + i += 256//_sre.CODESIZE + level += 1 + for j in range(arg): + print_2(_hex_code(code[i: i + 256//_CODEBITS])) + i += 256//_CODEBITS + level -= 1 + elif op in (MARK, GROUPREF, GROUPREF_IGNORE, GROUPREF_UNI_IGNORE, + GROUPREF_LOC_IGNORE): + arg = code[i] + i += 1 + print_(op, arg) + elif op is JUMP: + skip = code[i] + print_(op, skip, to=i+skip) + i += 1 + elif op is BRANCH: + skip = code[i] + print_(op, skip, to=i+skip) + while skip: + dis_(i+1, i+skip) + i += skip + start = i + skip = code[i] + if skip: + print_('branch', skip, to=i+skip) + else: + print_(FAILURE) + i += 1 + elif op in (REPEAT, REPEAT_ONE, MIN_REPEAT_ONE, + POSSESSIVE_REPEAT, POSSESSIVE_REPEAT_ONE): + skip, min, max = code[i: i+3] + if max == MAXREPEAT: + max = 'MAXREPEAT' + print_(op, skip, min, max, to=i+skip) + dis_(i+3, i+skip) + i += skip + elif op is GROUPREF_EXISTS: + arg, skip = code[i: i+2] + print_(op, arg, skip, to=i+skip) + i += 2 + elif op in (ASSERT, ASSERT_NOT): + skip, arg = code[i: i+2] + print_(op, skip, arg, to=i+skip) + dis_(i+2, i+skip) + i += skip + elif op is ATOMIC_GROUP: + skip = code[i] + print_(op, skip, to=i+skip) + dis_(i+1, i+skip) + i += skip + elif op is INFO: + skip, flags, min, max = code[i: i+4] + if max == MAXREPEAT: + max = 'MAXREPEAT' + print_(op, skip, bin(flags), min, max, to=i+skip) + start = i+4 + if flags & SRE_INFO_PREFIX: + prefix_len, prefix_skip = code[i+4: i+6] + print_2(' prefix_skip', prefix_skip) + start = i + 6 + prefix = code[start: start+prefix_len] + print_2(' prefix', + '[%s]' % ', '.join('%#02x' % x for x in prefix), + '(%r)' % ''.join(map(chr, prefix))) + start += prefix_len + print_2(' overlap', code[start: start+prefix_len]) + start += prefix_len + if flags & SRE_INFO_CHARSET: + level += 1 + print_2('in') + dis_(start, i+skip) + level -= 1 + i += skip + else: + raise ValueError(op) + + level -= 1 + + dis_(0, len(code)) + + +def compile(p, flags=0): + # internal: convert pattern list to internal format + + if isstring(p): + pattern = p + p = _parser.parse(p, flags) + else: + pattern = None + + code = _code(p, flags) + + if flags & SRE_FLAG_DEBUG: + print() + dis(code) + + # map in either direction + groupindex = p.state.groupdict + indexgroup = [None] * p.state.groups + for k, i in groupindex.items(): + indexgroup[i] = k + + return _sre.compile( + pattern, flags | p.state.flags, code, + p.state.groups-1, + groupindex, tuple(indexgroup) + ) diff --git a/Lib/re/_constants.py b/Lib/re/_constants.py new file mode 100644 index 0000000000..d8e483ac4f --- /dev/null +++ b/Lib/re/_constants.py @@ -0,0 +1,219 @@ +# +# Secret Labs' Regular Expression Engine +# +# various symbols used by the regular expression engine. +# run this script to update the _sre include files! +# +# Copyright (c) 1998-2001 by Secret Labs AB. All rights reserved. +# +# See the __init__.py file for information on usage and redistribution. +# + +"""Internal support module for sre""" + +# update when constants are added or removed + +MAGIC = 20230612 + +from _sre import MAXREPEAT, MAXGROUPS + +# SRE standard exception (access as sre.error) +# should this really be here? + +class error(Exception): + """Exception raised for invalid regular expressions. + + Attributes: + + msg: The unformatted error message + pattern: The regular expression pattern + pos: The index in the pattern where compilation failed (may be None) + lineno: The line corresponding to pos (may be None) + colno: The column corresponding to pos (may be None) + """ + + __module__ = 're' + + def __init__(self, msg, pattern=None, pos=None): + self.msg = msg + self.pattern = pattern + self.pos = pos + if pattern is not None and pos is not None: + msg = '%s at position %d' % (msg, pos) + if isinstance(pattern, str): + newline = '\n' + else: + newline = b'\n' + self.lineno = pattern.count(newline, 0, pos) + 1 + self.colno = pos - pattern.rfind(newline, 0, pos) + if newline in pattern: + msg = '%s (line %d, column %d)' % (msg, self.lineno, self.colno) + else: + self.lineno = self.colno = None + super().__init__(msg) + + +class _NamedIntConstant(int): + def __new__(cls, value, name): + self = super(_NamedIntConstant, cls).__new__(cls, value) + self.name = name + return self + + def __repr__(self): + return self.name + + __reduce__ = None + +MAXREPEAT = _NamedIntConstant(MAXREPEAT, 'MAXREPEAT') + +def _makecodes(*names): + items = [_NamedIntConstant(i, name) for i, name in enumerate(names)] + globals().update({item.name: item for item in items}) + return items + +# operators +OPCODES = _makecodes( + # failure=0 success=1 (just because it looks better that way :-) + 'FAILURE', 'SUCCESS', + + 'ANY', 'ANY_ALL', + 'ASSERT', 'ASSERT_NOT', + 'AT', + 'BRANCH', + 'CATEGORY', + 'CHARSET', 'BIGCHARSET', + 'GROUPREF', 'GROUPREF_EXISTS', + 'IN', + 'INFO', + 'JUMP', + 'LITERAL', + 'MARK', + 'MAX_UNTIL', + 'MIN_UNTIL', + 'NOT_LITERAL', + 'NEGATE', + 'RANGE', + 'REPEAT', + 'REPEAT_ONE', + 'SUBPATTERN', + 'MIN_REPEAT_ONE', + 'ATOMIC_GROUP', + 'POSSESSIVE_REPEAT', + 'POSSESSIVE_REPEAT_ONE', + + 'GROUPREF_IGNORE', + 'IN_IGNORE', + 'LITERAL_IGNORE', + 'NOT_LITERAL_IGNORE', + + 'GROUPREF_LOC_IGNORE', + 'IN_LOC_IGNORE', + 'LITERAL_LOC_IGNORE', + 'NOT_LITERAL_LOC_IGNORE', + + 'GROUPREF_UNI_IGNORE', + 'IN_UNI_IGNORE', + 'LITERAL_UNI_IGNORE', + 'NOT_LITERAL_UNI_IGNORE', + 'RANGE_UNI_IGNORE', + + # The following opcodes are only occurred in the parser output, + # but not in the compiled code. + 'MIN_REPEAT', 'MAX_REPEAT', +) +del OPCODES[-2:] # remove MIN_REPEAT and MAX_REPEAT + +# positions +ATCODES = _makecodes( + 'AT_BEGINNING', 'AT_BEGINNING_LINE', 'AT_BEGINNING_STRING', + 'AT_BOUNDARY', 'AT_NON_BOUNDARY', + 'AT_END', 'AT_END_LINE', 'AT_END_STRING', + + 'AT_LOC_BOUNDARY', 'AT_LOC_NON_BOUNDARY', + + 'AT_UNI_BOUNDARY', 'AT_UNI_NON_BOUNDARY', +) + +# categories +CHCODES = _makecodes( + 'CATEGORY_DIGIT', 'CATEGORY_NOT_DIGIT', + 'CATEGORY_SPACE', 'CATEGORY_NOT_SPACE', + 'CATEGORY_WORD', 'CATEGORY_NOT_WORD', + 'CATEGORY_LINEBREAK', 'CATEGORY_NOT_LINEBREAK', + + 'CATEGORY_LOC_WORD', 'CATEGORY_LOC_NOT_WORD', + + 'CATEGORY_UNI_DIGIT', 'CATEGORY_UNI_NOT_DIGIT', + 'CATEGORY_UNI_SPACE', 'CATEGORY_UNI_NOT_SPACE', + 'CATEGORY_UNI_WORD', 'CATEGORY_UNI_NOT_WORD', + 'CATEGORY_UNI_LINEBREAK', 'CATEGORY_UNI_NOT_LINEBREAK', +) + + +# replacement operations for "ignore case" mode +OP_IGNORE = { + LITERAL: LITERAL_IGNORE, + NOT_LITERAL: NOT_LITERAL_IGNORE, +} + +OP_LOCALE_IGNORE = { + LITERAL: LITERAL_LOC_IGNORE, + NOT_LITERAL: NOT_LITERAL_LOC_IGNORE, +} + +OP_UNICODE_IGNORE = { + LITERAL: LITERAL_UNI_IGNORE, + NOT_LITERAL: NOT_LITERAL_UNI_IGNORE, +} + +AT_MULTILINE = { + AT_BEGINNING: AT_BEGINNING_LINE, + AT_END: AT_END_LINE +} + +AT_LOCALE = { + AT_BOUNDARY: AT_LOC_BOUNDARY, + AT_NON_BOUNDARY: AT_LOC_NON_BOUNDARY +} + +AT_UNICODE = { + AT_BOUNDARY: AT_UNI_BOUNDARY, + AT_NON_BOUNDARY: AT_UNI_NON_BOUNDARY +} + +CH_LOCALE = { + CATEGORY_DIGIT: CATEGORY_DIGIT, + CATEGORY_NOT_DIGIT: CATEGORY_NOT_DIGIT, + CATEGORY_SPACE: CATEGORY_SPACE, + CATEGORY_NOT_SPACE: CATEGORY_NOT_SPACE, + CATEGORY_WORD: CATEGORY_LOC_WORD, + CATEGORY_NOT_WORD: CATEGORY_LOC_NOT_WORD, + CATEGORY_LINEBREAK: CATEGORY_LINEBREAK, + CATEGORY_NOT_LINEBREAK: CATEGORY_NOT_LINEBREAK +} + +CH_UNICODE = { + CATEGORY_DIGIT: CATEGORY_UNI_DIGIT, + CATEGORY_NOT_DIGIT: CATEGORY_UNI_NOT_DIGIT, + CATEGORY_SPACE: CATEGORY_UNI_SPACE, + CATEGORY_NOT_SPACE: CATEGORY_UNI_NOT_SPACE, + CATEGORY_WORD: CATEGORY_UNI_WORD, + CATEGORY_NOT_WORD: CATEGORY_UNI_NOT_WORD, + CATEGORY_LINEBREAK: CATEGORY_UNI_LINEBREAK, + CATEGORY_NOT_LINEBREAK: CATEGORY_UNI_NOT_LINEBREAK +} + +# flags +SRE_FLAG_IGNORECASE = 2 # case insensitive +SRE_FLAG_LOCALE = 4 # honour system locale +SRE_FLAG_MULTILINE = 8 # treat target as multiline string +SRE_FLAG_DOTALL = 16 # treat target as a single string +SRE_FLAG_UNICODE = 32 # use unicode "locale" +SRE_FLAG_VERBOSE = 64 # ignore whitespace and comments +SRE_FLAG_DEBUG = 128 # debugging +SRE_FLAG_ASCII = 256 # use ascii "locale" + +# flags for INFO primitive +SRE_INFO_PREFIX = 1 # has prefix +SRE_INFO_LITERAL = 2 # entire pattern is literal (given by prefix) +SRE_INFO_CHARSET = 4 # pattern starts with character from given set diff --git a/Lib/re/_parser.py b/Lib/re/_parser.py new file mode 100644 index 0000000000..f3c779340f --- /dev/null +++ b/Lib/re/_parser.py @@ -0,0 +1,1081 @@ +# +# Secret Labs' Regular Expression Engine +# +# convert re-style regular expression to sre pattern +# +# Copyright (c) 1998-2001 by Secret Labs AB. All rights reserved. +# +# See the __init__.py file for information on usage and redistribution. +# + +"""Internal support module for sre""" + +# XXX: show string offset and offending character for all errors + +from ._constants import * + +SPECIAL_CHARS = ".\\[{()*+?^$|" +REPEAT_CHARS = "*+?{" + +DIGITS = frozenset("0123456789") + +OCTDIGITS = frozenset("01234567") +HEXDIGITS = frozenset("0123456789abcdefABCDEF") +ASCIILETTERS = frozenset("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +WHITESPACE = frozenset(" \t\n\r\v\f") + +_REPEATCODES = frozenset({MIN_REPEAT, MAX_REPEAT, POSSESSIVE_REPEAT}) +_UNITCODES = frozenset({ANY, RANGE, IN, LITERAL, NOT_LITERAL, CATEGORY}) + +ESCAPES = { + r"\a": (LITERAL, ord("\a")), + r"\b": (LITERAL, ord("\b")), + r"\f": (LITERAL, ord("\f")), + r"\n": (LITERAL, ord("\n")), + r"\r": (LITERAL, ord("\r")), + r"\t": (LITERAL, ord("\t")), + r"\v": (LITERAL, ord("\v")), + r"\\": (LITERAL, ord("\\")) +} + +CATEGORIES = { + r"\A": (AT, AT_BEGINNING_STRING), # start of string + r"\b": (AT, AT_BOUNDARY), + r"\B": (AT, AT_NON_BOUNDARY), + r"\d": (IN, [(CATEGORY, CATEGORY_DIGIT)]), + r"\D": (IN, [(CATEGORY, CATEGORY_NOT_DIGIT)]), + r"\s": (IN, [(CATEGORY, CATEGORY_SPACE)]), + r"\S": (IN, [(CATEGORY, CATEGORY_NOT_SPACE)]), + r"\w": (IN, [(CATEGORY, CATEGORY_WORD)]), + r"\W": (IN, [(CATEGORY, CATEGORY_NOT_WORD)]), + r"\Z": (AT, AT_END_STRING), # end of string +} + +FLAGS = { + # standard flags + "i": SRE_FLAG_IGNORECASE, + "L": SRE_FLAG_LOCALE, + "m": SRE_FLAG_MULTILINE, + "s": SRE_FLAG_DOTALL, + "x": SRE_FLAG_VERBOSE, + # extensions + "a": SRE_FLAG_ASCII, + "u": SRE_FLAG_UNICODE, +} + +TYPE_FLAGS = SRE_FLAG_ASCII | SRE_FLAG_LOCALE | SRE_FLAG_UNICODE +GLOBAL_FLAGS = SRE_FLAG_DEBUG + +# Maximal value returned by SubPattern.getwidth(). +# Must be larger than MAXREPEAT, MAXCODE and sys.maxsize. +MAXWIDTH = 1 << 64 + +class State: + # keeps track of state for parsing + def __init__(self): + self.flags = 0 + self.groupdict = {} + self.groupwidths = [None] # group 0 + self.lookbehindgroups = None + self.grouprefpos = {} + @property + def groups(self): + return len(self.groupwidths) + def opengroup(self, name=None): + gid = self.groups + self.groupwidths.append(None) + if self.groups > MAXGROUPS: + raise error("too many groups") + if name is not None: + ogid = self.groupdict.get(name, None) + if ogid is not None: + raise error("redefinition of group name %r as group %d; " + "was group %d" % (name, gid, ogid)) + self.groupdict[name] = gid + return gid + def closegroup(self, gid, p): + self.groupwidths[gid] = p.getwidth() + def checkgroup(self, gid): + return gid < self.groups and self.groupwidths[gid] is not None + + def checklookbehindgroup(self, gid, source): + if self.lookbehindgroups is not None: + if not self.checkgroup(gid): + raise source.error('cannot refer to an open group') + if gid >= self.lookbehindgroups: + raise source.error('cannot refer to group defined in the same ' + 'lookbehind subpattern') + +class SubPattern: + # a subpattern, in intermediate form + def __init__(self, state, data=None): + self.state = state + if data is None: + data = [] + self.data = data + self.width = None + + def dump(self, level=0): + seqtypes = (tuple, list) + for op, av in self.data: + print(level*" " + str(op), end='') + if op is IN: + # member sublanguage + print() + for op, a in av: + print((level+1)*" " + str(op), a) + elif op is BRANCH: + print() + for i, a in enumerate(av[1]): + if i: + print(level*" " + "OR") + a.dump(level+1) + elif op is GROUPREF_EXISTS: + condgroup, item_yes, item_no = av + print('', condgroup) + item_yes.dump(level+1) + if item_no: + print(level*" " + "ELSE") + item_no.dump(level+1) + elif isinstance(av, SubPattern): + print() + av.dump(level+1) + elif isinstance(av, seqtypes): + nl = False + for a in av: + if isinstance(a, SubPattern): + if not nl: + print() + a.dump(level+1) + nl = True + else: + if not nl: + print(' ', end='') + print(a, end='') + nl = False + if not nl: + print() + else: + print('', av) + def __repr__(self): + return repr(self.data) + def __len__(self): + return len(self.data) + def __delitem__(self, index): + del self.data[index] + def __getitem__(self, index): + if isinstance(index, slice): + return SubPattern(self.state, self.data[index]) + return self.data[index] + def __setitem__(self, index, code): + self.data[index] = code + def insert(self, index, code): + self.data.insert(index, code) + def append(self, code): + self.data.append(code) + def getwidth(self): + # determine the width (min, max) for this subpattern + if self.width is not None: + return self.width + lo = hi = 0 + for op, av in self.data: + if op is BRANCH: + i = MAXWIDTH + j = 0 + for av in av[1]: + l, h = av.getwidth() + i = min(i, l) + j = max(j, h) + lo = lo + i + hi = hi + j + elif op is ATOMIC_GROUP: + i, j = av.getwidth() + lo = lo + i + hi = hi + j + elif op is SUBPATTERN: + i, j = av[-1].getwidth() + lo = lo + i + hi = hi + j + elif op in _REPEATCODES: + i, j = av[2].getwidth() + lo = lo + i * av[0] + if av[1] == MAXREPEAT and j: + hi = MAXWIDTH + else: + hi = hi + j * av[1] + elif op in _UNITCODES: + lo = lo + 1 + hi = hi + 1 + elif op is GROUPREF: + i, j = self.state.groupwidths[av] + lo = lo + i + hi = hi + j + elif op is GROUPREF_EXISTS: + i, j = av[1].getwidth() + if av[2] is not None: + l, h = av[2].getwidth() + i = min(i, l) + j = max(j, h) + else: + i = 0 + lo = lo + i + hi = hi + j + elif op is SUCCESS: + break + self.width = min(lo, MAXWIDTH), min(hi, MAXWIDTH) + return self.width + +class Tokenizer: + def __init__(self, string): + self.istext = isinstance(string, str) + self.string = string + if not self.istext: + string = str(string, 'latin1') + self.decoded_string = string + self.index = 0 + self.next = None + self.__next() + def __next(self): + index = self.index + try: + char = self.decoded_string[index] + except IndexError: + self.next = None + return + if char == "\\": + index += 1 + try: + char += self.decoded_string[index] + except IndexError: + raise error("bad escape (end of pattern)", + self.string, len(self.string) - 1) from None + self.index = index + 1 + self.next = char + def match(self, char): + if char == self.next: + self.__next() + return True + return False + def get(self): + this = self.next + self.__next() + return this + def getwhile(self, n, charset): + result = '' + for _ in range(n): + c = self.next + if c not in charset: + break + result += c + self.__next() + return result + def getuntil(self, terminator, name): + result = '' + while True: + c = self.next + self.__next() + if c is None: + if not result: + raise self.error("missing " + name) + raise self.error("missing %s, unterminated name" % terminator, + len(result)) + if c == terminator: + if not result: + raise self.error("missing " + name, 1) + break + result += c + return result + @property + def pos(self): + return self.index - len(self.next or '') + def tell(self): + return self.index - len(self.next or '') + def seek(self, index): + self.index = index + self.__next() + + def error(self, msg, offset=0): + if not self.istext: + msg = msg.encode('ascii', 'backslashreplace').decode('ascii') + return error(msg, self.string, self.tell() - offset) + + def checkgroupname(self, name, offset): + if not (self.istext or name.isascii()): + msg = "bad character in group name %a" % name + raise self.error(msg, len(name) + offset) + if not name.isidentifier(): + msg = "bad character in group name %r" % name + raise self.error(msg, len(name) + offset) + +def _class_escape(source, escape): + # handle escape code inside character class + code = ESCAPES.get(escape) + if code: + return code + code = CATEGORIES.get(escape) + if code and code[0] is IN: + return code + try: + c = escape[1:2] + if c == "x": + # hexadecimal escape (exactly two digits) + escape += source.getwhile(2, HEXDIGITS) + if len(escape) != 4: + raise source.error("incomplete escape %s" % escape, len(escape)) + return LITERAL, int(escape[2:], 16) + elif c == "u" and source.istext: + # unicode escape (exactly four digits) + escape += source.getwhile(4, HEXDIGITS) + if len(escape) != 6: + raise source.error("incomplete escape %s" % escape, len(escape)) + return LITERAL, int(escape[2:], 16) + elif c == "U" and source.istext: + # unicode escape (exactly eight digits) + escape += source.getwhile(8, HEXDIGITS) + if len(escape) != 10: + raise source.error("incomplete escape %s" % escape, len(escape)) + c = int(escape[2:], 16) + chr(c) # raise ValueError for invalid code + return LITERAL, c + elif c == "N" and source.istext: + import unicodedata + # named unicode escape e.g. \N{EM DASH} + if not source.match('{'): + raise source.error("missing {") + charname = source.getuntil('}', 'character name') + try: + c = ord(unicodedata.lookup(charname)) + except (KeyError, TypeError): + raise source.error("undefined character name %r" % charname, + len(charname) + len(r'\N{}')) from None + return LITERAL, c + elif c in OCTDIGITS: + # octal escape (up to three digits) + escape += source.getwhile(2, OCTDIGITS) + c = int(escape[1:], 8) + if c > 0o377: + raise source.error('octal escape value %s outside of ' + 'range 0-0o377' % escape, len(escape)) + return LITERAL, c + elif c in DIGITS: + raise ValueError + if len(escape) == 2: + if c in ASCIILETTERS: + raise source.error('bad escape %s' % escape, len(escape)) + return LITERAL, ord(escape[1]) + except ValueError: + pass + raise source.error("bad escape %s" % escape, len(escape)) + +def _escape(source, escape, state): + # handle escape code in expression + code = CATEGORIES.get(escape) + if code: + return code + code = ESCAPES.get(escape) + if code: + return code + try: + c = escape[1:2] + if c == "x": + # hexadecimal escape + escape += source.getwhile(2, HEXDIGITS) + if len(escape) != 4: + raise source.error("incomplete escape %s" % escape, len(escape)) + return LITERAL, int(escape[2:], 16) + elif c == "u" and source.istext: + # unicode escape (exactly four digits) + escape += source.getwhile(4, HEXDIGITS) + if len(escape) != 6: + raise source.error("incomplete escape %s" % escape, len(escape)) + return LITERAL, int(escape[2:], 16) + elif c == "U" and source.istext: + # unicode escape (exactly eight digits) + escape += source.getwhile(8, HEXDIGITS) + if len(escape) != 10: + raise source.error("incomplete escape %s" % escape, len(escape)) + c = int(escape[2:], 16) + chr(c) # raise ValueError for invalid code + return LITERAL, c + elif c == "N" and source.istext: + import unicodedata + # named unicode escape e.g. \N{EM DASH} + if not source.match('{'): + raise source.error("missing {") + charname = source.getuntil('}', 'character name') + try: + c = ord(unicodedata.lookup(charname)) + except (KeyError, TypeError): + raise source.error("undefined character name %r" % charname, + len(charname) + len(r'\N{}')) from None + return LITERAL, c + elif c == "0": + # octal escape + escape += source.getwhile(2, OCTDIGITS) + return LITERAL, int(escape[1:], 8) + elif c in DIGITS: + # octal escape *or* decimal group reference (sigh) + if source.next in DIGITS: + escape += source.get() + if (escape[1] in OCTDIGITS and escape[2] in OCTDIGITS and + source.next in OCTDIGITS): + # got three octal digits; this is an octal escape + escape += source.get() + c = int(escape[1:], 8) + if c > 0o377: + raise source.error('octal escape value %s outside of ' + 'range 0-0o377' % escape, + len(escape)) + return LITERAL, c + # not an octal escape, so this is a group reference + group = int(escape[1:]) + if group < state.groups: + if not state.checkgroup(group): + raise source.error("cannot refer to an open group", + len(escape)) + state.checklookbehindgroup(group, source) + return GROUPREF, group + raise source.error("invalid group reference %d" % group, len(escape) - 1) + if len(escape) == 2: + if c in ASCIILETTERS: + raise source.error("bad escape %s" % escape, len(escape)) + return LITERAL, ord(escape[1]) + except ValueError: + pass + raise source.error("bad escape %s" % escape, len(escape)) + +def _uniq(items): + return list(dict.fromkeys(items)) + +def _parse_sub(source, state, verbose, nested): + # parse an alternation: a|b|c + + items = [] + itemsappend = items.append + sourcematch = source.match + start = source.tell() + while True: + itemsappend(_parse(source, state, verbose, nested + 1, + not nested and not items)) + if not sourcematch("|"): + break + if not nested: + verbose = state.flags & SRE_FLAG_VERBOSE + + if len(items) == 1: + return items[0] + + subpattern = SubPattern(state) + + # check if all items share a common prefix + while True: + prefix = None + for item in items: + if not item: + break + if prefix is None: + prefix = item[0] + elif item[0] != prefix: + break + else: + # all subitems start with a common "prefix". + # move it out of the branch + for item in items: + del item[0] + subpattern.append(prefix) + continue # check next one + break + + # check if the branch can be replaced by a character set + set = [] + for item in items: + if len(item) != 1: + break + op, av = item[0] + if op is LITERAL: + set.append((op, av)) + elif op is IN and av[0][0] is not NEGATE: + set.extend(av) + else: + break + else: + # we can store this as a character set instead of a + # branch (the compiler may optimize this even more) + subpattern.append((IN, _uniq(set))) + return subpattern + + subpattern.append((BRANCH, (None, items))) + return subpattern + +def _parse(source, state, verbose, nested, first=False): + # parse a simple pattern + subpattern = SubPattern(state) + + # precompute constants into local variables + subpatternappend = subpattern.append + sourceget = source.get + sourcematch = source.match + _len = len + _ord = ord + + while True: + + this = source.next + if this is None: + break # end of pattern + if this in "|)": + break # end of subpattern + sourceget() + + if verbose: + # skip whitespace and comments + if this in WHITESPACE: + continue + if this == "#": + while True: + this = sourceget() + if this is None or this == "\n": + break + continue + + if this[0] == "\\": + code = _escape(source, this, state) + subpatternappend(code) + + elif this not in SPECIAL_CHARS: + subpatternappend((LITERAL, _ord(this))) + + elif this == "[": + here = source.tell() - 1 + # character set + set = [] + setappend = set.append +## if sourcematch(":"): +## pass # handle character classes + if source.next == '[': + import warnings + warnings.warn( + 'Possible nested set at position %d' % source.tell(), + FutureWarning, stacklevel=nested + 6 + ) + negate = sourcematch("^") + # check remaining characters + while True: + this = sourceget() + if this is None: + raise source.error("unterminated character set", + source.tell() - here) + if this == "]" and set: + break + elif this[0] == "\\": + code1 = _class_escape(source, this) + else: + if set and this in '-&~|' and source.next == this: + import warnings + warnings.warn( + 'Possible set %s at position %d' % ( + 'difference' if this == '-' else + 'intersection' if this == '&' else + 'symmetric difference' if this == '~' else + 'union', + source.tell() - 1), + FutureWarning, stacklevel=nested + 6 + ) + code1 = LITERAL, _ord(this) + if sourcematch("-"): + # potential range + that = sourceget() + if that is None: + raise source.error("unterminated character set", + source.tell() - here) + if that == "]": + if code1[0] is IN: + code1 = code1[1][0] + setappend(code1) + setappend((LITERAL, _ord("-"))) + break + if that[0] == "\\": + code2 = _class_escape(source, that) + else: + if that == '-': + import warnings + warnings.warn( + 'Possible set difference at position %d' % ( + source.tell() - 2), + FutureWarning, stacklevel=nested + 6 + ) + code2 = LITERAL, _ord(that) + if code1[0] != LITERAL or code2[0] != LITERAL: + msg = "bad character range %s-%s" % (this, that) + raise source.error(msg, len(this) + 1 + len(that)) + lo = code1[1] + hi = code2[1] + if hi < lo: + msg = "bad character range %s-%s" % (this, that) + raise source.error(msg, len(this) + 1 + len(that)) + setappend((RANGE, (lo, hi))) + else: + if code1[0] is IN: + code1 = code1[1][0] + setappend(code1) + + set = _uniq(set) + # XXX: should move set optimization to compiler! + if _len(set) == 1 and set[0][0] is LITERAL: + # optimization + if negate: + subpatternappend((NOT_LITERAL, set[0][1])) + else: + subpatternappend(set[0]) + else: + if negate: + set.insert(0, (NEGATE, None)) + # charmap optimization can't be added here because + # global flags still are not known + subpatternappend((IN, set)) + + elif this in REPEAT_CHARS: + # repeat previous item + here = source.tell() + if this == "?": + min, max = 0, 1 + elif this == "*": + min, max = 0, MAXREPEAT + + elif this == "+": + min, max = 1, MAXREPEAT + elif this == "{": + if source.next == "}": + subpatternappend((LITERAL, _ord(this))) + continue + + min, max = 0, MAXREPEAT + lo = hi = "" + while source.next in DIGITS: + lo += sourceget() + if sourcematch(","): + while source.next in DIGITS: + hi += sourceget() + else: + hi = lo + if not sourcematch("}"): + subpatternappend((LITERAL, _ord(this))) + source.seek(here) + continue + + if lo: + min = int(lo) + if min >= MAXREPEAT: + raise OverflowError("the repetition number is too large") + if hi: + max = int(hi) + if max >= MAXREPEAT: + raise OverflowError("the repetition number is too large") + if max < min: + raise source.error("min repeat greater than max repeat", + source.tell() - here) + else: + raise AssertionError("unsupported quantifier %r" % (char,)) + # figure out which item to repeat + if subpattern: + item = subpattern[-1:] + else: + item = None + if not item or item[0][0] is AT: + raise source.error("nothing to repeat", + source.tell() - here + len(this)) + if item[0][0] in _REPEATCODES: + raise source.error("multiple repeat", + source.tell() - here + len(this)) + if item[0][0] is SUBPATTERN: + group, add_flags, del_flags, p = item[0][1] + if group is None and not add_flags and not del_flags: + item = p + if sourcematch("?"): + # Non-Greedy Match + subpattern[-1] = (MIN_REPEAT, (min, max, item)) + elif sourcematch("+"): + # Possessive Match (Always Greedy) + subpattern[-1] = (POSSESSIVE_REPEAT, (min, max, item)) + else: + # Greedy Match + subpattern[-1] = (MAX_REPEAT, (min, max, item)) + + elif this == ".": + subpatternappend((ANY, None)) + + elif this == "(": + start = source.tell() - 1 + capture = True + atomic = False + name = None + add_flags = 0 + del_flags = 0 + if sourcematch("?"): + # options + char = sourceget() + if char is None: + raise source.error("unexpected end of pattern") + if char == "P": + # python extensions + if sourcematch("<"): + # named group: skip forward to end of name + name = source.getuntil(">", "group name") + source.checkgroupname(name, 1) + elif sourcematch("="): + # named backreference + name = source.getuntil(")", "group name") + source.checkgroupname(name, 1) + gid = state.groupdict.get(name) + if gid is None: + msg = "unknown group name %r" % name + raise source.error(msg, len(name) + 1) + if not state.checkgroup(gid): + raise source.error("cannot refer to an open group", + len(name) + 1) + state.checklookbehindgroup(gid, source) + subpatternappend((GROUPREF, gid)) + continue + + else: + char = sourceget() + if char is None: + raise source.error("unexpected end of pattern") + raise source.error("unknown extension ?P" + char, + len(char) + 2) + elif char == ":": + # non-capturing group + capture = False + elif char == "#": + # comment + while True: + if source.next is None: + raise source.error("missing ), unterminated comment", + source.tell() - start) + if sourceget() == ")": + break + continue + + elif char in "=!<": + # lookahead assertions + dir = 1 + if char == "<": + char = sourceget() + if char is None: + raise source.error("unexpected end of pattern") + if char not in "=!": + raise source.error("unknown extension ?<" + char, + len(char) + 2) + dir = -1 # lookbehind + lookbehindgroups = state.lookbehindgroups + if lookbehindgroups is None: + state.lookbehindgroups = state.groups + p = _parse_sub(source, state, verbose, nested + 1) + if dir < 0: + if lookbehindgroups is None: + state.lookbehindgroups = None + if not sourcematch(")"): + raise source.error("missing ), unterminated subpattern", + source.tell() - start) + if char == "=": + subpatternappend((ASSERT, (dir, p))) + elif p: + subpatternappend((ASSERT_NOT, (dir, p))) + else: + subpatternappend((FAILURE, ())) + continue + + elif char == "(": + # conditional backreference group + condname = source.getuntil(")", "group name") + if not (condname.isdecimal() and condname.isascii()): + source.checkgroupname(condname, 1) + condgroup = state.groupdict.get(condname) + if condgroup is None: + msg = "unknown group name %r" % condname + raise source.error(msg, len(condname) + 1) + else: + condgroup = int(condname) + if not condgroup: + raise source.error("bad group number", + len(condname) + 1) + if condgroup >= MAXGROUPS: + msg = "invalid group reference %d" % condgroup + raise source.error(msg, len(condname) + 1) + if condgroup not in state.grouprefpos: + state.grouprefpos[condgroup] = ( + source.tell() - len(condname) - 1 + ) + if not (condname.isdecimal() and condname.isascii()): + import warnings + warnings.warn( + "bad character in group name %s at position %d" % + (repr(condname) if source.istext else ascii(condname), + source.tell() - len(condname) - 1), + DeprecationWarning, stacklevel=nested + 6 + ) + state.checklookbehindgroup(condgroup, source) + item_yes = _parse(source, state, verbose, nested + 1) + if source.match("|"): + item_no = _parse(source, state, verbose, nested + 1) + if source.next == "|": + raise source.error("conditional backref with more than two branches") + else: + item_no = None + if not source.match(")"): + raise source.error("missing ), unterminated subpattern", + source.tell() - start) + subpatternappend((GROUPREF_EXISTS, (condgroup, item_yes, item_no))) + continue + + elif char == ">": + # non-capturing, atomic group + capture = False + atomic = True + elif char in FLAGS or char == "-": + # flags + flags = _parse_flags(source, state, char) + if flags is None: # global flags + if not first or subpattern: + raise source.error('global flags not at the start ' + 'of the expression', + source.tell() - start) + verbose = state.flags & SRE_FLAG_VERBOSE + continue + + add_flags, del_flags = flags + capture = False + else: + raise source.error("unknown extension ?" + char, + len(char) + 1) + + # parse group contents + if capture: + try: + group = state.opengroup(name) + except error as err: + raise source.error(err.msg, len(name) + 1) from None + else: + group = None + sub_verbose = ((verbose or (add_flags & SRE_FLAG_VERBOSE)) and + not (del_flags & SRE_FLAG_VERBOSE)) + p = _parse_sub(source, state, sub_verbose, nested + 1) + if not source.match(")"): + raise source.error("missing ), unterminated subpattern", + source.tell() - start) + if group is not None: + state.closegroup(group, p) + if atomic: + assert group is None + subpatternappend((ATOMIC_GROUP, p)) + else: + subpatternappend((SUBPATTERN, (group, add_flags, del_flags, p))) + + elif this == "^": + subpatternappend((AT, AT_BEGINNING)) + + elif this == "$": + subpatternappend((AT, AT_END)) + + else: + raise AssertionError("unsupported special character %r" % (char,)) + + # unpack non-capturing groups + for i in range(len(subpattern))[::-1]: + op, av = subpattern[i] + if op is SUBPATTERN: + group, add_flags, del_flags, p = av + if group is None and not add_flags and not del_flags: + subpattern[i: i+1] = p + + return subpattern + +def _parse_flags(source, state, char): + sourceget = source.get + add_flags = 0 + del_flags = 0 + if char != "-": + while True: + flag = FLAGS[char] + if source.istext: + if char == 'L': + msg = "bad inline flags: cannot use 'L' flag with a str pattern" + raise source.error(msg) + else: + if char == 'u': + msg = "bad inline flags: cannot use 'u' flag with a bytes pattern" + raise source.error(msg) + add_flags |= flag + if (flag & TYPE_FLAGS) and (add_flags & TYPE_FLAGS) != flag: + msg = "bad inline flags: flags 'a', 'u' and 'L' are incompatible" + raise source.error(msg) + char = sourceget() + if char is None: + raise source.error("missing -, : or )") + if char in ")-:": + break + if char not in FLAGS: + msg = "unknown flag" if char.isalpha() else "missing -, : or )" + raise source.error(msg, len(char)) + if char == ")": + state.flags |= add_flags + return None + if add_flags & GLOBAL_FLAGS: + raise source.error("bad inline flags: cannot turn on global flag", 1) + if char == "-": + char = sourceget() + if char is None: + raise source.error("missing flag") + if char not in FLAGS: + msg = "unknown flag" if char.isalpha() else "missing flag" + raise source.error(msg, len(char)) + while True: + flag = FLAGS[char] + if flag & TYPE_FLAGS: + msg = "bad inline flags: cannot turn off flags 'a', 'u' and 'L'" + raise source.error(msg) + del_flags |= flag + char = sourceget() + if char is None: + raise source.error("missing :") + if char == ":": + break + if char not in FLAGS: + msg = "unknown flag" if char.isalpha() else "missing :" + raise source.error(msg, len(char)) + assert char == ":" + if del_flags & GLOBAL_FLAGS: + raise source.error("bad inline flags: cannot turn off global flag", 1) + if add_flags & del_flags: + raise source.error("bad inline flags: flag turned on and off", 1) + return add_flags, del_flags + +def fix_flags(src, flags): + # Check and fix flags according to the type of pattern (str or bytes) + if isinstance(src, str): + if flags & SRE_FLAG_LOCALE: + raise ValueError("cannot use LOCALE flag with a str pattern") + if not flags & SRE_FLAG_ASCII: + flags |= SRE_FLAG_UNICODE + elif flags & SRE_FLAG_UNICODE: + raise ValueError("ASCII and UNICODE flags are incompatible") + else: + if flags & SRE_FLAG_UNICODE: + raise ValueError("cannot use UNICODE flag with a bytes pattern") + if flags & SRE_FLAG_LOCALE and flags & SRE_FLAG_ASCII: + raise ValueError("ASCII and LOCALE flags are incompatible") + return flags + +def parse(str, flags=0, state=None): + # parse 're' pattern into list of (opcode, argument) tuples + + source = Tokenizer(str) + + if state is None: + state = State() + state.flags = flags + state.str = str + + p = _parse_sub(source, state, flags & SRE_FLAG_VERBOSE, 0) + p.state.flags = fix_flags(str, p.state.flags) + + if source.next is not None: + assert source.next == ")" + raise source.error("unbalanced parenthesis") + + for g in p.state.grouprefpos: + if g >= p.state.groups: + msg = "invalid group reference %d" % g + raise error(msg, str, p.state.grouprefpos[g]) + + if flags & SRE_FLAG_DEBUG: + p.dump() + + return p + +def parse_template(source, pattern): + # parse 're' replacement string into list of literals and + # group references + s = Tokenizer(source) + sget = s.get + result = [] + literal = [] + lappend = literal.append + def addliteral(): + if s.istext: + result.append(''.join(literal)) + else: + # The tokenizer implicitly decodes bytes objects as latin-1, we must + # therefore re-encode the final representation. + result.append(''.join(literal).encode('latin-1')) + del literal[:] + def addgroup(index, pos): + if index > pattern.groups: + raise s.error("invalid group reference %d" % index, pos) + addliteral() + result.append(index) + groupindex = pattern.groupindex + while True: + this = sget() + if this is None: + break # end of replacement string + if this[0] == "\\": + # group + c = this[1] + if c == "g": + if not s.match("<"): + raise s.error("missing <") + name = s.getuntil(">", "group name") + if not (name.isdecimal() and name.isascii()): + s.checkgroupname(name, 1) + try: + index = groupindex[name] + except KeyError: + raise IndexError("unknown group name %r" % name) from None + else: + index = int(name) + if index >= MAXGROUPS: + raise s.error("invalid group reference %d" % index, + len(name) + 1) + if not (name.isdecimal() and name.isascii()): + import warnings + warnings.warn( + "bad character in group name %s at position %d" % + (repr(name) if s.istext else ascii(name), + s.tell() - len(name) - 1), + DeprecationWarning, stacklevel=5 + ) + addgroup(index, len(name) + 1) + elif c == "0": + if s.next in OCTDIGITS: + this += sget() + if s.next in OCTDIGITS: + this += sget() + lappend(chr(int(this[1:], 8) & 0xff)) + elif c in DIGITS: + isoctal = False + if s.next in DIGITS: + this += sget() + if (c in OCTDIGITS and this[2] in OCTDIGITS and + s.next in OCTDIGITS): + this += sget() + isoctal = True + c = int(this[1:], 8) + if c > 0o377: + raise s.error('octal escape value %s outside of ' + 'range 0-0o377' % this, len(this)) + lappend(chr(c)) + if not isoctal: + addgroup(int(this[1:]), len(this) - 1) + else: + try: + this = chr(ESCAPES[this][1]) + except KeyError: + if c in ASCIILETTERS: + raise s.error('bad escape %s' % this, len(this)) from None + lappend(this) + else: + lappend(this) + addliteral() + return result From ebe555203a09cf55c981177c9b1cbd6489076a19 Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sat, 18 Nov 2023 20:47:42 +0200 Subject: [PATCH 304/893] Replace Lib/sre_* from CPython --- Lib/re/_compiler.py | 3 + Lib/re/_constants.py | 4 +- Lib/re/_parser.py | 7 +- Lib/sre_compile.py | 789 +------------------------------ Lib/sre_constants.py | 220 +-------- Lib/sre_parse.py | 1069 +----------------------------------------- 6 files changed, 27 insertions(+), 2065 deletions(-) diff --git a/Lib/re/_compiler.py b/Lib/re/_compiler.py index f87712d6d6..861bbdb130 100644 --- a/Lib/re/_compiler.py +++ b/Lib/re/_compiler.py @@ -101,6 +101,8 @@ def _compile(code, pattern, flags): else: emit(ANY) elif op in REPEATING_CODES: + if flags & SRE_FLAG_TEMPLATE: + raise error("internal: unsupported template operator %r" % (op,)) if _simple(av[2]): emit(REPEATING_CODES[op][2]) skip = _len(code); emit(0) @@ -761,3 +763,4 @@ def compile(p, flags=0): p.state.groups-1, groupindex, tuple(indexgroup) ) + diff --git a/Lib/re/_constants.py b/Lib/re/_constants.py index d8e483ac4f..92494e385c 100644 --- a/Lib/re/_constants.py +++ b/Lib/re/_constants.py @@ -13,7 +13,7 @@ # update when constants are added or removed -MAGIC = 20230612 +MAGIC = 20221023 from _sre import MAXREPEAT, MAXGROUPS @@ -204,6 +204,7 @@ def _makecodes(*names): } # flags +SRE_FLAG_TEMPLATE = 1 # template mode (unknown purpose, deprecated) SRE_FLAG_IGNORECASE = 2 # case insensitive SRE_FLAG_LOCALE = 4 # honour system locale SRE_FLAG_MULTILINE = 8 # treat target as multiline string @@ -217,3 +218,4 @@ def _makecodes(*names): SRE_INFO_PREFIX = 1 # has prefix SRE_INFO_LITERAL = 2 # entire pattern is literal (given by prefix) SRE_INFO_CHARSET = 4 # pattern starts with character from given set +RE_INFO_CHARSET = 4 # pattern starts with character from given set diff --git a/Lib/re/_parser.py b/Lib/re/_parser.py index f3c779340f..4a492b79e8 100644 --- a/Lib/re/_parser.py +++ b/Lib/re/_parser.py @@ -61,11 +61,12 @@ "x": SRE_FLAG_VERBOSE, # extensions "a": SRE_FLAG_ASCII, + "t": SRE_FLAG_TEMPLATE, "u": SRE_FLAG_UNICODE, } TYPE_FLAGS = SRE_FLAG_ASCII | SRE_FLAG_LOCALE | SRE_FLAG_UNICODE -GLOBAL_FLAGS = SRE_FLAG_DEBUG +GLOBAL_FLAGS = SRE_FLAG_DEBUG | SRE_FLAG_TEMPLATE # Maximal value returned by SubPattern.getwidth(). # Must be larger than MAXREPEAT, MAXCODE and sys.maxsize. @@ -780,10 +781,8 @@ def _parse(source, state, verbose, nested, first=False): source.tell() - start) if char == "=": subpatternappend((ASSERT, (dir, p))) - elif p: - subpatternappend((ASSERT_NOT, (dir, p))) else: - subpatternappend((FAILURE, ())) + subpatternappend((ASSERT_NOT, (dir, p))) continue elif char == "(": diff --git a/Lib/sre_compile.py b/Lib/sre_compile.py index c6398bfb83..f9da61e648 100644 --- a/Lib/sre_compile.py +++ b/Lib/sre_compile.py @@ -1,784 +1,7 @@ -# -# Secret Labs' Regular Expression Engine -# -# convert template to internal format -# -# Copyright (c) 1997-2001 by Secret Labs AB. All rights reserved. -# -# See the sre.py file for information on usage and redistribution. -# +import warnings +warnings.warn(f"module {__name__!r} is deprecated", + DeprecationWarning, + stacklevel=2) -"""Internal support module for sre""" - -import _sre -import sre_parse -from sre_constants import * - -assert _sre.MAGIC == MAGIC, "SRE module mismatch" - -_LITERAL_CODES = {LITERAL, NOT_LITERAL} -_REPEATING_CODES = {REPEAT, MIN_REPEAT, MAX_REPEAT} -_SUCCESS_CODES = {SUCCESS, FAILURE} -_ASSERT_CODES = {ASSERT, ASSERT_NOT} -_UNIT_CODES = _LITERAL_CODES | {ANY, IN} - -# Sets of lowercase characters which have the same uppercase. -_equivalences = ( - # LATIN SMALL LETTER I, LATIN SMALL LETTER DOTLESS I - (0x69, 0x131), # iı - # LATIN SMALL LETTER S, LATIN SMALL LETTER LONG S - (0x73, 0x17f), # sſ - # MICRO SIGN, GREEK SMALL LETTER MU - (0xb5, 0x3bc), # µμ - # COMBINING GREEK YPOGEGRAMMENI, GREEK SMALL LETTER IOTA, GREEK PROSGEGRAMMENI - (0x345, 0x3b9, 0x1fbe), # \u0345ιι - # GREEK SMALL LETTER IOTA WITH DIALYTIKA AND TONOS, GREEK SMALL LETTER IOTA WITH DIALYTIKA AND OXIA - (0x390, 0x1fd3), # ΐΐ - # GREEK SMALL LETTER UPSILON WITH DIALYTIKA AND TONOS, GREEK SMALL LETTER UPSILON WITH DIALYTIKA AND OXIA - (0x3b0, 0x1fe3), # ΰΰ - # GREEK SMALL LETTER BETA, GREEK BETA SYMBOL - (0x3b2, 0x3d0), # βϐ - # GREEK SMALL LETTER EPSILON, GREEK LUNATE EPSILON SYMBOL - (0x3b5, 0x3f5), # εϵ - # GREEK SMALL LETTER THETA, GREEK THETA SYMBOL - (0x3b8, 0x3d1), # θϑ - # GREEK SMALL LETTER KAPPA, GREEK KAPPA SYMBOL - (0x3ba, 0x3f0), # κϰ - # GREEK SMALL LETTER PI, GREEK PI SYMBOL - (0x3c0, 0x3d6), # πϖ - # GREEK SMALL LETTER RHO, GREEK RHO SYMBOL - (0x3c1, 0x3f1), # ρϱ - # GREEK SMALL LETTER FINAL SIGMA, GREEK SMALL LETTER SIGMA - (0x3c2, 0x3c3), # ςσ - # GREEK SMALL LETTER PHI, GREEK PHI SYMBOL - (0x3c6, 0x3d5), # φϕ - # LATIN SMALL LETTER S WITH DOT ABOVE, LATIN SMALL LETTER LONG S WITH DOT ABOVE - (0x1e61, 0x1e9b), # ṡẛ - # LATIN SMALL LIGATURE LONG S T, LATIN SMALL LIGATURE ST - (0xfb05, 0xfb06), # ſtst -) - -# Maps the lowercase code to lowercase codes which have the same uppercase. -_ignorecase_fixes = {i: tuple(j for j in t if i != j) - for t in _equivalences for i in t} - -def _combine_flags(flags, add_flags, del_flags, - TYPE_FLAGS=sre_parse.TYPE_FLAGS): - if add_flags & TYPE_FLAGS: - flags &= ~TYPE_FLAGS - return (flags | add_flags) & ~del_flags - -def _compile(code, pattern, flags): - # internal: compile a (sub)pattern - emit = code.append - _len = len - LITERAL_CODES = _LITERAL_CODES - REPEATING_CODES = _REPEATING_CODES - SUCCESS_CODES = _SUCCESS_CODES - ASSERT_CODES = _ASSERT_CODES - iscased = None - tolower = None - fixes = None - if flags & SRE_FLAG_IGNORECASE and not flags & SRE_FLAG_LOCALE: - if flags & SRE_FLAG_UNICODE: - iscased = _sre.unicode_iscased - tolower = _sre.unicode_tolower - fixes = _ignorecase_fixes - else: - iscased = _sre.ascii_iscased - tolower = _sre.ascii_tolower - for op, av in pattern: - if op in LITERAL_CODES: - if not flags & SRE_FLAG_IGNORECASE: - emit(op) - emit(av) - elif flags & SRE_FLAG_LOCALE: - emit(OP_LOCALE_IGNORE[op]) - emit(av) - elif not iscased(av): - emit(op) - emit(av) - else: - lo = tolower(av) - if not fixes: # ascii - emit(OP_IGNORE[op]) - emit(lo) - elif lo not in fixes: - emit(OP_UNICODE_IGNORE[op]) - emit(lo) - else: - emit(IN_UNI_IGNORE) - skip = _len(code); emit(0) - if op is NOT_LITERAL: - emit(NEGATE) - for k in (lo,) + fixes[lo]: - emit(LITERAL) - emit(k) - emit(FAILURE) - code[skip] = _len(code) - skip - elif op is IN: - charset, hascased = _optimize_charset(av, iscased, tolower, fixes) - if flags & SRE_FLAG_IGNORECASE and flags & SRE_FLAG_LOCALE: - emit(IN_LOC_IGNORE) - elif not hascased: - emit(IN) - elif not fixes: # ascii - emit(IN_IGNORE) - else: - emit(IN_UNI_IGNORE) - skip = _len(code); emit(0) - _compile_charset(charset, flags, code) - code[skip] = _len(code) - skip - elif op is ANY: - if flags & SRE_FLAG_DOTALL: - emit(ANY_ALL) - else: - emit(ANY) - elif op in REPEATING_CODES: - if flags & SRE_FLAG_TEMPLATE: - raise error("internal: unsupported template operator %r" % (op,)) - if _simple(av[2]): - if op is MAX_REPEAT: - emit(REPEAT_ONE) - else: - emit(MIN_REPEAT_ONE) - skip = _len(code); emit(0) - emit(av[0]) - emit(av[1]) - _compile(code, av[2], flags) - emit(SUCCESS) - code[skip] = _len(code) - skip - else: - emit(REPEAT) - skip = _len(code); emit(0) - emit(av[0]) - emit(av[1]) - _compile(code, av[2], flags) - code[skip] = _len(code) - skip - if op is MAX_REPEAT: - emit(MAX_UNTIL) - else: - emit(MIN_UNTIL) - elif op is SUBPATTERN: - group, add_flags, del_flags, p = av - if group: - emit(MARK) - emit((group-1)*2) - # _compile_info(code, p, _combine_flags(flags, add_flags, del_flags)) - _compile(code, p, _combine_flags(flags, add_flags, del_flags)) - if group: - emit(MARK) - emit((group-1)*2+1) - elif op in SUCCESS_CODES: - emit(op) - elif op in ASSERT_CODES: - emit(op) - skip = _len(code); emit(0) - if av[0] >= 0: - emit(0) # look ahead - else: - lo, hi = av[1].getwidth() - if lo != hi: - raise error("look-behind requires fixed-width pattern") - emit(lo) # look behind - _compile(code, av[1], flags) - emit(SUCCESS) - code[skip] = _len(code) - skip - elif op is CALL: - emit(op) - skip = _len(code); emit(0) - _compile(code, av, flags) - emit(SUCCESS) - code[skip] = _len(code) - skip - elif op is AT: - emit(op) - if flags & SRE_FLAG_MULTILINE: - av = AT_MULTILINE.get(av, av) - if flags & SRE_FLAG_LOCALE: - av = AT_LOCALE.get(av, av) - elif flags & SRE_FLAG_UNICODE: - av = AT_UNICODE.get(av, av) - emit(av) - elif op is BRANCH: - emit(op) - tail = [] - tailappend = tail.append - for av in av[1]: - skip = _len(code); emit(0) - # _compile_info(code, av, flags) - _compile(code, av, flags) - emit(JUMP) - tailappend(_len(code)); emit(0) - code[skip] = _len(code) - skip - emit(FAILURE) # end of branch - for tail in tail: - code[tail] = _len(code) - tail - elif op is CATEGORY: - emit(op) - if flags & SRE_FLAG_LOCALE: - av = CH_LOCALE[av] - elif flags & SRE_FLAG_UNICODE: - av = CH_UNICODE[av] - emit(av) - elif op is GROUPREF: - if not flags & SRE_FLAG_IGNORECASE: - emit(op) - elif flags & SRE_FLAG_LOCALE: - emit(GROUPREF_LOC_IGNORE) - elif not fixes: # ascii - emit(GROUPREF_IGNORE) - else: - emit(GROUPREF_UNI_IGNORE) - emit(av-1) - elif op is GROUPREF_EXISTS: - emit(op) - emit(av[0]-1) - skipyes = _len(code); emit(0) - _compile(code, av[1], flags) - if av[2]: - emit(JUMP) - skipno = _len(code); emit(0) - code[skipyes] = _len(code) - skipyes + 1 - _compile(code, av[2], flags) - code[skipno] = _len(code) - skipno - else: - code[skipyes] = _len(code) - skipyes + 1 - else: - raise error("internal: unsupported operand type %r" % (op,)) - -def _compile_charset(charset, flags, code): - # compile charset subprogram - emit = code.append - for op, av in charset: - emit(op) - if op is NEGATE: - pass - elif op is LITERAL: - emit(av) - elif op is RANGE or op is RANGE_UNI_IGNORE: - emit(av[0]) - emit(av[1]) - elif op is CHARSET: - code.extend(av) - elif op is BIGCHARSET: - code.extend(av) - elif op is CATEGORY: - if flags & SRE_FLAG_LOCALE: - emit(CH_LOCALE[av]) - elif flags & SRE_FLAG_UNICODE: - emit(CH_UNICODE[av]) - else: - emit(av) - else: - raise error("internal: unsupported set operator %r" % (op,)) - emit(FAILURE) - -def _optimize_charset(charset, iscased=None, fixup=None, fixes=None): - # internal: optimize character set - out = [] - tail = [] - charmap = bytearray(256) - hascased = False - for op, av in charset: - while True: - try: - if op is LITERAL: - if fixup: - lo = fixup(av) - charmap[lo] = 1 - if fixes and lo in fixes: - for k in fixes[lo]: - charmap[k] = 1 - if not hascased and iscased(av): - hascased = True - else: - charmap[av] = 1 - elif op is RANGE: - r = range(av[0], av[1]+1) - if fixup: - if fixes: - for i in map(fixup, r): - charmap[i] = 1 - if i in fixes: - for k in fixes[i]: - charmap[k] = 1 - else: - for i in map(fixup, r): - charmap[i] = 1 - if not hascased: - hascased = any(map(iscased, r)) - else: - for i in r: - charmap[i] = 1 - elif op is NEGATE: - out.append((op, av)) - else: - tail.append((op, av)) - except IndexError: - if len(charmap) == 256: - # character set contains non-UCS1 character codes - charmap += b'\0' * 0xff00 - continue - # Character set contains non-BMP character codes. - if fixup: - hascased = True - # There are only two ranges of cased non-BMP characters: - # 10400-1044F (Deseret) and 118A0-118DF (Warang Citi), - # and for both ranges RANGE_UNI_IGNORE works. - if op is RANGE: - op = RANGE_UNI_IGNORE - tail.append((op, av)) - break - - # compress character map - runs = [] - q = 0 - while True: - p = charmap.find(1, q) - if p < 0: - break - if len(runs) >= 2: - runs = None - break - q = charmap.find(0, p) - if q < 0: - runs.append((p, len(charmap))) - break - runs.append((p, q)) - if runs is not None: - # use literal/range - for p, q in runs: - if q - p == 1: - out.append((LITERAL, p)) - else: - out.append((RANGE, (p, q - 1))) - out += tail - # if the case was changed or new representation is more compact - if hascased or len(out) < len(charset): - return out, hascased - # else original character set is good enough - return charset, hascased - - # use bitmap - if len(charmap) == 256: - data = _mk_bitmap(charmap) - out.append((CHARSET, data)) - out += tail - return out, hascased - - # To represent a big charset, first a bitmap of all characters in the - # set is constructed. Then, this bitmap is sliced into chunks of 256 - # characters, duplicate chunks are eliminated, and each chunk is - # given a number. In the compiled expression, the charset is - # represented by a 32-bit word sequence, consisting of one word for - # the number of different chunks, a sequence of 256 bytes (64 words) - # of chunk numbers indexed by their original chunk position, and a - # sequence of 256-bit chunks (8 words each). - - # Compression is normally good: in a typical charset, large ranges of - # Unicode will be either completely excluded (e.g. if only cyrillic - # letters are to be matched), or completely included (e.g. if large - # subranges of Kanji match). These ranges will be represented by - # chunks of all one-bits or all zero-bits. - - # Matching can be also done efficiently: the more significant byte of - # the Unicode character is an index into the chunk number, and the - # less significant byte is a bit index in the chunk (just like the - # CHARSET matching). - - charmap = bytes(charmap) # should be hashable - comps = {} - mapping = bytearray(256) - block = 0 - data = bytearray() - for i in range(0, 65536, 256): - chunk = charmap[i: i + 256] - if chunk in comps: - mapping[i // 256] = comps[chunk] - else: - mapping[i // 256] = comps[chunk] = block - block += 1 - data += chunk - data = _mk_bitmap(data) - data[0:0] = [block] + _bytes_to_codes(mapping) - out.append((BIGCHARSET, data)) - out += tail - return out, hascased - -_CODEBITS = _sre.CODESIZE * 8 -MAXCODE = (1 << _CODEBITS) - 1 -_BITS_TRANS = b'0' + b'1' * 255 -def _mk_bitmap(bits, _CODEBITS=_CODEBITS, _int=int): - s = bits.translate(_BITS_TRANS)[::-1] - return [_int(s[i - _CODEBITS: i], 2) - for i in range(len(s), 0, -_CODEBITS)] - -def _bytes_to_codes(b): - # Convert block indices to word array - a = memoryview(b).cast('I') - assert a.itemsize == _sre.CODESIZE - assert len(a) * a.itemsize == len(b) - return a.tolist() - -def _simple(p): - # check if this subpattern is a "simple" operator - if len(p) != 1: - return False - op, av = p[0] - if op is SUBPATTERN: - return av[0] is None and _simple(av[-1]) - return op in _UNIT_CODES - -def _generate_overlap_table(prefix): - """ - Generate an overlap table for the following prefix. - An overlap table is a table of the same size as the prefix which - informs about the potential self-overlap for each index in the prefix: - - if overlap[i] == 0, prefix[i:] can't overlap prefix[0:...] - - if overlap[i] == k with 0 < k <= i, prefix[i-k+1:i+1] overlaps with - prefix[0:k] - """ - table = [0] * len(prefix) - for i in range(1, len(prefix)): - idx = table[i - 1] - while prefix[i] != prefix[idx]: - if idx == 0: - table[i] = 0 - break - idx = table[idx - 1] - else: - table[i] = idx + 1 - return table - -def _get_iscased(flags): - if not flags & SRE_FLAG_IGNORECASE: - return None - elif flags & SRE_FLAG_UNICODE: - return _sre.unicode_iscased - else: - return _sre.ascii_iscased - -def _get_literal_prefix(pattern, flags): - # look for literal prefix - prefix = [] - prefixappend = prefix.append - prefix_skip = None - iscased = _get_iscased(flags) - for op, av in pattern.data: - if op is LITERAL: - if iscased and iscased(av): - break - prefixappend(av) - elif op is SUBPATTERN: - group, add_flags, del_flags, p = av - flags1 = _combine_flags(flags, add_flags, del_flags) - if flags1 & SRE_FLAG_IGNORECASE and flags1 & SRE_FLAG_LOCALE: - break - prefix1, prefix_skip1, got_all = _get_literal_prefix(p, flags1) - if prefix_skip is None: - if group is not None: - prefix_skip = len(prefix) - elif prefix_skip1 is not None: - prefix_skip = len(prefix) + prefix_skip1 - prefix.extend(prefix1) - if not got_all: - break - else: - break - else: - return prefix, prefix_skip, True - return prefix, prefix_skip, False - -def _get_charset_prefix(pattern, flags): - while True: - if not pattern.data: - return None - op, av = pattern.data[0] - if op is not SUBPATTERN: - break - group, add_flags, del_flags, pattern = av - flags = _combine_flags(flags, add_flags, del_flags) - if flags & SRE_FLAG_IGNORECASE and flags & SRE_FLAG_LOCALE: - return None - - iscased = _get_iscased(flags) - if op is LITERAL: - if iscased and iscased(av): - return None - return [(op, av)] - elif op is BRANCH: - charset = [] - charsetappend = charset.append - for p in av[1]: - if not p: - return None - op, av = p[0] - if op is LITERAL and not (iscased and iscased(av)): - charsetappend((op, av)) - else: - return None - return charset - elif op is IN: - charset = av - if iscased: - for op, av in charset: - if op is LITERAL: - if iscased(av): - return None - elif op is RANGE: - if av[1] > 0xffff: - return None - if any(map(iscased, range(av[0], av[1]+1))): - return None - return charset - return None - -def _compile_info(code, pattern, flags): - # internal: compile an info block. in the current version, - # this contains min/max pattern width, and an optional literal - # prefix or a character map - lo, hi = pattern.getwidth() - if hi > MAXCODE: - hi = MAXCODE - if lo == 0: - code.extend([INFO, 4, 0, lo, hi]) - return - # look for a literal prefix - prefix = [] - prefix_skip = 0 - charset = [] # not used - if not (flags & SRE_FLAG_IGNORECASE and flags & SRE_FLAG_LOCALE): - # look for literal prefix - prefix, prefix_skip, got_all = _get_literal_prefix(pattern, flags) - # if no prefix, look for charset prefix - if not prefix: - charset = _get_charset_prefix(pattern, flags) -## if prefix: -## print("*** PREFIX", prefix, prefix_skip) -## if charset: -## print("*** CHARSET", charset) - # add an info block - emit = code.append - emit(INFO) - skip = len(code); emit(0) - # literal flag - mask = 0 - if prefix: - mask = SRE_INFO_PREFIX - if prefix_skip is None and got_all: - mask = mask | SRE_INFO_LITERAL - elif charset: - mask = mask | SRE_INFO_CHARSET - emit(mask) - # pattern length - if lo < MAXCODE: - emit(lo) - else: - emit(MAXCODE) - prefix = prefix[:MAXCODE] - emit(min(hi, MAXCODE)) - # add literal prefix - if prefix: - emit(len(prefix)) # length - if prefix_skip is None: - prefix_skip = len(prefix) - emit(prefix_skip) # skip - code.extend(prefix) - # generate overlap table - code.extend(_generate_overlap_table(prefix)) - elif charset: - charset, hascased = _optimize_charset(charset) - assert not hascased - _compile_charset(charset, flags, code) - code[skip] = len(code) - skip - -def isstring(obj): - return isinstance(obj, (str, bytes)) - -def _code(p, flags): - - flags = p.state.flags | flags - code = [] - - # compile info block - _compile_info(code, p, flags) - - # compile the pattern - _compile(code, p.data, flags) - - code.append(SUCCESS) - - return code - -def _hex_code(code): - return '[%s]' % ', '.join('%#0*x' % (_sre.CODESIZE*2+2, x) for x in code) - -def dis(code): - import sys - - labels = set() - level = 0 - offset_width = len(str(len(code) - 1)) - - def dis_(start, end): - def print_(*args, to=None): - if to is not None: - labels.add(to) - args += ('(to %d)' % (to,),) - print('%*d%s ' % (offset_width, start, ':' if start in labels else '.'), - end=' '*(level-1)) - print(*args) - - def print_2(*args): - print(end=' '*(offset_width + 2*level)) - print(*args) - - nonlocal level - level += 1 - i = start - while i < end: - start = i - op = code[i] - i += 1 - op = OPCODES[op] - if op in (SUCCESS, FAILURE, ANY, ANY_ALL, - MAX_UNTIL, MIN_UNTIL, NEGATE): - print_(op) - elif op in (LITERAL, NOT_LITERAL, - LITERAL_IGNORE, NOT_LITERAL_IGNORE, - LITERAL_UNI_IGNORE, NOT_LITERAL_UNI_IGNORE, - LITERAL_LOC_IGNORE, NOT_LITERAL_LOC_IGNORE): - arg = code[i] - i += 1 - print_(op, '%#02x (%r)' % (arg, chr(arg))) - elif op is AT: - arg = code[i] - i += 1 - arg = str(ATCODES[arg]) - assert arg[:3] == 'AT_' - print_(op, arg[3:]) - elif op is CATEGORY: - arg = code[i] - i += 1 - arg = str(CHCODES[arg]) - assert arg[:9] == 'CATEGORY_' - print_(op, arg[9:]) - elif op in (IN, IN_IGNORE, IN_UNI_IGNORE, IN_LOC_IGNORE): - skip = code[i] - print_(op, skip, to=i+skip) - dis_(i+1, i+skip) - i += skip - elif op in (RANGE, RANGE_UNI_IGNORE): - lo, hi = code[i: i+2] - i += 2 - print_(op, '%#02x %#02x (%r-%r)' % (lo, hi, chr(lo), chr(hi))) - elif op is CHARSET: - print_(op, _hex_code(code[i: i + 256//_CODEBITS])) - i += 256//_CODEBITS - elif op is BIGCHARSET: - arg = code[i] - i += 1 - mapping = list(b''.join(x.to_bytes(_sre.CODESIZE, sys.byteorder) - for x in code[i: i + 256//_sre.CODESIZE])) - print_(op, arg, mapping) - i += 256//_sre.CODESIZE - level += 1 - for j in range(arg): - print_2(_hex_code(code[i: i + 256//_CODEBITS])) - i += 256//_CODEBITS - level -= 1 - elif op in (MARK, GROUPREF, GROUPREF_IGNORE, GROUPREF_UNI_IGNORE, - GROUPREF_LOC_IGNORE): - arg = code[i] - i += 1 - print_(op, arg) - elif op is JUMP: - skip = code[i] - print_(op, skip, to=i+skip) - i += 1 - elif op is BRANCH: - skip = code[i] - print_(op, skip, to=i+skip) - while skip: - dis_(i+1, i+skip) - i += skip - start = i - skip = code[i] - if skip: - print_('branch', skip, to=i+skip) - else: - print_(FAILURE) - i += 1 - elif op in (REPEAT, REPEAT_ONE, MIN_REPEAT_ONE): - skip, min, max = code[i: i+3] - if max == MAXREPEAT: - max = 'MAXREPEAT' - print_(op, skip, min, max, to=i+skip) - dis_(i+3, i+skip) - i += skip - elif op is GROUPREF_EXISTS: - arg, skip = code[i: i+2] - print_(op, arg, skip, to=i+skip) - i += 2 - elif op in (ASSERT, ASSERT_NOT): - skip, arg = code[i: i+2] - print_(op, skip, arg, to=i+skip) - dis_(i+2, i+skip) - i += skip - elif op is INFO: - skip, flags, min, max = code[i: i+4] - if max == MAXREPEAT: - max = 'MAXREPEAT' - print_(op, skip, bin(flags), min, max, to=i+skip) - start = i+4 - if flags & SRE_INFO_PREFIX: - prefix_len, prefix_skip = code[i+4: i+6] - print_2(' prefix_skip', prefix_skip) - start = i + 6 - prefix = code[start: start+prefix_len] - print_2(' prefix', - '[%s]' % ', '.join('%#02x' % x for x in prefix), - '(%r)' % ''.join(map(chr, prefix))) - start += prefix_len - print_2(' overlap', code[start: start+prefix_len]) - start += prefix_len - if flags & SRE_INFO_CHARSET: - level += 1 - print_2('in') - dis_(start, i+skip) - level -= 1 - i += skip - else: - raise ValueError(op) - - level -= 1 - - dis_(0, len(code)) - - -def compile(p, flags=0): - # internal: convert pattern list to internal format - - if isstring(p): - pattern = p - p = sre_parse.parse(p, flags) - else: - pattern = None - - code = _code(p, flags) - - if flags & SRE_FLAG_DEBUG: - print() - dis(code) - - # map in either direction - groupindex = p.state.groupdict - indexgroup = [None] * p.state.groups - for k, i in groupindex.items(): - indexgroup[i] = k - - return _sre.compile( - pattern, flags | p.state.flags, code, - p.state.groups-1, - groupindex, tuple(indexgroup) - ) +from re import _compiler as _ +globals().update({k: v for k, v in vars(_).items() if k[:2] != '__'}) diff --git a/Lib/sre_constants.py b/Lib/sre_constants.py index 8360acb695..8543e2bc8c 100644 --- a/Lib/sre_constants.py +++ b/Lib/sre_constants.py @@ -1,218 +1,10 @@ -# -# Secret Labs' Regular Expression Engine -# -# various symbols used by the regular expression engine. -# run this script to update the _sre include files! -# -# Copyright (c) 1998-2001 by Secret Labs AB. All rights reserved. -# -# See the sre.py file for information on usage and redistribution. -# +import warnings +warnings.warn(f"module {__name__!r} is deprecated", + DeprecationWarning, + stacklevel=2) -"""Internal support module for sre""" - -# update when constants are added or removed - -MAGIC = 20171005 - -from _sre import MAXREPEAT, MAXGROUPS - -# SRE standard exception (access as sre.error) -# should this really be here? - -class error(Exception): - """Exception raised for invalid regular expressions. - - Attributes: - - msg: The unformatted error message - pattern: The regular expression pattern - pos: The index in the pattern where compilation failed (may be None) - lineno: The line corresponding to pos (may be None) - colno: The column corresponding to pos (may be None) - """ - - __module__ = 're' - - def __init__(self, msg, pattern=None, pos=None): - self.msg = msg - self.pattern = pattern - self.pos = pos - if pattern is not None and pos is not None: - msg = '%s at position %d' % (msg, pos) - if isinstance(pattern, str): - newline = '\n' - else: - newline = b'\n' - self.lineno = pattern.count(newline, 0, pos) + 1 - self.colno = pos - pattern.rfind(newline, 0, pos) - if newline in pattern: - msg = '%s (line %d, column %d)' % (msg, self.lineno, self.colno) - else: - self.lineno = self.colno = None - super().__init__(msg) - - -class _NamedIntConstant(int): - def __new__(cls, value, name): - self = super(_NamedIntConstant, cls).__new__(cls, value) - self.name = name - return self - - def __repr__(self): - return self.name - -MAXREPEAT = _NamedIntConstant(MAXREPEAT, 'MAXREPEAT') - -def _makecodes(names): - names = names.strip().split() - items = [_NamedIntConstant(i, name) for i, name in enumerate(names)] - globals().update({item.name: item for item in items}) - return items - -# operators -# failure=0 success=1 (just because it looks better that way :-) -OPCODES = _makecodes(""" - FAILURE SUCCESS - - ANY ANY_ALL - ASSERT ASSERT_NOT - AT - BRANCH - CALL - CATEGORY - CHARSET BIGCHARSET - GROUPREF GROUPREF_EXISTS - IN - INFO - JUMP - LITERAL - MARK - MAX_UNTIL - MIN_UNTIL - NOT_LITERAL - NEGATE - RANGE - REPEAT - REPEAT_ONE - SUBPATTERN - MIN_REPEAT_ONE - - GROUPREF_IGNORE - IN_IGNORE - LITERAL_IGNORE - NOT_LITERAL_IGNORE - - GROUPREF_LOC_IGNORE - IN_LOC_IGNORE - LITERAL_LOC_IGNORE - NOT_LITERAL_LOC_IGNORE - - GROUPREF_UNI_IGNORE - IN_UNI_IGNORE - LITERAL_UNI_IGNORE - NOT_LITERAL_UNI_IGNORE - RANGE_UNI_IGNORE - - MIN_REPEAT MAX_REPEAT -""") -del OPCODES[-2:] # remove MIN_REPEAT and MAX_REPEAT - -# positions -ATCODES = _makecodes(""" - AT_BEGINNING AT_BEGINNING_LINE AT_BEGINNING_STRING - AT_BOUNDARY AT_NON_BOUNDARY - AT_END AT_END_LINE AT_END_STRING - - AT_LOC_BOUNDARY AT_LOC_NON_BOUNDARY - - AT_UNI_BOUNDARY AT_UNI_NON_BOUNDARY -""") - -# categories -CHCODES = _makecodes(""" - CATEGORY_DIGIT CATEGORY_NOT_DIGIT - CATEGORY_SPACE CATEGORY_NOT_SPACE - CATEGORY_WORD CATEGORY_NOT_WORD - CATEGORY_LINEBREAK CATEGORY_NOT_LINEBREAK - - CATEGORY_LOC_WORD CATEGORY_LOC_NOT_WORD - - CATEGORY_UNI_DIGIT CATEGORY_UNI_NOT_DIGIT - CATEGORY_UNI_SPACE CATEGORY_UNI_NOT_SPACE - CATEGORY_UNI_WORD CATEGORY_UNI_NOT_WORD - CATEGORY_UNI_LINEBREAK CATEGORY_UNI_NOT_LINEBREAK -""") - - -# replacement operations for "ignore case" mode -OP_IGNORE = { - LITERAL: LITERAL_IGNORE, - NOT_LITERAL: NOT_LITERAL_IGNORE, -} - -OP_LOCALE_IGNORE = { - LITERAL: LITERAL_LOC_IGNORE, - NOT_LITERAL: NOT_LITERAL_LOC_IGNORE, -} - -OP_UNICODE_IGNORE = { - LITERAL: LITERAL_UNI_IGNORE, - NOT_LITERAL: NOT_LITERAL_UNI_IGNORE, -} - -AT_MULTILINE = { - AT_BEGINNING: AT_BEGINNING_LINE, - AT_END: AT_END_LINE -} - -AT_LOCALE = { - AT_BOUNDARY: AT_LOC_BOUNDARY, - AT_NON_BOUNDARY: AT_LOC_NON_BOUNDARY -} - -AT_UNICODE = { - AT_BOUNDARY: AT_UNI_BOUNDARY, - AT_NON_BOUNDARY: AT_UNI_NON_BOUNDARY -} - -CH_LOCALE = { - CATEGORY_DIGIT: CATEGORY_DIGIT, - CATEGORY_NOT_DIGIT: CATEGORY_NOT_DIGIT, - CATEGORY_SPACE: CATEGORY_SPACE, - CATEGORY_NOT_SPACE: CATEGORY_NOT_SPACE, - CATEGORY_WORD: CATEGORY_LOC_WORD, - CATEGORY_NOT_WORD: CATEGORY_LOC_NOT_WORD, - CATEGORY_LINEBREAK: CATEGORY_LINEBREAK, - CATEGORY_NOT_LINEBREAK: CATEGORY_NOT_LINEBREAK -} - -CH_UNICODE = { - CATEGORY_DIGIT: CATEGORY_UNI_DIGIT, - CATEGORY_NOT_DIGIT: CATEGORY_UNI_NOT_DIGIT, - CATEGORY_SPACE: CATEGORY_UNI_SPACE, - CATEGORY_NOT_SPACE: CATEGORY_UNI_NOT_SPACE, - CATEGORY_WORD: CATEGORY_UNI_WORD, - CATEGORY_NOT_WORD: CATEGORY_UNI_NOT_WORD, - CATEGORY_LINEBREAK: CATEGORY_UNI_LINEBREAK, - CATEGORY_NOT_LINEBREAK: CATEGORY_UNI_NOT_LINEBREAK -} - -# flags -SRE_FLAG_TEMPLATE = 1 # template mode (disable backtracking) -SRE_FLAG_IGNORECASE = 2 # case insensitive -SRE_FLAG_LOCALE = 4 # honour system locale -SRE_FLAG_MULTILINE = 8 # treat target as multiline string -SRE_FLAG_DOTALL = 16 # treat target as a single string -SRE_FLAG_UNICODE = 32 # use unicode "locale" -SRE_FLAG_VERBOSE = 64 # ignore whitespace and comments -SRE_FLAG_DEBUG = 128 # debugging -SRE_FLAG_ASCII = 256 # use ascii "locale" - -# flags for INFO primitive -SRE_INFO_PREFIX = 1 # has prefix -SRE_INFO_LITERAL = 2 # entire pattern is literal (given by prefix) -SRE_INFO_CHARSET = 4 # pattern starts with character from given set +from re import _constants as _ +globals().update({k: v for k, v in vars(_).items() if k[:2] != '__'}) if __name__ == "__main__": def dump(f, d, typ, int_t, prefix): diff --git a/Lib/sre_parse.py b/Lib/sre_parse.py index 83119168e6..25a3f557d4 100644 --- a/Lib/sre_parse.py +++ b/Lib/sre_parse.py @@ -1,1064 +1,7 @@ -# -# Secret Labs' Regular Expression Engine -# -# convert re-style regular expression to sre pattern -# -# Copyright (c) 1998-2001 by Secret Labs AB. All rights reserved. -# -# See the sre.py file for information on usage and redistribution. -# +import warnings +warnings.warn(f"module {__name__!r} is deprecated", + DeprecationWarning, + stacklevel=2) -"""Internal support module for sre""" - -# XXX: show string offset and offending character for all errors - -from sre_constants import * - -SPECIAL_CHARS = ".\\[{()*+?^$|" -REPEAT_CHARS = "*+?{" - -DIGITS = frozenset("0123456789") - -OCTDIGITS = frozenset("01234567") -HEXDIGITS = frozenset("0123456789abcdefABCDEF") -ASCIILETTERS = frozenset("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") - -WHITESPACE = frozenset(" \t\n\r\v\f") - -_REPEATCODES = frozenset({MIN_REPEAT, MAX_REPEAT}) -_UNITCODES = frozenset({ANY, RANGE, IN, LITERAL, NOT_LITERAL, CATEGORY}) - -ESCAPES = { - r"\a": (LITERAL, ord("\a")), - r"\b": (LITERAL, ord("\b")), - r"\f": (LITERAL, ord("\f")), - r"\n": (LITERAL, ord("\n")), - r"\r": (LITERAL, ord("\r")), - r"\t": (LITERAL, ord("\t")), - r"\v": (LITERAL, ord("\v")), - r"\\": (LITERAL, ord("\\")) -} - -CATEGORIES = { - r"\A": (AT, AT_BEGINNING_STRING), # start of string - r"\b": (AT, AT_BOUNDARY), - r"\B": (AT, AT_NON_BOUNDARY), - r"\d": (IN, [(CATEGORY, CATEGORY_DIGIT)]), - r"\D": (IN, [(CATEGORY, CATEGORY_NOT_DIGIT)]), - r"\s": (IN, [(CATEGORY, CATEGORY_SPACE)]), - r"\S": (IN, [(CATEGORY, CATEGORY_NOT_SPACE)]), - r"\w": (IN, [(CATEGORY, CATEGORY_WORD)]), - r"\W": (IN, [(CATEGORY, CATEGORY_NOT_WORD)]), - r"\Z": (AT, AT_END_STRING), # end of string -} - -FLAGS = { - # standard flags - "i": SRE_FLAG_IGNORECASE, - "L": SRE_FLAG_LOCALE, - "m": SRE_FLAG_MULTILINE, - "s": SRE_FLAG_DOTALL, - "x": SRE_FLAG_VERBOSE, - # extensions - "a": SRE_FLAG_ASCII, - "t": SRE_FLAG_TEMPLATE, - "u": SRE_FLAG_UNICODE, -} - -TYPE_FLAGS = SRE_FLAG_ASCII | SRE_FLAG_LOCALE | SRE_FLAG_UNICODE -GLOBAL_FLAGS = SRE_FLAG_DEBUG | SRE_FLAG_TEMPLATE - -class Verbose(Exception): - pass - -class State: - # keeps track of state for parsing - def __init__(self): - self.flags = 0 - self.groupdict = {} - self.groupwidths = [None] # group 0 - self.lookbehindgroups = None - @property - def groups(self): - return len(self.groupwidths) - def opengroup(self, name=None): - gid = self.groups - self.groupwidths.append(None) - if self.groups > MAXGROUPS: - raise error("too many groups") - if name is not None: - ogid = self.groupdict.get(name, None) - if ogid is not None: - raise error("redefinition of group name %r as group %d; " - "was group %d" % (name, gid, ogid)) - self.groupdict[name] = gid - return gid - def closegroup(self, gid, p): - self.groupwidths[gid] = p.getwidth() - def checkgroup(self, gid): - return gid < self.groups and self.groupwidths[gid] is not None - - def checklookbehindgroup(self, gid, source): - if self.lookbehindgroups is not None: - if not self.checkgroup(gid): - raise source.error('cannot refer to an open group') - if gid >= self.lookbehindgroups: - raise source.error('cannot refer to group defined in the same ' - 'lookbehind subpattern') - -class SubPattern: - # a subpattern, in intermediate form - def __init__(self, state, data=None): - self.state = state - if data is None: - data = [] - self.data = data - self.width = None - - def dump(self, level=0): - nl = True - seqtypes = (tuple, list) - for op, av in self.data: - print(level*" " + str(op), end='') - if op is IN: - # member sublanguage - print() - for op, a in av: - print((level+1)*" " + str(op), a) - elif op is BRANCH: - print() - for i, a in enumerate(av[1]): - if i: - print(level*" " + "OR") - a.dump(level+1) - elif op is GROUPREF_EXISTS: - condgroup, item_yes, item_no = av - print('', condgroup) - item_yes.dump(level+1) - if item_no: - print(level*" " + "ELSE") - item_no.dump(level+1) - elif isinstance(av, seqtypes): - nl = False - for a in av: - if isinstance(a, SubPattern): - if not nl: - print() - a.dump(level+1) - nl = True - else: - if not nl: - print(' ', end='') - print(a, end='') - nl = False - if not nl: - print() - else: - print('', av) - def __repr__(self): - return repr(self.data) - def __len__(self): - return len(self.data) - def __delitem__(self, index): - del self.data[index] - def __getitem__(self, index): - if isinstance(index, slice): - return SubPattern(self.state, self.data[index]) - return self.data[index] - def __setitem__(self, index, code): - self.data[index] = code - def insert(self, index, code): - self.data.insert(index, code) - def append(self, code): - self.data.append(code) - def getwidth(self): - # determine the width (min, max) for this subpattern - if self.width is not None: - return self.width - lo = hi = 0 - for op, av in self.data: - if op is BRANCH: - i = MAXREPEAT - 1 - j = 0 - for av in av[1]: - l, h = av.getwidth() - i = min(i, l) - j = max(j, h) - lo = lo + i - hi = hi + j - elif op is CALL: - i, j = av.getwidth() - lo = lo + i - hi = hi + j - elif op is SUBPATTERN: - i, j = av[-1].getwidth() - lo = lo + i - hi = hi + j - elif op in _REPEATCODES: - i, j = av[2].getwidth() - lo = lo + i * av[0] - hi = hi + j * av[1] - elif op in _UNITCODES: - lo = lo + 1 - hi = hi + 1 - elif op is GROUPREF: - i, j = self.state.groupwidths[av] - lo = lo + i - hi = hi + j - elif op is GROUPREF_EXISTS: - i, j = av[1].getwidth() - if av[2] is not None: - l, h = av[2].getwidth() - i = min(i, l) - j = max(j, h) - else: - i = 0 - lo = lo + i - hi = hi + j - elif op is SUCCESS: - break - self.width = min(lo, MAXREPEAT - 1), min(hi, MAXREPEAT) - return self.width - -class Tokenizer: - def __init__(self, string): - self.istext = isinstance(string, str) - self.string = string - if not self.istext: - string = str(string, 'latin1') - self.decoded_string = string - self.index = 0 - self.next = None - self.__next() - def __next(self): - index = self.index - try: - char = self.decoded_string[index] - except IndexError: - self.next = None - return - if char == "\\": - index += 1 - try: - char += self.decoded_string[index] - except IndexError: - raise error("bad escape (end of pattern)", - self.string, len(self.string) - 1) from None - self.index = index + 1 - self.next = char - def match(self, char): - if char == self.next: - self.__next() - return True - return False - def get(self): - this = self.next - self.__next() - return this - def getwhile(self, n, charset): - result = '' - for _ in range(n): - c = self.next - if c not in charset: - break - result += c - self.__next() - return result - def getuntil(self, terminator, name): - result = '' - while True: - c = self.next - self.__next() - if c is None: - if not result: - raise self.error("missing " + name) - raise self.error("missing %s, unterminated name" % terminator, - len(result)) - if c == terminator: - if not result: - raise self.error("missing " + name, 1) - break - result += c - return result - @property - def pos(self): - return self.index - len(self.next or '') - def tell(self): - return self.index - len(self.next or '') - def seek(self, index): - self.index = index - self.__next() - - def error(self, msg, offset=0): - return error(msg, self.string, self.tell() - offset) - -def _class_escape(source, escape): - # handle escape code inside character class - code = ESCAPES.get(escape) - if code: - return code - code = CATEGORIES.get(escape) - if code and code[0] is IN: - return code - try: - c = escape[1:2] - if c == "x": - # hexadecimal escape (exactly two digits) - escape += source.getwhile(2, HEXDIGITS) - if len(escape) != 4: - raise source.error("incomplete escape %s" % escape, len(escape)) - return LITERAL, int(escape[2:], 16) - elif c == "u" and source.istext: - # unicode escape (exactly four digits) - escape += source.getwhile(4, HEXDIGITS) - if len(escape) != 6: - raise source.error("incomplete escape %s" % escape, len(escape)) - return LITERAL, int(escape[2:], 16) - elif c == "U" and source.istext: - # unicode escape (exactly eight digits) - escape += source.getwhile(8, HEXDIGITS) - if len(escape) != 10: - raise source.error("incomplete escape %s" % escape, len(escape)) - c = int(escape[2:], 16) - chr(c) # raise ValueError for invalid code - return LITERAL, c - elif c == "N" and source.istext: - import unicodedata - # named unicode escape e.g. \N{EM DASH} - if not source.match('{'): - raise source.error("missing {") - charname = source.getuntil('}', 'character name') - try: - c = ord(unicodedata.lookup(charname)) - except KeyError: - raise source.error("undefined character name %r" % charname, - len(charname) + len(r'\N{}')) - return LITERAL, c - elif c in OCTDIGITS: - # octal escape (up to three digits) - escape += source.getwhile(2, OCTDIGITS) - c = int(escape[1:], 8) - if c > 0o377: - raise source.error('octal escape value %s outside of ' - 'range 0-0o377' % escape, len(escape)) - return LITERAL, c - elif c in DIGITS: - raise ValueError - if len(escape) == 2: - if c in ASCIILETTERS: - raise source.error('bad escape %s' % escape, len(escape)) - return LITERAL, ord(escape[1]) - except ValueError: - pass - raise source.error("bad escape %s" % escape, len(escape)) - -def _escape(source, escape, state): - # handle escape code in expression - code = CATEGORIES.get(escape) - if code: - return code - code = ESCAPES.get(escape) - if code: - return code - try: - c = escape[1:2] - if c == "x": - # hexadecimal escape - escape += source.getwhile(2, HEXDIGITS) - if len(escape) != 4: - raise source.error("incomplete escape %s" % escape, len(escape)) - return LITERAL, int(escape[2:], 16) - elif c == "u" and source.istext: - # unicode escape (exactly four digits) - escape += source.getwhile(4, HEXDIGITS) - if len(escape) != 6: - raise source.error("incomplete escape %s" % escape, len(escape)) - return LITERAL, int(escape[2:], 16) - elif c == "U" and source.istext: - # unicode escape (exactly eight digits) - escape += source.getwhile(8, HEXDIGITS) - if len(escape) != 10: - raise source.error("incomplete escape %s" % escape, len(escape)) - c = int(escape[2:], 16) - chr(c) # raise ValueError for invalid code - return LITERAL, c - elif c == "N" and source.istext: - import unicodedata - # named unicode escape e.g. \N{EM DASH} - if not source.match('{'): - raise source.error("missing {") - charname = source.getuntil('}', 'character name') - try: - c = ord(unicodedata.lookup(charname)) - except KeyError: - raise source.error("undefined character name %r" % charname, - len(charname) + len(r'\N{}')) - return LITERAL, c - elif c == "0": - # octal escape - escape += source.getwhile(2, OCTDIGITS) - return LITERAL, int(escape[1:], 8) - elif c in DIGITS: - # octal escape *or* decimal group reference (sigh) - if source.next in DIGITS: - escape += source.get() - if (escape[1] in OCTDIGITS and escape[2] in OCTDIGITS and - source.next in OCTDIGITS): - # got three octal digits; this is an octal escape - escape += source.get() - c = int(escape[1:], 8) - if c > 0o377: - raise source.error('octal escape value %s outside of ' - 'range 0-0o377' % escape, - len(escape)) - return LITERAL, c - # not an octal escape, so this is a group reference - group = int(escape[1:]) - if group < state.groups: - if not state.checkgroup(group): - raise source.error("cannot refer to an open group", - len(escape)) - state.checklookbehindgroup(group, source) - return GROUPREF, group - raise source.error("invalid group reference %d" % group, len(escape) - 1) - if len(escape) == 2: - if c in ASCIILETTERS: - raise source.error("bad escape %s" % escape, len(escape)) - return LITERAL, ord(escape[1]) - except ValueError: - pass - raise source.error("bad escape %s" % escape, len(escape)) - -def _uniq(items): - return list(dict.fromkeys(items)) - -def _parse_sub(source, state, verbose, nested): - # parse an alternation: a|b|c - - items = [] - itemsappend = items.append - sourcematch = source.match - start = source.tell() - while True: - itemsappend(_parse(source, state, verbose, nested + 1, - not nested and not items)) - if not sourcematch("|"): - break - - if len(items) == 1: - return items[0] - - subpattern = SubPattern(state) - - # check if all items share a common prefix - while True: - prefix = None - for item in items: - if not item: - break - if prefix is None: - prefix = item[0] - elif item[0] != prefix: - break - else: - # all subitems start with a common "prefix". - # move it out of the branch - for item in items: - del item[0] - subpattern.append(prefix) - continue # check next one - break - - # check if the branch can be replaced by a character set - set = [] - for item in items: - if len(item) != 1: - break - op, av = item[0] - if op is LITERAL: - set.append((op, av)) - elif op is IN and av[0][0] is not NEGATE: - set.extend(av) - else: - break - else: - # we can store this as a character set instead of a - # branch (the compiler may optimize this even more) - subpattern.append((IN, _uniq(set))) - return subpattern - - subpattern.append((BRANCH, (None, items))) - return subpattern - -def _parse(source, state, verbose, nested, first=False): - # parse a simple pattern - subpattern = SubPattern(state) - - # precompute constants into local variables - subpatternappend = subpattern.append - sourceget = source.get - sourcematch = source.match - _len = len - _ord = ord - - while True: - - this = source.next - if this is None: - break # end of pattern - if this in "|)": - break # end of subpattern - sourceget() - - if verbose: - # skip whitespace and comments - if this in WHITESPACE: - continue - if this == "#": - while True: - this = sourceget() - if this is None or this == "\n": - break - continue - - if this[0] == "\\": - code = _escape(source, this, state) - subpatternappend(code) - - elif this not in SPECIAL_CHARS: - subpatternappend((LITERAL, _ord(this))) - - elif this == "[": - here = source.tell() - 1 - # character set - set = [] - setappend = set.append -## if sourcematch(":"): -## pass # handle character classes - if source.next == '[': - import warnings - warnings.warn( - 'Possible nested set at position %d' % source.tell(), - FutureWarning, stacklevel=nested + 6 - ) - negate = sourcematch("^") - # check remaining characters - while True: - this = sourceget() - if this is None: - raise source.error("unterminated character set", - source.tell() - here) - if this == "]" and set: - break - elif this[0] == "\\": - code1 = _class_escape(source, this) - else: - if set and this in '-&~|' and source.next == this: - import warnings - warnings.warn( - 'Possible set %s at position %d' % ( - 'difference' if this == '-' else - 'intersection' if this == '&' else - 'symmetric difference' if this == '~' else - 'union', - source.tell() - 1), - FutureWarning, stacklevel=nested + 6 - ) - code1 = LITERAL, _ord(this) - if sourcematch("-"): - # potential range - that = sourceget() - if that is None: - raise source.error("unterminated character set", - source.tell() - here) - if that == "]": - if code1[0] is IN: - code1 = code1[1][0] - setappend(code1) - setappend((LITERAL, _ord("-"))) - break - if that[0] == "\\": - code2 = _class_escape(source, that) - else: - if that == '-': - import warnings - warnings.warn( - 'Possible set difference at position %d' % ( - source.tell() - 2), - FutureWarning, stacklevel=nested + 6 - ) - code2 = LITERAL, _ord(that) - if code1[0] != LITERAL or code2[0] != LITERAL: - msg = "bad character range %s-%s" % (this, that) - raise source.error(msg, len(this) + 1 + len(that)) - lo = code1[1] - hi = code2[1] - if hi < lo: - msg = "bad character range %s-%s" % (this, that) - raise source.error(msg, len(this) + 1 + len(that)) - setappend((RANGE, (lo, hi))) - else: - if code1[0] is IN: - code1 = code1[1][0] - setappend(code1) - - set = _uniq(set) - # XXX: should move set optimization to compiler! - if _len(set) == 1 and set[0][0] is LITERAL: - # optimization - if negate: - subpatternappend((NOT_LITERAL, set[0][1])) - else: - subpatternappend(set[0]) - else: - if negate: - set.insert(0, (NEGATE, None)) - # charmap optimization can't be added here because - # global flags still are not known - subpatternappend((IN, set)) - - elif this in REPEAT_CHARS: - # repeat previous item - here = source.tell() - if this == "?": - min, max = 0, 1 - elif this == "*": - min, max = 0, MAXREPEAT - - elif this == "+": - min, max = 1, MAXREPEAT - elif this == "{": - if source.next == "}": - subpatternappend((LITERAL, _ord(this))) - continue - - min, max = 0, MAXREPEAT - lo = hi = "" - while source.next in DIGITS: - lo += sourceget() - if sourcematch(","): - while source.next in DIGITS: - hi += sourceget() - else: - hi = lo - if not sourcematch("}"): - subpatternappend((LITERAL, _ord(this))) - source.seek(here) - continue - - if lo: - min = int(lo) - if min >= MAXREPEAT: - raise OverflowError("the repetition number is too large") - if hi: - max = int(hi) - if max >= MAXREPEAT: - raise OverflowError("the repetition number is too large") - if max < min: - raise source.error("min repeat greater than max repeat", - source.tell() - here) - else: - raise AssertionError("unsupported quantifier %r" % (char,)) - # figure out which item to repeat - if subpattern: - item = subpattern[-1:] - else: - item = None - if not item or item[0][0] is AT: - raise source.error("nothing to repeat", - source.tell() - here + len(this)) - if item[0][0] in _REPEATCODES: - raise source.error("multiple repeat", - source.tell() - here + len(this)) - if item[0][0] is SUBPATTERN: - group, add_flags, del_flags, p = item[0][1] - if group is None and not add_flags and not del_flags: - item = p - if sourcematch("?"): - subpattern[-1] = (MIN_REPEAT, (min, max, item)) - else: - subpattern[-1] = (MAX_REPEAT, (min, max, item)) - - elif this == ".": - subpatternappend((ANY, None)) - - elif this == "(": - start = source.tell() - 1 - group = True - name = None - add_flags = 0 - del_flags = 0 - if sourcematch("?"): - # options - char = sourceget() - if char is None: - raise source.error("unexpected end of pattern") - if char == "P": - # python extensions - if sourcematch("<"): - # named group: skip forward to end of name - name = source.getuntil(">", "group name") - if not name.isidentifier(): - msg = "bad character in group name %r" % name - raise source.error(msg, len(name) + 1) - elif sourcematch("="): - # named backreference - name = source.getuntil(")", "group name") - if not name.isidentifier(): - msg = "bad character in group name %r" % name - raise source.error(msg, len(name) + 1) - gid = state.groupdict.get(name) - if gid is None: - msg = "unknown group name %r" % name - raise source.error(msg, len(name) + 1) - if not state.checkgroup(gid): - raise source.error("cannot refer to an open group", - len(name) + 1) - state.checklookbehindgroup(gid, source) - subpatternappend((GROUPREF, gid)) - continue - - else: - char = sourceget() - if char is None: - raise source.error("unexpected end of pattern") - raise source.error("unknown extension ?P" + char, - len(char) + 2) - elif char == ":": - # non-capturing group - group = None - elif char == "#": - # comment - while True: - if source.next is None: - raise source.error("missing ), unterminated comment", - source.tell() - start) - if sourceget() == ")": - break - continue - - elif char in "=!<": - # lookahead assertions - dir = 1 - if char == "<": - char = sourceget() - if char is None: - raise source.error("unexpected end of pattern") - if char not in "=!": - raise source.error("unknown extension ?<" + char, - len(char) + 2) - dir = -1 # lookbehind - lookbehindgroups = state.lookbehindgroups - if lookbehindgroups is None: - state.lookbehindgroups = state.groups - p = _parse_sub(source, state, verbose, nested + 1) - if dir < 0: - if lookbehindgroups is None: - state.lookbehindgroups = None - if not sourcematch(")"): - raise source.error("missing ), unterminated subpattern", - source.tell() - start) - if char == "=": - subpatternappend((ASSERT, (dir, p))) - else: - subpatternappend((ASSERT_NOT, (dir, p))) - continue - - elif char == "(": - # conditional backreference group - condname = source.getuntil(")", "group name") - if condname.isidentifier(): - condgroup = state.groupdict.get(condname) - if condgroup is None: - msg = "unknown group name %r" % condname - raise source.error(msg, len(condname) + 1) - else: - try: - condgroup = int(condname) - if condgroup < 0: - raise ValueError - except ValueError: - msg = "bad character in group name %r" % condname - raise source.error(msg, len(condname) + 1) from None - if not condgroup: - raise source.error("bad group number", - len(condname) + 1) - if condgroup >= MAXGROUPS: - msg = "invalid group reference %d" % condgroup - raise source.error(msg, len(condname) + 1) - state.checklookbehindgroup(condgroup, source) - item_yes = _parse(source, state, verbose, nested + 1) - if source.match("|"): - item_no = _parse(source, state, verbose, nested + 1) - if source.next == "|": - raise source.error("conditional backref with more than two branches") - else: - item_no = None - if not source.match(")"): - raise source.error("missing ), unterminated subpattern", - source.tell() - start) - subpatternappend((GROUPREF_EXISTS, (condgroup, item_yes, item_no))) - continue - - elif char in FLAGS or char == "-": - # flags - flags = _parse_flags(source, state, char) - if flags is None: # global flags - if not first or subpattern: - import warnings - warnings.warn( - 'Flags not at the start of the expression %r%s' % ( - source.string[:20], # truncate long regexes - ' (truncated)' if len(source.string) > 20 else '', - ), - DeprecationWarning, stacklevel=nested + 6 - ) - if (state.flags & SRE_FLAG_VERBOSE) and not verbose: - raise Verbose - continue - - add_flags, del_flags = flags - group = None - else: - raise source.error("unknown extension ?" + char, - len(char) + 1) - - # parse group contents - if group is not None: - try: - group = state.opengroup(name) - except error as err: - raise source.error(err.msg, len(name) + 1) from None - sub_verbose = ((verbose or (add_flags & SRE_FLAG_VERBOSE)) and - not (del_flags & SRE_FLAG_VERBOSE)) - p = _parse_sub(source, state, sub_verbose, nested + 1) - if not source.match(")"): - raise source.error("missing ), unterminated subpattern", - source.tell() - start) - if group is not None: - state.closegroup(group, p) - subpatternappend((SUBPATTERN, (group, add_flags, del_flags, p))) - - elif this == "^": - subpatternappend((AT, AT_BEGINNING)) - - elif this == "$": - subpatternappend((AT, AT_END)) - - else: - raise AssertionError("unsupported special character %r" % (char,)) - - # unpack non-capturing groups - for i in range(len(subpattern))[::-1]: - op, av = subpattern[i] - if op is SUBPATTERN: - group, add_flags, del_flags, p = av - if group is None and not add_flags and not del_flags: - subpattern[i: i+1] = p - - return subpattern - -def _parse_flags(source, state, char): - sourceget = source.get - add_flags = 0 - del_flags = 0 - if char != "-": - while True: - flag = FLAGS[char] - if source.istext: - if char == 'L': - msg = "bad inline flags: cannot use 'L' flag with a str pattern" - raise source.error(msg) - else: - if char == 'u': - msg = "bad inline flags: cannot use 'u' flag with a bytes pattern" - raise source.error(msg) - add_flags |= flag - if (flag & TYPE_FLAGS) and (add_flags & TYPE_FLAGS) != flag: - msg = "bad inline flags: flags 'a', 'u' and 'L' are incompatible" - raise source.error(msg) - char = sourceget() - if char is None: - raise source.error("missing -, : or )") - if char in ")-:": - break - if char not in FLAGS: - msg = "unknown flag" if char.isalpha() else "missing -, : or )" - raise source.error(msg, len(char)) - if char == ")": - state.flags |= add_flags - return None - if add_flags & GLOBAL_FLAGS: - raise source.error("bad inline flags: cannot turn on global flag", 1) - if char == "-": - char = sourceget() - if char is None: - raise source.error("missing flag") - if char not in FLAGS: - msg = "unknown flag" if char.isalpha() else "missing flag" - raise source.error(msg, len(char)) - while True: - flag = FLAGS[char] - if flag & TYPE_FLAGS: - msg = "bad inline flags: cannot turn off flags 'a', 'u' and 'L'" - raise source.error(msg) - del_flags |= flag - char = sourceget() - if char is None: - raise source.error("missing :") - if char == ":": - break - if char not in FLAGS: - msg = "unknown flag" if char.isalpha() else "missing :" - raise source.error(msg, len(char)) - assert char == ":" - if del_flags & GLOBAL_FLAGS: - raise source.error("bad inline flags: cannot turn off global flag", 1) - if add_flags & del_flags: - raise source.error("bad inline flags: flag turned on and off", 1) - return add_flags, del_flags - -def fix_flags(src, flags): - # Check and fix flags according to the type of pattern (str or bytes) - if isinstance(src, str): - if flags & SRE_FLAG_LOCALE: - raise ValueError("cannot use LOCALE flag with a str pattern") - if not flags & SRE_FLAG_ASCII: - flags |= SRE_FLAG_UNICODE - elif flags & SRE_FLAG_UNICODE: - raise ValueError("ASCII and UNICODE flags are incompatible") - else: - if flags & SRE_FLAG_UNICODE: - raise ValueError("cannot use UNICODE flag with a bytes pattern") - if flags & SRE_FLAG_LOCALE and flags & SRE_FLAG_ASCII: - raise ValueError("ASCII and LOCALE flags are incompatible") - return flags - -def parse(str, flags=0, state=None): - # parse 're' pattern into list of (opcode, argument) tuples - - source = Tokenizer(str) - - if state is None: - state = State() - state.flags = flags - state.str = str - - try: - p = _parse_sub(source, state, flags & SRE_FLAG_VERBOSE, 0) - except Verbose: - # the VERBOSE flag was switched on inside the pattern. to be - # on the safe side, we'll parse the whole thing again... - state = State() - state.flags = flags | SRE_FLAG_VERBOSE - state.str = str - source.seek(0) - p = _parse_sub(source, state, True, 0) - - p.state.flags = fix_flags(str, p.state.flags) - - if source.next is not None: - assert source.next == ")" - raise source.error("unbalanced parenthesis") - - if flags & SRE_FLAG_DEBUG: - p.dump() - - return p - -def parse_template(source, state): - # parse 're' replacement string into list of literals and - # group references - s = Tokenizer(source) - sget = s.get - groups = [] - literals = [] - literal = [] - lappend = literal.append - def addgroup(index, pos): - if index > state.groups: - raise s.error("invalid group reference %d" % index, pos) - if literal: - literals.append(''.join(literal)) - del literal[:] - groups.append((len(literals), index)) - literals.append(None) - groupindex = state.groupindex - while True: - this = sget() - if this is None: - break # end of replacement string - if this[0] == "\\": - # group - c = this[1] - if c == "g": - name = "" - if not s.match("<"): - raise s.error("missing <") - name = s.getuntil(">", "group name") - if name.isidentifier(): - try: - index = groupindex[name] - except KeyError: - raise IndexError("unknown group name %r" % name) - else: - try: - index = int(name) - if index < 0: - raise ValueError - except ValueError: - raise s.error("bad character in group name %r" % name, - len(name) + 1) from None - if index >= MAXGROUPS: - raise s.error("invalid group reference %d" % index, - len(name) + 1) - addgroup(index, len(name) + 1) - elif c == "0": - if s.next in OCTDIGITS: - this += sget() - if s.next in OCTDIGITS: - this += sget() - lappend(chr(int(this[1:], 8) & 0xff)) - elif c in DIGITS: - isoctal = False - if s.next in DIGITS: - this += sget() - if (c in OCTDIGITS and this[2] in OCTDIGITS and - s.next in OCTDIGITS): - this += sget() - isoctal = True - c = int(this[1:], 8) - if c > 0o377: - raise s.error('octal escape value %s outside of ' - 'range 0-0o377' % this, len(this)) - lappend(chr(c)) - if not isoctal: - addgroup(int(this[1:]), len(this) - 1) - else: - try: - this = chr(ESCAPES[this][1]) - except KeyError: - if c in ASCIILETTERS: - raise s.error('bad escape %s' % this, len(this)) - lappend(this) - else: - lappend(this) - if literal: - literals.append(''.join(literal)) - if not isinstance(source, str): - # The tokenizer implicitly decodes bytes objects as latin-1, we must - # therefore re-encode the final representation. - literals = [None if s is None else s.encode('latin-1') for s in literals] - return groups, literals - -def expand_template(template, match): - g = match.group - empty = match.string[:0] - groups, literals = template - literals = literals[:] - try: - for index, group in groups: - literals[index] = g(group) or empty - except IndexError: - raise error("invalid group reference %d" % index) - return empty.join(literals) +from re import _parser as _ +globals().update({k: v for k, v in vars(_).items() if k[:2] != '__'}) From 1e3d57817cbb8a08b6b7c7dc04eefe0e566975ad Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Wed, 22 Nov 2023 23:20:04 +0200 Subject: [PATCH 305/893] Replace re_test.py from CPython 3.12 and mark failed tests --- Cargo.lock | 37 +- Lib/test/test_re.py | 844 ++++++++++++++++++++++++++++++++++---------- vm/Cargo.toml | 4 +- 3 files changed, 658 insertions(+), 227 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 52afbb053f..da13a2a33c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1422,34 +1422,13 @@ dependencies = [ "libc", ] -[[package]] -name = "num_enum" -version = "0.5.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d829733185c1ca374f17e52b762f24f535ec625d2cc1f070e34c8a9068f341b" -dependencies = [ - "num_enum_derive 0.5.9", -] - [[package]] name = "num_enum" version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02339744ee7253741199f897151b38e72257d13802d4ee837285cc2990a90845" dependencies = [ - "num_enum_derive 0.7.2", -] - -[[package]] -name = "num_enum_derive" -version = "0.5.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2be1598bf1c313dcdd12092e3f1920f463462525a21b7b4e11b4168353d0123e" -dependencies = [ - "proc-macro-crate", - "proc-macro2", - "quote", - "syn 1.0.107", + "num_enum_derive", ] [[package]] @@ -2191,7 +2170,7 @@ name = "rustpython-sre_engine" version = "0.6.0" dependencies = [ "bitflags 2.4.0", - "num_enum 0.7.2", + "num_enum", "optional", ] @@ -2230,7 +2209,7 @@ dependencies = [ "num-complex", "num-integer", "num-traits", - "num_enum 0.7.2", + "num_enum", "once_cell", "openssl", "openssl-probe", @@ -2300,7 +2279,7 @@ dependencies = [ "num-integer", "num-traits", "num_cpus", - "num_enum 0.7.2", + "num_enum", "once_cell", "optional", "parking_lot", @@ -2552,12 +2531,10 @@ dependencies = [ [[package]] name = "sre-engine" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a490c5c46c35dba9a6f5e7ee8e4d67e775eb2d2da0f115750b8d10e1c1ac2d28" +version = "0.6.0" dependencies = [ - "bitflags 1.3.2", - "num_enum 0.5.9", + "bitflags 2.4.0", + "num_enum", "optional", ] diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py index 9b30b4137c..1fd2432aae 100644 --- a/Lib/test/test_re.py +++ b/Lib/test/test_re.py @@ -1,15 +1,25 @@ from test.support import (gc_collect, bigmemtest, _2G, cpython_only, captured_stdout, - check_disallow_instantiation) + check_disallow_instantiation, is_emscripten, is_wasi, + SHORT_TIMEOUT) import locale import re -import sre_compile import string +import sys +import time import unittest import warnings from re import Scanner from weakref import proxy +# some platforms lack working multiprocessing +try: + import _multiprocessing +except ImportError: + multiprocessing = None +else: + import multiprocessing + # Misc tests from Tim Peters' re.doc # WARNING: Don't change details in these tests if you don't know @@ -85,10 +95,29 @@ def test_search_star_plus(self): self.assertEqual(re.match('x*', 'xxxa').span(), (0, 3)) self.assertIsNone(re.match('a+', 'xxx')) + def test_branching(self): + """Test Branching + Test expressions using the OR ('|') operator.""" + self.assertEqual(re.match('(ab|ba)', 'ab').span(), (0, 2)) + self.assertEqual(re.match('(ab|ba)', 'ba').span(), (0, 2)) + self.assertEqual(re.match('(abc|bac|ca|cb)', 'abc').span(), + (0, 3)) + self.assertEqual(re.match('(abc|bac|ca|cb)', 'bac').span(), + (0, 3)) + self.assertEqual(re.match('(abc|bac|ca|cb)', 'ca').span(), + (0, 2)) + self.assertEqual(re.match('(abc|bac|ca|cb)', 'cb').span(), + (0, 2)) + self.assertEqual(re.match('((a)|(b)|(c))', 'a').span(), (0, 1)) + self.assertEqual(re.match('((a)|(b)|(c))', 'b').span(), (0, 1)) + self.assertEqual(re.match('((a)|(b)|(c))', 'c').span(), (0, 1)) + def bump_num(self, matchobj): int_value = int(matchobj.group(0)) return str(int_value + 1) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_basic_re_sub(self): self.assertTypedEqual(re.sub('y', 'a', 'xyz'), 'xaz') self.assertTypedEqual(re.sub('y', S('a'), S('xyz')), 'xaz') @@ -119,6 +148,7 @@ def test_basic_re_sub(self): self.assertEqual(re.sub('(?Px)', r'\g\g<1>', 'xx'), 'xxxx') self.assertEqual(re.sub('(?Px)', r'\g\g', 'xx'), 'xxxx') self.assertEqual(re.sub('(?Px)', r'\g<1>\g<1>', 'xx'), 'xxxx') + self.assertEqual(re.sub('()x', r'\g<0>\g<0>', 'xx'), 'xxxx') self.assertEqual(re.sub('a', r'\t\n\v\r\f\a\b', 'a'), '\t\n\v\r\f\a\b') self.assertEqual(re.sub('a', '\t\n\v\r\f\a\b', 'a'), '\t\n\v\r\f\a\b') @@ -131,11 +161,15 @@ def test_basic_re_sub(self): self.assertEqual(re.sub(r'^\s*', 'X', 'test'), 'Xtest') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_bug_449964(self): # fails for group followed by other escape self.assertEqual(re.sub(r'(?Px)', r'\g<1>\g<1>\b', 'xx'), 'xx\bxx\b') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_bug_449000(self): # Test for sub() on escaped characters self.assertEqual(re.sub(r'\r\n', r'\n', 'abc\r\ndef\r\n'), @@ -159,6 +193,8 @@ def test_bug_3629(self): # A regex that triggered a bug in the sre-code validator re.compile("(?P)(?(quote))") + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_sub_template_numeric_escape(self): # bug 776311 and friends self.assertEqual(re.sub('x', r'\0', 'x'), '\0') @@ -212,6 +248,8 @@ def test_qualified_re_sub(self): self.assertEqual(re.sub('a', 'b', 'aaaaa', 1), 'baaaa') self.assertEqual(re.sub('a', 'b', 'aaaaa', count=1), 'baaaa') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_bug_114660(self): self.assertEqual(re.sub(r'(\S)\s+(\S)', r'\1 \2', 'hello there'), 'hello there') @@ -258,7 +296,15 @@ def test_symbolic_groups_errors(self): self.checkPatternError('(?P<©>x)', "bad character in group name '©'", 4) self.checkPatternError('(?P=©)', "bad character in group name '©'", 4) self.checkPatternError('(?(©)y)', "bad character in group name '©'", 3) + self.checkPatternError(b'(?P<\xc2\xb5>x)', + r"bad character in group name '\xc2\xb5'", 4) + self.checkPatternError(b'(?P=\xc2\xb5)', + r"bad character in group name '\xc2\xb5'", 4) + self.checkPatternError(b'(?(\xc2\xb5)y)', + r"bad character in group name '\xc2\xb5'", 3) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_symbolic_refs(self): self.assertEqual(re.sub('(?Px)|(?Py)', r'\g', 'xx'), '') self.assertEqual(re.sub('(?Px)|(?Py)', r'\2', 'xx'), '') @@ -270,6 +316,8 @@ def test_symbolic_refs(self): pat = '|'.join('x(?P%x)y' % (i, i) for i in range(1, 200 + 1)) self.assertEqual(re.sub(pat, r'\g<200>', 'xc8yzxc8y'), 'c8zc8') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_symbolic_refs_errors(self): self.checkTemplateError('(?Px)', r'\g, unterminated name', 3) @@ -290,12 +338,22 @@ def test_symbolic_refs_errors(self): re.sub('(?Px)', r'\g', 'xx') self.checkTemplateError('(?Px)', r'\g<-1>', 'xx', "bad character in group name '-1'", 3) + self.checkTemplateError('(?Px)', r'\g<+1>', 'xx', + "bad character in group name '+1'", 3) + self.checkTemplateError('()'*10, r'\g<1_0>', 'xx', + "bad character in group name '1_0'", 3) + self.checkTemplateError('(?Px)', r'\g< 1 >', 'xx', + "bad character in group name ' 1 '", 3) self.checkTemplateError('(?Px)', r'\g<©>', 'xx', "bad character in group name '©'", 3) + self.checkTemplateError(b'(?Px)', b'\\g<\xc2\xb5>', b'xx', + r"bad character in group name '\xc2\xb5'", 3) self.checkTemplateError('(?Px)', r'\g<㊀>', 'xx', "bad character in group name '㊀'", 3) self.checkTemplateError('(?Px)', r'\g<¹>', 'xx', "bad character in group name '¹'", 3) + self.checkTemplateError('(?Px)', r'\g<१>', 'xx', + "bad character in group name '१'", 3) def test_re_subn(self): self.assertEqual(re.subn("(?i)b+", "x", "bbbb BBBB"), ('x x', 2)) @@ -557,16 +615,22 @@ def test_re_groupref_exists(self): pat = '(?:%s)(?(200)z)' % pat self.assertEqual(re.match(pat, 'xc8yz').span(), (0, 5)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_re_groupref_exists_errors(self): self.checkPatternError(r'(?P)(?(0)a|b)', 'bad group number', 10) self.checkPatternError(r'()(?(-1)a|b)', "bad character in group name '-1'", 5) + self.checkPatternError(r'()(?(+1)a|b)', + "bad character in group name '+1'", 5) + self.checkPatternError(r'()'*10 + r'(?(1_0)a|b)', + "bad character in group name '1_0'", 23) + self.checkPatternError(r'()(?( 1 )a|b)', + "bad character in group name ' 1 '", 5) self.checkPatternError(r'()(?(㊀)a|b)', "bad character in group name '㊀'", 5) self.checkPatternError(r'()(?(¹)a|b)', "bad character in group name '¹'", 5) + self.checkPatternError(r'()(?(१)a|b)', + "bad character in group name '१'", 5) self.checkPatternError(r'()(?(1', "missing ), unterminated name", 5) self.checkPatternError(r'()(?(1)a', @@ -582,8 +646,15 @@ def test_re_groupref_exists_errors(self): self.checkPatternError(r'()(?(2)a)', "invalid group reference 2", 5) + def test_re_groupref_exists_validation_bug(self): + for i in range(256): + with self.subTest(code=i): + re.compile(r'()(?(1)\x%02x?)' % i) + + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_re_groupref_overflow(self): - from sre_constants import MAXGROUPS + from re._constants import MAXGROUPS self.checkTemplateError('()', r'\g<%s>' % MAXGROUPS, 'xx', 'invalid group reference %d' % MAXGROUPS, 3) self.checkPatternError(r'(?P)(?(%d))' % MAXGROUPS, @@ -608,6 +679,8 @@ def test_groupdict(self): 'first second').groupdict(), {'first':'first', 'second':'second'}) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_expand(self): self.assertEqual(re.match("(?Pfirst) (?Psecond)", "first second") @@ -871,8 +944,6 @@ def test_lookbehind(self): self.assertRaises(re.error, re.compile, r'(a)b(?<=(a)(?(2)b|x))(c)') self.assertRaises(re.error, re.compile, r'(a)b(?<=(.)(?<=\2))(c)') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_ignore_case(self): self.assertEqual(re.match("abc", "ABC", re.I).group(0), "ABC") self.assertEqual(re.match(b"abc", b"ABC", re.I).group(0), b"ABC") @@ -913,8 +984,6 @@ def test_ignore_case(self): self.assertTrue(re.match(r'\ufb05', '\ufb06', re.I)) self.assertTrue(re.match(r'\ufb06', '\ufb05', re.I)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_ignore_case_set(self): self.assertTrue(re.match(r'[19A]', 'A', re.I)) self.assertTrue(re.match(r'[19a]', 'a', re.I)) @@ -953,8 +1022,6 @@ def test_ignore_case_set(self): self.assertTrue(re.match(r'[19\ufb05]', '\ufb06', re.I)) self.assertTrue(re.match(r'[19\ufb06]', '\ufb05', re.I)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_ignore_case_range(self): # Issues #3511, #17381. self.assertTrue(re.match(r'[9-a]', '_', re.I)) @@ -1005,33 +1072,6 @@ def test_ignore_case_range(self): def test_category(self): self.assertEqual(re.match(r"(\s)", " ").group(1), " ") - @cpython_only - def test_case_helpers(self): - import _sre - for i in range(128): - c = chr(i) - lo = ord(c.lower()) - self.assertEqual(_sre.ascii_tolower(i), lo) - self.assertEqual(_sre.unicode_tolower(i), lo) - iscased = c in string.ascii_letters - self.assertEqual(_sre.ascii_iscased(i), iscased) - self.assertEqual(_sre.unicode_iscased(i), iscased) - - for i in list(range(128, 0x1000)) + [0x10400, 0x10428]: - c = chr(i) - self.assertEqual(_sre.ascii_tolower(i), i) - if i != 0x0130: - self.assertEqual(_sre.unicode_tolower(i), ord(c.lower())) - iscased = c != c.lower() or c != c.upper() - self.assertFalse(_sre.ascii_iscased(i)) - self.assertEqual(_sre.unicode_iscased(i), - c != c.lower() or c != c.upper()) - - self.assertEqual(_sre.ascii_tolower(0x0130), 0x0130) - self.assertEqual(_sre.unicode_tolower(0x0130), ord('i')) - self.assertFalse(_sre.ascii_iscased(0x0130)) - self.assertTrue(_sre.unicode_iscased(0x0130)) - def test_not_literal(self): self.assertEqual(re.search(r"\s([^a])", " b").group(1), "b") self.assertEqual(re.search(r"\s([^a]*)", " bb").group(1), "bb") @@ -1332,11 +1372,13 @@ def test_nothing_to_repeat(self): 'nothing to repeat', 3) def test_multiple_repeat(self): - for outer_reps in '*', '+', '{1,2}': - for outer_mod in '', '?': + for outer_reps in '*', '+', '?', '{1,2}': + for outer_mod in '', '?', '+': outer_op = outer_reps + outer_mod for inner_reps in '*', '+', '?', '{1,2}': - for inner_mod in '', '?': + for inner_mod in '', '?', '+': + if inner_mod + outer_reps in ('?', '+'): + continue inner_op = inner_reps + inner_mod self.checkPatternError(r'x%s%s' % (inner_op, outer_op), 'multiple repeat', 1 + len(inner_op)) @@ -1491,8 +1533,6 @@ def test_empty_array(self): self.assertIsNone(re.compile(b"bla").match(a)) self.assertEqual(re.compile(b"").match(a).groups(), ()) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_inline_flags(self): # Bug #1700 upper_char = '\u1ea0' # Latin Capital Letter A with Dot Below @@ -1536,70 +1576,27 @@ def test_inline_flags(self): self.assertTrue(re.match('(?x) (?i) ' + upper_char, lower_char)) self.assertTrue(re.match(' (?x) (?i) ' + upper_char, lower_char, re.X)) - p = upper_char + '(?i)' - with self.assertWarns(DeprecationWarning) as warns: - self.assertTrue(re.match(p, lower_char)) - self.assertEqual( - str(warns.warnings[0].message), - 'Flags not at the start of the expression %r' - ' but at position 1' % p - ) - self.assertEqual(warns.warnings[0].filename, __file__) - - p = upper_char + '(?i)%s' % ('.?' * 100) - with self.assertWarns(DeprecationWarning) as warns: - self.assertTrue(re.match(p, lower_char)) - self.assertEqual( - str(warns.warnings[0].message), - 'Flags not at the start of the expression %r (truncated)' - ' but at position 1' % p[:20] - ) - self.assertEqual(warns.warnings[0].filename, __file__) + msg = "global flags not at the start of the expression" + self.checkPatternError(upper_char + '(?i)', msg, 1) # bpo-30605: Compiling a bytes instance regex was throwing a BytesWarning with warnings.catch_warnings(): warnings.simplefilter('error', BytesWarning) - p = b'A(?i)' - with self.assertWarns(DeprecationWarning) as warns: - self.assertTrue(re.match(p, b'a')) - self.assertEqual( - str(warns.warnings[0].message), - 'Flags not at the start of the expression %r' - ' but at position 1' % p - ) - self.assertEqual(warns.warnings[0].filename, __file__) - - with self.assertWarns(DeprecationWarning): - self.assertTrue(re.match('(?s).(?i)' + upper_char, '\n' + lower_char)) - with self.assertWarns(DeprecationWarning): - self.assertTrue(re.match('(?i) ' + upper_char + ' (?x)', lower_char)) - with self.assertWarns(DeprecationWarning): - self.assertTrue(re.match(' (?x) (?i) ' + upper_char, lower_char)) - with self.assertWarns(DeprecationWarning): - self.assertTrue(re.match('^(?i)' + upper_char, lower_char)) - with self.assertWarns(DeprecationWarning): - self.assertTrue(re.match('$|(?i)' + upper_char, lower_char)) - with self.assertWarns(DeprecationWarning) as warns: - self.assertTrue(re.match('(?:(?i)' + upper_char + ')', lower_char)) - self.assertRegex(str(warns.warnings[0].message), - 'Flags not at the start') - self.assertEqual(warns.warnings[0].filename, __file__) - with self.assertWarns(DeprecationWarning) as warns: - self.assertTrue(re.fullmatch('(^)?(?(1)(?i)' + upper_char + ')', - lower_char)) - self.assertRegex(str(warns.warnings[0].message), - 'Flags not at the start') - self.assertEqual(warns.warnings[0].filename, __file__) - with self.assertWarns(DeprecationWarning) as warns: - self.assertTrue(re.fullmatch('($)?(?(1)|(?i)' + upper_char + ')', - lower_char)) - self.assertRegex(str(warns.warnings[0].message), - 'Flags not at the start') - self.assertEqual(warns.warnings[0].filename, __file__) + self.checkPatternError(b'A(?i)', msg, 1) + + self.checkPatternError('(?s).(?i)' + upper_char, msg, 5) + self.checkPatternError('(?i) ' + upper_char + ' (?x)', msg, 7) + self.checkPatternError(' (?x) (?i) ' + upper_char, msg, 1) + self.checkPatternError('^(?i)' + upper_char, msg, 1) + self.checkPatternError('$|(?i)' + upper_char, msg, 2) + self.checkPatternError('(?:(?i)' + upper_char + ')', msg, 3) + self.checkPatternError('(^)?(?(1)(?i)' + upper_char + ')', msg, 9) + self.checkPatternError('($)?(?(1)|(?i)' + upper_char + ')', msg, 10) def test_dollar_matches_twice(self): - "$ matches the end of string, and just before the terminating \n" + r"""Test that $ does not include \n + $ matches the end of string, and just before the terminating \n""" pattern = re.compile('$') self.assertEqual(pattern.sub('#', 'a\nb\n'), 'a\nb#\n#') self.assertEqual(pattern.sub('#', 'a\nb\nc'), 'a\nb\nc#') @@ -1757,6 +1754,8 @@ def test_comments(self): self.assertTrue(re.fullmatch('(?x)#x\na|#y\nb', 'a')) self.assertTrue(re.fullmatch('(?x)#x\na|#y\nb', 'b')) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_bug_6509(self): # Replacement strings of both types must parse properly. # all strings @@ -1775,24 +1774,6 @@ def test_bug_6509(self): pat = re.compile(b'..') self.assertEqual(pat.sub(lambda m: b'bytes', b'a5'), b'bytes') - # RUSTPYTHON: here in rustpython, we borrow the string only at the - # time of matching, so we will not check the string type when creating - # SRE_Scanner, expect this, other tests has passed - @cpython_only - def test_dealloc(self): - # issue 3299: check for segfault in debug build - import _sre - # the overflow limit is different on wide and narrow builds and it - # depends on the definition of SRE_CODE (see sre.h). - # 2**128 should be big enough to overflow on both. For smaller values - # a RuntimeError is raised instead of OverflowError. - long_overflow = 2**128 - self.assertRaises(TypeError, re.finditer, "a", {}) - with self.assertRaises(OverflowError): - _sre.compile("abc", 0, [long_overflow], 0, {}, ()) - with self.assertRaises(TypeError): - _sre.compile({}, 0, [], 0, [], []) - def test_search_dot_unicode(self): self.assertTrue(re.search("123.*-", '123abc-')) self.assertTrue(re.search("123.*-", '123\xe9-')) @@ -1850,20 +1831,28 @@ def test_repeat_minmax_overflow(self): self.assertRaises(OverflowError, re.compile, r".{%d,}?" % 2**128) self.assertRaises(OverflowError, re.compile, r".{%d,%d}" % (2**129, 2**128)) - @cpython_only - def test_repeat_minmax_overflow_maxrepeat(self): - try: - from _sre import MAXREPEAT - except ImportError: - self.skipTest('requires _sre.MAXREPEAT constant') - string = "x" * 100000 - self.assertIsNone(re.match(r".{%d}" % (MAXREPEAT - 1), string)) - self.assertEqual(re.match(r".{,%d}" % (MAXREPEAT - 1), string).span(), - (0, 100000)) - self.assertIsNone(re.match(r".{%d,}?" % (MAXREPEAT - 1), string)) - self.assertRaises(OverflowError, re.compile, r".{%d}" % MAXREPEAT) - self.assertRaises(OverflowError, re.compile, r".{,%d}" % MAXREPEAT) - self.assertRaises(OverflowError, re.compile, r".{%d,}?" % MAXREPEAT) + def test_look_behind_overflow(self): + string = "x" * 2_500_000 + p1 = r"(?<=((.{%d}){%d}){%d})" + p2 = r"(?...), which does + not maintain any stack point created within the group once the + group is finished being evaluated.""" + pattern1 = re.compile(r'a(?>bc|b)c') + self.assertIsNone(pattern1.match('abc')) + self.assertTrue(pattern1.match('abcc')) + self.assertIsNone(re.match(r'(?>.*).', 'abc')) + self.assertTrue(re.match(r'(?>x)++', 'xxx')) + self.assertTrue(re.match(r'(?>x++)', 'xxx')) + self.assertIsNone(re.match(r'(?>x)++x', 'xxx')) + self.assertIsNone(re.match(r'(?>x++)x', 'xxx')) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_fullmatch_atomic_grouping(self): + self.assertTrue(re.fullmatch(r'(?>a+)', 'a')) + self.assertTrue(re.fullmatch(r'(?>a*)', 'a')) + self.assertTrue(re.fullmatch(r'(?>a?)', 'a')) + self.assertTrue(re.fullmatch(r'(?>a{1,3})', 'a')) + self.assertIsNone(re.fullmatch(r'(?>a+)', 'ab')) + self.assertIsNone(re.fullmatch(r'(?>a*)', 'ab')) + self.assertIsNone(re.fullmatch(r'(?>a?)', 'ab')) + self.assertIsNone(re.fullmatch(r'(?>a{1,3})', 'ab')) + self.assertTrue(re.fullmatch(r'(?>a+)b', 'ab')) + self.assertTrue(re.fullmatch(r'(?>a*)b', 'ab')) + self.assertTrue(re.fullmatch(r'(?>a?)b', 'ab')) + self.assertTrue(re.fullmatch(r'(?>a{1,3})b', 'ab')) + + self.assertTrue(re.fullmatch(r'(?>(?:ab)+)', 'ab')) + self.assertTrue(re.fullmatch(r'(?>(?:ab)*)', 'ab')) + self.assertTrue(re.fullmatch(r'(?>(?:ab)?)', 'ab')) + self.assertTrue(re.fullmatch(r'(?>(?:ab){1,3})', 'ab')) + self.assertIsNone(re.fullmatch(r'(?>(?:ab)+)', 'abc')) + self.assertIsNone(re.fullmatch(r'(?>(?:ab)*)', 'abc')) + self.assertIsNone(re.fullmatch(r'(?>(?:ab)?)', 'abc')) + self.assertIsNone(re.fullmatch(r'(?>(?:ab){1,3})', 'abc')) + self.assertTrue(re.fullmatch(r'(?>(?:ab)+)c', 'abc')) + self.assertTrue(re.fullmatch(r'(?>(?:ab)*)c', 'abc')) + self.assertTrue(re.fullmatch(r'(?>(?:ab)?)c', 'abc')) + self.assertTrue(re.fullmatch(r'(?>(?:ab){1,3})c', 'abc')) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_findall_atomic_grouping(self): + self.assertEqual(re.findall(r'(?>a+)', 'aab'), ['aa']) + self.assertEqual(re.findall(r'(?>a*)', 'aab'), ['aa', '', '']) + self.assertEqual(re.findall(r'(?>a?)', 'aab'), ['a', 'a', '', '']) + self.assertEqual(re.findall(r'(?>a{1,3})', 'aab'), ['aa']) + + self.assertEqual(re.findall(r'(?>(?:ab)+)', 'ababc'), ['abab']) + self.assertEqual(re.findall(r'(?>(?:ab)*)', 'ababc'), ['abab', '', '']) + self.assertEqual(re.findall(r'(?>(?:ab)?)', 'ababc'), ['ab', 'ab', '', '']) + self.assertEqual(re.findall(r'(?>(?:ab){1,3})', 'ababc'), ['abab']) + + def test_bug_gh91616(self): + self.assertTrue(re.fullmatch(r'(?s:(?>.*?\.).*)\Z', "a.txt")) # reproducer + self.assertTrue(re.fullmatch(r'(?s:(?=(?P.*?\.))(?P=g0).*)\Z', "a.txt")) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_template_function_and_flag_is_deprecated(self): + with self.assertWarns(DeprecationWarning) as cm: + template_re1 = re.template(r'a') + self.assertIn('re.template()', str(cm.warning)) + self.assertIn('is deprecated', str(cm.warning)) + self.assertIn('function', str(cm.warning)) + self.assertNotIn('flag', str(cm.warning)) + + with self.assertWarns(DeprecationWarning) as cm: + # we deliberately use more flags here to test that that still + # triggers the warning + # if paranoid, we could test multiple different combinations, + # but it's probably not worth it + template_re2 = re.compile(r'a', flags=re.TEMPLATE|re.UNICODE) + self.assertIn('re.TEMPLATE', str(cm.warning)) + self.assertIn('is deprecated', str(cm.warning)) + self.assertIn('flag', str(cm.warning)) + self.assertNotIn('function', str(cm.warning)) + + # while deprecated, is should still function + self.assertEqual(template_re1, template_re2) + self.assertTrue(template_re1.match('ahoy')) + self.assertFalse(template_re1.match('nope')) + + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_bug_gh106052(self): + # gh-100061 + self.assertEqual(re.match('(?>(?:.(?!D))+)', 'ABCDE').span(), (0, 2)) + self.assertEqual(re.match('(?:.(?!D))++', 'ABCDE').span(), (0, 2)) + self.assertEqual(re.match('(?>(?:.(?!D))*)', 'ABCDE').span(), (0, 2)) + self.assertEqual(re.match('(?:.(?!D))*+', 'ABCDE').span(), (0, 2)) + self.assertEqual(re.match('(?>(?:.(?!D))?)', 'CDE').span(), (0, 0)) + self.assertEqual(re.match('(?:.(?!D))?+', 'CDE').span(), (0, 0)) + self.assertEqual(re.match('(?>(?:.(?!D)){1,3})', 'ABCDE').span(), (0, 2)) + self.assertEqual(re.match('(?:.(?!D)){1,3}+', 'ABCDE').span(), (0, 2)) + # gh-106052 + self.assertEqual(re.match("(?>(?:ab?c)+)", "aca").span(), (0, 2)) + self.assertEqual(re.match("(?:ab?c)++", "aca").span(), (0, 2)) + self.assertEqual(re.match("(?>(?:ab?c)*)", "aca").span(), (0, 2)) + self.assertEqual(re.match("(?:ab?c)*+", "aca").span(), (0, 2)) + self.assertEqual(re.match("(?>(?:ab?c)?)", "a").span(), (0, 0)) + self.assertEqual(re.match("(?:ab?c)?+", "a").span(), (0, 0)) + self.assertEqual(re.match("(?>(?:ab?c){1,3})", "aca").span(), (0, 2)) + self.assertEqual(re.match("(?:ab?c){1,3}+", "aca").span(), (0, 2)) + + # TODO: RUSTPYTHON + @unittest.skipUnless(sys.platform == 'linux', 'multiprocessing related issue') + @unittest.skipIf(multiprocessing is None, 'test requires multiprocessing') + def test_regression_gh94675(self): + pattern = re.compile(r'(?<=[({}])(((//[^\n]*)?[\n])([\000-\040])*)*' + r'((/[^/\[\n]*(([^\n]|(\[\n]*(]*)*\]))' + r'[^/\[]*)*/))((((//[^\n]*)?[\n])' + r'([\000-\040]|(/\*[^*]*\*+' + r'([^/*]\*+)*/))*)+(?=[^\000-\040);\]}]))') + input_js = '''a(function() { + /////////////////////////////////////////////////////////////////// + });''' + p = multiprocessing.Process(target=pattern.sub, args=('', input_js)) + p.start() + p.join(SHORT_TIMEOUT) + try: + self.assertFalse(p.is_alive(), 'pattern.sub() timed out') + finally: + if p.is_alive(): + p.terminate() + p.join() + + +def get_debug_out(pat): + with captured_stdout() as out: + re.compile(pat, re.DEBUG) + return out.getvalue() + + +@cpython_only +class DebugTests(unittest.TestCase): + maxDiff = None + + def test_debug_flag(self): + pat = r'(\.)(?:[ch]|py)(?(1)$|: )' + dump = '''\ +SUBPATTERN 1 0 0 + LITERAL 46 +BRANCH + IN + LITERAL 99 + LITERAL 104 +OR + LITERAL 112 + LITERAL 121 +GROUPREF_EXISTS 1 + AT AT_END +ELSE + LITERAL 58 + LITERAL 32 + + 0. INFO 8 0b1 2 5 (to 9) + prefix_skip 0 + prefix [0x2e] ('.') + overlap [0] + 9: MARK 0 +11. LITERAL 0x2e ('.') +13. MARK 1 +15. BRANCH 10 (to 26) +17. IN 6 (to 24) +19. LITERAL 0x63 ('c') +21. LITERAL 0x68 ('h') +23. FAILURE +24: JUMP 9 (to 34) +26: branch 7 (to 33) +27. LITERAL 0x70 ('p') +29. LITERAL 0x79 ('y') +31. JUMP 2 (to 34) +33: FAILURE +34: GROUPREF_EXISTS 0 6 (to 41) +37. AT END +39. JUMP 5 (to 45) +41: LITERAL 0x3a (':') +43. LITERAL 0x20 (' ') +45: SUCCESS +''' + self.assertEqual(get_debug_out(pat), dump) + # Debug output is output again even a second time (bypassing + # the cache -- issue #20426). + self.assertEqual(get_debug_out(pat), dump) + + def test_atomic_group(self): + self.assertEqual(get_debug_out(r'(?>ab?)'), '''\ +ATOMIC_GROUP + LITERAL 97 + MAX_REPEAT 0 1 + LITERAL 98 + + 0. INFO 4 0b0 1 2 (to 5) + 5: ATOMIC_GROUP 11 (to 17) + 7. LITERAL 0x61 ('a') + 9. REPEAT_ONE 6 0 1 (to 16) +13. LITERAL 0x62 ('b') +15. SUCCESS +16: SUCCESS +17: SUCCESS +''') + + def test_possesive_repeat_one(self): + self.assertEqual(get_debug_out(r'a?+'), '''\ +POSSESSIVE_REPEAT 0 1 + LITERAL 97 + + 0. INFO 4 0b0 0 1 (to 5) + 5: POSSESSIVE_REPEAT_ONE 6 0 1 (to 12) + 9. LITERAL 0x61 ('a') +11. SUCCESS +12: SUCCESS +''') + + def test_possesive_repeat(self): + self.assertEqual(get_debug_out(r'(?:ab)?+'), '''\ +POSSESSIVE_REPEAT 0 1 + LITERAL 97 + LITERAL 98 + + 0. INFO 4 0b0 0 2 (to 5) + 5: POSSESSIVE_REPEAT 7 0 1 (to 13) + 9. LITERAL 0x61 ('a') +11. LITERAL 0x62 ('b') +13: SUCCESS +14. SUCCESS +''') + class PatternReprTests(unittest.TestCase): def check(self, pattern, expected): @@ -2312,11 +2664,13 @@ def test_flags_repr(self): "re.IGNORECASE|re.DOTALL|re.VERBOSE") self.assertEqual(repr(re.I|re.S|re.X|(1<<20)), "re.IGNORECASE|re.DOTALL|re.VERBOSE|0x100000") - self.assertEqual(repr(~re.I), "~re.IGNORECASE") + self.assertEqual( + repr(~re.I), + "re.ASCII|re.LOCALE|re.UNICODE|re.MULTILINE|re.DOTALL|re.VERBOSE|re.TEMPLATE|re.DEBUG") self.assertEqual(repr(~(re.I|re.S|re.X)), - "~(re.IGNORECASE|re.DOTALL|re.VERBOSE)") + "re.ASCII|re.LOCALE|re.UNICODE|re.MULTILINE|re.TEMPLATE|re.DEBUG") self.assertEqual(repr(~(re.I|re.S|re.X|(1<<20))), - "~(re.IGNORECASE|re.DOTALL|re.VERBOSE|0x100000)") + "re.ASCII|re.LOCALE|re.UNICODE|re.MULTILINE|re.TEMPLATE|re.DEBUG|0xffe00") class ImplementationTest(unittest.TestCase): @@ -2337,7 +2691,7 @@ def test_immutable(self): tp.foo = 1 def test_overlap_table(self): - f = sre_compile._generate_overlap_table + f = re._compiler._generate_overlap_table self.assertEqual(f(""), []) self.assertEqual(f("a"), [0]) self.assertEqual(f("abcd"), [0, 0, 0, 0]) @@ -2346,8 +2700,8 @@ def test_overlap_table(self): self.assertEqual(f("abcabdac"), [0, 0, 0, 1, 2, 0, 1, 0]) def test_signedness(self): - self.assertGreaterEqual(sre_compile.MAXREPEAT, 0) - self.assertGreaterEqual(sre_compile.MAXGROUPS, 0) + self.assertGreaterEqual(re._compiler.MAXREPEAT, 0) + self.assertGreaterEqual(re._compiler.MAXGROUPS, 0) @cpython_only def test_disallow_instantiation(self): @@ -2357,6 +2711,106 @@ def test_disallow_instantiation(self): pat = re.compile("") check_disallow_instantiation(self, type(pat.scanner(""))) + # TODO: RUSTPYTHON + @unittest.expectedFailure + def test_deprecated_modules(self): + deprecated = { + 'sre_compile': ['compile', 'error', + 'SRE_FLAG_IGNORECASE', 'SUBPATTERN', + '_compile_info'], + 'sre_constants': ['error', 'SRE_FLAG_IGNORECASE', 'SUBPATTERN', + '_NamedIntConstant'], + 'sre_parse': ['SubPattern', 'parse', + 'SRE_FLAG_IGNORECASE', 'SUBPATTERN', + '_parse_sub'], + } + for name in deprecated: + with self.subTest(module=name): + sys.modules.pop(name, None) + with self.assertWarns(DeprecationWarning) as w: + __import__(name) + self.assertEqual(str(w.warning), + f"module {name!r} is deprecated") + self.assertEqual(w.filename, __file__) + self.assertIn(name, sys.modules) + mod = sys.modules[name] + self.assertEqual(mod.__name__, name) + self.assertEqual(mod.__package__, '') + for attr in deprecated[name]: + self.assertTrue(hasattr(mod, attr)) + del sys.modules[name] + + @cpython_only + def test_case_helpers(self): + import _sre + for i in range(128): + c = chr(i) + lo = ord(c.lower()) + self.assertEqual(_sre.ascii_tolower(i), lo) + self.assertEqual(_sre.unicode_tolower(i), lo) + iscased = c in string.ascii_letters + self.assertEqual(_sre.ascii_iscased(i), iscased) + self.assertEqual(_sre.unicode_iscased(i), iscased) + + for i in list(range(128, 0x1000)) + [0x10400, 0x10428]: + c = chr(i) + self.assertEqual(_sre.ascii_tolower(i), i) + if i != 0x0130: + self.assertEqual(_sre.unicode_tolower(i), ord(c.lower())) + iscased = c != c.lower() or c != c.upper() + self.assertFalse(_sre.ascii_iscased(i)) + self.assertEqual(_sre.unicode_iscased(i), + c != c.lower() or c != c.upper()) + + self.assertEqual(_sre.ascii_tolower(0x0130), 0x0130) + self.assertEqual(_sre.unicode_tolower(0x0130), ord('i')) + self.assertFalse(_sre.ascii_iscased(0x0130)) + self.assertTrue(_sre.unicode_iscased(0x0130)) + + @cpython_only + def test_dealloc(self): + # issue 3299: check for segfault in debug build + import _sre + # the overflow limit is different on wide and narrow builds and it + # depends on the definition of SRE_CODE (see sre.h). + # 2**128 should be big enough to overflow on both. For smaller values + # a RuntimeError is raised instead of OverflowError. + long_overflow = 2**128 + self.assertRaises(TypeError, re.finditer, "a", {}) + with self.assertRaises(OverflowError): + _sre.compile("abc", 0, [long_overflow], 0, {}, ()) + with self.assertRaises(TypeError): + _sre.compile({}, 0, [], 0, [], []) + # gh-110590: `TypeError` was overwritten with `OverflowError`: + with self.assertRaises(TypeError): + _sre.compile('', 0, ['abc'], 0, {}, ()) + + @cpython_only + def test_repeat_minmax_overflow_maxrepeat(self): + try: + from _sre import MAXREPEAT + except ImportError: + self.skipTest('requires _sre.MAXREPEAT constant') + string = "x" * 100000 + self.assertIsNone(re.match(r".{%d}" % (MAXREPEAT - 1), string)) + self.assertEqual(re.match(r".{,%d}" % (MAXREPEAT - 1), string).span(), + (0, 100000)) + self.assertIsNone(re.match(r".{%d,}?" % (MAXREPEAT - 1), string)) + self.assertRaises(OverflowError, re.compile, r".{%d}" % MAXREPEAT) + self.assertRaises(OverflowError, re.compile, r".{,%d}" % MAXREPEAT) + self.assertRaises(OverflowError, re.compile, r".{%d,}?" % MAXREPEAT) + + @cpython_only + def test_sre_template_invalid_group_index(self): + # see gh-106524 + import _sre + with self.assertRaises(TypeError) as cm: + _sre.template("", ["", -1, ""]) + self.assertIn("invalid template", str(cm.exception)) + with self.assertRaises(TypeError) as cm: + _sre.template("", ["", (), ""]) + self.assertIn("an integer is required", str(cm.exception)) + class ExternalTests(unittest.TestCase): diff --git a/vm/Cargo.toml b/vm/Cargo.toml index f061f54a85..aa8c6df307 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -78,11 +78,11 @@ result-like = "0.4.5" timsort = "0.1.2" # RustPython crates implementing functionality based on CPython -sre-engine = "0.4.1" +# sre-engine = "0.4.1" # to work on sre-engine locally or git version # sre-engine = { git = "https://github.com/RustPython/sre-engine", rev = "refs/pull/14/head" } # sre-engine = { git = "https://github.com/RustPython/sre-engine" } -# sre-engine = { path = "../../sre-engine" } +sre-engine = { path = "../../sre-engine" } ## unicode stuff unicode_names2 = { workspace = true } From d9375b9fe19c7f27ac3b71a7160f612acd7fbe2d Mon Sep 17 00:00:00 2001 From: Kangzhi Shi Date: Sat, 25 Nov 2023 22:17:24 +0200 Subject: [PATCH 306/893] impl re.template(), template_compile template_expand subx --- Cargo.lock | 1 + Lib/string.py | 31 ++++++- Lib/test/test_re.py | 38 +------- vm/Cargo.toml | 3 +- vm/src/stdlib/sre.rs | 215 +++++++++++++++++++++++++++++++------------ 5 files changed, 191 insertions(+), 97 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index da13a2a33c..0080135c26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2532,6 +2532,7 @@ dependencies = [ [[package]] name = "sre-engine" version = "0.6.0" +source = "git+https://github.com/RustPython/sre-engine?rev=refs/pull/17/head#9725808c302b13e873ff7f955dd4c1632f1137fd" dependencies = [ "bitflags 2.4.0", "num_enum", diff --git a/Lib/string.py b/Lib/string.py index 489777b10c..2eab6d4f59 100644 --- a/Lib/string.py +++ b/Lib/string.py @@ -45,7 +45,7 @@ def capwords(s, sep=None): sep is used to split and join the words. """ - return (sep or ' ').join(x.capitalize() for x in s.split(sep)) + return (sep or ' ').join(map(str.capitalize, s.split(sep))) #################################################################### @@ -141,6 +141,35 @@ def convert(mo): self.pattern) return self.pattern.sub(convert, self.template) + def is_valid(self): + for mo in self.pattern.finditer(self.template): + if mo.group('invalid') is not None: + return False + if (mo.group('named') is None + and mo.group('braced') is None + and mo.group('escaped') is None): + # If all the groups are None, there must be + # another group we're not expecting + raise ValueError('Unrecognized named group in pattern', + self.pattern) + return True + + def get_identifiers(self): + ids = [] + for mo in self.pattern.finditer(self.template): + named = mo.group('named') or mo.group('braced') + if named is not None and named not in ids: + # add a named group only the first time it appears + ids.append(named) + elif (named is None + and mo.group('invalid') is None + and mo.group('escaped') is None): + # If all the groups are None, there must be + # another group we're not expecting + raise ValueError('Unrecognized named group in pattern', + self.pattern) + return ids + # Initialize Template.pattern. __init_subclass__() is automatically called # only for subclasses, not for the Template class itself. Template.__init_subclass__() diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py index 1fd2432aae..5b442acbb1 100644 --- a/Lib/test/test_re.py +++ b/Lib/test/test_re.py @@ -116,8 +116,6 @@ def bump_num(self, matchobj): int_value = int(matchobj.group(0)) return str(int_value + 1) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_basic_re_sub(self): self.assertTypedEqual(re.sub('y', 'a', 'xyz'), 'xaz') self.assertTypedEqual(re.sub('y', S('a'), S('xyz')), 'xaz') @@ -161,15 +159,11 @@ def test_basic_re_sub(self): self.assertEqual(re.sub(r'^\s*', 'X', 'test'), 'Xtest') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bug_449964(self): # fails for group followed by other escape self.assertEqual(re.sub(r'(?Px)', r'\g<1>\g<1>\b', 'xx'), 'xx\bxx\b') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bug_449000(self): # Test for sub() on escaped characters self.assertEqual(re.sub(r'\r\n', r'\n', 'abc\r\ndef\r\n'), @@ -193,8 +187,6 @@ def test_bug_3629(self): # A regex that triggered a bug in the sre-code validator re.compile("(?P)(?(quote))") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_sub_template_numeric_escape(self): # bug 776311 and friends self.assertEqual(re.sub('x', r'\0', 'x'), '\0') @@ -248,8 +240,6 @@ def test_qualified_re_sub(self): self.assertEqual(re.sub('a', 'b', 'aaaaa', 1), 'baaaa') self.assertEqual(re.sub('a', 'b', 'aaaaa', count=1), 'baaaa') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bug_114660(self): self.assertEqual(re.sub(r'(\S)\s+(\S)', r'\1 \2', 'hello there'), 'hello there') @@ -303,8 +293,6 @@ def test_symbolic_groups_errors(self): self.checkPatternError(b'(?(\xc2\xb5)y)', r"bad character in group name '\xc2\xb5'", 3) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_symbolic_refs(self): self.assertEqual(re.sub('(?Px)|(?Py)', r'\g', 'xx'), '') self.assertEqual(re.sub('(?Px)|(?Py)', r'\2', 'xx'), '') @@ -316,8 +304,6 @@ def test_symbolic_refs(self): pat = '|'.join('x(?P%x)y' % (i, i) for i in range(1, 200 + 1)) self.assertEqual(re.sub(pat, r'\g<200>', 'xc8yzxc8y'), 'c8zc8') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_symbolic_refs_errors(self): self.checkTemplateError('(?Px)', r'\g, unterminated name', 3) @@ -651,8 +637,6 @@ def test_re_groupref_exists_validation_bug(self): with self.subTest(code=i): re.compile(r'()(?(1)\x%02x?)' % i) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_re_groupref_overflow(self): from re._constants import MAXGROUPS self.checkTemplateError('()', r'\g<%s>' % MAXGROUPS, 'xx', @@ -1754,8 +1738,6 @@ def test_comments(self): self.assertTrue(re.fullmatch('(?x)#x\na|#y\nb', 'a')) self.assertTrue(re.fullmatch('(?x)#x\na|#y\nb', 'b')) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bug_6509(self): # Replacement strings of both types must parse properly. # all strings @@ -1902,8 +1884,6 @@ def test_match_repr(self): ) self.assertRegex(repr(second), pattern) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_zerowidth(self): # Issues 852532, 1647489, 3262, 25054. self.assertEqual(re.split(r"\b", "a::bc"), ['', 'a', '::', 'bc', '']) @@ -2235,8 +2215,6 @@ def test_MIN_REPEAT_ONE_mark_bug(self): p = r'(?:a*?(xx)??z)*' self.assertEqual(re.match(p, s).groups(), ('xx',)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_ASSERT_NOT_mark_bug(self): # Fixed in issue35859, reported in issue725149. # JUMP_ASSERT_NOT should LASTMARK_SAVE() @@ -2249,16 +2227,12 @@ def test_ASSERT_NOT_mark_bug(self): self.assertEqual(m.span(3), (3, 4)) self.assertEqual(m.groups(), ('b', None, 'b')) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bug_40736(self): with self.assertRaisesRegex(TypeError, "got 'int'"): re.search("x*", 5) with self.assertRaisesRegex(TypeError, "got 'type'"): re.search("x*", type) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_search_anchor_at_beginning(self): s = 'x'*10**7 start = time.perf_counter() @@ -2273,7 +2247,8 @@ def test_search_anchor_at_beginning(self): # With optimization -- 0.0003 seconds. self.assertLess(t, 0.1) - @unittest.skip('dead lock') + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_possessive_quantifiers(self): """Test Possessive Quantifiers Test quantifiers of the form @+ for some repetition operator @, @@ -2342,7 +2317,6 @@ def test_fullmatch_possessive_quantifiers(self): self.assertTrue(re.fullmatch(r'(?:ab)?+c', 'abc')) self.assertTrue(re.fullmatch(r'(?:ab){1,3}+c', 'abc')) - @unittest.skip("dead lock") def test_findall_possessive_quantifiers(self): self.assertEqual(re.findall(r'a++', 'aab'), ['aa']) self.assertEqual(re.findall(r'a*+', 'aab'), ['aa', '', '']) @@ -2354,8 +2328,6 @@ def test_findall_possessive_quantifiers(self): self.assertEqual(re.findall(r'(?:ab)?+', 'ababc'), ['ab', 'ab', '', '']) self.assertEqual(re.findall(r'(?:ab){1,3}+', 'ababc'), ['abab']) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_atomic_grouping(self): """Test Atomic Grouping Test non-capturing groups of the form (?>...), which does @@ -2399,8 +2371,6 @@ def test_fullmatch_atomic_grouping(self): self.assertTrue(re.fullmatch(r'(?>(?:ab)?)c', 'abc')) self.assertTrue(re.fullmatch(r'(?>(?:ab){1,3})c', 'abc')) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_findall_atomic_grouping(self): self.assertEqual(re.findall(r'(?>a+)', 'aab'), ['aa']) self.assertEqual(re.findall(r'(?>a*)', 'aab'), ['aa', '', '']) @@ -2412,6 +2382,8 @@ def test_findall_atomic_grouping(self): self.assertEqual(re.findall(r'(?>(?:ab)?)', 'ababc'), ['ab', 'ab', '', '']) self.assertEqual(re.findall(r'(?>(?:ab){1,3})', 'ababc'), ['abab']) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_bug_gh91616(self): self.assertTrue(re.fullmatch(r'(?s:(?>.*?\.).*)\Z', "a.txt")) # reproducer self.assertTrue(re.fullmatch(r'(?s:(?=(?P.*?\.))(?P=g0).*)\Z', "a.txt")) @@ -2442,8 +2414,6 @@ def test_template_function_and_flag_is_deprecated(self): self.assertTrue(template_re1.match('ahoy')) self.assertFalse(template_re1.match('nope')) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bug_gh106052(self): # gh-100061 self.assertEqual(re.match('(?>(?:.(?!D))+)', 'ABCDE').span(), (0, 2)) diff --git a/vm/Cargo.toml b/vm/Cargo.toml index aa8c6df307..5efd01d442 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -80,9 +80,8 @@ timsort = "0.1.2" # RustPython crates implementing functionality based on CPython # sre-engine = "0.4.1" # to work on sre-engine locally or git version -# sre-engine = { git = "https://github.com/RustPython/sre-engine", rev = "refs/pull/14/head" } +sre-engine = { git = "https://github.com/RustPython/sre-engine", rev = "refs/pull/17/head" } # sre-engine = { git = "https://github.com/RustPython/sre-engine" } -sre-engine = { path = "../../sre-engine" } ## unicode stuff unicode_names2 = { workspace = true } diff --git a/vm/src/stdlib/sre.rs b/vm/src/stdlib/sre.rs index 93ecd7c24e..b51320cfd2 100644 --- a/vm/src/stdlib/sre.rs +++ b/vm/src/stdlib/sre.rs @@ -5,13 +5,13 @@ mod _sre { use crate::{ atomic_func, builtins::{ - PyCallableIterator, PyDictRef, PyGenericAlias, PyInt, PyList, PyStr, PyStrRef, PyTuple, - PyTupleRef, PyTypeRef, + PyCallableIterator, PyDictRef, PyGenericAlias, PyInt, PyList, PyListRef, PyStr, + PyStrRef, PyTuple, PyTupleRef, PyTypeRef, }, common::{ascii, hash::PyHash}, convert::ToPyObject, function::{ArgCallable, OptionalArg, PosArgs, PyComparisonValue}, - protocol::{PyBuffer, PyMappingMethods}, + protocol::{PyBuffer, PyCallable, PyMappingMethods}, stdlib::sys, types::{AsMapping, Comparable, Hashable, Representable}, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, @@ -22,12 +22,12 @@ mod _sre { use itertools::Itertools; use num_traits::ToPrimitive; use sre_engine::{ - constants::SreFlag, - engine::{lower_ascii, lower_unicode, upper_unicode, Request, SearchIter, State, StrDrive}, + string::{lower_ascii, lower_unicode, upper_unicode}, + Request, SearchIter, SreFlag, State, StrDrive, }; #[pyattr] - pub use sre_engine::{constants::SRE_MAGIC as MAGIC, CODESIZE, MAXGROUPS, MAXREPEAT}; + pub use sre_engine::{CODESIZE, MAXGROUPS, MAXREPEAT, SRE_MAGIC as MAGIC}; #[pyfunction] fn getcodesize() -> usize { @@ -103,6 +103,58 @@ mod _sre { }) } + #[pyattr] + #[pyclass(name = "SRE_Template")] + #[derive(Debug, PyPayload)] + struct Template { + literal: PyObjectRef, + items: Vec<(usize, PyObjectRef)>, + } + + #[pyclass] + impl Template { + fn compile( + pattern: PyRef, + repl: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult> { + let re = vm.import("re", None, 0)?; + let func = re.get_attr("_compile_template", vm)?; + let result = func.call((pattern, repl.clone()), vm)?; + result + .downcast::() + .map_err(|_| vm.new_runtime_error("expected SRE_Template".to_owned())) + } + } + + #[pyfunction] + fn template( + _pattern: PyObjectRef, + template: PyListRef, + vm: &VirtualMachine, + ) -> PyResult