Skip to content

Commit e90ce31

Browse files
authored
Improves F8 conversion (#40)
1 parent 85e2e52 commit e90ce31

File tree

3 files changed

+46
-13
lines changed

3 files changed

+46
-13
lines changed
Binary file not shown.

_unittests/ut_validation/test_f8.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,32 @@ def test_fe4m3fn_to_float32_bug(self):
12201220
continue
12211221
raise AssertionError(f"Unexpected value for pt={pt}.")
12221222

1223+
def test_inf(self):
1224+
for x, e in [(numpy.float32(numpy.inf), 126), (numpy.float32(-numpy.inf), 254)]:
1225+
f8 = float32_to_fe4m3(x)
1226+
self.assertEqual(e, f8)
1227+
1228+
def test_nan(self):
1229+
expected = 127
1230+
values = [
1231+
(
1232+
None,
1233+
int.from_bytes(struct.pack("<f", numpy.float32(numpy.nan)), "little"),
1234+
numpy.float32(numpy.nan),
1235+
expected,
1236+
)
1237+
]
1238+
for i in range(0, 23):
1239+
v = 0x7F800000 | (1 << i)
1240+
f = numpy.uint32(v).view(numpy.float32)
1241+
values.append((i, v, f, expected))
1242+
values.append((i, v, -f, expected | 128))
1243+
1244+
for i, v, x, e in values:
1245+
with self.subTest(x=x, e=e, h=hex(v), i=i):
1246+
f8 = float32_to_fe4m3(x)
1247+
self.assertEqual(e, f8)
1248+
12231249

12241250
if __name__ == "__main__":
1225-
TestF8().test_fe4m3fn_to_float32_bug()
12261251
unittest.main(verbosity=2)

onnx_array_api/validation/f8.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -503,15 +503,18 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
503503
"""
504504
if not fn:
505505
raise NotImplementedError("fn=False is not implemented.")
506-
b = int.from_bytes(struct.pack("<f", numpy.float32(x)), "little")
506+
if not isinstance(x, numpy.float32):
507+
x = numpy.float32(x)
508+
b = int.from_bytes(struct.pack("<f", x), "little")
507509
ret = (b & 0x80000000) >> 24 # sign
508510
if uz:
509-
if (b & 0x7FC00000) == 0x7FC00000:
510-
return 0x80
511-
if numpy.isinf(x):
511+
if (b & 0x7FFFFFFF) == 0x7F800000:
512+
# infinity
512513
if saturate:
513514
return ret | 127
514515
return 0x80
516+
if (b & 0x7F800000) == 0x7F800000:
517+
return 0x80
515518
e = (b & 0x7F800000) >> 23 # exponent
516519
m = b & 0x007FFFFF # mantissa
517520

@@ -558,12 +561,14 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
558561
ret = 0
559562
return int(ret)
560563
else:
561-
if (b & 0x7FC00000) == 0x7FC00000:
562-
return 0x7F | ret
563-
if numpy.isinf(x):
564+
if (b & 0x7FFFFFFF) == 0x7F800000:
565+
# infinity
564566
if saturate:
565567
return ret | 126
566568
return 0x7F | ret
569+
if (b & 0x7F800000) == 0x7F800000:
570+
# non
571+
return 0x7F | ret
567572
e = (b & 0x7F800000) >> 23 # exponent
568573
m = b & 0x007FFFFF # mantissa
569574

@@ -624,13 +629,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
624629
ret = (b & 0x80000000) >> 24 # sign
625630

626631
if fn and uz:
627-
if (b & 0x7FC00000) == 0x7FC00000:
628-
return 0x80
629632
if (b & 0x7FFFFFFF) == 0x7F800000:
630633
# inf
631634
if saturate:
632635
return ret | 0x7F
633636
return 0x80
637+
if (b & 0x7F800000) == 0x7F800000:
638+
# nan
639+
return 0x80
634640
e = (b & 0x7F800000) >> 23 # exponent
635641
m = b & 0x007FFFFF # mantissa
636642

@@ -675,12 +681,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
675681
ret = 0
676682
return int(ret)
677683
elif not fn and not uz:
678-
if (b & 0x7FC00000) == 0x7FC00000:
679-
return 0x7F | ret
680-
if numpy.isinf(x):
684+
if (b & 0x7FFFFFFF) == 0x7F800000:
685+
# inf
681686
if saturate:
682687
return 0x7B | ret
683688
return 0x7C | ret
689+
if (b & 0x7F800000) == 0x7F800000:
690+
# nan
691+
return 0x7F | ret
684692
e = (b & 0x7F800000) >> 23 # exponent
685693
m = b & 0x007FFFFF # mantissa
686694

0 commit comments

Comments
 (0)