diff --git a/_unittests/ut_validation/test_f8.py b/_unittests/ut_validation/test_f8.py index a05caef..984765c 100644 --- a/_unittests/ut_validation/test_f8.py +++ b/_unittests/ut_validation/test_f8.py @@ -1,9 +1,11 @@ import os import pprint +import struct import unittest import warnings import numpy import pandas +from onnx import TensorProto from onnx_array_api.validation.f8 import ( CastFloat8, UndefinedCastError, @@ -285,6 +287,15 @@ def test_search_float32_into_fe4m3fn(self): ok="" if b == nf else "WRONG", true=value, add=add, + exponent=( + int.from_bytes( + struct.pack("> 23, + d1=v - fe4m3_to_float32_float(nf), + d2=v - fe4m3_to_float32_float(b), ) ) if wrong > 0: @@ -449,10 +460,13 @@ def test_search_e4m3_pow(self): continue r2 = float32_to_fe4m3(v) if r1 != r2: + ex = abs(v - fe4m3_to_float32(r1)) == abs(v - fe4m3_to_float32(r2)) raise AssertionError( f"p={p}, v={v}, " f"search={r1}:{display_fe4m3(r1)}={fe4m3_to_float32(r1)} != " - f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)}" + f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)} " + f"d1={v - fe4m3_to_float32(r1)} d2={v - fe4m3_to_float32(r2)} " + f"|d1|==|d2|={ex}" ) for p in range(1, 40): v = -(2 ** (-p)) @@ -462,10 +476,13 @@ def test_search_e4m3_pow(self): continue r2 = float32_to_fe4m3(v) if r1 != r2: + ex = abs(v - fe4m3_to_float32(r1)) == abs(v - fe4m3_to_float32(r2)) raise AssertionError( f"p={p}, v={v}, " f"search={r1}:{display_fe4m3(r1)}={fe4m3_to_float32(r1)} != " - f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)}" + f"bit={r2}:{display_fe4m3(r2)}={fe4m3_to_float32(r2)} " + f"d1={v - fe4m3_to_float32(r1)} d2={v - fe4m3_to_float32(r2)} " + f"|d1|==|d2|={ex}" ) def test_search_e5m2_pow(self): @@ -478,10 +495,13 @@ def test_search_e5m2_pow(self): continue r2 = float32_to_fe5m2(v) if r1 != r2: + ex = abs(v - fe5m2_to_float32(r1)) == abs(v - fe5m2_to_float32(r2)) raise AssertionError( f"p={p}, v={v}, " f"search={r1}:{display_fe5m2(r1)}={fe5m2_to_float32(r1)} != " - f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)}" + f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)} " + f"d1={v - fe4m3_to_float32(r1)} d2={v - fe5m2_to_float32(r2)} " + f"|d1|==|d2|={ex}" ) for p in range(1, 40): v = -(2 ** (-p)) @@ -491,10 +511,13 @@ def test_search_e5m2_pow(self): continue r2 = float32_to_fe5m2(v) if r1 != r2: + ex = abs(v - fe5m2_to_float32(r1)) == abs(v - fe5m2_to_float32(r2)) raise AssertionError( f"p={p}, v={v}, " f"search={r1}:{display_fe5m2(r1)}={fe5m2_to_float32(r1)} != " - f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)}" + f"bit={r2}:{display_fe5m2(r2)}={fe5m2_to_float32(r2)} " + f"d1={v - fe4m3_to_float32(r1)} d2={v - fe5m2_to_float32(r2)} " + f"|d1|==|d2|={ex}" ) def test_float32_to_fe4m3fn_inf(self): @@ -1152,13 +1175,50 @@ def test_float8_e5m2fnuz_negative_nan(self): self.assertTrue(numpy.isnan(back)) def test_fe4m3fn_to_float32_bug(self): - cases = [(1.8131605, 1.875)] - for val, expected in cases: - with self.subTest(value=val, expected=expected): - res = fe4m3_to_float32(search_float32_into_fe4m3(val)) - self.assertEqual(expected, res) - res = fe4m3_to_float32(float32_to_fe4m3(val)) - self.assertEqual(expected, res) + cases = [ + (0.00439453125, 0.00390625, TensorProto.FLOAT8E4M3FN), + (0.005859375, 0.005859375, TensorProto.FLOAT8E4M3FN), + (0.005759375, 0.005859375, TensorProto.FLOAT8E4M3FN), + (0.0046875, 0.00390625, TensorProto.FLOAT8E4M3FN), + (0.001953125, 0.001953125, TensorProto.FLOAT8E4M3FN), + (0.0029296875, 0.00390625, TensorProto.FLOAT8E4M3FN), + (0.002053125, 0.001953125, TensorProto.FLOAT8E4M3FN), + (0.00234375, 0.001953125, TensorProto.FLOAT8E4M3FN), + (0.0087890625, 0.0078125, TensorProto.FLOAT8E4M3FN), + (0.001171875, 0.001953125, TensorProto.FLOAT8E4M3FN), + (1.8131605, 1.875, TensorProto.FLOAT8E4M3FN), + (-100, -96, TensorProto.FLOAT8E4M3FNUZ), + (416, 384, TensorProto.FLOAT8E5M2FNUZ), + ] + for val, expected, pt in cases: + with self.subTest(value=val, expected=expected, proto=pt): + if pt == TensorProto.FLOAT8E4M3FN: + res = fe4m3_to_float32(search_float32_into_fe4m3(val)) + self.assertEqual(expected, res) + res = fe4m3_to_float32(float32_to_fe4m3(val)) + self.assertEqual(expected, res) + continue + if pt == TensorProto.FLOAT8E4M3FNUZ: + res = fe4m3_to_float32( + search_float32_into_fe4m3(val, uz=True), uz=True + ) + self.assertEqual(expected, res) + res = fe4m3_to_float32(float32_to_fe4m3(val, uz=True), uz=True) + self.assertEqual(expected, res) + continue + if pt == TensorProto.FLOAT8E5M2FNUZ: + res = fe5m2_to_float32( + search_float32_into_fe5m2(val, fn=True, uz=True), + fn=True, + uz=True, + ) + self.assertEqual(expected, res) + res = fe5m2_to_float32( + float32_to_fe5m2(val, fn=True, uz=True), fn=True, uz=True + ) + self.assertEqual(expected, res) + continue + raise AssertionError(f"Unexpected value for pt={pt}.") if __name__ == "__main__": diff --git a/onnx_array_api/validation/f8.py b/onnx_array_api/validation/f8.py index 649eab5..0439048 100644 --- a/onnx_array_api/validation/f8.py +++ b/onnx_array_api/validation/f8.py @@ -399,10 +399,9 @@ def find_closest_value(value, sorted_values): if d1 < d2: return sorted_values[a][1] if d1 == d2: - raise UndefinedCastError( - f"Unable to cast {value}, d1={d1}, d2={d2}, " - f"options are {sorted_values[a][1]} and {sorted_values[b][1]}." - ) + # Applies rule tie to even + ca, cb = sorted_values[a][1], sorted_values[b][1] + return cb if ca & 1 == 1 else ca return sorted_values[b][1] return sorted_values[a][1] @@ -520,18 +519,27 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True if e < 116: pass elif e < 117: - ret |= 1 + # first positive number + if m > 0: + ret |= 1 if (m >> 23) & 1: # rounding ret += 1 - elif e < 120: # 127 - 8 + 1 - d = 119 - e - ret |= 1 << (2 - d) - ret |= m >> (21 + d) - if (m >> (20 + d)) & 1: + elif e < 120: + # denormalized number + ex = e - 119 + ret |= 1 << (2 + ex) + ret |= m >> (21 - ex) + mask = 1 << (20 - ex) + if m & mask and ( + ret & 1 + or m & (mask - 1) > 0 + or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) + ): # rounding ret += 1 - elif e < 135: # 127 + 8 + elif e < 135: + # normalized number ex = e - 119 # 127 - 8 if ex == 0: ret |= 0x4 @@ -539,9 +547,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True else: ret |= ex << 3 ret |= m >> 20 - if (m & 0x80000) and ( - (m & 0x100000) or (m & 0x7FFFF) - ): # round to nearest even + if m & 0x80000 and ((m & 0x100000) or (m & 0x7FFFF)): if (ret & 0x7F) < 0x7F: # rounding ret += 1 @@ -569,19 +575,25 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True if e < 117: pass elif e < 118: - ret |= 1 - if (m >> 23) & 1: - # rounding - ret += 1 - elif e < 121: # 127 - 7 + 1 - d = 120 - e - ret |= 1 << (2 - d) - ret |= m >> (21 + d) - if (m >> (20 + d)) & 1: + # first positive number + if m > 0: + ret |= 1 + elif e < 121: + # denormalized number + ex = e - 120 + ret |= 1 << (2 + ex) + ret |= m >> (21 - ex) + mask = 1 << (20 - ex) + if m & mask and ( + ret & 1 + or m & (mask - 1) > 0 + or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) + ): # rounding ret += 1 - elif e < 136: # 127 + 8 + 1 - ex = e - 120 # 127 - 7 + elif e < 136: + # normalized number + ex = e - 120 if ex == 0: ret |= 0x4 ret |= m >> 21 @@ -590,9 +602,7 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True ret |= m >> 20 if (ret & 0x7F) == 0x7F: ret &= 0xFE - if (m & 0x80000) and ( - (m & 0x100000) or (m & 0x7FFFF) - ): # round to nearest even + if (m & 0x80000) and ((m & 0x100000) or (m & 0x7FFFF)): if (ret & 0x7F) < 0x7E: # rounding ret += 1 @@ -633,25 +643,31 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru if e < 109: pass elif e < 110: - ret |= 1 + # first positive number + if m > 0: + ret |= 1 if (m >> 23) & 1: # rounding - # may be unused ret += 1 - elif e < 112: # 127 - 16 + 1 - d = 111 - e - ret |= 1 << (1 - d) - ret |= m >> (22 + d) - if (m >> (21 + d)) & 1: + elif e < 112: + # denormlized number + ex = e - 111 + ret |= 1 << (1 + ex) + ret |= m >> (22 - ex) + mask = 1 << (21 - ex) + if m & mask and ( + ret & 1 + or m & (mask - 1) > 0 + or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) + ): # rounding ret += 1 - elif e < 143: # 127 + 15 + 1 - ex = e - 111 # 127 - 16 + elif e < 143: + # normalized number + ex = e - 111 ret |= ex << 2 ret |= m >> 21 - if m & 0x100000 and ( - (m & 0xFFFFF) or (m & 0x200000) - ): # round to nearest even + if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)): if (ret & 0x7F) < 0x7F: # rounding ret += 1 @@ -681,25 +697,31 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru if e < 110: pass elif e < 111: - ret |= 1 + # first positive number + if m > 0: + ret |= 1 if (m >> 23) & 1: # rounding - # may be unused ret += 1 - elif e < 113: # 127 - 15 + 1 - d = 112 - e - ret |= 1 << (1 - d) - ret |= m >> (22 + d) - if (m >> (21 + d)) & 1: + elif e < 113: + # denormlized number + ex = e - 112 + ret |= 1 << (1 + ex) + ret |= m >> (22 - ex) + mask = 1 << (21 - ex) + if m & mask and ( + ret & 1 + or m & (mask - 1) > 0 + or (m & mask and m & (mask << 1) and m & (mask - 1) == 0) + ): # rounding ret += 1 - elif e < 143: # 127 + 15 + 1 - ex = e - 112 # 127 - 15 + elif e < 143: + # normalized number + ex = e - 112 ret |= ex << 2 ret |= m >> 21 - if m & 0x100000 and ( - (m & 0xFFFFF) or (m & 0x200000) - ): # round to nearest even + if m & 0x100000 and ((m & 0xFFFFF) or (m & 0x200000)): if (ret & 0x7F) < 0x7B: # rounding ret += 1