Skip to content

Fix float 8 cast #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 71 additions & 11 deletions _unittests/ut_validation/test_f8.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pprint
import struct
import unittest
import warnings
import numpy
import pandas
from onnx import TensorProto
from onnx_array_api.validation.f8 import (
CastFloat8,
UndefinedCastError,
Expand Down Expand Up @@ -285,6 +287,15 @@ def test_search_float32_into_fe4m3fn(self):
ok="" if b == nf else "WRONG",
true=value,
add=add,
exponent=(
int.from_bytes(
struct.pack("<f", numpy.float32(v)), "little"
)
& 0x7F800000
)
>> 23,
d1=v - fe4m3_to_float32_float(nf),
d2=v - fe4m3_to_float32_float(b),
)
)
if wrong > 0:
Expand Down Expand Up @@ -449,10 +460,13 @@ def test_search_e4m3_pow(self):
continue
r2 = float32_to_fe4m3(v)
if r1 != r2:
ex = abs(v - fe4m3_to_float32(r1)) == abs(v - fe4m3_to_float32(r2))
raise AssertionError(
f"p={p}, v={v}, "
f"search={r1}:{display_fe4m3(r1)}={fe4m3_to_float32(r1)} != "
f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)}"
f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)} "
f"d1={v - fe4m3_to_float32(r1)} d2={v - fe4m3_to_float32(r2)} "
f"|d1|==|d2|={ex}"
)
for p in range(1, 40):
v = -(2 ** (-p))
Expand All @@ -462,10 +476,13 @@ def test_search_e4m3_pow(self):
continue
r2 = float32_to_fe4m3(v)
if r1 != r2:
ex = abs(v - fe4m3_to_float32(r1)) == abs(v - fe4m3_to_float32(r2))
raise AssertionError(
f"p={p}, v={v}, "
f"search={r1}:{display_fe4m3(r1)}={fe4m3_to_float32(r1)} != "
f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)}"
f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)} "
f"d1={v - fe4m3_to_float32(r1)} d2={v - fe4m3_to_float32(r2)} "
f"|d1|==|d2|={ex}"
)

def test_search_e5m2_pow(self):
Expand All @@ -478,10 +495,13 @@ def test_search_e5m2_pow(self):
continue
r2 = float32_to_fe5m2(v)
if r1 != r2:
ex = abs(v - fe5m2_to_float32(r1)) == abs(v - fe5m2_to_float32(r2))
raise AssertionError(
f"p={p}, v={v}, "
f"search={r1}:{display_fe5m2(r1)}={fe5m2_to_float32(r1)} != "
f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)}"
f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)} "
f"d1={v - fe4m3_to_float32(r1)} d2={v - fe5m2_to_float32(r2)} "
f"|d1|==|d2|={ex}"
)
for p in range(1, 40):
v = -(2 ** (-p))
Expand All @@ -491,10 +511,13 @@ def test_search_e5m2_pow(self):
continue
r2 = float32_to_fe5m2(v)
if r1 != r2:
ex = abs(v - fe5m2_to_float32(r1)) == abs(v - fe5m2_to_float32(r2))
raise AssertionError(
f"p={p}, v={v}, "
f"search={r1}:{display_fe5m2(r1)}={fe5m2_to_float32(r1)} != "
f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)}"
f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)} "
f"d1={v - fe4m3_to_float32(r1)} d2={v - fe5m2_to_float32(r2)} "
f"|d1|==|d2|={ex}"
)

def test_float32_to_fe4m3fn_inf(self):
Expand Down Expand Up @@ -1152,13 +1175,50 @@ def test_float8_e5m2fnuz_negative_nan(self):
self.assertTrue(numpy.isnan(back))

def test_fe4m3fn_to_float32_bug(self):
cases = [(1.8131605, 1.875)]
for val, expected in cases:
with self.subTest(value=val, expected=expected):
res = fe4m3_to_float32(search_float32_into_fe4m3(val))
self.assertEqual(expected, res)
res = fe4m3_to_float32(float32_to_fe4m3(val))
self.assertEqual(expected, res)
cases = [
(0.00439453125, 0.00390625, TensorProto.FLOAT8E4M3FN),
(0.005859375, 0.005859375, TensorProto.FLOAT8E4M3FN),
(0.005759375, 0.005859375, TensorProto.FLOAT8E4M3FN),
(0.0046875, 0.00390625, TensorProto.FLOAT8E4M3FN),
(0.001953125, 0.001953125, TensorProto.FLOAT8E4M3FN),
(0.0029296875, 0.00390625, TensorProto.FLOAT8E4M3FN),
(0.002053125, 0.001953125, TensorProto.FLOAT8E4M3FN),
(0.00234375, 0.001953125, TensorProto.FLOAT8E4M3FN),
(0.0087890625, 0.0078125, TensorProto.FLOAT8E4M3FN),
(0.001171875, 0.001953125, TensorProto.FLOAT8E4M3FN),
(1.8131605, 1.875, TensorProto.FLOAT8E4M3FN),
(-100, -96, TensorProto.FLOAT8E4M3FNUZ),
(416, 384, TensorProto.FLOAT8E5M2FNUZ),
]
for val, expected, pt in cases:
with self.subTest(value=val, expected=expected, proto=pt):
if pt == TensorProto.FLOAT8E4M3FN:
res = fe4m3_to_float32(search_float32_into_fe4m3(val))
self.assertEqual(expected, res)
res = fe4m3_to_float32(float32_to_fe4m3(val))
self.assertEqual(expected, res)
continue
if pt == TensorProto.FLOAT8E4M3FNUZ:
res = fe4m3_to_float32(
search_float32_into_fe4m3(val, uz=True), uz=True
)
self.assertEqual(expected, res)
res = fe4m3_to_float32(float32_to_fe4m3(val, uz=True), uz=True)
self.assertEqual(expected, res)
continue
if pt == TensorProto.FLOAT8E5M2FNUZ:
res = fe5m2_to_float32(
search_float32_into_fe5m2(val, fn=True, uz=True),
fn=True,
uz=True,
)
self.assertEqual(expected, res)
res = fe5m2_to_float32(
float32_to_fe5m2(val, fn=True, uz=True), fn=True, uz=True
)
self.assertEqual(expected, res)
continue
raise AssertionError(f"Unexpected value for pt={pt}.")


if __name__ == "__main__":
Expand Down
126 changes: 74 additions & 52 deletions onnx_array_api/validation/f8.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,9 @@ def find_closest_value(value, sorted_values):
if d1 < d2:
return sorted_values[a][1]
if d1 == d2:
raise UndefinedCastError(
f"Unable to cast {value}, d1={d1}, d2={d2}, "
f"options are {sorted_values[a][1]} and {sorted_values[b][1]}."
)
# Applies rule tie to even
ca, cb = sorted_values[a][1], sorted_values[b][1]
return cb if ca & 1 == 1 else ca
return sorted_values[b][1]
return sorted_values[a][1]

Expand Down Expand Up @@ -520,28 +519,35 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
if e < 116:
pass
elif e < 117:
ret |= 1
# first positive number
if m > 0:
ret |= 1
if (m >> 23) & 1:
# rounding
ret += 1
elif e < 120: # 127 - 8 + 1
d = 119 - e
ret |= 1 << (2 - d)
ret |= m >> (21 + d)
if (m >> (20 + d)) & 1:
elif e < 120:
# denormalized number
ex = e - 119
ret |= 1 << (2 + ex)
ret |= m >> (21 - ex)
mask = 1 << (20 - ex)
if m & mask and (
ret & 1
or m & (mask - 1) > 0
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
):
# rounding
ret += 1
elif e < 135: # 127 + 8
elif e < 135:
# normalized number
ex = e - 119 # 127 - 8
if ex == 0:
ret |= 0x4
ret |= m >> 21
else:
ret |= ex << 3
ret |= m >> 20
if (m & 0x80000) and (
(m & 0x100000) or (m & 0x7FFFF)
): # round to nearest even
if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)):
if (ret & 0x7F) < 0x7F:
# rounding
ret += 1
Expand Down Expand Up @@ -569,19 +575,25 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
if e < 117:
pass
elif e < 118:
ret |= 1
if (m >> 23) & 1:
# rounding
ret += 1
elif e < 121: # 127 - 7 + 1
d = 120 - e
ret |= 1 << (2 - d)
ret |= m >> (21 + d)
if (m >> (20 + d)) & 1:
# first positive number
if m > 0:
ret |= 1
elif e < 121:
# denormalized number
ex = e - 120
ret |= 1 << (2 + ex)
ret |= m >> (21 - ex)
mask = 1 << (20 - ex)
if m & mask and (
ret & 1
or m & (mask - 1) > 0
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
):
# rounding
ret += 1
elif e < 136: # 127 + 8 + 1
ex = e - 120 # 127 - 7
elif e < 136:
# normalized number
ex = e - 120
if ex == 0:
ret |= 0x4
ret |= m >> 21
Expand All @@ -590,9 +602,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
ret |= m >> 20
if (ret & 0x7F) == 0x7F:
ret &= 0xFE
if (m & 0x80000) and (
(m & 0x100000) or (m & 0x7FFFF)
): # round to nearest even
if (m & 0x80000) and ((m & 0x100000) or (m & 0x7FFFF)):
if (ret & 0x7F) < 0x7E:
# rounding
ret += 1
Expand Down Expand Up @@ -633,25 +643,31 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
if e < 109:
pass
elif e < 110:
ret |= 1
# first positive number
if m > 0:
ret |= 1
if (m >> 23) & 1:
# rounding
# may be unused
ret += 1
elif e < 112: # 127 - 16 + 1
d = 111 - e
ret |= 1 << (1 - d)
ret |= m >> (22 + d)
if (m >> (21 + d)) & 1:
elif e < 112:
# denormlized number
ex = e - 111
ret |= 1 << (1 + ex)
ret |= m >> (22 - ex)
mask = 1 << (21 - ex)
if m & mask and (
ret & 1
or m & (mask - 1) > 0
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
):
# rounding
ret += 1
elif e < 143: # 127 + 15 + 1
ex = e - 111 # 127 - 16
elif e < 143:
# normalized number
ex = e - 111
ret |= ex << 2
ret |= m >> 21
if m & 0x100000 and (
(m & 0xFFFFF) or (m & 0x200000)
): # round to nearest even
if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)):
if (ret & 0x7F) < 0x7F:
# rounding
ret += 1
Expand Down Expand Up @@ -681,25 +697,31 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
if e < 110:
pass
elif e < 111:
ret |= 1
# first positive number
if m > 0:
ret |= 1
if (m >> 23) & 1:
# rounding
# may be unused
ret += 1
elif e < 113: # 127 - 15 + 1
d = 112 - e
ret |= 1 << (1 - d)
ret |= m >> (22 + d)
if (m >> (21 + d)) & 1:
elif e < 113:
# denormlized number
ex = e - 112
ret |= 1 << (1 + ex)
ret |= m >> (22 - ex)
mask = 1 << (21 - ex)
if m & mask and (
ret & 1
or m & (mask - 1) > 0
or (m & mask and m & (mask << 1) and m & (mask - 1) == 0)
):
# rounding
ret += 1
elif e < 143: # 127 + 15 + 1
ex = e - 112 # 127 - 15
elif e < 143:
# normalized number
ex = e - 112
ret |= ex << 2
ret |= m >> 21
if m & 0x100000 and (
(m & 0xFFFFF) or (m & 0x200000)
): # round to nearest even
if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)):
if (ret & 0x7F) < 0x7B:
# rounding
ret += 1
Expand Down