Skip to content

Update upstream opinfo to generate appropriately scaled sample inputs #158018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 72 additions & 26 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,25 @@
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types,
)
from torch.testing._internal.common_device_type import \
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIf, precisionOverride,
skipCPUIfNoMklSparse,
toleranceOverride, tol)
from torch.testing._internal.common_device_type import (
onlyCPU,
onlyCUDA,
onlyNativeDeviceTypes,
disablecuDNN,
skipCUDAIfNoMagma,
skipCUDAIfNoMagmaAndNoCusolver,
skipCUDAIfNoCusolver,
skipCPUIfNoLapack,
skipCPUIfNoFFT,
skipCUDAIf,
precisionOverride,
skipCPUIfNoMklSparse,
toleranceOverride,
tol,
e4m3_type,
E4M3_MAX_POS,
E5M2_MAX_POS,
)
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
SM53OrLater, SM80OrLater, SM89OrLater, with_tf32_off, TEST_CUDNN, _get_torch_cuda_version,
Expand Down Expand Up @@ -8780,33 +8794,65 @@ def error_inputs_triplet_margin_loss(op_info, device, **kwargs):
yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
error_type=error_type, error_regex=error_regex)


def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad)
make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad)
make_scale = partial(make_tensor, device=device, dtype=torch.float, requires_grad=False)
def to_fp8_saturated(x: torch.Tensor, fp8_dtype: torch.dtype):
max_val = E4M3_MAX_POS if fp8_dtype == e4m3_type else E5M2_MAX_POS
x = x.clamp(min=-1 * max_val, max=max_val)
return x.to(fp8_dtype)

def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
EPS = 1e-12
max_pos = E4M3_MAX_POS if float8_dtype == e4m3_type else E5M2_MAX_POS
scale_val = max_pos / torch.clamp(amax, min=EPS)
return scale_val.to(dtype=torch.float32, device=device)

def make_scale(x: float, float8_dtype: torch.dtype, dim=None):
if dim is None:
amax = torch.tensor(abs(x), dtype=torch.float32, device=device)
else:
amax = torch.max(
torch.abs(torch.tensor(x, device=device)), dim=dim, keepdim=True
).values
return amax_to_scale(amax, float8_dtype)

def make_mat(size: tuple[int], scale: float, fp8_dtype: torch.dtype):
mat = torch.randn(size, device=device, dtype=torch.float32)
return to_fp8_saturated(mat * scale, fp8_dtype)

M, N, K = 15, 32, 16
samples = []
# two e4m3
mat1 = make_mat_e4m3((M, K))
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))
# mat1 e4m3 mat2 e5m2
mat1 = make_mat_e4m3((M, K))
mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))
# mat1 e5m2 mat2 e4m3
mat1 = make_mat_e5m2((M, K))
mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
scale1 = make_scale((1,))
scale2 = make_scale((1,))
samples.append(SampleInput(mat1, mat2, scale1, scale2))

# Case 1: Both matrices e4m3
scale1 = random.random()
scale2 = random.random()
mat1 = make_mat((M, K), scale1, torch.float8_e4m3fn)
mat2 = make_mat((K, N), scale2, torch.float8_e4m3fn)
scale_tensor1 = make_scale(scale1, torch.float8_e4m3fn)
scale_tensor2 = make_scale(scale2, torch.float8_e4m3fn)
samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2))

# Case 2: mat1 e4m3, mat2 e5m2
scale1 = random.random()
scale2 = random.random()
mat1 = make_mat((M, K), scale1, torch.float8_e4m3fn)
mat2 = make_mat((K, N), scale2, torch.float8_e5m2)
scale_tensor1 = make_scale(scale1, torch.float8_e4m3fn)
scale_tensor2 = make_scale(scale2, torch.float8_e5m2)
samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2))

# Case 3: mat1 e5m2, mat2 e4m3
scale1 = random.random()
scale2 = random.random()
mat1 = make_mat((M, K), scale1, torch.float8_e5m2)
mat2 = make_mat((K, N), scale2, torch.float8_e4m3fn)
scale_tensor1 = make_scale(scale1, torch.float8_e5m2)
scale_tensor2 = make_scale(scale2, torch.float8_e4m3fn)
samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2))

yield from samples


def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8
Expand Down
Loading