From e0ffe1ab078e4e687ffe45c027b88f677e9cc6b9 Mon Sep 17 00:00:00 2001 From: Matthew Haddock Date: Tue, 8 Jul 2025 14:59:43 +0000 Subject: [PATCH 1/3] Change sample inputs to generate saturated fp8 inputs --- .../_internal/common_methods_invocations.py | 142 ++++++++++++++---- 1 file changed, 116 insertions(+), 26 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 92ae95bef8d0..43fab0d1050e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -24,11 +24,26 @@ floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types, ) -from torch.testing._internal.common_device_type import \ - (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, - skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride, - skipCPUIfNoMklSparse, - toleranceOverride, tol) +from torch.testing._internal.common_device_type import ( + onlyCPU, + onlyCUDA, + onlyNativeDeviceTypes, + disablecuDNN, + skipCUDAIfNoMagma, + skipCUDAIfNoMagmaAndNoCusolver, + skipCUDAIfNoCusolver, + skipCPUIfNoLapack, + skipCPUIfNoFFT, + skipCUDAIf, + precisionOverride, + skipCPUIfNoMklSparse, + toleranceOverride, + tol, + e4m3_type, + e5m2_type, + E4M3_MAX_POS, + E5M2_MAX_POS, +) from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, @@ -8780,33 +8795,108 @@ def error_inputs_triplet_margin_loss(op_info, device, **kwargs): yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs), error_type=error_type, error_regex=error_regex) + def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs): - make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad) - make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad) - make_scale = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + def to_fp8_saturated(x: torch.Tensor, fp8_dtype: torch.dtype): + if fp8_dtype == e4m3_type: + x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) + elif fp8_dtype == e5m2_type: + x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) + else: + raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}") + + return x.to(fp8_dtype) + + def amax_to_scale( + amax: torch.Tensor, + float8_dtype: torch.dtype, + ): + """Converts the amax value of a tensor to the fp8 scale. + Args: + amax: The amax value of the tensor. + float8_dtype: the float8 dtype. + orig_dtype: The original dtype of the tensor. + """ + # avoid division by zero when calculating scale + EPS = 1e-12 + + scale = torch.empty_like(amax, dtype=torch.float32, device=device) + if float8_dtype == e4m3_type: + res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + elif float8_dtype == e5m2_type: + res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + else: + raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") + + scale.copy_(res) + return scale + + def make_scale(x: float, float8_dtype: torch.dtype, dim=None): + def creator(): + if dim is None: + amax = torch.max(torch.abs(torch.tensor(x))) + else: + amax = torch.max( + torch.abs(torch.tensor(x)), dim=dim, keepdim=True + ).values + + return amax_to_scale(amax, float8_dtype) + + return creator + + def make_mat(size: tuple[int], scale: float, fp8_dtype: torch.dtype): + """Create random matrix and convert to FP8 with scaling""" + + def creator(): + mat = torch.randn(size, device=device, dtype=torch.float32) + return to_fp8_saturated(mat * scale, fp8_dtype) + + return creator + M, N, K = 15, 32, 16 samples = [] - # two e4m3 - mat1 = make_mat_e4m3((M, K)) - mat2 = make_mat_e4m3((K, N)).t().contiguous().t() - scale1 = make_scale((1,)) - scale2 = make_scale((1,)) - samples.append(SampleInput(mat1, mat2, scale1, scale2)) - # mat1 e4m3 mat2 e5m2 - mat1 = make_mat_e4m3((M, K)) - mat2 = make_mat_e5m2((K, N)).t().contiguous().t() - scale1 = make_scale((1,)) - scale2 = make_scale((1,)) - samples.append(SampleInput(mat1, mat2, scale1, scale2)) - # mat1 e5m2 mat2 e4m3 - mat1 = make_mat_e5m2((M, K)) - mat2 = make_mat_e4m3((K, N)).t().contiguous().t() - scale1 = make_scale((1,)) - scale2 = make_scale((1,)) - samples.append(SampleInput(mat1, mat2, scale1, scale2)) + # Case 1: Both matrices are e4m3 + scale1 = random.random() + mat1_e4m3 = partial(make_mat, (M, K), scale1, torch.float8_e4m3fn) + make_scale1 = partial(make_scale, scale1, torch.float8_e4m3fn) + scale2 = random.random() + mat2_e4m3 = partial(make_mat, (K, N), scale2, torch.float8_e4m3fn) + make_scale2 = partial(make_scale, scale2, torch.float8_e4m3fn) + samples.append( + SampleInput( + mat1_e4m3, mat2_e4m3, make_scale1, make_scale2, output_dtype=torch.float32 + ) + ) + + # Case 2: mat1 e4m3, mat2 e5m2 + scale1 = random.random() + mat1_e4m3 = partial(make_mat, (M, K), scale1, torch.float8_e4m3fn) + make_scale1 = partial(make_scale, scale1, torch.float8_e4m3fn) + scale2 = random.random() + mat2_e4m3 = partial(make_mat, (K, N), scale2, torch.float8_e5m2fn) + make_scale2 = partial(make_scale, scale2, torch.float8_e5m2fn) + samples.append( + SampleInput( + mat1_e4m3, mat2_e4m3, make_scale1, make_scale2, output_dtype=torch.float32 + ) + ) + + # Case 3: mat1 e5m2, mat2 e4m3 + scale1 = random.random() + mat1_e4m3 = partial(make_mat, (M, K), scale1, torch.float8_e5m2fn) + make_scale1 = partial(make_scale, scale1, torch.float8_e5m2fn) + scale2 = random.random() + mat2_e4m3 = partial(make_mat, (K, N), scale2, torch.float8_e4m3fn) + make_scale2 = partial(make_scale, scale2, torch.float8_e4m3fn) + samples.append( + SampleInput( + mat1_e4m3, mat2_e4m3, make_scale1, make_scale2, output_dtype=torch.float32 + ) + ) yield from samples + def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8 From 5192752f72f99d099b73249c58c46e4d2e340fd8 Mon Sep 17 00:00:00 2001 From: Matthew Haddock Date: Mon, 14 Jul 2025 14:50:58 +0000 Subject: [PATCH 2/3] remove partials --- .../_internal/common_methods_invocations.py | 97 ++++++------------- 1 file changed, 31 insertions(+), 66 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 43fab0d1050e..8d2f3f4005b8 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8803,97 +8803,62 @@ def to_fp8_saturated(x: torch.Tensor, fp8_dtype: torch.dtype): elif fp8_dtype == e5m2_type: x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) else: - raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}") - + raise ValueError(f"Unsupported fp8_dtype: {fp8_dtype}") return x.to(fp8_dtype) - def amax_to_scale( - amax: torch.Tensor, - float8_dtype: torch.dtype, - ): - """Converts the amax value of a tensor to the fp8 scale. - Args: - amax: The amax value of the tensor. - float8_dtype: the float8 dtype. - orig_dtype: The original dtype of the tensor. - """ - # avoid division by zero when calculating scale + def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): EPS = 1e-12 - - scale = torch.empty_like(amax, dtype=torch.float32, device=device) if float8_dtype == e4m3_type: - res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + scale_val = E4M3_MAX_POS / torch.clamp(amax, min=EPS) elif float8_dtype == e5m2_type: - res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + scale_val = E5M2_MAX_POS / torch.clamp(amax, min=EPS) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") - - scale.copy_(res) - return scale + return scale_val.to(dtype=torch.float32, device=device) def make_scale(x: float, float8_dtype: torch.dtype, dim=None): - def creator(): - if dim is None: - amax = torch.max(torch.abs(torch.tensor(x))) - else: - amax = torch.max( - torch.abs(torch.tensor(x)), dim=dim, keepdim=True - ).values - - return amax_to_scale(amax, float8_dtype) - - return creator + if dim is None: + amax = torch.tensor(abs(x), dtype=torch.float32, device=device) + else: + amax = torch.max( + torch.abs(torch.tensor(x, device=device)), dim=dim, keepdim=True + ).values + return amax_to_scale(amax, float8_dtype) def make_mat(size: tuple[int], scale: float, fp8_dtype: torch.dtype): - """Create random matrix and convert to FP8 with scaling""" - - def creator(): - mat = torch.randn(size, device=device, dtype=torch.float32) - return to_fp8_saturated(mat * scale, fp8_dtype) - - return creator + mat = torch.randn(size, device=device, dtype=torch.float32) + return to_fp8_saturated(mat * scale, fp8_dtype) M, N, K = 15, 32, 16 samples = [] - # Case 1: Both matrices are e4m3 + # Case 1: Both matrices e4m3 scale1 = random.random() - mat1_e4m3 = partial(make_mat, (M, K), scale1, torch.float8_e4m3fn) - make_scale1 = partial(make_scale, scale1, torch.float8_e4m3fn) scale2 = random.random() - mat2_e4m3 = partial(make_mat, (K, N), scale2, torch.float8_e4m3fn) - make_scale2 = partial(make_scale, scale2, torch.float8_e4m3fn) - samples.append( - SampleInput( - mat1_e4m3, mat2_e4m3, make_scale1, make_scale2, output_dtype=torch.float32 - ) - ) + mat1 = make_mat((M, K), scale1, torch.float8_e4m3fn) + mat2 = make_mat((K, N), scale2, torch.float8_e4m3fn) + scale_tensor1 = make_scale(scale1, torch.float8_e4m3fn) + scale_tensor2 = make_scale(scale2, torch.float8_e4m3fn) + samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2)) # Case 2: mat1 e4m3, mat2 e5m2 scale1 = random.random() - mat1_e4m3 = partial(make_mat, (M, K), scale1, torch.float8_e4m3fn) - make_scale1 = partial(make_scale, scale1, torch.float8_e4m3fn) scale2 = random.random() - mat2_e4m3 = partial(make_mat, (K, N), scale2, torch.float8_e5m2fn) - make_scale2 = partial(make_scale, scale2, torch.float8_e5m2fn) - samples.append( - SampleInput( - mat1_e4m3, mat2_e4m3, make_scale1, make_scale2, output_dtype=torch.float32 - ) - ) + mat1 = make_mat((M, K), scale1, torch.float8_e4m3fn) + mat2 = make_mat((K, N), scale2, torch.float8_e5m2) + scale_tensor1 = make_scale(scale1, torch.float8_e4m3fn) + scale_tensor2 = make_scale(scale2, torch.float8_e5m2) + samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2)) # Case 3: mat1 e5m2, mat2 e4m3 scale1 = random.random() - mat1_e4m3 = partial(make_mat, (M, K), scale1, torch.float8_e5m2fn) - make_scale1 = partial(make_scale, scale1, torch.float8_e5m2fn) scale2 = random.random() - mat2_e4m3 = partial(make_mat, (K, N), scale2, torch.float8_e4m3fn) - make_scale2 = partial(make_scale, scale2, torch.float8_e4m3fn) - samples.append( - SampleInput( - mat1_e4m3, mat2_e4m3, make_scale1, make_scale2, output_dtype=torch.float32 - ) - ) + mat1 = make_mat((M, K), scale1, torch.float8_e5m2) + mat2 = make_mat((K, N), scale2, torch.float8_e4m3fn) + scale_tensor1 = make_scale(scale1, torch.float8_e5m2) + scale_tensor2 = make_scale(scale2, torch.float8_e4m3fn) + samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2)) + yield from samples From e793cf16cae33e1e5dc434473863a3e924e2d69a Mon Sep 17 00:00:00 2001 From: Matthew Haddock Date: Thu, 17 Jul 2025 09:21:32 +0000 Subject: [PATCH 3/3] suggestion --- .../_internal/common_methods_invocations.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 8d2f3f4005b8..53f10318fa95 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -40,7 +40,6 @@ toleranceOverride, tol, e4m3_type, - e5m2_type, E4M3_MAX_POS, E5M2_MAX_POS, ) @@ -8798,22 +8797,14 @@ def error_inputs_triplet_margin_loss(op_info, device, **kwargs): def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs): def to_fp8_saturated(x: torch.Tensor, fp8_dtype: torch.dtype): - if fp8_dtype == e4m3_type: - x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) - elif fp8_dtype == e5m2_type: - x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) - else: - raise ValueError(f"Unsupported fp8_dtype: {fp8_dtype}") + max_val = E4M3_MAX_POS if fp8_dtype == e4m3_type else E5M2_MAX_POS + x = x.clamp(min=-1 * max_val, max=max_val) return x.to(fp8_dtype) def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): EPS = 1e-12 - if float8_dtype == e4m3_type: - scale_val = E4M3_MAX_POS / torch.clamp(amax, min=EPS) - elif float8_dtype == e5m2_type: - scale_val = E5M2_MAX_POS / torch.clamp(amax, min=EPS) - else: - raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") + max_pos = E4M3_MAX_POS if float8_dtype == e4m3_type else E5M2_MAX_POS + scale_val = max_pos / torch.clamp(amax, min=EPS) return scale_val.to(dtype=torch.float32, device=device) def make_scale(x: float, float8_dtype: torch.dtype, dim=None):