Skip to content

Commit e7fa32c

Browse files
authored
Merge pull request #3402 from Snowapril/refactor_exceptions
Refactor duplicated codes in exception creations
2 parents 15d8eec + 97f3434 commit e7fa32c

File tree

10 files changed

+156
-185
lines changed

10 files changed

+156
-185
lines changed

derive/src/pymodule.rs

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::error::Diagnostic;
22
use crate::util::{
3-
iter_use_idents, pyclass_ident_and_attrs, text_signature, AttributeExt, ClassItemMeta,
4-
ContentItem, ContentItemInner, ErrorVec, ItemMeta, ItemNursery, SimpleItemMeta,
3+
iter_use_idents, pyclass_ident_and_attrs, text_signature, AttrItemMeta, AttributeExt,
4+
ClassItemMeta, ContentItem, ContentItemInner, ErrorVec, ItemMeta, ItemNursery, SimpleItemMeta,
55
ALL_ALLOWED_NAMES,
66
};
77
use proc_macro2::TokenStream;
@@ -241,7 +241,7 @@ impl ContentItem for AttributeItem {
241241
}
242242

243243
struct ModuleItemArgs<'a> {
244-
item: &'a Item,
244+
item: &'a mut Item,
245245
attrs: &'a mut Vec<Attribute>,
246246
context: &'a mut ModuleContext,
247247
cfgs: &'a [Attribute],
@@ -370,15 +370,36 @@ impl ModuleItem for AttributeItem {
370370
fn gen_module_item(&self, args: ModuleItemArgs<'_>) -> Result<()> {
371371
let cfgs = args.cfgs.to_vec();
372372
let attr = args.attrs.remove(self.index());
373-
let get_py_name = |attr: &Attribute, ident: &Ident| -> Result<_> {
374-
let item_meta = SimpleItemMeta::from_attr(ident.clone(), attr)?;
375-
let py_name = item_meta.simple_name()?;
376-
Ok(py_name)
377-
};
378373
let (py_name, tokens) = match args.item {
379-
Item::Fn(syn::ItemFn { sig, .. }) => {
374+
Item::Fn(syn::ItemFn { sig, block, .. }) => {
380375
let ident = &sig.ident;
381-
let py_name = get_py_name(&attr, ident)?;
376+
// If `once` keyword is in #[pyattr],
377+
// wrapping it with static_cell for preventing it from using it as function
378+
let attr_meta = AttrItemMeta::from_attr(ident.clone(), &attr)?;
379+
if attr_meta.inner()._bool("once")? {
380+
let stmts = &block.stmts;
381+
let return_type = match &sig.output {
382+
syn::ReturnType::Default => {
383+
unreachable!("#[pyattr] attached function must have return type.")
384+
}
385+
syn::ReturnType::Type(_, ty) => ty,
386+
};
387+
let stmt: syn::Stmt = parse_quote! {
388+
{
389+
rustpython_common::static_cell! {
390+
static ERROR: #return_type;
391+
}
392+
ERROR
393+
.get_or_init(|| {
394+
#(#stmts)*
395+
})
396+
.clone()
397+
}
398+
};
399+
block.stmts = vec![stmt];
400+
}
401+
402+
let py_name = attr_meta.simple_name()?;
382403
(
383404
py_name.clone(),
384405
quote_spanned! { ident.span() =>
@@ -387,7 +408,8 @@ impl ModuleItem for AttributeItem {
387408
)
388409
}
389410
Item::Const(syn::ItemConst { ident, .. }) => {
390-
let py_name = get_py_name(&attr, ident)?;
411+
let item_meta = SimpleItemMeta::from_attr(ident.clone(), &attr)?;
412+
let py_name = item_meta.simple_name()?;
391413
(
392414
py_name.clone(),
393415
quote_spanned! { ident.span() =>

derive/src/util.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,19 @@ impl ItemMeta for SimpleItemMeta {
236236
}
237237
}
238238

239+
pub(crate) struct AttrItemMeta(pub ItemMetaInner);
240+
241+
impl ItemMeta for AttrItemMeta {
242+
const ALLOWED_NAMES: &'static [&'static str] = &["name", "once"];
243+
244+
fn from_inner(inner: ItemMetaInner) -> Self {
245+
Self(inner)
246+
}
247+
fn inner(&self) -> &ItemMetaInner {
248+
&self.0
249+
}
250+
}
251+
239252
pub(crate) struct ClassItemMeta(ItemMetaInner);
240253

241254
impl ItemMeta for ClassItemMeta {

stdlib/src/binascii.rs

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,18 @@ mod decl {
1111
};
1212
use itertools::Itertools;
1313

14-
#[pyattr(name = "Error")]
14+
#[pyattr(name = "Error", once)]
1515
fn error_type(vm: &VirtualMachine) -> PyTypeRef {
16-
rustpython_common::static_cell! {
17-
static BINASCII_ERROR: PyTypeRef;
18-
}
19-
BINASCII_ERROR
20-
.get_or_init(|| {
21-
vm.ctx.new_class(
22-
Some("binascii"),
23-
"Error",
24-
&vm.ctx.exceptions.value_error,
25-
Default::default(),
26-
)
27-
})
28-
.clone()
16+
vm.ctx.new_exception_type(
17+
"binascii",
18+
"Error",
19+
Some(vec![vm.ctx.exceptions.value_error.clone()]),
20+
)
2921
}
3022

31-
#[pyattr(name = "Incomplete")]
23+
#[pyattr(name = "Incomplete", once)]
3224
fn incomplete_type(vm: &VirtualMachine) -> PyTypeRef {
33-
rustpython_common::static_cell! {
34-
static BINASCII_INCOMPLTE: PyTypeRef;
35-
}
36-
BINASCII_INCOMPLTE
37-
.get_or_init(|| {
38-
vm.ctx.new_class(
39-
Some("binascii"),
40-
"Incomplete",
41-
&vm.ctx.exceptions.exception_type,
42-
Default::default(),
43-
)
44-
})
45-
.clone()
25+
vm.ctx.new_exception_type("binascii", "Incomplete", None)
4626
}
4727

4828
fn hex_nibble(n: u8) -> u8 {

stdlib/src/csv.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyObjectRef {
1111
mod _csv {
1212
use crate::common::lock::PyMutex;
1313
use crate::vm::{
14-
builtins::{PyStr, PyStrRef, PyType, PyTypeRef},
14+
builtins::{PyStr, PyStrRef, PyTypeRef},
1515
function::{ArgIterable, ArgumentError, FromArgs, FuncArgs},
1616
match_class,
1717
protocol::{PyIter, PyIterReturn},
@@ -30,9 +30,13 @@ mod _csv {
3030
#[pyattr]
3131
const QUOTE_NONE: i32 = QuoteStyle::None as i32;
3232

33-
#[pyattr(name = "Error")]
33+
#[pyattr(name = "Error", once)]
3434
fn error(vm: &VirtualMachine) -> PyTypeRef {
35-
PyType::new_simple_ref("_csv.Error", &vm.ctx.exceptions.exception_type).unwrap()
35+
vm.ctx.new_exception_type(
36+
"_csv",
37+
"Error",
38+
Some(vec![vm.ctx.exceptions.exception_type.clone()]),
39+
)
3640
}
3741

3842
#[pyfunction]

stdlib/src/socket.rs

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -80,53 +80,29 @@ mod _socket {
8080
vm.ctx.exceptions.os_error.clone()
8181
}
8282

83-
#[pyattr]
83+
#[pyattr(once)]
8484
fn timeout(vm: &VirtualMachine) -> PyTypeRef {
85-
rustpython_common::static_cell! {
86-
static ERROR: PyTypeRef;
87-
}
88-
ERROR
89-
.get_or_init(|| {
90-
vm.ctx.new_class(
91-
Some("socket"),
92-
"timeout",
93-
&vm.ctx.exceptions.os_error,
94-
Default::default(),
95-
)
96-
})
97-
.clone()
85+
vm.ctx.new_exception_type(
86+
"socket",
87+
"timeout",
88+
Some(vec![vm.ctx.exceptions.os_error.clone()]),
89+
)
9890
}
99-
#[pyattr]
91+
#[pyattr(once)]
10092
fn herror(vm: &VirtualMachine) -> PyTypeRef {
101-
rustpython_common::static_cell! {
102-
static ERROR: PyTypeRef;
103-
}
104-
ERROR
105-
.get_or_init(|| {
106-
vm.ctx.new_class(
107-
Some("socket"),
108-
"herror",
109-
&vm.ctx.exceptions.os_error,
110-
Default::default(),
111-
)
112-
})
113-
.clone()
93+
vm.ctx.new_exception_type(
94+
"socket",
95+
"herror",
96+
Some(vec![vm.ctx.exceptions.os_error.clone()]),
97+
)
11498
}
115-
#[pyattr]
99+
#[pyattr(once)]
116100
fn gaierror(vm: &VirtualMachine) -> PyTypeRef {
117-
rustpython_common::static_cell! {
118-
static ERROR: PyTypeRef;
119-
}
120-
ERROR
121-
.get_or_init(|| {
122-
vm.ctx.new_class(
123-
Some("socket"),
124-
"gaierror",
125-
&vm.ctx.exceptions.os_error,
126-
Default::default(),
127-
)
128-
})
129-
.clone()
101+
vm.ctx.new_exception_type(
102+
"socket",
103+
"gaierror",
104+
Some(vec![vm.ctx.exceptions.os_error.clone()]),
105+
)
130106
}
131107

132108
#[pyfunction]

stdlib/src/ssl.rs

Lines changed: 35 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ mod _ssl {
3030
},
3131
socket::{self, PySocket},
3232
vm::{
33-
builtins::{PyBaseException, PyBaseExceptionRef, PyStrRef, PyType, PyTypeRef, PyWeak},
33+
builtins::{PyBaseExceptionRef, PyStrRef, PyType, PyTypeRef, PyWeak},
3434
exceptions,
3535
function::{
3636
ArgBytesLike, ArgCallable, ArgMemoryBuffer, ArgStrOrBytesLike, IntoPyException,
@@ -39,7 +39,7 @@ mod _ssl {
3939
stdlib::os::PyPathLike,
4040
types::Constructor,
4141
utils::{Either, ToCString},
42-
ItemProtocol, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, VirtualMachine,
42+
ItemProtocol, PyObjectRef, PyRef, PyResult, PyValue, VirtualMachine,
4343
},
4444
};
4545
use crossbeam_utils::atomic::AtomicCell;
@@ -174,90 +174,58 @@ mod _ssl {
174174
parse_version_info(openssl_api_version)
175175
}
176176

177-
#[pyattr(name = "SSLError")]
177+
/// An error occurred in the SSL implementation.
178+
#[pyattr(name = "SSLError", once)]
178179
fn ssl_error(vm: &VirtualMachine) -> PyTypeRef {
179-
rustpython_common::static_cell! {
180-
static ERROR: PyTypeRef;
181-
}
182-
ERROR
183-
.get_or_init(|| {
184-
PyType::new_simple_ref("ssl.SSLError", &vm.ctx.exceptions.os_error).unwrap()
185-
})
186-
.clone()
180+
vm.ctx.new_exception_type(
181+
"ssl",
182+
"SSLError",
183+
Some(vec![vm.ctx.exceptions.os_error.clone()]),
184+
)
187185
}
188186

189-
#[pyattr(name = "SSLCertVerificationError")]
187+
/// A certificate could not be verified.
188+
#[pyattr(name = "SSLCertVerificationError", once)]
190189
fn ssl_cert_verification_error(vm: &VirtualMachine) -> PyTypeRef {
191-
rustpython_common::static_cell! {
192-
static ERROR: PyTypeRef;
193-
}
194-
ERROR
195-
.get_or_init(|| {
196-
let ssl_error = ssl_error(vm);
197-
PyType::new_ref(
198-
"ssl.SSLCertVerificationError",
199-
vec![ssl_error, vm.ctx.exceptions.value_error.clone()],
200-
Default::default(),
201-
PyBaseException::make_slots(),
202-
vm.ctx.types.type_type.clone(),
203-
)
204-
.unwrap()
205-
})
206-
.clone()
190+
vm.ctx.new_exception_type(
191+
"ssl",
192+
"SSLCertVerificationError",
193+
Some(vec![ssl_error(vm), vm.ctx.exceptions.value_error.clone()]),
194+
)
207195
}
208196

209-
#[pyattr(name = "SSLZeroReturnError")]
197+
/// SSL/TLS session closed cleanly.
198+
#[pyattr(name = "SSLZeroReturnError", once)]
210199
fn ssl_zero_return_error(vm: &VirtualMachine) -> PyTypeRef {
211-
rustpython_common::static_cell! {
212-
static ERROR: PyTypeRef;
213-
}
214-
ERROR
215-
.get_or_init(|| {
216-
PyType::new_simple_ref("ssl.SSLZeroReturnError", &ssl_error(vm)).unwrap()
217-
})
218-
.clone()
200+
vm.ctx
201+
.new_exception_type("ssl", "SSLZeroReturnError", Some(vec![ssl_error(vm)]))
219202
}
220203

221-
#[pyattr(name = "SSLWantReadError")]
204+
/// Non-blocking SSL socket needs to read more data before the requested operation can be completed.
205+
#[pyattr(name = "SSLWantReadError", once)]
222206
fn ssl_want_read_error(vm: &VirtualMachine) -> PyTypeRef {
223-
rustpython_common::static_cell! {
224-
static ERROR: PyTypeRef;
225-
}
226-
ERROR
227-
.get_or_init(|| PyType::new_simple_ref("ssl.SSLWantReadError", &ssl_error(vm)).unwrap())
228-
.clone()
207+
vm.ctx
208+
.new_exception_type("ssl", "SSLWantReadError", Some(vec![ssl_error(vm)]))
229209
}
230210

231-
#[pyattr(name = "SSLWantWriteError")]
211+
/// Non-blocking SSL socket needs to write more data before the requested operation can be completed.
212+
#[pyattr(name = "SSLWantWriteError", once)]
232213
fn ssl_want_write_error(vm: &VirtualMachine) -> PyTypeRef {
233-
rustpython_common::static_cell! {
234-
static ERROR: PyTypeRef;
235-
}
236-
ERROR
237-
.get_or_init(|| {
238-
PyType::new_simple_ref("ssl.SSLWantWriteError", &ssl_error(vm)).unwrap()
239-
})
240-
.clone()
214+
vm.ctx
215+
.new_exception_type("ssl", "SSLWantWriteError", Some(vec![ssl_error(vm)]))
241216
}
242217

243-
#[pyattr(name = "SSLSyscallError")]
218+
/// System error when attempting SSL operation.
219+
#[pyattr(name = "SSLSyscallError", once)]
244220
fn ssl_syscall_error(vm: &VirtualMachine) -> PyTypeRef {
245-
rustpython_common::static_cell! {
246-
static ERROR: PyTypeRef;
247-
}
248-
ERROR
249-
.get_or_init(|| PyType::new_simple_ref("ssl.SSLSyscallError", &ssl_error(vm)).unwrap())
250-
.clone()
221+
vm.ctx
222+
.new_exception_type("ssl", "SSLSyscallError", Some(vec![ssl_error(vm)]))
251223
}
252224

253-
#[pyattr(name = "SSLEOFError")]
225+
/// SSL/TLS connection terminated abruptly.
226+
#[pyattr(name = "SSLEOFError", once)]
254227
fn ssl_eof_error(vm: &VirtualMachine) -> PyTypeRef {
255-
rustpython_common::static_cell! {
256-
static ERROR: PyTypeRef;
257-
}
258-
ERROR
259-
.get_or_init(|| PyType::new_simple_ref("ssl.SSLEOFError", &ssl_error(vm)).unwrap())
260-
.clone()
228+
PyType::new_simple_ref("ssl.SSLEOFError", &ssl_error(vm)).unwrap()
261229
}
262230

263231
type OpensslVersionInfo = (u8, u8, u8, u8, u8);

0 commit comments

Comments
 (0)