Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions Lib/test/test_contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,6 @@ def method(self, a, b, c=None):
self.assertEqual(test.b, 2)


# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_typo_enter(self):
class mycontext(ContextDecorator):
def __unter__(self):
Expand All @@ -559,8 +557,6 @@ def __exit__(self, *exc):
pass


# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_typo_exit(self):
class mycontext(ContextDecorator):
def __enter__(self):
Expand Down
6 changes: 0 additions & 6 deletions Lib/test/test_with.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def fooNotDeclared():
with foo: pass
self.assertRaises(NameError, fooNotDeclared)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def testEnterAttributeError1(self):
class LacksEnter(object):
def __exit__(self, type, value, traceback):
Expand All @@ -121,8 +119,6 @@ def fooLacksEnter():
with foo: pass
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnter)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def testEnterAttributeError2(self):
class LacksEnterAndExit(object):
pass
Expand All @@ -132,8 +128,6 @@ def fooLacksEnterAndExit():
with foo: pass
self.assertRaisesRegex(TypeError, 'the context manager', fooLacksEnterAndExit)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def testExitAttributeError(self):
class LacksExit(object):
def __enter__(self):
Expand Down
45 changes: 36 additions & 9 deletions vm/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,12 +840,24 @@ impl ExecutingFrame<'_> {
}
bytecode::Instruction::SetupWith { end } => {
let context_manager = self.pop_value();
let enter_res = vm.call_special_method(
context_manager.clone(),
identifier!(vm, __enter__),
(),
)?;
let exit = context_manager.get_attr(identifier!(vm, __exit__), vm)?;
let error_string = || -> String {
format!(
"'{:.200}' object does not support the context manager protocol",
context_manager.class().name(),
)
};
let enter_res = vm
.get_special_method(context_manager.clone(), identifier!(vm, __enter__))?
.map_err(|_obj| vm.new_type_error(error_string()))?
.invoke((), vm)?;

let exit = context_manager
.get_attr(identifier!(vm, __exit__), vm)
.map_err(|_exc| {
vm.new_type_error({
format!("'{} (missed __exit__ method)", error_string())
})
})?;
self.push_value(exit);
self.push_block(BlockType::Finally {
handler: end.get(arg),
Expand All @@ -855,9 +867,24 @@ impl ExecutingFrame<'_> {
}
bytecode::Instruction::BeforeAsyncWith => {
let mgr = self.pop_value();
let aenter_res =
vm.call_special_method(mgr.clone(), identifier!(vm, __aenter__), ())?;
let aexit = mgr.get_attr(identifier!(vm, __aexit__), vm)?;
let error_string = || -> String {
format!(
"'{:.200}' object does not support the asynchronous context manager protocol",
mgr.class().name(),
)
};

let aenter_res = vm
.get_special_method(mgr.clone(), identifier!(vm, __aenter__))?
.map_err(|_obj| vm.new_type_error(error_string()))?
.invoke((), vm)?;
let aexit = mgr
.get_attr(identifier!(vm, __aexit__), vm)
.map_err(|_exc| {
vm.new_type_error({
format!("'{} (missed __aexit__ method)", error_string())
})
})?;
self.push_value(aexit);
self.push_value(aenter_res);

Expand Down