1
1
import os
2
2
import pprint
3
+ import struct
3
4
import unittest
4
5
import warnings
5
6
import numpy
6
7
import pandas
8
+ from onnx import TensorProto
7
9
from onnx_array_api .validation .f8 import (
8
10
CastFloat8 ,
9
11
UndefinedCastError ,
@@ -285,6 +287,15 @@ def test_search_float32_into_fe4m3fn(self):
285
287
ok = "" if b == nf else "WRONG" ,
286
288
true = value ,
287
289
add = add ,
290
+ exponent = (
291
+ int .from_bytes (
292
+ struct .pack ("<f" , numpy .float32 (v )), "little"
293
+ )
294
+ & 0x7F800000
295
+ )
296
+ >> 23 ,
297
+ d1 = v - fe4m3_to_float32_float (nf ),
298
+ d2 = v - fe4m3_to_float32_float (b ),
288
299
)
289
300
)
290
301
if wrong > 0 :
@@ -449,10 +460,13 @@ def test_search_e4m3_pow(self):
449
460
continue
450
461
r2 = float32_to_fe4m3 (v )
451
462
if r1 != r2 :
463
+ ex = abs (v - fe4m3_to_float32 (r1 )) == abs (v - fe4m3_to_float32 (r2 ))
452
464
raise AssertionError (
453
465
f"p={ p } , v={ v } , "
454
466
f"search={ r1 } :{ display_fe4m3 (r1 )} ={ fe4m3_to_float32 (r1 )} != "
455
- f"bit={ r2 } :{ display_fe4m3 (r2 )} ={ fe4m3_to_float32 (r2 )} "
467
+ f"bit={ r2 } :{ display_fe4m3 (r2 )} ={ fe4m3_to_float32 (r2 )} "
468
+ f"d1={ v - fe4m3_to_float32 (r1 )} d2={ v - fe4m3_to_float32 (r2 )} "
469
+ f"|d1|==|d2|={ ex } "
456
470
)
457
471
for p in range (1 , 40 ):
458
472
v = - (2 ** (- p ))
@@ -462,10 +476,13 @@ def test_search_e4m3_pow(self):
462
476
continue
463
477
r2 = float32_to_fe4m3 (v )
464
478
if r1 != r2 :
479
+ ex = abs (v - fe4m3_to_float32 (r1 )) == abs (v - fe4m3_to_float32 (r2 ))
465
480
raise AssertionError (
466
481
f"p={ p } , v={ v } , "
467
482
f"search={ r1 } :{ display_fe4m3 (r1 )} ={ fe4m3_to_float32 (r1 )} != "
468
- f"bit={ r2 } :{ display_fe4m3 (r2 )} ={ fe4m3_to_float32 (r2 )} "
483
+ f"bit={ r2 } :{ display_fe4m3 (r2 )} ={ fe4m3_to_float32 (r2 )} "
484
+ f"d1={ v - fe4m3_to_float32 (r1 )} d2={ v - fe4m3_to_float32 (r2 )} "
485
+ f"|d1|==|d2|={ ex } "
469
486
)
470
487
471
488
def test_search_e5m2_pow (self ):
@@ -478,10 +495,13 @@ def test_search_e5m2_pow(self):
478
495
continue
479
496
r2 = float32_to_fe5m2 (v )
480
497
if r1 != r2 :
498
+ ex = abs (v - fe5m2_to_float32 (r1 )) == abs (v - fe5m2_to_float32 (r2 ))
481
499
raise AssertionError (
482
500
f"p={ p } , v={ v } , "
483
501
f"search={ r1 } :{ display_fe5m2 (r1 )} ={ fe5m2_to_float32 (r1 )} != "
484
- f"bit={ r2 } :{ display_fe5m2 (r2 )} ={ fe5m2_to_float32 (r2 )} "
502
+ f"bit={ r2 } :{ display_fe5m2 (r2 )} ={ fe5m2_to_float32 (r2 )} "
503
+ f"d1={ v - fe4m3_to_float32 (r1 )} d2={ v - fe5m2_to_float32 (r2 )} "
504
+ f"|d1|==|d2|={ ex } "
485
505
)
486
506
for p in range (1 , 40 ):
487
507
v = - (2 ** (- p ))
@@ -491,10 +511,13 @@ def test_search_e5m2_pow(self):
491
511
continue
492
512
r2 = float32_to_fe5m2 (v )
493
513
if r1 != r2 :
514
+ ex = abs (v - fe5m2_to_float32 (r1 )) == abs (v - fe5m2_to_float32 (r2 ))
494
515
raise AssertionError (
495
516
f"p={ p } , v={ v } , "
496
517
f"search={ r1 } :{ display_fe5m2 (r1 )} ={ fe5m2_to_float32 (r1 )} != "
497
- f"bit={ r2 } :{ display_fe5m2 (r2 )} ={ fe5m2_to_float32 (r2 )} "
518
+ f"bit={ r2 } :{ display_fe5m2 (r2 )} ={ fe5m2_to_float32 (r2 )} "
519
+ f"d1={ v - fe4m3_to_float32 (r1 )} d2={ v - fe5m2_to_float32 (r2 )} "
520
+ f"|d1|==|d2|={ ex } "
498
521
)
499
522
500
523
def test_float32_to_fe4m3fn_inf (self ):
@@ -1152,13 +1175,50 @@ def test_float8_e5m2fnuz_negative_nan(self):
1152
1175
self .assertTrue (numpy .isnan (back ))
1153
1176
1154
1177
def test_fe4m3fn_to_float32_bug (self ):
1155
- cases = [(1.8131605 , 1.875 )]
1156
- for val , expected in cases :
1157
- with self .subTest (value = val , expected = expected ):
1158
- res = fe4m3_to_float32 (search_float32_into_fe4m3 (val ))
1159
- self .assertEqual (expected , res )
1160
- res = fe4m3_to_float32 (float32_to_fe4m3 (val ))
1161
- self .assertEqual (expected , res )
1178
+ cases = [
1179
+ (0.00439453125 , 0.00390625 , TensorProto .FLOAT8E4M3FN ),
1180
+ (0.005859375 , 0.005859375 , TensorProto .FLOAT8E4M3FN ),
1181
+ (0.005759375 , 0.005859375 , TensorProto .FLOAT8E4M3FN ),
1182
+ (0.0046875 , 0.00390625 , TensorProto .FLOAT8E4M3FN ),
1183
+ (0.001953125 , 0.001953125 , TensorProto .FLOAT8E4M3FN ),
1184
+ (0.0029296875 , 0.00390625 , TensorProto .FLOAT8E4M3FN ),
1185
+ (0.002053125 , 0.001953125 , TensorProto .FLOAT8E4M3FN ),
1186
+ (0.00234375 , 0.001953125 , TensorProto .FLOAT8E4M3FN ),
1187
+ (0.0087890625 , 0.0078125 , TensorProto .FLOAT8E4M3FN ),
1188
+ (0.001171875 , 0.001953125 , TensorProto .FLOAT8E4M3FN ),
1189
+ (1.8131605 , 1.875 , TensorProto .FLOAT8E4M3FN ),
1190
+ (- 100 , - 96 , TensorProto .FLOAT8E4M3FNUZ ),
1191
+ (416 , 384 , TensorProto .FLOAT8E5M2FNUZ ),
1192
+ ]
1193
+ for val , expected , pt in cases :
1194
+ with self .subTest (value = val , expected = expected , proto = pt ):
1195
+ if pt == TensorProto .FLOAT8E4M3FN :
1196
+ res = fe4m3_to_float32 (search_float32_into_fe4m3 (val ))
1197
+ self .assertEqual (expected , res )
1198
+ res = fe4m3_to_float32 (float32_to_fe4m3 (val ))
1199
+ self .assertEqual (expected , res )
1200
+ continue
1201
+ if pt == TensorProto .FLOAT8E4M3FNUZ :
1202
+ res = fe4m3_to_float32 (
1203
+ search_float32_into_fe4m3 (val , uz = True ), uz = True
1204
+ )
1205
+ self .assertEqual (expected , res )
1206
+ res = fe4m3_to_float32 (float32_to_fe4m3 (val , uz = True ), uz = True )
1207
+ self .assertEqual (expected , res )
1208
+ continue
1209
+ if pt == TensorProto .FLOAT8E5M2FNUZ :
1210
+ res = fe5m2_to_float32 (
1211
+ search_float32_into_fe5m2 (val , fn = True , uz = True ),
1212
+ fn = True ,
1213
+ uz = True ,
1214
+ )
1215
+ self .assertEqual (expected , res )
1216
+ res = fe5m2_to_float32 (
1217
+ float32_to_fe5m2 (val , fn = True , uz = True ), fn = True , uz = True
1218
+ )
1219
+ self .assertEqual (expected , res )
1220
+ continue
1221
+ raise AssertionError (f"Unexpected value for pt={ pt } ." )
1162
1222
1163
1223
1164
1224
if __name__ == "__main__" :
0 commit comments