Skip to content

__type_params__ in __build_class__ #5883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,6 @@ class A(Generic[P]): ...
P_default = ParamSpec('P_default', default=...)
self.assertIs(P_default.__default__, ...)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_paramspec_none(self):
U = ParamSpec('U')
U_None = ParamSpec('U_None', default=None)
Expand Down Expand Up @@ -756,8 +754,6 @@ class A(Generic[T, P, U]): ...
self.assertEqual(A[float, [range]].__args__, (float, (range,), float))
self.assertEqual(A[float, [range], int].__args__, (float, (range,), int))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_typevartuple_none(self):
U = TypeVarTuple('U')
U_None = TypeVarTuple('U_None', default=None)
Expand Down Expand Up @@ -3893,8 +3889,6 @@ def f(x: X): ...
{'x': list[list[ForwardRef('X')]]}
)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_pep695_generic_class_with_future_annotations(self):
original_globals = dict(ann_module695.__dict__)

Expand All @@ -3913,8 +3907,6 @@ def test_pep695_generic_class_with_future_annotations_and_local_shadowing(self):
hints_for_B = get_type_hints(ann_module695.B)
self.assertEqual(hints_for_B, {"x": int, "y": str, "z": bytes})

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_pep695_generic_class_with_future_annotations_name_clash_with_global_vars(self):
hints_for_C = get_type_hints(ann_module695.C)
self.assertEqual(
Expand Down
25 changes: 22 additions & 3 deletions compiler/codegen/src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,7 @@ impl Compiler<'_> {
/// Store each type parameter so it is accessible to the current scope, and leave a tuple of
/// all the type parameters on the stack.
fn compile_type_params(&mut self, type_params: &TypeParams) -> CompileResult<()> {
// First, compile each type parameter and store it
for type_param in &type_params.type_params {
match type_param {
TypeParam::TypeVar(TypeParamTypeVar { name, bound, .. }) => {
Expand Down Expand Up @@ -1664,8 +1665,12 @@ impl Compiler<'_> {
let qualified_name = self.qualified_path.join(".");

// If there are type params, we need to push a special symbol table just for them
if type_params.is_some() {
if let Some(type_params) = type_params {
self.push_symbol_table();
// Compile type parameters and store as .type_params
self.compile_type_params(type_params)?;
let dot_type_params = self.name(".type_params");
emit!(self, Instruction::StoreLocal(dot_type_params));
}

self.push_output(bytecode::CodeFlags::empty(), 0, 0, 0, name.to_owned());
Expand All @@ -1688,6 +1693,18 @@ impl Compiler<'_> {
if Self::find_ann(body) {
emit!(self, Instruction::SetupAnnotation);
}

// Set __type_params__ from .type_params if we have type parameters (PEP 695)
if type_params.is_some() {
// Load .type_params from enclosing scope
let dot_type_params = self.name(".type_params");
emit!(self, Instruction::LoadNameAny(dot_type_params));

// Store as __type_params__
let dunder_type_params = self.name("__type_params__");
emit!(self, Instruction::StoreLocal(dunder_type_params));
}

self.compile_statements(body)?;

let classcell_idx = self
Expand Down Expand Up @@ -1721,8 +1738,10 @@ impl Compiler<'_> {
let mut func_flags = bytecode::MakeFunctionFlags::empty();

// Prepare generic type parameters:
if let Some(type_params) = type_params {
self.compile_type_params(type_params)?;
if type_params.is_some() {
// Load .type_params from the type params scope
let dot_type_params = self.name(".type_params");
emit!(self, Instruction::LoadNameAny(dot_type_params));
func_flags |= bytecode::MakeFunctionFlags::TYPE_PARAMS;
}

Expand Down
1 change: 1 addition & 0 deletions compiler/codegen/src/symboltable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,7 @@ impl SymbolTableBuilder<'_> {
}

fn scan_type_params(&mut self, type_params: &TypeParams) -> SymbolTableResult {
// First register all type parameters
for type_param in &type_params.type_params {
match type_param {
TypeParam::TypeVar(TypeParamTypeVar {
Expand Down
3 changes: 2 additions & 1 deletion vm/src/builtins/bytearray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,8 @@ impl PyByteArray {
self.borrow_buf_mut().reverse();
}

#[pyclassmethod]
// TODO: Uncomment when Python adds __class_getitem__ to bytearray
// #[pyclassmethod]
fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}
Expand Down
3 changes: 2 additions & 1 deletion vm/src/builtins/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ impl PyBytes {
PyTuple::new_ref(param, &vm.ctx)
}

#[pyclassmethod]
// TODO: Uncomment when Python adds __class_getitem__ to bytes
// #[pyclassmethod]
fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}
Expand Down
3 changes: 2 additions & 1 deletion vm/src/builtins/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,8 @@ impl Py<PyMemoryView> {
Representable
))]
impl PyMemoryView {
#[pyclassmethod]
// TODO: Uncomment when Python adds __class_getitem__ to memoryview
// #[pyclassmethod]
fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}
Expand Down
11 changes: 6 additions & 5 deletions vm/src/builtins/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,6 @@ pub fn init(context: &Context) {
Representable
))]
impl PyRange {
#[pyclassmethod]
fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}

fn new(cls: PyTypeRef, stop: ArgIndex, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
PyRange {
start: vm.ctx.new_pyref(0),
Expand Down Expand Up @@ -328,6 +323,12 @@ impl PyRange {

Ok(range.into())
}

// TODO: Uncomment when Python adds __class_getitem__ to range
// #[pyclassmethod]
fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}
}

#[pyclass]
Expand Down
11 changes: 6 additions & 5 deletions vm/src/builtins/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ impl PyPayload for PySlice {

#[pyclass(with(Comparable, Representable, Hashable))]
impl PySlice {
#[pyclassmethod]
fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}

#[pygetset]
fn start(&self, vm: &VirtualMachine) -> PyObjectRef {
self.start.clone().to_pyobject(vm)
Expand Down Expand Up @@ -200,6 +195,12 @@ impl PySlice {
(zelf.start.clone(), zelf.stop.clone(), zelf.step.clone()),
))
}

// TODO: Uncomment when Python adds __class_getitem__ to slice
// #[pyclassmethod]
fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
PyGenericAlias::from_args(cls, args, vm)
}
}

impl Hashable for PySlice {
Expand Down
35 changes: 35 additions & 0 deletions vm/src/stdlib/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,23 @@ mod builtins {
))
})?;

// For PEP 695 classes, set .type_params in namespace before calling the function
if let Ok(type_params) = function
.as_object()
.get_attr(identifier!(vm, __type_params__), vm)
{
if let Some(type_params_tuple) = type_params.downcast_ref::<PyTuple>() {
if !type_params_tuple.is_empty() {
// Set .type_params in namespace so the compiler-generated code can use it
namespace.as_object().set_item(
vm.ctx.intern_str(".type_params"),
type_params,
vm,
)?;
}
}
}

let classcell = function.invoke_with_locals(().into(), Some(namespace.clone()), vm)?;
let classcell = <Option<PyCellRef>>::try_from_object(vm, classcell)?;

Expand All @@ -943,9 +960,27 @@ mod builtins {
)?;
}

// Remove .type_params from namespace before creating the class
namespace
.as_object()
.del_item(vm.ctx.intern_str(".type_params"), vm)
.ok();

let args = FuncArgs::new(vec![name_obj.into(), bases, namespace.into()], kwargs);
let class = metaclass.call(args, vm)?;

// For PEP 695 classes, set __type_params__ on the class from the function
if let Ok(type_params) = function
.as_object()
.get_attr(identifier!(vm, __type_params__), vm)
{
if let Some(type_params_tuple) = type_params.downcast_ref::<PyTuple>() {
if !type_params_tuple.is_empty() {
class.set_attr(identifier!(vm, __type_params__), type_params, vm)?;
}
}
}

if let Some(ref classcell) = classcell {
let classcell = classcell.get().ok_or_else(|| {
vm.new_type_error(format!(
Expand Down
Loading