diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 41bb2b96bd938..d1fdd9292063f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -24,11 +24,25 @@ floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types, ) -from torch.testing._internal.common_device_type import \ - (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, - skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride, - skipCPUIfNoMklSparse, - toleranceOverride, tol) +from torch.testing._internal.common_device_type import ( + onlyCPU, + onlyCUDA, + onlyNativeDeviceTypes, + disablecuDNN, + skipCUDAIfNoMagma, + skipCUDAIfNoMagmaAndNoCusolver, + skipCUDAIfNoCusolver, + skipCPUIfNoLapack, + skipCPUIfNoFFT, + skipCUDAIf, + precisionOverride, + skipCPUIfNoMklSparse, + toleranceOverride, + tol, + e4m3_type, + E4M3_MAX_POS, + E5M2_MAX_POS, +) from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version, @@ -8780,33 +8794,65 @@ def error_inputs_triplet_margin_loss(op_info, device, **kwargs): yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs), error_type=error_type, error_regex=error_regex) + def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs): - make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad) - make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad) - make_scale = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False) + def to_fp8_saturated(x: torch.Tensor, fp8_dtype: torch.dtype): + max_val = E4M3_MAX_POS if fp8_dtype == e4m3_type else E5M2_MAX_POS + x = x.clamp(min=-1 * max_val, max=max_val) + return x.to(fp8_dtype) + + def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): + EPS = 1e-12 + max_pos = E4M3_MAX_POS if float8_dtype == e4m3_type else E5M2_MAX_POS + scale_val = max_pos / torch.clamp(amax, min=EPS) + return scale_val.to(dtype=torch.float32, device=device) + + def make_scale(x: float, float8_dtype: torch.dtype, dim=None): + if dim is None: + amax = torch.tensor(abs(x), dtype=torch.float32, device=device) + else: + amax = torch.max( + torch.abs(torch.tensor(x, device=device)), dim=dim, keepdim=True + ).values + return amax_to_scale(amax, float8_dtype) + + def make_mat(size: tuple[int], scale: float, fp8_dtype: torch.dtype): + mat = torch.randn(size, device=device, dtype=torch.float32) + return to_fp8_saturated(mat * scale, fp8_dtype) + M, N, K = 15, 32, 16 samples = [] - # two e4m3 - mat1 = make_mat_e4m3((M, K)) - mat2 = make_mat_e4m3((K, N)).t().contiguous().t() - scale1 = make_scale((1,)) - scale2 = make_scale((1,)) - samples.append(SampleInput(mat1, mat2, scale1, scale2)) - # mat1 e4m3 mat2 e5m2 - mat1 = make_mat_e4m3((M, K)) - mat2 = make_mat_e5m2((K, N)).t().contiguous().t() - scale1 = make_scale((1,)) - scale2 = make_scale((1,)) - samples.append(SampleInput(mat1, mat2, scale1, scale2)) - # mat1 e5m2 mat2 e4m3 - mat1 = make_mat_e5m2((M, K)) - mat2 = make_mat_e4m3((K, N)).t().contiguous().t() - scale1 = make_scale((1,)) - scale2 = make_scale((1,)) - samples.append(SampleInput(mat1, mat2, scale1, scale2)) + + # Case 1: Both matrices e4m3 + scale1 = random.random() + scale2 = random.random() + mat1 = make_mat((M, K), scale1, torch.float8_e4m3fn) + mat2 = make_mat((K, N), scale2, torch.float8_e4m3fn) + scale_tensor1 = make_scale(scale1, torch.float8_e4m3fn) + scale_tensor2 = make_scale(scale2, torch.float8_e4m3fn) + samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2)) + + # Case 2: mat1 e4m3, mat2 e5m2 + scale1 = random.random() + scale2 = random.random() + mat1 = make_mat((M, K), scale1, torch.float8_e4m3fn) + mat2 = make_mat((K, N), scale2, torch.float8_e5m2) + scale_tensor1 = make_scale(scale1, torch.float8_e4m3fn) + scale_tensor2 = make_scale(scale2, torch.float8_e5m2) + samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2)) + + # Case 3: mat1 e5m2, mat2 e4m3 + scale1 = random.random() + scale2 = random.random() + mat1 = make_mat((M, K), scale1, torch.float8_e5m2) + mat2 = make_mat((K, N), scale2, torch.float8_e4m3fn) + scale_tensor1 = make_scale(scale1, torch.float8_e5m2) + scale_tensor2 = make_scale(scale2, torch.float8_e4m3fn) + samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2)) yield from samples + def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8