Skip to content

Commit f7bc922

Browse files
authored
Fix float 8 cast (#37)
* fix f8 * fix * cleaning
1 parent 2fde01f commit f7bc922

File tree

2 files changed

+145
-63
lines changed

2 files changed

+145
-63
lines changed

_unittests/ut_validation/test_f8.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import os
22
import pprint
3+
import struct
34
import unittest
45
import warnings
56
import numpy
67
import pandas
8+
from onnx import TensorProto
79
from onnx_array_api.validation.f8 import (
810
CastFloat8,
911
UndefinedCastError,
@@ -285,6 +287,15 @@ def test_search_float32_into_fe4m3fn(self):
285287
ok="" if b == nf else "WRONG",
286288
true=value,
287289
add=add,
290+
exponent=(
291+
int.from_bytes(
292+
struct.pack("<f", numpy.float32(v)), "little"
293+
)
294+
& 0x7F800000
295+
)
296+
>> 23,
297+
d1=v - fe4m3_to_float32_float(nf),
298+
d2=v - fe4m3_to_float32_float(b),
288299
)
289300
)
290301
if wrong > 0:
@@ -449,10 +460,13 @@ def test_search_e4m3_pow(self):
449460
continue
450461
r2 = float32_to_fe4m3(v)
451462
if r1 != r2:
463+
ex = abs(v - fe4m3_to_float32(r1)) == abs(v - fe4m3_to_float32(r2))
452464
raise AssertionError(
453465
f"p={p}, v={v}, "
454466
f"search={r1}:{display_fe4m3(r1)}={fe4m3_to_float32(r1)} != "
455-
f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)}"
467+
f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)} "
468+
f"d1={v - fe4m3_to_float32(r1)} d2={v - fe4m3_to_float32(r2)} "
469+
f"|d1|==|d2|={ex}"
456470
)
457471
for p in range(1, 40):
458472
v = -(2 ** (-p))
@@ -462,10 +476,13 @@ def test_search_e4m3_pow(self):
462476
continue
463477
r2 = float32_to_fe4m3(v)
464478
if r1 != r2:
479+
ex = abs(v - fe4m3_to_float32(r1)) == abs(v - fe4m3_to_float32(r2))
465480
raise AssertionError(
466481
f"p={p}, v={v}, "
467482
f"search={r1}:{display_fe4m3(r1)}={fe4m3_to_float32(r1)} != "
468-
f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)}"
483+
f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)} "
484+
f"d1={v - fe4m3_to_float32(r1)} d2={v - fe4m3_to_float32(r2)} "
485+
f"|d1|==|d2|={ex}"
469486
)
470487

471488
def test_search_e5m2_pow(self):
@@ -478,10 +495,13 @@ def test_search_e5m2_pow(self):
478495
continue
479496
r2 = float32_to_fe5m2(v)
480497
if r1 != r2:
498+
ex = abs(v - fe5m2_to_float32(r1)) == abs(v - fe5m2_to_float32(r2))
481499
raise AssertionError(
482500
f"p={p}, v={v}, "
483501
f"search={r1}:{display_fe5m2(r1)}={fe5m2_to_float32(r1)} != "
484-
f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)}"
502+
f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)} "
503+
f"d1={v - fe4m3_to_float32(r1)} d2={v - fe5m2_to_float32(r2)} "
504+
f"|d1|==|d2|={ex}"
485505
)
486506
for p in range(1, 40):
487507
v = -(2 ** (-p))
@@ -491,10 +511,13 @@ def test_search_e5m2_pow(self):
491511
continue
492512
r2 = float32_to_fe5m2(v)
493513
if r1 != r2:
514+
ex = abs(v - fe5m2_to_float32(r1)) == abs(v - fe5m2_to_float32(r2))
494515
raise AssertionError(
495516
f"p={p}, v={v}, "
496517
f"search={r1}:{display_fe5m2(r1)}={fe5m2_to_float32(r1)} != "
497-
f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)}"
518+
f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)} "
519+
f"d1={v - fe4m3_to_float32(r1)} d2={v - fe5m2_to_float32(r2)} "
520+
f"|d1|==|d2|={ex}"
498521
)
499522

500523
def test_float32_to_fe4m3fn_inf(self):
@@ -1152,13 +1175,50 @@ def test_float8_e5m2fnuz_negative_nan(self):
11521175
self.assertTrue(numpy.isnan(back))
11531176

11541177
def test_fe4m3fn_to_float32_bug(self):
1155-
cases = [(1.8131605, 1.875)]
1156-
for val, expected in cases:
1157-
with self.subTest(value=val, expected=expected):
1158-
res = fe4m3_to_float32(search_float32_into_fe4m3(val))
1159-
self.assertEqual(expected, res)
1160-
res = fe4m3_to_float32(float32_to_fe4m3(val))
1161-
self.assertEqual(expected, res)
1178+
cases = [
1179+
(0.00439453125, 0.00390625, TensorProto.FLOAT8E4M3FN),
1180+
(0.005859375, 0.005859375, TensorProto.FLOAT8E4M3FN),
1181+
(0.005759375, 0.005859375, TensorProto.FLOAT8E4M3FN),
1182+
(0.0046875, 0.00390625, TensorProto.FLOAT8E4M3FN),
1183+
(0.001953125, 0.001953125, TensorProto.FLOAT8E4M3FN),
1184+
(0.0029296875, 0.00390625, TensorProto.FLOAT8E4M3FN),
1185+
(0.002053125, 0.001953125, TensorProto.FLOAT8E4M3FN),
1186+
(0.00234375, 0.001953125, TensorProto.FLOAT8E4M3FN),
1187+
(0.0087890625, 0.0078125, TensorProto.FLOAT8E4M3FN),
1188+
(0.001171875, 0.001953125, TensorProto.FLOAT8E4M3FN),
1189+
(1.8131605, 1.875, TensorProto.FLOAT8E4M3FN),
1190+
(-100, -96, TensorProto.FLOAT8E4M3FNUZ),
1191+
(416, 384, TensorProto.FLOAT8E5M2FNUZ),
1192+
]
1193+
for val, expected, pt in cases:
1194+
with self.subTest(value=val, expected=expected, proto=pt):
1195+
if pt == TensorProto.FLOAT8E4M3FN:
1196+
res = fe4m3_to_float32(search_float32_into_fe4m3(val))
1197+
self.assertEqual(expected, res)
1198+
res = fe4m3_to_float32(float32_to_fe4m3(val))
1199+
self.assertEqual(expected, res)
1200+
continue
1201+
if pt == TensorProto.FLOAT8E4M3FNUZ:
1202+
res = fe4m3_to_float32(
1203+
search_float32_into_fe4m3(val, uz=True), uz=True
1204+
)
1205+
self.assertEqual(expected, res)
1206+
res = fe4m3_to_float32(float32_to_fe4m3(val, uz=True), uz=True)
1207+
self.assertEqual(expected, res)
1208+
continue
1209+
if pt == TensorProto.FLOAT8E5M2FNUZ:
1210+
res = fe5m2_to_float32(
1211+
search_float32_into_fe5m2(val, fn=True, uz=True),
1212+
fn=True,
1213+
uz=True,
1214+
)
1215+
self.assertEqual(expected, res)
1216+
res = fe5m2_to_float32(
1217+
float32_to_fe5m2(val, fn=True, uz=True), fn=True, uz=True
1218+
)
1219+
self.assertEqual(expected, res)
1220+
continue
1221+
raise AssertionError(f"Unexpected value for pt={pt}.")
11621222

11631223

11641224
if __name__ == "__main__":

onnx_array_api/validation/f8.py

Lines changed: 74 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,9 @@ def find_closest_value(value, sorted_values):
399399
if d1 < d2:
400400
return sorted_values[a][1]
401401
if d1 == d2:
402-
raise UndefinedCastError(
403-
f"Unable to cast {value}, d1={d1}, d2={d2}, "
404-
f"options are {sorted_values[a][1]} and {sorted_values[b][1]}."
405-
)
402+
# Applies rule tie to even
403+
ca, cb = sorted_values[a][1], sorted_values[b][1]
404+
return cb if ca & 1 == 1 else ca
406405
return sorted_values[b][1]
407406
return sorted_values[a][1]
408407

@@ -520,28 +519,35 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
520519
if e < 116:
521520
pass
522521
elif e < 117:
523-
ret |= 1
522+
# first positive number
523+
if m > 0:
524+
ret |= 1
524525
if (m >> 23) & 1:
525526
# rounding
526527
ret += 1
527-
elif e < 120: # 127 - 8 + 1
528-
d = 119 - e
529-
ret |= 1 << (2 - d)
530-
ret |= m >> (21 + d)
531-
if (m >> (20 + d)) & 1:
528+
elif e < 120:
529+
# denormalized number
530+
ex = e - 119
531+
ret |= 1 << (2 + ex)
532+
ret |= m >> (21 - ex)
533+
mask = 1 << (20 - ex)
534+
if m & mask and (
535+
ret & 1
536+
or m & (mask - 1) > 0
537+
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
538+
):
532539
# rounding
533540
ret += 1
534-
elif e < 135: # 127 + 8
541+
elif e < 135:
542+
# normalized number
535543
ex = e - 119 # 127 - 8
536544
if ex == 0:
537545
ret |= 0x4
538546
ret |= m >> 21
539547
else:
540548
ret |= ex << 3
541549
ret |= m >> 20
542-
if (m & 0x80000) and (
543-
(m & 0x100000) or (m & 0x7FFFF)
544-
): # round to nearest even
550+
if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)):
545551
if (ret & 0x7F) < 0x7F:
546552
# rounding
547553
ret += 1
@@ -569,19 +575,25 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
569575
if e < 117:
570576
pass
571577
elif e < 118:
572-
ret |= 1
573-
if (m >> 23) & 1:
574-
# rounding
575-
ret += 1
576-
elif e < 121: # 127 - 7 + 1
577-
d = 120 - e
578-
ret |= 1 << (2 - d)
579-
ret |= m >> (21 + d)
580-
if (m >> (20 + d)) & 1:
578+
# first positive number
579+
if m > 0:
580+
ret |= 1
581+
elif e < 121:
582+
# denormalized number
583+
ex = e - 120
584+
ret |= 1 << (2 + ex)
585+
ret |= m >> (21 - ex)
586+
mask = 1 << (20 - ex)
587+
if m & mask and (
588+
ret & 1
589+
or m & (mask - 1) > 0
590+
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
591+
):
581592
# rounding
582593
ret += 1
583-
elif e < 136: # 127 + 8 + 1
584-
ex = e - 120 # 127 - 7
594+
elif e < 136:
595+
# normalized number
596+
ex = e - 120
585597
if ex == 0:
586598
ret |= 0x4
587599
ret |= m >> 21
@@ -590,9 +602,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
590602
ret |= m >> 20
591603
if (ret & 0x7F) == 0x7F:
592604
ret &= 0xFE
593-
if (m & 0x80000) and (
594-
(m & 0x100000) or (m & 0x7FFFF)
595-
): # round to nearest even
605+
if (m & 0x80000) and ((m & 0x100000) or (m & 0x7FFFF)):
596606
if (ret & 0x7F) < 0x7E:
597607
# rounding
598608
ret += 1
@@ -633,25 +643,31 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
633643
if e < 109:
634644
pass
635645
elif e < 110:
636-
ret |= 1
646+
# first positive number
647+
if m > 0:
648+
ret |= 1
637649
if (m >> 23) & 1:
638650
# rounding
639-
# may be unused
640651
ret += 1
641-
elif e < 112: # 127 - 16 + 1
642-
d = 111 - e
643-
ret |= 1 << (1 - d)
644-
ret |= m >> (22 + d)
645-
if (m >> (21 + d)) & 1:
652+
elif e < 112:
653+
# denormlized number
654+
ex = e - 111
655+
ret |= 1 << (1 + ex)
656+
ret |= m >> (22 - ex)
657+
mask = 1 << (21 - ex)
658+
if m & mask and (
659+
ret & 1
660+
or m & (mask - 1) > 0
661+
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
662+
):
646663
# rounding
647664
ret += 1
648-
elif e < 143: # 127 + 15 + 1
649-
ex = e - 111 # 127 - 16
665+
elif e < 143:
666+
# normalized number
667+
ex = e - 111
650668
ret |= ex << 2
651669
ret |= m >> 21
652-
if m & 0x100000 and (
653-
(m & 0xFFFFF) or (m & 0x200000)
654-
): # round to nearest even
670+
if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)):
655671
if (ret & 0x7F) < 0x7F:
656672
# rounding
657673
ret += 1
@@ -681,25 +697,31 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
681697
if e < 110:
682698
pass
683699
elif e < 111:
684-
ret |= 1
700+
# first positive number
701+
if m > 0:
702+
ret |= 1
685703
if (m >> 23) & 1:
686704
# rounding
687-
# may be unused
688705
ret += 1
689-
elif e < 113: # 127 - 15 + 1
690-
d = 112 - e
691-
ret |= 1 << (1 - d)
692-
ret |= m >> (22 + d)
693-
if (m >> (21 + d)) & 1:
706+
elif e < 113:
707+
# denormlized number
708+
ex = e - 112
709+
ret |= 1 << (1 + ex)
710+
ret |= m >> (22 - ex)
711+
mask = 1 << (21 - ex)
712+
if m & mask and (
713+
ret & 1
714+
or m & (mask - 1) > 0
715+
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
716+
):
694717
# rounding
695718
ret += 1
696-
elif e < 143: # 127 + 15 + 1
697-
ex = e - 112 # 127 - 15
719+
elif e < 143:
720+
# normalized number
721+
ex = e - 112
698722
ret |= ex << 2
699723
ret |= m >> 21
700-
if m & 0x100000 and (
701-
(m & 0xFFFFF) or (m & 0x200000)
702-
): # round to nearest even
724+
if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)):
703725
if (ret & 0x7F) < 0x7B:
704726
# rounding
705727
ret += 1

0 commit comments

Comments
 (0)