Skip to content

ZJIT: Implement getspecial #13642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions test/ruby/test_zjit.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,106 @@ def test = 1.nil?
}, insns: [:opt_nil_p]
end

def test_getspecial_last_match
assert_compiles '"hello"', %q{
def test(str)
str =~ /hello/
$&
end
test("hello world")
}, insns: [:getspecial]
end

def test_getspecial_match_pre
assert_compiles '"hello "', %q{
def test(str)
str =~ /world/
$`
end
test("hello world")
}, insns: [:getspecial]
end

def test_getspecial_match_post
assert_compiles '" world"', %q{
def test(str)
str =~ /hello/
$'
end
test("hello world")
}, insns: [:getspecial]
end

def test_getspecial_match_last_group
assert_compiles '"world"', %q{
def test(str)
str =~ /(hello) (world)/
$+
end
test("hello world")
}, insns: [:getspecial]
end

def test_getspecial_numbered_match_1
assert_compiles '"hello"', %q{
def test(str)
str =~ /(hello) (world)/
$1
end
test("hello world")
}, insns: [:getspecial]
end

def test_getspecial_numbered_match_2
assert_compiles '"world"', %q{
def test(str)
str =~ /(hello) (world)/
$2
end
test("hello world")
}, insns: [:getspecial]
end

def test_getspecial_numbered_match_nonexistent
assert_compiles 'nil', %q{
def test(str)
str =~ /(hello)/
$2
end
test("hello world")
}, insns: [:getspecial]
end

def test_getspecial_no_match
assert_compiles 'nil', %q{
def test(str)
str =~ /xyz/
$&
end
test("hello world")
}, insns: [:getspecial]
end

def test_getspecial_complex_pattern
assert_compiles '"123"', %q{
def test(str)
str =~ /(\d+)/
$1
end
test("abc123def")
}, insns: [:getspecial]
end

def test_getspecial_multiple_groups
assert_compiles '"456"', %q{
def test(str)
str =~ /(\d+)-(\d+)/
$2
end
test("123-456")
}, insns: [:getspecial]
end

# tool/ruby_vm/views/*.erb relies on the zjit instructions a) being contiguous and
# b) being reliably ordered after all the other instructions.
def test_instruction_order
Expand Down
35 changes: 34 additions & 1 deletion zjit/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::state::ZJITState;
use crate::stats::{counter_ptr, with_time_stat, Counter, Counter::compile_time_ns};
use crate::{asm::CodeBlock, cruby::*, options::debug, virtualmem::CodePtr};
use crate::backend::lir::{self, asm_comment, asm_ccall, Assembler, Opnd, Target, CFP, C_ARG_OPNDS, C_RET_OPND, EC, NATIVE_STACK_PTR, NATIVE_BASE_PTR, SP};
use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, Invariant, RangeType, SideExitReason, SideExitReason::*, SpecialObjectType, SELF_PARAM_IDX};
use crate::hir::{iseq_to_hir, Block, BlockId, BranchEdge, Invariant, RangeType, SideExitReason, SideExitReason::*, SpecialObjectType, SpecialBackrefSymbol, SELF_PARAM_IDX};
use crate::hir::{Const, FrameState, Function, Insn, InsnId};
use crate::hir_type::{types, Type};
use crate::options::get_option;
Expand Down Expand Up @@ -378,6 +378,8 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
Insn::PutSpecialObject { value_type } => gen_putspecialobject(asm, *value_type),
Insn::AnyToString { val, str, state } => gen_anytostring(asm, opnd!(val), opnd!(str), &function.frame_state(*state))?,
Insn::Defined { op_type, obj, pushval, v, state } => gen_defined(jit, asm, *op_type, *obj, *pushval, opnd!(v), &function.frame_state(*state))?,
Insn::GetSpecialSymbol { symbol_type, state: _ } => gen_getspecial_symbol(asm, *symbol_type),
Insn::GetSpecialNumber { nth, state } => gen_getspecial_number(asm, *nth, &function.frame_state(*state)),
&Insn::IncrCounter(counter) => return Some(gen_incr_counter(asm, counter)),
Insn::ObjToString { val, cd, state, .. } => gen_objtostring(jit, asm, opnd!(val), *cd, &function.frame_state(*state))?,
Insn::ArrayExtend { .. }
Expand Down Expand Up @@ -640,6 +642,37 @@ fn gen_putspecialobject(asm: &mut Assembler, value_type: SpecialObjectType) -> O
asm_ccall!(asm, rb_vm_get_special_object, ep_reg, Opnd::UImm(u64::from(value_type)))
}

fn gen_getspecial_symbol(asm: &mut Assembler, symbol_type: SpecialBackrefSymbol) -> Opnd {
// Fetch a "special" backref based on the symbol type

let backref = asm_ccall!(asm, rb_backref_get,);

match symbol_type {
SpecialBackrefSymbol::LastMatch => {
asm_ccall!(asm, rb_reg_last_match, backref)
}
SpecialBackrefSymbol::PreMatch => {
asm_ccall!(asm, rb_reg_match_pre, backref)
}
SpecialBackrefSymbol::PostMatch => {
asm_ccall!(asm, rb_reg_match_post, backref)
}
SpecialBackrefSymbol::LastGroup => {
asm_ccall!(asm, rb_reg_match_last, backref)
}
}
}

fn gen_getspecial_number(asm: &mut Assembler, nth: u64, state: &FrameState) -> Opnd {
// Fetch the N-th match from the last backref based on type shifted by 1

let backref = asm_ccall!(asm, rb_backref_get,);

gen_prepare_call_with_gc(asm, state);

asm_ccall!(asm, rb_reg_nth_match, Opnd::Imm((nth >> 1).try_into().unwrap()), backref)
}

/// Compile an interpreter entry block to be inserted into an ISEQ
fn gen_entry_prologue(asm: &mut Assembler, iseq: IseqPtr) {
asm_comment!(asm, "ZJIT entry point: {}", iseq_get_location(iseq, 0));
Expand Down
57 changes: 57 additions & 0 deletions zjit/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,29 @@ impl From<RangeType> for u32 {
}
}

/// Special regex backref symbol types
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SpecialBackrefSymbol {
LastMatch, // $&
PreMatch, // $`
PostMatch, // $'
LastGroup, // $+
}

impl TryFrom<u8> for SpecialBackrefSymbol {
type Error = String;

fn try_from(value: u8) -> Result<Self, Self::Error> {
match value as char {
'&' => Ok(SpecialBackrefSymbol::LastMatch),
'`' => Ok(SpecialBackrefSymbol::PreMatch),
'\'' => Ok(SpecialBackrefSymbol::PostMatch),
'+' => Ok(SpecialBackrefSymbol::LastGroup),
c => Err(format!("invalid backref symbol: '{}'", c)),
}
}
}

/// Print adaptor for [`Const`]. See [`PtrPrintMap`].
struct ConstPrinter<'a> {
inner: &'a Const,
Expand Down Expand Up @@ -415,6 +438,7 @@ pub enum SideExitReason {
PatchPoint(Invariant),
CalleeSideExit,
ObjToStringFallback,
UnknownSpecialVariable(u64),
}

impl std::fmt::Display for SideExitReason {
Expand Down Expand Up @@ -494,6 +518,8 @@ pub enum Insn {
GetLocal { level: u32, ep_offset: u32 },
/// Set a local variable in a higher scope or the heap
SetLocal { level: u32, ep_offset: u32, val: InsnId },
GetSpecialSymbol { symbol_type: SpecialBackrefSymbol, state: InsnId },
GetSpecialNumber { nth: u64, state: InsnId },

/// Own a FrameState so that instructions can look up their dominating FrameState when
/// generating deopt side-exits and frame reconstruction metadata. Does not directly generate
Expand Down Expand Up @@ -774,6 +800,8 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
Insn::SetGlobal { id, val, .. } => write!(f, "SetGlobal :{}, {val}", id.contents_lossy()),
Insn::GetLocal { level, ep_offset } => write!(f, "GetLocal l{level}, EP@{ep_offset}"),
Insn::SetLocal { val, level, ep_offset } => write!(f, "SetLocal l{level}, EP@{ep_offset}, {val}"),
Insn::GetSpecialSymbol { symbol_type, .. } => write!(f, "GetSpecialSymbol {symbol_type:?}"),
Insn::GetSpecialNumber { nth, .. } => write!(f, "GetSpecialNumber {nth}"),
Insn::ToArray { val, .. } => write!(f, "ToArray {val}"),
Insn::ToNewArray { val, .. } => write!(f, "ToNewArray {val}"),
Insn::ArrayExtend { left, right, .. } => write!(f, "ArrayExtend {left}, {right}"),
Expand Down Expand Up @@ -1221,6 +1249,8 @@ impl Function {
&GetIvar { self_val, id, state } => GetIvar { self_val: find!(self_val), id, state },
&SetIvar { self_val, id, val, state } => SetIvar { self_val: find!(self_val), id, val: find!(val), state },
&SetLocal { val, ep_offset, level } => SetLocal { val: find!(val), ep_offset, level },
&GetSpecialSymbol { symbol_type, state } => GetSpecialSymbol { symbol_type, state },
&GetSpecialNumber { nth, state } => GetSpecialNumber { nth, state },
&ToArray { val, state } => ToArray { val: find!(val), state },
&ToNewArray { val, state } => ToNewArray { val: find!(val), state },
&ArrayExtend { left, right, state } => ArrayExtend { left: find!(left), right: find!(right), state },
Expand Down Expand Up @@ -1306,6 +1336,8 @@ impl Function {
Insn::ArrayMax { .. } => types::BasicObject,
Insn::GetGlobal { .. } => types::BasicObject,
Insn::GetIvar { .. } => types::BasicObject,
Insn::GetSpecialSymbol { .. } => types::BasicObject,
Insn::GetSpecialNumber { .. } => types::BasicObject,
Insn::ToNewArray { .. } => types::ArrayExact,
Insn::ToArray { .. } => types::ArrayExact,
Insn::ObjToString { .. } => types::BasicObject,
Expand Down Expand Up @@ -1995,6 +2027,8 @@ impl Function {
worklist.push_back(state);
}
&Insn::GetGlobal { state, .. } |
&Insn::GetSpecialSymbol { state, .. } |
&Insn::GetSpecialNumber { state, .. } |
&Insn::SideExit { state, .. } => worklist.push_back(state),
}
}
Expand Down Expand Up @@ -3323,6 +3357,29 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
let anytostring = fun.push_insn(block, Insn::AnyToString { val, str, state: exit_id });
state.stack_push(anytostring);
}
YARVINSN_getspecial => {
let key = get_arg(pc, 0).as_u64();
let svar = get_arg(pc, 1).as_u64();

let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });

if svar == 0 {
// TODO: Handle non-backref
fun.push_insn(block, Insn::SideExit { state: exit_id, reason: SideExitReason::UnknownSpecialVariable(key) });
// End the block
break;
} else if svar & 0x01 != 0 {
// Handle symbol backrefs like $&, $`, $', $+
let shifted_svar: u8 = (svar >> 1).try_into().unwrap();
let symbol_type = SpecialBackrefSymbol::try_from(shifted_svar).expect("invalid backref symbol");
let result = fun.push_insn(block, Insn::GetSpecialSymbol { symbol_type, state: exit_id });
state.stack_push(result);
} else {
// Handle number backrefs like $1, $2, $3
let result = fun.push_insn(block, Insn::GetSpecialNumber { nth: svar, state: exit_id });
state.stack_push(result);
}
}
_ => {
// Unknown opcode; side-exit into the interpreter
let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state });
Expand Down
Loading