diff --git a/stdlib/src/unicodedata.rs b/stdlib/src/unicodedata.rs index b43db0dd1d..719d63c67b 100644 --- a/stdlib/src/unicodedata.rs +++ b/stdlib/src/unicodedata.rs @@ -1,7 +1,10 @@ /* Access to the unicode database. See also: https://docs.python.org/3/library/unicodedata.html */ -use crate::vm::{PyObjectRef, PyPayload, VirtualMachine}; +use crate::vm::{ + builtins::PyStr, convert::TryFromBorrowedObject, PyObject, PyObjectRef, PyPayload, PyResult, + VirtualMachine, +}; pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let module = unicodedata::make_module(vm); @@ -28,6 +31,30 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { module } +enum NormalizeForm { + Nfc, + Nfkc, + Nfd, + Nfkd, +} + +impl TryFromBorrowedObject for NormalizeForm { + fn try_from_borrowed_object(vm: &VirtualMachine, obj: &PyObject) -> PyResult { + obj.try_value_with( + |form: &PyStr| { + Ok(match form.as_str() { + "NFC" => NormalizeForm::Nfc, + "NFKC" => NormalizeForm::Nfkc, + "NFD" => NormalizeForm::Nfd, + "NFKD" => NormalizeForm::Nfkd, + _ => return Err(vm.new_value_error("invalid normalization form".to_owned())), + }) + }, + vm, + ) + } +} + #[pymodule] mod unicodedata { use crate::vm::{ @@ -63,11 +90,7 @@ mod unicodedata { vm.new_type_error("argument must be an unicode character, not str".to_owned()) })?; - if self.check_age(c) { - Ok(Some(c)) - } else { - Ok(None) - } + Ok(self.check_age(c).then_some(c)) } } @@ -112,40 +135,41 @@ mod unicodedata { } #[pymethod] - fn bidirectional(&self, character: PyStrRef, vm: &VirtualMachine) -> PyResult { + fn bidirectional( + &self, + character: PyStrRef, + vm: &VirtualMachine, + ) -> PyResult<&'static str> { let bidi = match self.extract_char(character, vm)? { Some(c) => BidiClass::of(c).abbr_name(), None => "", }; - Ok(bidi.to_owned()) + Ok(bidi) } /// NOTE: This function uses 9.0.0 database instead of 3.2.0 #[pymethod] - fn east_asian_width(&self, character: PyStrRef, vm: &VirtualMachine) -> PyResult { + fn east_asian_width( + &self, + character: PyStrRef, + vm: &VirtualMachine, + ) -> PyResult<&'static str> { Ok(self .extract_char(character, vm)? .map_or(EastAsianWidth::Neutral, |c| c.east_asian_width()) - .abbr_name() - .to_owned()) + .abbr_name()) } #[pymethod] - fn normalize( - &self, - form: PyStrRef, - unistr: PyStrRef, - vm: &VirtualMachine, - ) -> PyResult { + fn normalize(&self, form: super::NormalizeForm, unistr: PyStrRef) -> PyResult { + use super::NormalizeForm::*; let text = unistr.as_str(); - let normalized_text = match form.as_str() { - "NFC" => text.nfc().collect::(), - "NFKC" => text.nfkc().collect::(), - "NFD" => text.nfd().collect::(), - "NFKD" => text.nfkd().collect::(), - _ => return Err(vm.new_value_error("invalid normalization form".to_owned())), + let normalized_text = match form { + Nfc => text.nfc().collect::(), + Nfkc => text.nfkc().collect::(), + Nfd => text.nfd().collect::(), + Nfkd => text.nfkd().collect::(), }; - Ok(normalized_text) }