@@ -8803,97 +8803,62 @@ def to_fp8_saturated(x: torch.Tensor, fp8_dtype: torch.dtype):
8803
8803
elif fp8_dtype == e5m2_type:
8804
8804
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
8805
8805
else:
8806
- raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
8807
-
8806
+ raise ValueError(f"Unsupported fp8_dtype: {fp8_dtype}")
8808
8807
return x.to(fp8_dtype)
8809
8808
8810
- def amax_to_scale(
8811
- amax: torch.Tensor,
8812
- float8_dtype: torch.dtype,
8813
- ):
8814
- """Converts the amax value of a tensor to the fp8 scale.
8815
- Args:
8816
- amax: The amax value of the tensor.
8817
- float8_dtype: the float8 dtype.
8818
- orig_dtype: The original dtype of the tensor.
8819
- """
8820
- # avoid division by zero when calculating scale
8809
+ def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
8821
8810
EPS = 1e-12
8822
-
8823
- scale = torch.empty_like(amax, dtype=torch.float32, device=device)
8824
8811
if float8_dtype == e4m3_type:
8825
- res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
8812
+ scale_val = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
8826
8813
elif float8_dtype == e5m2_type:
8827
- res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
8814
+ scale_val = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
8828
8815
else:
8829
8816
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
8830
-
8831
- scale.copy_(res)
8832
- return scale
8817
+ return scale_val.to(dtype=torch.float32, device=device)
8833
8818
8834
8819
def make_scale(x: float, float8_dtype: torch.dtype, dim=None):
8835
- def creator():
8836
- if dim is None:
8837
- amax = torch.max(torch.abs(torch.tensor(x)))
8838
- else:
8839
- amax = torch.max(
8840
- torch.abs(torch.tensor(x)), dim=dim, keepdim=True
8841
- ).values
8842
-
8843
- return amax_to_scale(amax, float8_dtype)
8844
-
8845
- return creator
8820
+ if dim is None:
8821
+ amax = torch.tensor(abs(x), dtype=torch.float32, device=device)
8822
+ else:
8823
+ amax = torch.max(
8824
+ torch.abs(torch.tensor(x, device=device)), dim=dim, keepdim=True
8825
+ ).values
8826
+ return amax_to_scale(amax, float8_dtype)
8846
8827
8847
8828
def make_mat(size: tuple[int], scale: float, fp8_dtype: torch.dtype):
8848
- """Create random matrix and convert to FP8 with scaling"""
8849
-
8850
- def creator():
8851
- mat = torch.randn(size, device=device, dtype=torch.float32)
8852
- return to_fp8_saturated(mat * scale, fp8_dtype)
8853
-
8854
- return creator
8829
+ mat = torch.randn(size, device=device, dtype=torch.float32)
8830
+ return to_fp8_saturated(mat * scale, fp8_dtype)
8855
8831
8856
8832
M, N, K = 15, 32, 16
8857
8833
samples = []
8858
8834
8859
- # Case 1: Both matrices are e4m3
8835
+ # Case 1: Both matrices e4m3
8860
8836
scale1 = random.random()
8861
- mat1_e4m3 = partial(make_mat, (M, K), scale1, torch.float8_e4m3fn)
8862
- make_scale1 = partial(make_scale, scale1, torch.float8_e4m3fn)
8863
8837
scale2 = random.random()
8864
- mat2_e4m3 = partial(make_mat, (K, N), scale2, torch.float8_e4m3fn)
8865
- make_scale2 = partial(make_scale, scale2, torch.float8_e4m3fn)
8866
- samples.append(
8867
- SampleInput(
8868
- mat1_e4m3, mat2_e4m3, make_scale1, make_scale2, output_dtype=torch.float32
8869
- )
8870
- )
8838
+ mat1 = make_mat((M, K), scale1, torch.float8_e4m3fn)
8839
+ mat2 = make_mat((K, N), scale2, torch.float8_e4m3fn)
8840
+ scale_tensor1 = make_scale(scale1, torch.float8_e4m3fn)
8841
+ scale_tensor2 = make_scale(scale2, torch.float8_e4m3fn)
8842
+ samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2))
8871
8843
8872
8844
# Case 2: mat1 e4m3, mat2 e5m2
8873
8845
scale1 = random.random()
8874
- mat1_e4m3 = partial(make_mat, (M, K), scale1, torch.float8_e4m3fn)
8875
- make_scale1 = partial(make_scale, scale1, torch.float8_e4m3fn)
8876
8846
scale2 = random.random()
8877
- mat2_e4m3 = partial(make_mat, (K, N), scale2, torch.float8_e5m2fn)
8878
- make_scale2 = partial(make_scale, scale2, torch.float8_e5m2fn)
8879
- samples.append(
8880
- SampleInput(
8881
- mat1_e4m3, mat2_e4m3, make_scale1, make_scale2, output_dtype=torch.float32
8882
- )
8883
- )
8847
+ mat1 = make_mat((M, K), scale1, torch.float8_e4m3fn)
8848
+ mat2 = make_mat((K, N), scale2, torch.float8_e5m2)
8849
+ scale_tensor1 = make_scale(scale1, torch.float8_e4m3fn)
8850
+ scale_tensor2 = make_scale(scale2, torch.float8_e5m2)
8851
+ samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2))
8884
8852
8885
8853
# Case 3: mat1 e5m2, mat2 e4m3
8886
8854
scale1 = random.random()
8887
- mat1_e4m3 = partial(make_mat, (M, K), scale1, torch.float8_e5m2fn)
8888
- make_scale1 = partial(make_scale, scale1, torch.float8_e5m2fn)
8889
8855
scale2 = random.random()
8890
- mat2_e4m3 = partial(make_mat, (K, N), scale2, torch.float8_e4m3fn)
8891
- make_scale2 = partial(make_scale, scale2, torch.float8_e4m3fn)
8892
- samples.append(
8893
- SampleInput(
8894
- mat1_e4m3, mat2_e4m3, make_scale1, make_scale2, output_dtype=torch.float32
8895
- )
8896
- )
8856
+ mat1 = make_mat((M, K), scale1, torch.float8_e5m2)
8857
+ mat2 = make_mat((K, N), scale2, torch.float8_e4m3fn)
8858
+ scale_tensor1 = make_scale(scale1, torch.float8_e5m2)
8859
+ scale_tensor2 = make_scale(scale2, torch.float8_e4m3fn)
8860
+ samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2))
8861
+
8897
8862
yield from samples
8898
8863
8899
8864
0 commit comments