diff --git a/tests/snippets/stdlib_traceback.py b/tests/snippets/stdlib_traceback.py index 12ac9835ec..689f36e027 100644 --- a/tests/snippets/stdlib_traceback.py +++ b/tests/snippets/stdlib_traceback.py @@ -3,5 +3,25 @@ try: 1/0 except ZeroDivisionError as ex: - tb = traceback.format_list(traceback.extract_tb(ex.__traceback__)) + tb = traceback.extract_tb(ex.__traceback__) + assert len(tb) == 1 + + +try: + try: + 1/0 + except ZeroDivisionError as ex: + raise KeyError().with_traceback(ex.__traceback__) +except KeyError as ex2: + tb = traceback.extract_tb(ex2.__traceback__) + assert tb[1].line == "1/0" + + +try: + try: + 1/0 + except ZeroDivisionError as ex: + raise ex.with_traceback(None) +except ZeroDivisionError as ex2: + tb = traceback.extract_tb(ex2.__traceback__) assert len(tb) == 1 diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index 79924df9c8..43682a8d9d 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -208,6 +208,19 @@ fn exception_repr(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.new_str(joined_str)) } +fn exception_with_traceback( + zelf: PyObjectRef, + tb: Option, + vm: &VirtualMachine, +) -> PyResult { + vm.set_attr( + &zelf, + "__traceback__", + tb.map_or(vm.get_none(), |tb| tb.into_object()), + )?; + Ok(zelf) +} + #[derive(Debug)] pub struct ExceptionZoo { pub arithmetic_error: PyClassRef, @@ -400,7 +413,8 @@ fn import_error_init(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { pub fn init(context: &PyContext) { let base_exception_type = &context.exceptions.base_exception_type; extend_class!(context, base_exception_type, { - "__init__" => context.new_rustfunc(exception_init) + "__init__" => context.new_rustfunc(exception_init), + "with_traceback" => context.new_rustfunc(exception_with_traceback) }); let exception_type = &context.exceptions.exception_type;