From ad09d17a718ef30dda4a30009225b68f2a73e380 Mon Sep 17 00:00:00 2001 From: Ashwin Naren Date: Tue, 25 Mar 2025 22:35:39 -0700 Subject: [PATCH 1/5] match statement rewrite with tests --- compiler/codegen/src/compile.rs | 1012 +++++++++++++++-- compiler/codegen/src/error.rs | 27 +- ...on_codegen__compile__tests__match.snap.new | 54 + compiler/core/src/bytecode.rs | 40 +- extra_tests/snippets/syntax_match.py | 66 ++ vm/src/frame.rs | 109 ++ 6 files changed, 1197 insertions(+), 111 deletions(-) create mode 100644 compiler/codegen/src/snapshots/rustpython_codegen__compile__tests__match.snap.new create mode 100644 extra_tests/snippets/syntax_match.py diff --git a/compiler/codegen/src/compile.rs b/compiler/codegen/src/compile.rs index 38d5b8fb12..7dfba53f18 100644 --- a/compiler/codegen/src/compile.rs +++ b/compiler/codegen/src/compile.rs @@ -9,8 +9,8 @@ use crate::{ IndexSet, ToPythonName, - error::{CodegenError, CodegenErrorType}, - ir, + error::{CodegenError, CodegenErrorType, PatternUnreachableReason}, + ir::{self, BlockIdx}, symboltable::{self, SymbolFlags, SymbolScope, SymbolTable}, unparse::unparse_expr, }; @@ -22,10 +22,11 @@ use ruff_python_ast::{ Alias, Arguments, BoolOp, CmpOp, Comprehension, ConversionFlag, DebugText, Decorator, DictItem, ExceptHandler, ExceptHandlerExceptHandler, Expr, ExprAttribute, ExprBoolOp, ExprFString, ExprList, ExprName, ExprStarred, ExprSubscript, ExprTuple, ExprUnaryOp, FString, - FStringElement, FStringElements, FStringFlags, FStringPart, Int, Keyword, MatchCase, - ModExpression, ModModule, Operator, Parameters, Pattern, PatternMatchAs, PatternMatchValue, - Stmt, StmtExpr, TypeParam, TypeParamParamSpec, TypeParamTypeVar, TypeParamTypeVarTuple, - TypeParams, UnaryOp, WithItem, + FStringElement, FStringElements, FStringFlags, FStringPart, Identifier, Int, Keyword, + MatchCase, ModExpression, ModModule, Operator, Parameters, Pattern, PatternMatchAs, + PatternMatchClass, PatternMatchOr, PatternMatchSequence, PatternMatchSingleton, + PatternMatchStar, PatternMatchValue, Singleton, Stmt, StmtExpr, TypeParam, TypeParamParamSpec, + TypeParamTypeVar, TypeParamTypeVarTuple, TypeParams, UnaryOp, WithItem, }; use ruff_source_file::OneIndexed; use ruff_text_size::{Ranged, TextRange}; @@ -33,7 +34,10 @@ use rustpython_wtf8::Wtf8Buf; // use rustpython_ast::located::{self as located_ast, Located}; use rustpython_compiler_core::{ Mode, - bytecode::{self, Arg as OpArgMarker, CodeObject, ConstantData, Instruction, OpArg, OpArgType}, + bytecode::{ + self, Arg as OpArgMarker, BinaryOperator, CodeObject, ComparisonOperator, ConstantData, + Instruction, OpArg, OpArgType, UnpackExArgs, + }, }; use rustpython_compiler_source::SourceCode; // use rustpython_parser_core::source_code::{LineNumber, SourceLocation}; @@ -206,10 +210,37 @@ macro_rules! emit { }; } -struct PatternContext { - current_block: usize, - blocks: Vec, - allow_irrefutable: bool, +/// The pattern context holds information about captured names and jump targets. +#[derive(Clone)] +pub struct PatternContext { + /// A list of names captured by the pattern. + pub stores: Vec, + /// If false, then any name captures against our subject will raise. + pub allow_irrefutable: bool, + /// A list of jump target labels used on pattern failure. + pub fail_pop: Vec, + /// The number of items on top of the stack that should remain. + pub on_top: usize, +} + +impl PatternContext { + pub fn new() -> Self { + PatternContext { + stores: Vec::new(), + allow_irrefutable: false, + fail_pop: Vec::new(), + on_top: 0, + } + } + + pub fn fail_pop_size(&self) -> usize { + self.fail_pop.len() + } +} + +enum JumpOp { + Jump, + PopJumpIfFalse, } impl<'src> Compiler<'src> { @@ -1800,73 +1831,817 @@ impl Compiler<'_> { Ok(()) } - fn compile_pattern_value( + fn forbidden_name(&mut self, name: &str, ctx: NameUsage) -> CompileResult { + if ctx == NameUsage::Store && name == "__debug__" { + return Err(self.error(CodegenErrorType::Assign("__debug__"))); + // return Ok(true); + } + if ctx == NameUsage::Delete && name == "__debug__" { + return Err(self.error(CodegenErrorType::Delete("__debug__"))); + // return Ok(true); + } + Ok(false) + } + + fn compile_error_forbidden_name(&mut self, name: &str) -> CodegenError { + // TODO: make into error (fine for now since it realistically errors out earlier) + panic!("Failing due to forbidden name {:?}", name); + } + + /// Ensures that `pc.fail_pop` has at least `n + 1` entries. + /// If not, new labels are generated and pushed until the required size is reached. + fn ensure_fail_pop(&mut self, pc: &mut PatternContext, n: usize) -> CompileResult<()> { + let required_size = n + 1; + if required_size <= pc.fail_pop.len() { + return Ok(()); + } + while pc.fail_pop.len() < required_size { + let new_block = self.new_block(); + pc.fail_pop.push(new_block); + } + Ok(()) + } + + fn jump_to_fail_pop(&mut self, pc: &mut PatternContext, op: JumpOp) -> CompileResult<()> { + // Compute the total number of items to pop: + // items on top plus the captured objects. + let pops = pc.on_top + pc.stores.len(); + // Ensure that the fail_pop vector has at least `pops + 1` elements. + self.ensure_fail_pop(pc, pops)?; + // Emit a jump using the jump target stored at index `pops`. + match op { + JumpOp::Jump => { + emit!( + self, + Instruction::Jump { + target: pc.fail_pop[pops] + } + ); + } + JumpOp::PopJumpIfFalse => { + emit!( + self, + Instruction::JumpIfFalse { + target: pc.fail_pop[pops] + } + ); + } + } + Ok(()) + } + + /// Emits the necessary POP instructions for all failure targets in the pattern context, + /// then resets the fail_pop vector. + fn emit_and_reset_fail_pop(&mut self, pc: &mut PatternContext) -> CompileResult<()> { + // If the fail_pop vector is empty, nothing needs to be done. + if pc.fail_pop.is_empty() { + debug_assert!(pc.fail_pop.is_empty()); + return Ok(()); + } + // Iterate over the fail_pop vector in reverse order, skipping the first label. + for &label in pc.fail_pop.iter().skip(1).rev() { + self.switch_to_block(label); + // Emit the POP instruction. + emit!(self, Instruction::Pop); + } + // Finally, use the first label. + self.switch_to_block(pc.fail_pop[0]); + pc.fail_pop.clear(); + // Free the memory used by the vector. + pc.fail_pop.shrink_to_fit(); + Ok(()) + } + + /// Duplicate the effect of Python 3.10's ROT_* instructions using SWAPs. + fn pattern_helper_rotate(&mut self, mut count: usize) -> CompileResult<()> { + while count > 1 { + // Emit a SWAP instruction with the current count. + emit!( + self, + Instruction::Swap { + index: count as u32 + } + ); + count -= 1; + } + Ok(()) + } + + /// Helper to store a captured name for a star pattern. + /// + /// If `n` is `None`, it emits a POP_TOP instruction. Otherwise, it first + /// checks that the name is allowed and not already stored. Then it rotates + /// the object on the stack beneath any preserved items and appends the name + /// to the list of captured names. + fn pattern_helper_store_name( &mut self, - value: &PatternMatchValue, - _pattern_context: &mut PatternContext, + n: Option<&Identifier>, + pc: &mut PatternContext, ) -> CompileResult<()> { - use crate::compile::bytecode::ComparisonOperator::*; + // If no name is provided, simply pop the top of the stack. + if n.is_none() { + emit!(self, Instruction::Pop); + return Ok(()); + } + let name = n.unwrap(); - self.compile_expression(&value.value)?; - emit!(self, Instruction::CompareOperation { op: Equal }); + // Check if the name is forbidden for storing. + if self.forbidden_name(name.as_str(), NameUsage::Store)? { + return Err(self.compile_error_forbidden_name(name.as_str())); + } + + // Ensure we don't store the same name twice. + if pc.stores.contains(&name.to_string()) { + return Err(self.error(CodegenErrorType::DuplicateStore(name.as_str().to_string()))); + } + + // Calculate how many items to rotate: + // the count is the number of items to preserve on top plus the current stored names, + // plus one for the new value. + let rotations = pc.on_top + pc.stores.len() + 1; + self.pattern_helper_rotate(rotations)?; + + // Append the name to the captured stores. + pc.stores.push(name.to_string()); + Ok(()) + } + + fn pattern_unpack_helper(&mut self, elts: &[Pattern]) -> CompileResult<()> { + let n = elts.len(); + let mut seen_star = false; + for (i, elt) in elts.iter().enumerate() { + if elt.is_match_star() { + if !seen_star { + if i >= (1 << 8) || (n - i - 1) >= ((i32::MAX as usize) >> 8) { + todo!(); + // return self.compiler_error(loc, "too many expressions in star-unpacking sequence pattern"); + } + let args = UnpackExArgs { + before: i as u8, + after: (n - i - 1) as u8, + }; + emit!(self, Instruction::UnpackEx { args }); + seen_star = true; + } else { + // TODO: Fix error msg + return Err(self.error(CodegenErrorType::MultipleStarArgs)); + // return self.compiler_error(loc, "multiple starred expressions in sequence pattern"); + } + } + } + if !seen_star { + emit!(self, Instruction::UnpackSequence { size: n as u32 }); + } + Ok(()) + } + + fn pattern_helper_sequence_unpack( + &mut self, + patterns: &[Pattern], + _star: Option, + pc: &mut PatternContext, + ) -> CompileResult<()> { + // Unpack the sequence into individual subjects. + self.pattern_unpack_helper(patterns)?; + let size = patterns.len(); + // Increase the on_top counter for the newly unpacked subjects. + pc.on_top += size as usize; + // For each unpacked subject, compile its subpattern. + for pattern in patterns { + // Decrement on_top for each subject as it is consumed. + pc.on_top -= 1; + self.compile_pattern_subpattern(pattern, pc)?; + } + Ok(()) + } + + fn pattern_helper_sequence_subscr( + &mut self, + patterns: &[Pattern], + star: usize, + pc: &mut PatternContext, + ) -> CompileResult<()> { + // Keep the subject around for extracting elements. + pc.on_top += 1; + let size = patterns.len(); + for i in 0..size { + let pattern = &patterns[i]; + // if pattern.is_wildcard() { + // continue; + // } + if i == star { + // This must be a starred wildcard. + // assert!(pattern.is_star_wildcard()); + continue; + } + // Duplicate the subject. + emit!(self, Instruction::CopyItem { index: 1 as u32 }); + if i < star { + // For indices before the star, use a nonnegative index equal to i. + self.emit_load_const(ConstantData::Integer { value: i.into() }); + } else { + // For indices after the star, compute a nonnegative index: + // index = len(subject) - (size - i) + emit!(self, Instruction::GetLen); + self.emit_load_const(ConstantData::Integer { + value: (size - 1).into(), + }); + // Subtract to compute the correct index. + emit!( + self, + Instruction::BinaryOperation { + op: BinaryOperator::Subtract + } + ); + } + // Use BINARY_OP/NB_SUBSCR to extract the element. + emit!(self, Instruction::BinarySubscript); + // Compile the subpattern in irrefutable mode. + self.compile_pattern_subpattern(pattern, pc)?; + } + // Pop the subject off the stack. + pc.on_top -= 1; + emit!(self, Instruction::Pop); + Ok(()) + } + + fn compile_pattern_subpattern( + &mut self, + p: &Pattern, + pc: &mut PatternContext, + ) -> CompileResult<()> { + // Save the current allow_irrefutable state. + let old_allow_irrefutable = pc.allow_irrefutable; + // Temporarily allow irrefutable patterns. + pc.allow_irrefutable = true; + // Compile the pattern. + self.compile_pattern(p, pc)?; + // Restore the original state. + pc.allow_irrefutable = old_allow_irrefutable; Ok(()) } fn compile_pattern_as( &mut self, - as_pattern: &PatternMatchAs, - pattern_context: &mut PatternContext, + p: &PatternMatchAs, + pc: &mut PatternContext, ) -> CompileResult<()> { - if as_pattern.pattern.is_none() && !pattern_context.allow_irrefutable { - // TODO: better error message - if let Some(_name) = as_pattern.name.as_ref() { - return Err(self.error_ranged(CodegenErrorType::InvalidMatchCase, as_pattern.range)); + // If there is no sub-pattern, then it's an irrefutable match. + if p.pattern.is_none() { + if !pc.allow_irrefutable { + if let Some(_name) = p.name.as_ref() { + // TODO: This error message does not match cpython exactly + // A name capture makes subsequent patterns unreachable. + return Err(self.error(CodegenErrorType::UnreachablePattern( + PatternUnreachableReason::NameCapture, + ))); + } else { + // A wildcard makes remaining patterns unreachable. + return Err(self.error(CodegenErrorType::UnreachablePattern( + PatternUnreachableReason::Wildcard, + ))); + } } - return Err(self.error_ranged(CodegenErrorType::InvalidMatchCase, as_pattern.range)); + // If irrefutable matches are allowed, store the name (if any). + return self.pattern_helper_store_name(p.name.as_ref(), pc); } - // Need to make a copy for (possibly) storing later: - emit!(self, Instruction::Duplicate); - if let Some(pattern) = &as_pattern.pattern { - self.compile_pattern_inner(pattern, pattern_context)?; + + // Otherwise, there is a sub-pattern. Duplicate the object on top of the stack. + pc.on_top += 1; + emit!(self, Instruction::CopyItem { index: 1 as u32 }); + // Compile the sub-pattern. + self.compile_pattern(p.pattern.as_ref().unwrap(), pc)?; + // After success, decrement the on_top counter. + pc.on_top -= 1; + // Store the captured name (if any). + self.pattern_helper_store_name(p.name.as_ref(), pc)?; + Ok(()) + } + + fn compile_pattern_star( + &mut self, + p: &PatternMatchStar, + pc: &mut PatternContext, + ) -> CompileResult<()> { + self.pattern_helper_store_name(p.name.as_ref(), pc)?; + Ok(()) + } + + /// Validates that keyword attributes in a class pattern are allowed + /// and not duplicated. + fn validate_kwd_attrs( + &mut self, + attrs: &[Identifier], + _patterns: &[Pattern], + ) -> CompileResult<()> { + let nattrs = attrs.len(); + for i in 0..nattrs { + let attr = attrs[i].as_str(); + // Check if the attribute name is forbidden in a Store context. + if self.forbidden_name(attr, NameUsage::Store)? { + // Return an error if the name is forbidden. + return Err(self.compile_error_forbidden_name(attr)); + } + // Check for duplicates: compare with every subsequent attribute. + for j in (i + 1)..nattrs { + let other = attrs[j].as_str(); + if attr == other { + todo!(); + // return Err(self.compiler_error( + // &format!("attribute name repeated in class pattern: {}", attr), + // )); + } + } } - if let Some(name) = as_pattern.name.as_ref() { - self.store_name(name.as_str())?; - } else { - emit!(self, Instruction::Pop); + Ok(()) + } + + fn compile_pattern_class( + &mut self, + p: &PatternMatchClass, + pc: &mut PatternContext, + ) -> CompileResult<()> { + dbg!(); + // Extract components from the MatchClass pattern. + let match_class = p; + let patterns = &match_class.arguments.patterns; + + // Extract keyword attributes and patterns. + // Capacity is pre-allocated based on the number of keyword arguments. + let mut kwd_attrs = Vec::with_capacity(match_class.arguments.keywords.len()); + let mut kwd_patterns = Vec::with_capacity(match_class.arguments.keywords.len()); + for kwd in &match_class.arguments.keywords { + kwd_attrs.push(kwd.attr.clone()); + kwd_patterns.push(kwd.pattern.clone()); + } + + let nargs = patterns.len(); + let nattrs = kwd_attrs.len(); + let nkwd_patterns = kwd_patterns.len(); + + // Validate that keyword attribute names and patterns match in length. + if nattrs != nkwd_patterns { + let msg = format!( + "kwd_attrs ({}) / kwd_patterns ({}) length mismatch in class pattern", + nattrs, nkwd_patterns + ); + unreachable!("{}", msg); + } + + // Check for too many sub-patterns. + if nargs > u32::MAX as usize || (nargs + nattrs).saturating_sub(1) > i32::MAX as usize { + let msg = format!( + "too many sub-patterns in class pattern {:?}", + match_class.cls + ); + panic!("{}", msg); + // return self.compiler_error(&msg); + } + + // Validate keyword attributes if any. + if nattrs != 0 { + self.validate_kwd_attrs(&kwd_attrs, &kwd_patterns)?; + } + + // Compile the class expression. + self.compile_expression(&match_class.cls)?; + + // Create a new tuple of attribute names. + let mut attr_names = vec![]; + for name in kwd_attrs.iter() { + // Py_NewRef(name) is emulated by cloning the name into a PyObject. + attr_names.push(ConstantData::Str { + value: name.as_str().to_string().into(), + }); + } + + // Emit instructions: + // 1. Load the new tuple of attribute names. + self.emit_load_const(ConstantData::Tuple { + elements: attr_names, + }); + // 2. Emit MATCH_CLASS with nargs. + emit!(self, Instruction::MatchClass(nargs as u32)); + // 3. Duplicate the top of the stack. + emit!(self, Instruction::CopyItem { index: 1_u32 }); + // 4. Load None. + self.emit_load_const(ConstantData::None); + // 5. Compare with IS_OP 1. + emit!(self, Instruction::IsOperation(true)); + + // At this point the TOS is a tuple of (nargs + nattrs) attributes (or None). + pc.on_top += 1; + self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse)?; + + // Unpack the tuple into (nargs + nattrs) items. + let total = nargs + nattrs; + emit!(self, Instruction::UnpackSequence { size: total as u32 }); + pc.on_top += total; + pc.on_top -= 1; + + // Process each sub-pattern. + for i in 0..total { + // Decrement the on_top counter as each sub-pattern is processed. + pc.on_top -= 1; + let subpattern = if i < nargs { + // Positional sub-pattern. + &patterns[i] + } else { + // Keyword sub-pattern. + &kwd_patterns[i - nargs] + }; + if subpattern.is_wildcard() { + // For wildcard patterns, simply pop the top of the stack. + emit!(self, Instruction::Pop); + continue; + } + // Compile the subpattern without irrefutability checks. + self.compile_pattern_subpattern(subpattern, pc)?; } Ok(()) } - fn compile_pattern_inner( + // fn compile_pattern_mapping(&mut self, p: &PatternMatchMapping, pc: &mut PatternContext) -> CompileResult<()> { + // // Ensure the pattern is a mapping pattern. + // let mapping = p; // Extract MatchMapping-specific data. + // let keys = &mapping.keys; + // let patterns = &mapping.patterns; + // let size = keys.len(); + // let npatterns = patterns.len(); + + // if size != npatterns { + // panic!("keys ({}) / patterns ({}) length mismatch in mapping pattern", size, npatterns); + // // return self.compiler_error( + // // &format!("keys ({}) / patterns ({}) length mismatch in mapping pattern", size, npatterns) + // // ); + // } + + // // A double-star target is present if `rest` is set. + // let star_target = mapping.rest; + + // // Keep the subject on top during the mapping and length checks. + // pc.on_top += 1; + // emit!(self, Instruction::MatchMapping); + // self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse)?; + + // // If the pattern is just "{}" (empty mapping) and there's no star target, + // // we're done—pop the subject. + // if size == 0 && star_target.is_none() { + // pc.on_top -= 1; + // emit!(self, Instruction::Pop); + // return Ok(()); + // } + + // // If there are any keys, perform a length check. + // if size != 0 { + // emit!(self, Instruction::GetLen); + // self.emit_load_const(ConstantData::Integer { value: size.into() }); + // emit!(self, Instruction::CompareOperation { op: ComparisonOperator::GreaterOrEqual }); + // self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse)?; + // } + + // // Check that the number of subpatterns is not absurd. + // if size.saturating_sub(1) > (i32::MAX as usize) { + // panic!("too many sub-patterns in mapping pattern"); + // // return self.compiler_error("too many sub-patterns in mapping pattern"); + // } + + // // Collect all keys into a set for duplicate checking. + // let mut seen = HashSet::new(); + + // // For each key, validate it and check for duplicates. + // for (i, key) in keys.iter().enumerate() { + // if let Some(key_val) = key.as_literal_expr() { + // let in_seen = seen.contains(&key_val); + // if in_seen { + // panic!("mapping pattern checks duplicate key: {:?}", key_val); + // // return self.compiler_error(format!("mapping pattern checks duplicate key: {:?}", key_val)); + // } + // seen.insert(key_val); + // } else if !key.is_attribute_expr() { + // panic!("mapping pattern keys may only match literals and attribute lookups"); + // // return self.compiler_error("mapping pattern keys may only match literals and attribute lookups"); + // } + + // // Visit the key expression. + // self.compile_expression(key)?; + // } + // // Drop the set (its resources will be freed automatically). + + // // Build a tuple of keys and emit MATCH_KEYS. + // emit!(self, Instruction::BuildTuple { size: size as u32 }); + // emit!(self, Instruction::MatchKeys); + // // Now, on top of the subject there are two new tuples: one of keys and one of values. + // pc.on_top += 2; + + // // Prepare for matching the values. + // emit!(self, Instruction::CopyItem { index: 1_u32 }); + // self.emit_load_const(ConstantData::None); + // // TODO: should be is + // emit!(self, Instruction::IsOperation(true)); + // self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse)?; + + // // Unpack the tuple of values. + // emit!(self, Instruction::UnpackSequence { size: size as u32 }); + // pc.on_top += size.saturating_sub(1); + + // // Compile each subpattern in "subpattern" mode. + // for pattern in patterns { + // pc.on_top = pc.on_top.saturating_sub(1); + // self.compile_pattern_subpattern(pattern, pc)?; + // } + + // // Consume the tuple of keys and the subject. + // pc.on_top = pc.on_top.saturating_sub(2); + // if let Some(star_target) = star_target { + // // If we have a starred name, bind a dict of remaining items to it. + // // This sequence of instructions performs: + // // rest = dict(subject) + // // for key in keys: del rest[key] + // emit!(self, Instruction::BuildMap { size: 0 }); // Build an empty dict. + // emit!(self, Instruction::Swap(3)); // Rearrange stack: [empty, keys, subject] + // emit!(self, Instruction::DictUpdate { size: 2 }); // Update dict with subject. + // emit!(self, Instruction::UnpackSequence { size: size as u32 }); // Unpack keys. + // let mut remaining = size; + // while remaining > 0 { + // emit!(self, Instruction::CopyItem { index: 1 + remaining as u32 }); // Duplicate subject copy. + // emit!(self, Instruction::Swap { index: 2_u32 }); // Bring key to top. + // emit!(self, Instruction::DeleteSubscript); // Delete key from dict. + // remaining -= 1; + // } + // // Bind the dict to the starred target. + // self.pattern_helper_store_name(Some(&star_target), pc)?; + // } else { + // // No starred target: just pop the tuple of keys and the subject. + // emit!(self, Instruction::Pop); + // emit!(self, Instruction::Pop); + // } + // Ok(()) + // } + + fn compile_pattern_or( &mut self, - pattern_type: &Pattern, - pattern_context: &mut PatternContext, + p: &PatternMatchOr, + pc: &mut PatternContext, ) -> CompileResult<()> { - match &pattern_type { - Pattern::MatchValue(value) => self.compile_pattern_value(value, pattern_context), - Pattern::MatchAs(as_pattern) => self.compile_pattern_as(as_pattern, pattern_context), - _ => { - eprintln!("not implemented pattern type: {pattern_type:?}"); - Err(self.error(CodegenErrorType::NotImplementedYet)) + // Ensure the pattern is a MatchOr. + let end = self.new_block(); // Create a new jump target label. + let size = p.patterns.len(); + assert!(size > 1, "MatchOr must have more than one alternative"); + + // Save the current pattern context. + let old_pc = pc.clone(); + // Simulate Py_INCREF on pc.stores by cloning it. + pc.stores = pc.stores.clone(); + let mut control: Option> = None; // Will hold the capture list of the first alternative. + + // Process each alternative. + for (i, alt) in p.patterns.iter().enumerate() { + // Create a fresh empty store for this alternative. + pc.stores = Vec::new(); + // An irrefutable subpattern must be last (if allowed). + pc.allow_irrefutable = (i == size - 1) && old_pc.allow_irrefutable; + // Reset failure targets and the on_top counter. + pc.fail_pop.clear(); + pc.on_top = 0; + // Emit a COPY(1) instruction before compiling the alternative. + emit!(self, Instruction::CopyItem { index: 1 as u32 }); + self.compile_pattern(alt, pc)?; + + let nstores = pc.stores.len(); + if i == 0 { + // Save the captured names from the first alternative. + control = Some(pc.stores.clone()); + } else { + let control_vec = control.as_ref().unwrap(); + if nstores != control_vec.len() { + todo!(); + // return self.compiler_error("alternative patterns bind different names"); + } else if nstores > 0 { + // Check that the names occur in the same order. + for icontrol in (0..nstores).rev() { + let name = &control_vec[icontrol]; + // Find the index of `name` in the current stores. + let istores = pc.stores.iter().position(|n| n == name).unwrap(); + // .ok_or_else(|| self.compiler_error("alternative patterns bind different names"))?; + if icontrol != istores { + // The orders differ; we must reorder. + assert!(istores < icontrol, "expected istores < icontrol"); + let rotations = istores + 1; + // Rotate pc.stores: take a slice of the first `rotations` items... + let rotated = pc.stores[0..rotations].to_vec(); + // Remove those elements. + for _ in 0..rotations { + pc.stores.remove(0); + } + // Insert the rotated slice at the appropriate index. + let insert_pos = icontrol - istores; + for (j, elem) in rotated.into_iter().enumerate() { + pc.stores.insert(insert_pos + j, elem); + } + // Also perform the same rotation on the evaluation stack. + for _ in 0..(istores + 1) { + self.pattern_helper_rotate(icontrol + 1)?; + } + } + } + } + } + // Emit a jump to the common end label and reset any failure jump targets. + emit!(self, Instruction::Jump { target: end }); + self.emit_and_reset_fail_pop(pc)?; + } + + // Restore the original pattern context. + *pc = old_pc.clone(); + // Simulate Py_INCREF on pc.stores. + pc.stores = pc.stores.clone(); + // In C, old_pc.fail_pop is set to NULL to avoid freeing it later. + // In Rust, old_pc is a local clone, so we need not worry about that. + + // No alternative matched: pop the subject and fail. + emit!(self, Instruction::Pop); + self.jump_to_fail_pop(pc, JumpOp::Jump)?; + + // Use the label "end". + self.switch_to_block(end); + + // Adjust the final captures. + let nstores = control.as_ref().unwrap().len(); + let nrots = nstores + 1 + pc.on_top + pc.stores.len(); + for i in 0..nstores { + // Rotate the capture to its proper place. + self.pattern_helper_rotate(nrots)?; + let name = &control.as_ref().unwrap()[i]; + // Check for duplicate binding. + if pc.stores.iter().any(|n| n == name) { + return Err(self.error(CodegenErrorType::DuplicateStore(name.to_string()))); } + pc.stores.push(name.clone()); } + + // Old context and control will be dropped automatically. + // Finally, pop the copy of the subject. + emit!(self, Instruction::Pop); + Ok(()) } - fn compile_pattern( + fn compile_pattern_sequence( &mut self, - pattern_type: &Pattern, - pattern_context: &mut PatternContext, + p: &PatternMatchSequence, + pc: &mut PatternContext, + ) -> CompileResult<()> { + // Ensure the pattern is a MatchSequence. + let patterns = &p.patterns; // a slice of Pattern + let size = patterns.len(); + let mut star: Option = None; + let mut only_wildcard = true; + let mut star_wildcard = false; + + // Find a starred pattern, if it exists. There may be at most one. + for (i, pattern) in patterns.iter().enumerate() { + if pattern.is_match_star() { + if star.is_some() { + // TODO: Fix error msg + return Err(self.error(CodegenErrorType::MultipleStarArgs)); + } + // star wildcard check + star_wildcard = pattern + .as_match_star() + .map(|m| m.name.is_none()) + .unwrap_or(false); + only_wildcard &= star_wildcard; + star = Some(i); + continue; + } + // wildcard check + only_wildcard &= pattern + .as_match_as() + .map(|m| m.name.is_none()) + .unwrap_or(false); + } + + // Keep the subject on top during the sequence and length checks. + pc.on_top += 1; + emit!(self, Instruction::MatchSequence); + self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse)?; + + if star.is_none() { + // No star: len(subject) == size + emit!(self, Instruction::GetLen); + self.emit_load_const(ConstantData::Integer { value: size.into() }); + emit!( + self, + Instruction::CompareOperation { + op: ComparisonOperator::Equal + } + ); + self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse)?; + } else if size > 1 { + // Star exists: len(subject) >= size - 1 + emit!(self, Instruction::GetLen); + self.emit_load_const(ConstantData::Integer { + value: (size - 1).into(), + }); + emit!( + self, + Instruction::CompareOperation { + op: ComparisonOperator::GreaterOrEqual + } + ); + self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse)?; + } + + // Whatever comes next should consume the subject. + pc.on_top -= 1; + if only_wildcard { + // Patterns like: [] / [_] / [_, _] / [*_] / [_, *_] / [_, _, *_] / etc. + emit!(self, Instruction::Pop); + } else if star_wildcard { + self.pattern_helper_sequence_subscr(&patterns, star.unwrap(), pc)?; + } else { + self.pattern_helper_sequence_unpack(&patterns, star, pc)?; + } + Ok(()) + } + + fn compile_pattern_value( + &mut self, + p: &PatternMatchValue, + pc: &mut PatternContext, + ) -> CompileResult<()> { + // TODO: ensure literal or attribute lookup + self.compile_expression(&p.value)?; + emit!( + self, + Instruction::CompareOperation { + op: bytecode::ComparisonOperator::Equal + } + ); + // emit!(self, Instruction::ToBool); + self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse)?; + Ok(()) + } + + fn compile_pattern_singleton( + &mut self, + p: &PatternMatchSingleton, + pc: &mut PatternContext, ) -> CompileResult<()> { - self.compile_pattern_inner(pattern_type, pattern_context)?; + // Load the singleton constant value. + self.emit_load_const(match p.value { + Singleton::None => ConstantData::None, + Singleton::False => ConstantData::Boolean { value: false }, + Singleton::True => ConstantData::Boolean { value: true }, + }); + // Compare using the "Is" operator. emit!( self, - Instruction::JumpIfFalse { - target: pattern_context.blocks[pattern_context.current_block + 1] + Instruction::CompareOperation { + op: bytecode::ComparisonOperator::Equal } ); + // Jump to the failure label if the comparison is false. + self.jump_to_fail_pop(pc, JumpOp::PopJumpIfFalse)?; Ok(()) } + fn compile_pattern( + &mut self, + pattern_type: &Pattern, + pattern_context: &mut PatternContext, + ) -> CompileResult<()> { + match &pattern_type { + Pattern::MatchValue(pattern_type) => { + self.compile_pattern_value(pattern_type, pattern_context) + } + Pattern::MatchSingleton(pattern_type) => { + self.compile_pattern_singleton(pattern_type, pattern_context) + } + Pattern::MatchSequence(pattern_type) => { + self.compile_pattern_sequence(pattern_type, pattern_context) + } + // Pattern::MatchMapping(pattern_type) => self.compile_pattern_mapping(pattern_type, pattern_context), + Pattern::MatchClass(pattern_type) => { + self.compile_pattern_class(pattern_type, pattern_context) + } + Pattern::MatchStar(pattern_type) => { + self.compile_pattern_star(pattern_type, pattern_context) + } + Pattern::MatchAs(pattern_type) => { + self.compile_pattern_as(pattern_type, pattern_context) + } + Pattern::MatchOr(pattern_type) => { + self.compile_pattern_or(pattern_type, pattern_context) + } + _ => { + // The eprintln gives context as to which pattern type is not implemented. + eprintln!("not implemented pattern type: {pattern_type:?}"); + Err(self.error(CodegenErrorType::NotImplementedYet)) + } + } + } + fn compile_match_inner( &mut self, subject: &Expr, @@ -1874,63 +2649,69 @@ impl Compiler<'_> { pattern_context: &mut PatternContext, ) -> CompileResult<()> { self.compile_expression(subject)?; - pattern_context.blocks = std::iter::repeat_with(|| self.new_block()) - .take(cases.len() + 1) - .collect::>(); - let end_block = *pattern_context.blocks.last().unwrap(); - - let _match_case_type = cases.last().expect("cases is not empty"); - // TODO: get proper check for default case - // let has_default = match_case_type.pattern.is_match_as() && 1 < cases.len(); - let has_default = false; - for i in 0..cases.len() - (has_default as usize) { - self.switch_to_block(pattern_context.blocks[i]); - pattern_context.current_block = i; - pattern_context.allow_irrefutable = cases[i].guard.is_some() || i == cases.len() - 1; + let end = self.new_block(); + + let num_cases = cases.len(); + assert!(num_cases > 0); + let has_default = cases.iter().last().unwrap().pattern.is_match_star() && num_cases > 1; + + let case_count = num_cases - if has_default { 1 } else { 0 }; + for i in 0..case_count { let m = &cases[i]; - // Only copy the subject if we're *not* on the last case: - if i != cases.len() - has_default as usize - 1 { - emit!(self, Instruction::Duplicate); + + // Only copy the subject if not on the last case + if i != case_count - 1 { + emit!(self, Instruction::CopyItem { index: 1 as u32 }); } + + pattern_context.stores = Vec::with_capacity(1); + pattern_context.allow_irrefutable = m.guard.is_some() || i == case_count - 1; + pattern_context.fail_pop.clear(); + pattern_context.on_top = 0; + self.compile_pattern(&m.pattern, pattern_context)?; + assert_eq!(pattern_context.on_top, 0); + + for name in &pattern_context.stores { + self.compile_name(name, NameUsage::Store)?; + } + + if let Some(ref _guard) = m.guard { + self.ensure_fail_pop(pattern_context, 0)?; + // TODO: Fix compile jump if call + return Err(self.error(CodegenErrorType::NotImplementedYet)); + // Jump if the guard fails. We assume that patter_context.fail_pop[0] is the jump target. + // self.compile_jump_if(&m.pattern, &guard, pattern_context.fail_pop[0])?; + } + + if i != case_count - 1 { + emit!(self, Instruction::Pop); + } + self.compile_statements(&m.body)?; - emit!(self, Instruction::Jump { target: end_block }); + emit!(self, Instruction::Jump { target: end }); + self.emit_and_reset_fail_pop(pattern_context)?; } - // TODO: below code is not called and does not work + if has_default { - // A trailing "case _" is common, and lets us save a bit of redundant - // pushing and popping in the loop above: - let m = &cases.last().unwrap(); - self.switch_to_block(*pattern_context.blocks.last().unwrap()); - if cases.len() == 1 { - // No matches. Done with the subject: + let m = &cases[num_cases - 1]; + if num_cases == 1 { emit!(self, Instruction::Pop); } else { - // Show line coverage for default case (it doesn't create bytecode) - // emit!(self, Instruction::Nop); + emit!(self, Instruction::Nop); + } + if let Some(ref _guard) = m.guard { + // TODO: Fix compile jump if call + return Err(self.error(CodegenErrorType::NotImplementedYet)); } self.compile_statements(&m.body)?; } - - self.switch_to_block(end_block); - - let code = self.current_code_info(); - pattern_context - .blocks - .iter() - .zip(pattern_context.blocks.iter().skip(1)) - .for_each(|(a, b)| { - code.blocks[a.0 as usize].next = *b; - }); + self.switch_to_block(end); Ok(()) } fn compile_match(&mut self, subject: &Expr, cases: &[MatchCase]) -> CompileResult<()> { - let mut pattern_context = PatternContext { - current_block: usize::MAX, - blocks: Vec::new(), - allow_irrefutable: false, - }; + let mut pattern_context = PatternContext::new(); self.compile_match_inner(subject, cases, &mut pattern_context)?; Ok(()) } @@ -3638,7 +4419,7 @@ impl ToU32 for usize { } #[cfg(test)] -mod tests { +mod ruff_tests { use super::*; use ruff_python_ast::name::Name; use ruff_python_ast::*; @@ -3741,26 +4522,29 @@ mod tests { } } -/* #[cfg(test)] mod tests { use super::*; - use rustpython_parser::Parse; - use rustpython_parser::ast::Suite; - use rustpython_parser_core::source_code::LinearLocator; fn compile_exec(source: &str) -> CodeObject { - let mut locator: LinearLocator<'_> = LinearLocator::new(source); - use rustpython_parser::ast::fold::Fold; - let mut compiler: Compiler = Compiler::new( - CompileOpts::default(), - "source_path".to_owned(), - "".to_owned(), - ); - let ast = Suite::parse(source, "").unwrap(); - let ast = locator.fold(ast).unwrap(); - let symbol_scope = SymbolTable::scan_program(&ast).unwrap(); - compiler.compile_program(&ast, symbol_scope).unwrap(); + let opts = CompileOpts::default(); + let mode = Mode::Exec; + let source_code = SourceCode::new("source_path", source); + let parsed = ruff_python_parser::parse( + source_code.text, + ruff_python_parser::Mode::from(mode).into(), + ) + .unwrap(); + let ast = parsed.into_syntax(); + let ast = match ast { + ruff_python_ast::Mod::Module(stmts) => stmts, + _ => unreachable!(), + }; + let symbol_table = SymbolTable::scan_program(&ast, source_code.clone()) + .map_err(|e| e.into_codegen_error(source_code.path.to_owned())) + .unwrap(); + let mut compiler = Compiler::new(opts, source_code, "".to_owned()); + compiler.compile_program(&ast, symbol_table).unwrap(); compiler.pop_code_object() } @@ -3817,8 +4601,24 @@ for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')): self.assertIs(ex, stop_exc) else: self.fail(f'{stop_exc} was suppressed') +" + )); + } + + #[test] + fn test_match() { + assert_dis_snapshot!(compile_exec( + "\ +class Test: + pass + +t = Test() +match t: + case Test(): + assert True + case _: + assert False " )); } } -*/ diff --git a/compiler/codegen/src/error.rs b/compiler/codegen/src/error.rs index 8f38680de0..b1b4f9379f 100644 --- a/compiler/codegen/src/error.rs +++ b/compiler/codegen/src/error.rs @@ -1,7 +1,22 @@ use ruff_source_file::SourceLocation; -use std::fmt; +use std::fmt::{self, Display}; use thiserror::Error; +#[derive(Debug)] +pub enum PatternUnreachableReason { + NameCapture, + Wildcard, +} + +impl Display for PatternUnreachableReason { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::NameCapture => write!(f, "name capture"), + Self::Wildcard => write!(f, "wildcard"), + } + } +} + // pub type CodegenError = rustpython_parser_core::source_code::LocatedError; #[derive(Error, Debug)] @@ -47,8 +62,9 @@ pub enum CodegenErrorType { TooManyStarUnpack, EmptyWithItems, EmptyWithBody, + ForbiddenName, DuplicateStore(String), - InvalidMatchCase, + UnreachablePattern(PatternUnreachableReason), NotImplementedYet, // RustPython marker for unimplemented features } @@ -94,11 +110,14 @@ impl fmt::Display for CodegenErrorType { EmptyWithBody => { write!(f, "empty body on With") } + ForbiddenName => { + write!(f, "forbidden attribute name") + } DuplicateStore(s) => { write!(f, "duplicate store {s}") } - InvalidMatchCase => { - write!(f, "invalid match case") + UnreachablePattern(reason) => { + write!(f, "{reason} makes remaining patterns unreachable") } NotImplementedYet => { write!(f, "RustPython does not implement this feature yet") diff --git a/compiler/codegen/src/snapshots/rustpython_codegen__compile__tests__match.snap.new b/compiler/codegen/src/snapshots/rustpython_codegen__compile__tests__match.snap.new new file mode 100644 index 0000000000..8aaaf9bf9e --- /dev/null +++ b/compiler/codegen/src/snapshots/rustpython_codegen__compile__tests__match.snap.new @@ -0,0 +1,54 @@ +--- +source: compiler/codegen/src/compile.rs +assertion_line: 4553 +expression: "compile_exec(\"\\\nclass Test:\n pass\n\nt = Test()\nmatch t:\n case Test():\n assert True\n case _:\n assert False\n\")" +--- + 2 0 LoadBuildClass + 1 LoadConst (): 1 0 LoadGlobal (0, __name__) + 1 StoreLocal (1, __module__) + 2 LoadConst ("Test") + 3 StoreLocal (2, __qualname__) + 4 LoadConst (None) + 5 StoreLocal (3, __doc__) + + 2 6 ReturnConst (None) + + 2 LoadConst ("Test") + 3 MakeFunction (MakeFunctionFlags(0x0)) + 4 LoadConst ("Test") + 5 CallFunctionPositional(2) + 6 StoreLocal (0, Test) + + 4 7 LoadNameAny (0, Test) + 8 CallFunctionPositional(0) + 9 StoreLocal (1, t) + + 5 10 LoadNameAny (1, t) + 11 CopyItem (1) + + 6 12 LoadNameAny (0, Test) + 13 LoadConst (()) + 14 MatchClass (0) + 15 CopyItem (1) + 16 LoadConst (None) + 17 IsOperation (true) + 18 JumpIfFalse (27) + 19 UnpackSequence (0) + 20 Pop + + 7 21 LoadConst (True) + 22 JumpIfTrue (26) + 23 LoadGlobal (2, AssertionError) + 24 CallFunctionPositional(0) + 25 Raise (Raise) + >> 26 Jump (35) + >> 27 Pop + 28 Pop + + 9 29 LoadConst (False) + 30 JumpIfTrue (34) + 31 LoadGlobal (2, AssertionError) + 32 CallFunctionPositional(0) + 33 Raise (Raise) + >> 34 Jump (35) + >> 35 ReturnConst (None) diff --git a/compiler/core/src/bytecode.rs b/compiler/core/src/bytecode.rs index 2e8ff29014..916077c57a 100644 --- a/compiler/core/src/bytecode.rs +++ b/compiler/core/src/bytecode.rs @@ -381,6 +381,7 @@ pub type NameIdx = u32; #[derive(Debug, Copy, Clone, PartialEq, Eq)] #[repr(u8)] pub enum Instruction { + Nop, /// Importing by name ImportName { idx: Arg, @@ -429,21 +430,32 @@ pub enum Instruction { BinaryOperationInplace { op: Arg, }, + BinarySubscript, LoadAttr { idx: Arg, }, TestOperation { op: Arg, }, + /// If the argument is true, perform IS NOT. Otherwise perform the IS operation. + IsOperation(Arg), CompareOperation { op: Arg, }, + CopyItem { + index: Arg, + }, Pop, + Swap { + index: Arg, + }, + // ToBool, Rotate2, Rotate3, Duplicate, Duplicate2, GetIter, + GetLen, Continue { target: Arg