@@ -503,15 +503,18 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
503
503
"""
504
504
if not fn :
505
505
raise NotImplementedError ("fn=False is not implemented." )
506
- b = int .from_bytes (struct .pack ("<f" , numpy .float32 (x )), "little" )
506
+ if not isinstance (x , numpy .float32 ):
507
+ x = numpy .float32 (x )
508
+ b = int .from_bytes (struct .pack ("<f" , x ), "little" )
507
509
ret = (b & 0x80000000 ) >> 24 # sign
508
510
if uz :
509
- if (b & 0x7FC00000 ) == 0x7FC00000 :
510
- return 0x80
511
- if numpy .isinf (x ):
511
+ if (b & 0x7FFFFFFF ) == 0x7F800000 :
512
+ # infinity
512
513
if saturate :
513
514
return ret | 127
514
515
return 0x80
516
+ if (b & 0x7F800000 ) == 0x7F800000 :
517
+ return 0x80
515
518
e = (b & 0x7F800000 ) >> 23 # exponent
516
519
m = b & 0x007FFFFF # mantissa
517
520
@@ -558,12 +561,14 @@ def float32_to_fe4m3(x, fn: bool = True, uz: bool = False, saturate: bool = True
558
561
ret = 0
559
562
return int (ret )
560
563
else :
561
- if (b & 0x7FC00000 ) == 0x7FC00000 :
562
- return 0x7F | ret
563
- if numpy .isinf (x ):
564
+ if (b & 0x7FFFFFFF ) == 0x7F800000 :
565
+ # infinity
564
566
if saturate :
565
567
return ret | 126
566
568
return 0x7F | ret
569
+ if (b & 0x7F800000 ) == 0x7F800000 :
570
+ # non
571
+ return 0x7F | ret
567
572
e = (b & 0x7F800000 ) >> 23 # exponent
568
573
m = b & 0x007FFFFF # mantissa
569
574
@@ -624,13 +629,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
624
629
ret = (b & 0x80000000 ) >> 24 # sign
625
630
626
631
if fn and uz :
627
- if (b & 0x7FC00000 ) == 0x7FC00000 :
628
- return 0x80
629
632
if (b & 0x7FFFFFFF ) == 0x7F800000 :
630
633
# inf
631
634
if saturate :
632
635
return ret | 0x7F
633
636
return 0x80
637
+ if (b & 0x7F800000 ) == 0x7F800000 :
638
+ # nan
639
+ return 0x80
634
640
e = (b & 0x7F800000 ) >> 23 # exponent
635
641
m = b & 0x007FFFFF # mantissa
636
642
@@ -675,12 +681,14 @@ def float32_to_fe5m2(x, fn: bool = False, uz: bool = False, saturate: bool = Tru
675
681
ret = 0
676
682
return int (ret )
677
683
elif not fn and not uz :
678
- if (b & 0x7FC00000 ) == 0x7FC00000 :
679
- return 0x7F | ret
680
- if numpy .isinf (x ):
684
+ if (b & 0x7FFFFFFF ) == 0x7F800000 :
685
+ # inf
681
686
if saturate :
682
687
return 0x7B | ret
683
688
return 0x7C | ret
689
+ if (b & 0x7F800000 ) == 0x7F800000 :
690
+ # nan
691
+ return 0x7F | ret
684
692
e = (b & 0x7F800000 ) >> 23 # exponent
685
693
m = b & 0x007FFFFF # mantissa
686
694
0 commit comments