Skip to content

Commit 4f1281a

Browse files
got enums to work
1 parent 34b66c2 commit 4f1281a

File tree

4 files changed

+434
-61
lines changed

4 files changed

+434
-61
lines changed

Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ crate-type = ["rlib", "cdylib"]
1515

1616
[features]
1717
default = []
18-
python = ["pyo3"]
18+
python = ["pyo3", "once_cell"]
1919
python-extension = ["python", "pyo3/extension-module"]
2020

2121
[dependencies]
@@ -28,6 +28,10 @@ bitflags = "1.2"
2828
[dependencies.algebraics]
2929
version = ">= 0.1.2, < 0.2"
3030

31+
[dependencies.once_cell]
32+
version = "1.2"
33+
optional = true
34+
3135
[dependencies.pyo3]
3236
version = "0.8.2"
3337
optional = true

src/lib.rs

Lines changed: 72 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,18 @@ use std::ops::ShrAssign;
3333
use pyo3::prelude::*;
3434

3535
mod python;
36+
#[macro_use]
37+
mod python_macros;
3638

3739
#[cfg(test)]
3840
mod test_cases;
3941

40-
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
41-
#[repr(u8)]
42-
pub enum Sign {
43-
Positive = 0,
44-
Negative = 1,
42+
python_enum! {
43+
#[pyenum(module = simple_soft_float, repr = u8, test_fn = test_sign_enum)]
44+
pub enum Sign {
45+
Positive = 0,
46+
Negative = 1,
47+
}
4548
}
4649

4750
impl Neg for Sign {
@@ -123,14 +126,15 @@ impl_float_bits_type!(u32, to_u32);
123126
impl_float_bits_type!(u64, to_u64);
124127
impl_float_bits_type!(u128, to_u128);
125128

126-
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
127-
#[repr(u32)]
128-
pub enum RoundingMode {
129-
TiesToEven = 0,
130-
TowardZero = 1,
131-
TowardNegative = 2,
132-
TowardPositive = 3,
133-
TiesToAway = 4,
129+
python_enum! {
130+
#[pyenum(module = simple_soft_float, repr = u8, test_fn = test_rounding_mode_enum)]
131+
pub enum RoundingMode {
132+
TiesToEven = 0,
133+
TowardZero = 1,
134+
TowardNegative = 2,
135+
TowardPositive = 3,
136+
TiesToAway = 4,
137+
}
134138
}
135139

136140
impl Default for RoundingMode {
@@ -179,19 +183,23 @@ impl Default for TininessDetectionMode {
179183
}
180184
}
181185

182-
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
183-
pub enum BinaryNaNPropagationMode {
184-
AlwaysCanonical,
185-
FirstSecond,
186-
SecondFirst,
187-
FirstSecondPreferringSNaN,
188-
SecondFirstPreferringSNaN,
186+
python_enum! {
187+
#[pyenum(module = simple_soft_float, repr = u8, test_fn = test_binary_nan_propagation_mode_enum)]
188+
pub enum BinaryNaNPropagationMode {
189+
AlwaysCanonical,
190+
FirstSecond,
191+
SecondFirst,
192+
FirstSecondPreferringSNaN,
193+
SecondFirstPreferringSNaN,
194+
}
189195
}
190196

191-
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
192-
pub enum UnaryNaNPropagationMode {
193-
AlwaysCanonical,
194-
First,
197+
python_enum! {
198+
#[pyenum(module = simple_soft_float, repr = u8, test_fn = test_unary_nan_propagation_mode_enum)]
199+
pub enum UnaryNaNPropagationMode {
200+
AlwaysCanonical,
201+
First,
202+
}
195203
}
196204

197205
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
@@ -340,21 +348,23 @@ impl Default for TernaryNaNPropagationResults {
340348
}
341349
}
342350

343-
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
344-
pub enum TernaryNaNPropagationMode {
345-
AlwaysCanonical,
346-
FirstSecondThird,
347-
FirstThirdSecond,
348-
SecondFirstThird,
349-
SecondThirdFirst,
350-
ThirdFirstSecond,
351-
ThirdSecondFirst,
352-
FirstSecondThirdPreferringSNaN,
353-
FirstThirdSecondPreferringSNaN,
354-
SecondFirstThirdPreferringSNaN,
355-
SecondThirdFirstPreferringSNaN,
356-
ThirdFirstSecondPreferringSNaN,
357-
ThirdSecondFirstPreferringSNaN,
351+
python_enum! {
352+
#[pyenum(module = simple_soft_float, repr = u8, test_fn = test_ternary_nan_propagation_mode_enum)]
353+
pub enum TernaryNaNPropagationMode {
354+
AlwaysCanonical,
355+
FirstSecondThird,
356+
FirstThirdSecond,
357+
SecondFirstThird,
358+
SecondThirdFirst,
359+
ThirdFirstSecond,
360+
ThirdSecondFirst,
361+
FirstSecondThirdPreferringSNaN,
362+
FirstThirdSecondPreferringSNaN,
363+
SecondFirstThirdPreferringSNaN,
364+
SecondThirdFirstPreferringSNaN,
365+
ThirdFirstSecondPreferringSNaN,
366+
ThirdSecondFirstPreferringSNaN,
367+
}
358368
}
359369

360370
impl Default for TernaryNaNPropagationMode {
@@ -547,11 +557,13 @@ impl TernaryNaNPropagationMode {
547557
}
548558
}
549559

550-
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
551-
pub enum FMAInfZeroQNaNResult {
552-
FollowNaNPropagationMode,
553-
CanonicalAndGenerateInvalid,
554-
PropagateAndGenerateInvalid,
560+
python_enum! {
561+
#[pyenum(module = simple_soft_float, repr = u8, test_fn = test_fma_inf_zero_qnan_result_enum)]
562+
pub enum FMAInfZeroQNaNResult {
563+
FollowNaNPropagationMode,
564+
CanonicalAndGenerateInvalid,
565+
PropagateAndGenerateInvalid,
566+
}
555567
}
556568

557569
impl Default for FMAInfZeroQNaNResult {
@@ -560,10 +572,12 @@ impl Default for FMAInfZeroQNaNResult {
560572
}
561573
}
562574

563-
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
564-
pub enum FloatToFloatConversionNaNPropagationMode {
565-
AlwaysCanonical,
566-
RetainMostSignificantBits,
575+
python_enum! {
576+
#[pyenum(module = simple_soft_float, repr = u8, test_fn = test_float_to_float_conversion_nan_propagation_mode_enum)]
577+
pub enum FloatToFloatConversionNaNPropagationMode {
578+
AlwaysCanonical,
579+
RetainMostSignificantBits,
580+
}
567581
}
568582

569583
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, Default)]
@@ -746,6 +760,7 @@ impl Default for QuietNaNFormat {
746760
}
747761
}
748762

763+
#[cfg_attr(feature = "python", pyclass(module = "simple_soft_float"))]
749764
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
750765
pub struct PlatformProperties {
751766
pub canonical_nan_sign: Sign,
@@ -818,6 +833,15 @@ macro_rules! platform_properties_constants {
818833
pub const $ident:ident: PlatformProperties = $init:expr;
819834
)+
820835
) => {
836+
#[cfg(feature = "python")]
837+
impl PlatformProperties {
838+
fn add_to_module(py: Python, m: &PyModule) -> PyResult<()> {
839+
m.add_class::<Self>()?;
840+
$(m.add::<PyObject>(concat!("PlatformProperties_", stringify!($ident)), Self::$ident.into_py(py))?;)+
841+
Ok(())
842+
}
843+
}
844+
821845
impl PlatformProperties {
822846
$(
823847
$(#[$meta])*

src/python.rs

Lines changed: 131 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,18 @@
22
// See Notices.txt for copyright information
33
#![cfg(feature = "python")]
44

5+
use crate::python_macros::PythonEnum;
6+
use crate::BinaryNaNPropagationMode;
57
use crate::DynamicFloat;
8+
use crate::FMAInfZeroQNaNResult;
69
use crate::FloatProperties;
10+
use crate::FloatToFloatConversionNaNPropagationMode;
11+
use crate::PlatformProperties;
12+
use crate::RoundingMode;
13+
use crate::Sign;
714
use crate::StatusFlags;
15+
use crate::TernaryNaNPropagationMode;
16+
use crate::UnaryNaNPropagationMode;
817
use pyo3::basic::CompareOp;
918
use pyo3::exceptions::TypeError;
1019
use pyo3::prelude::*;
@@ -14,8 +23,23 @@ use pyo3::types::PyType;
1423
use pyo3::wrap_pymodule;
1524
use pyo3::PyNativeType;
1625
use pyo3::PyObjectProtocol;
26+
use std::borrow::Cow;
1727
use std::fmt::Write as _;
1828

29+
pub(crate) trait ToPythonRepr {
30+
fn to_python_repr(&self) -> Cow<str>;
31+
}
32+
33+
impl ToPythonRepr for bool {
34+
fn to_python_repr(&self) -> Cow<str> {
35+
if *self {
36+
Cow::Borrowed("True")
37+
} else {
38+
Cow::Borrowed("False")
39+
}
40+
}
41+
}
42+
1943
impl FromPyObject<'_> for StatusFlags {
2044
fn extract(object: &PyAny) -> PyResult<Self> {
2145
if !Self::get_python_class(object.py())
@@ -58,7 +82,7 @@ impl StatusFlags {
5882

5983
#[cfg(feature = "python")]
6084
#[pymodule]
61-
fn simple_soft_float(py: Python, m: &PyModule) -> PyResult<()> {
85+
pub(crate) fn simple_soft_float(py: Python, m: &PyModule) -> PyResult<()> {
6286
m.add_class::<DynamicFloat>()?;
6387
let dict = PyDict::new(py);
6488
fn make_src() -> Result<String, std::fmt::Error> {
@@ -77,6 +101,14 @@ fn simple_soft_float(py: Python, m: &PyModule) -> PyResult<()> {
77101
.into_py(py);
78102
m.add(StatusFlags::NAME, class)?;
79103
m.add_class::<FloatProperties>()?;
104+
BinaryNaNPropagationMode::add_to_module(py, m)?;
105+
FloatToFloatConversionNaNPropagationMode::add_to_module(py, m)?;
106+
FMAInfZeroQNaNResult::add_to_module(py, m)?;
107+
RoundingMode::add_to_module(py, m)?;
108+
Sign::add_to_module(py, m)?;
109+
TernaryNaNPropagationMode::add_to_module(py, m)?;
110+
UnaryNaNPropagationMode::add_to_module(py, m)?;
111+
PlatformProperties::add_to_module(py, m)?;
80112
Ok(())
81113
}
82114

@@ -98,28 +130,115 @@ impl DynamicFloat {
98130
// FIXME: finish
99131
}
100132

133+
macro_rules! impl_platform_properties_new {
134+
($($name:ident:$type:ty,)+) => {
135+
#[pymethods]
136+
impl PlatformProperties {
137+
#[new]
138+
#[args(
139+
value = "None",
140+
"*",
141+
$($name = "None"),+
142+
)]
143+
fn __new__(
144+
obj: &PyRawObject,
145+
value: Option<&Self>,
146+
$($name: Option<$type>,)+
147+
) {
148+
let mut value = value.copied().unwrap_or_default();
149+
$(value.$name = $name.unwrap_or(value.$name);)+
150+
obj.init(value);
151+
}
152+
}
153+
154+
#[pyproto]
155+
impl PyObjectProtocol for PlatformProperties {
156+
fn __repr__(&self) -> PyResult<String> {
157+
#![allow(unused_assignments)]
158+
let mut retval = String::new();
159+
write!(retval, "PlatformProperties(").unwrap();
160+
let mut first = true;
161+
$(
162+
if first {
163+
first = false;
164+
} else {
165+
write!(retval, ", ").unwrap();
166+
}
167+
write!(retval, concat!(stringify!($name), "={}"), self.$name.to_python_repr()).unwrap();
168+
)+
169+
write!(retval, ")").unwrap();
170+
Ok(retval)
171+
}
172+
fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult<PyObject> {
173+
if let Ok(rhs) = <&Self>::extract(other) {
174+
match op {
175+
CompareOp::Eq => return Ok((self == rhs).into_py(other.py())),
176+
CompareOp::Ne => return Ok((self != rhs).into_py(other.py())),
177+
CompareOp::Ge | CompareOp::Gt | CompareOp::Le | CompareOp::Lt => {}
178+
};
179+
}
180+
Ok(other.py().NotImplemented())
181+
}
182+
}
183+
};
184+
}
185+
186+
impl_platform_properties_new!(
187+
canonical_nan_sign: Sign,
188+
canonical_nan_mantissa_msb: bool,
189+
canonical_nan_mantissa_second_to_msb: bool,
190+
canonical_nan_mantissa_rest: bool,
191+
std_bin_ops_nan_propagation_mode: BinaryNaNPropagationMode,
192+
fma_nan_propagation_mode: TernaryNaNPropagationMode,
193+
fma_inf_zero_qnan_result: FMAInfZeroQNaNResult,
194+
round_to_integral_nan_propagation_mode: UnaryNaNPropagationMode,
195+
next_up_or_down_nan_propagation_mode: UnaryNaNPropagationMode,
196+
scale_b_nan_propagation_mode: UnaryNaNPropagationMode,
197+
sqrt_nan_propagation_mode: UnaryNaNPropagationMode,
198+
float_to_float_conversion_nan_propagation_mode: FloatToFloatConversionNaNPropagationMode,
199+
rsqrt_nan_propagation_mode: UnaryNaNPropagationMode,
200+
);
201+
101202
#[pymethods]
102-
impl FloatProperties {
203+
impl PlatformProperties {
103204
// FIXME: finish
104205
}
105206

207+
#[pymethods]
208+
impl FloatProperties {
209+
#[new]
210+
fn __new__(
211+
obj: &PyRawObject,
212+
exponent_width: usize,
213+
mantissa_width: usize,
214+
has_implicit_leading_bit: bool,
215+
has_sign_bit: bool,
216+
platform_properties: &PlatformProperties,
217+
) {
218+
obj.init(Self::new_with_extended_flags(
219+
exponent_width,
220+
mantissa_width,
221+
has_implicit_leading_bit,
222+
has_sign_bit,
223+
*platform_properties,
224+
));
225+
}
226+
}
227+
106228
#[pyproto]
107229
impl PyObjectProtocol for FloatProperties {
108230
fn __repr__(&self) -> PyResult<String> {
109231
Ok(format!("<{:?}>", self))
110232
}
111233
fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult<PyObject> {
112-
let inverted = match op {
113-
CompareOp::Eq => false,
114-
CompareOp::Ne => true,
115-
CompareOp::Ge | CompareOp::Gt | CompareOp::Le | CompareOp::Lt => {
116-
return Ok(other.py().NotImplemented());
117-
}
118-
};
119-
match <&FloatProperties>::extract(other) {
120-
Ok(v) => Ok(if inverted { self != v } else { self == v }.into_py(other.py())),
121-
Err(_) => Ok(other.py().NotImplemented()),
234+
if let Ok(rhs) = <&FloatProperties>::extract(other) {
235+
match op {
236+
CompareOp::Eq => return Ok((self == rhs).into_py(other.py())),
237+
CompareOp::Ne => return Ok((self != rhs).into_py(other.py())),
238+
CompareOp::Ge | CompareOp::Gt | CompareOp::Le | CompareOp::Lt => {}
239+
};
122240
}
241+
Ok(other.py().NotImplemented())
123242
}
124243
}
125244

0 commit comments

Comments
 (0)