Skip to content

Commit 5192752

Browse files
matthewhagraphcorepytorchmergebot
authored andcommitted
remove partials
1 parent e0ffe1a commit 5192752

File tree

1 file changed

+31
-66
lines changed

1 file changed

+31
-66
lines changed

torch/testing/_internal/common_methods_invocations.py

Lines changed: 31 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -8803,97 +8803,62 @@ def to_fp8_saturated(x: torch.Tensor, fp8_dtype: torch.dtype):
88038803
elif fp8_dtype == e5m2_type:
88048804
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
88058805
else:
8806-
raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
8807-
8806+
raise ValueError(f"Unsupported fp8_dtype: {fp8_dtype}")
88088807
return x.to(fp8_dtype)
88098808

8810-
def amax_to_scale(
8811-
amax: torch.Tensor,
8812-
float8_dtype: torch.dtype,
8813-
):
8814-
"""Converts the amax value of a tensor to the fp8 scale.
8815-
Args:
8816-
amax: The amax value of the tensor.
8817-
float8_dtype: the float8 dtype.
8818-
orig_dtype: The original dtype of the tensor.
8819-
"""
8820-
# avoid division by zero when calculating scale
8809+
def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
88218810
EPS = 1e-12
8822-
8823-
scale = torch.empty_like(amax, dtype=torch.float32, device=device)
88248811
if float8_dtype == e4m3_type:
8825-
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
8812+
scale_val = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
88268813
elif float8_dtype == e5m2_type:
8827-
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
8814+
scale_val = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
88288815
else:
88298816
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
8830-
8831-
scale.copy_(res)
8832-
return scale
8817+
return scale_val.to(dtype=torch.float32, device=device)
88338818

88348819
def make_scale(x: float, float8_dtype: torch.dtype, dim=None):
8835-
def creator():
8836-
if dim is None:
8837-
amax = torch.max(torch.abs(torch.tensor(x)))
8838-
else:
8839-
amax = torch.max(
8840-
torch.abs(torch.tensor(x)), dim=dim, keepdim=True
8841-
).values
8842-
8843-
return amax_to_scale(amax, float8_dtype)
8844-
8845-
return creator
8820+
if dim is None:
8821+
amax = torch.tensor(abs(x), dtype=torch.float32, device=device)
8822+
else:
8823+
amax = torch.max(
8824+
torch.abs(torch.tensor(x, device=device)), dim=dim, keepdim=True
8825+
).values
8826+
return amax_to_scale(amax, float8_dtype)
88468827

88478828
def make_mat(size: tuple[int], scale: float, fp8_dtype: torch.dtype):
8848-
"""Create random matrix and convert to FP8 with scaling"""
8849-
8850-
def creator():
8851-
mat = torch.randn(size, device=device, dtype=torch.float32)
8852-
return to_fp8_saturated(mat * scale, fp8_dtype)
8853-
8854-
return creator
8829+
mat = torch.randn(size, device=device, dtype=torch.float32)
8830+
return to_fp8_saturated(mat * scale, fp8_dtype)
88558831

88568832
M, N, K = 15, 32, 16
88578833
samples = []
88588834

8859-
# Case 1: Both matrices are e4m3
8835+
# Case 1: Both matrices e4m3
88608836
scale1 = random.random()
8861-
mat1_e4m3 = partial(make_mat, (M, K), scale1, torch.float8_e4m3fn)
8862-
make_scale1 = partial(make_scale, scale1, torch.float8_e4m3fn)
88638837
scale2 = random.random()
8864-
mat2_e4m3 = partial(make_mat, (K, N), scale2, torch.float8_e4m3fn)
8865-
make_scale2 = partial(make_scale, scale2, torch.float8_e4m3fn)
8866-
samples.append(
8867-
SampleInput(
8868-
mat1_e4m3, mat2_e4m3, make_scale1, make_scale2, output_dtype=torch.float32
8869-
)
8870-
)
8838+
mat1 = make_mat((M, K), scale1, torch.float8_e4m3fn)
8839+
mat2 = make_mat((K, N), scale2, torch.float8_e4m3fn)
8840+
scale_tensor1 = make_scale(scale1, torch.float8_e4m3fn)
8841+
scale_tensor2 = make_scale(scale2, torch.float8_e4m3fn)
8842+
samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2))
88718843

88728844
# Case 2: mat1 e4m3, mat2 e5m2
88738845
scale1 = random.random()
8874-
mat1_e4m3 = partial(make_mat, (M, K), scale1, torch.float8_e4m3fn)
8875-
make_scale1 = partial(make_scale, scale1, torch.float8_e4m3fn)
88768846
scale2 = random.random()
8877-
mat2_e4m3 = partial(make_mat, (K, N), scale2, torch.float8_e5m2fn)
8878-
make_scale2 = partial(make_scale, scale2, torch.float8_e5m2fn)
8879-
samples.append(
8880-
SampleInput(
8881-
mat1_e4m3, mat2_e4m3, make_scale1, make_scale2, output_dtype=torch.float32
8882-
)
8883-
)
8847+
mat1 = make_mat((M, K), scale1, torch.float8_e4m3fn)
8848+
mat2 = make_mat((K, N), scale2, torch.float8_e5m2)
8849+
scale_tensor1 = make_scale(scale1, torch.float8_e4m3fn)
8850+
scale_tensor2 = make_scale(scale2, torch.float8_e5m2)
8851+
samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2))
88848852

88858853
# Case 3: mat1 e5m2, mat2 e4m3
88868854
scale1 = random.random()
8887-
mat1_e4m3 = partial(make_mat, (M, K), scale1, torch.float8_e5m2fn)
8888-
make_scale1 = partial(make_scale, scale1, torch.float8_e5m2fn)
88898855
scale2 = random.random()
8890-
mat2_e4m3 = partial(make_mat, (K, N), scale2, torch.float8_e4m3fn)
8891-
make_scale2 = partial(make_scale, scale2, torch.float8_e4m3fn)
8892-
samples.append(
8893-
SampleInput(
8894-
mat1_e4m3, mat2_e4m3, make_scale1, make_scale2, output_dtype=torch.float32
8895-
)
8896-
)
8856+
mat1 = make_mat((M, K), scale1, torch.float8_e5m2)
8857+
mat2 = make_mat((K, N), scale2, torch.float8_e4m3fn)
8858+
scale_tensor1 = make_scale(scale1, torch.float8_e5m2)
8859+
scale_tensor2 = make_scale(scale2, torch.float8_e4m3fn)
8860+
samples.append(SampleInput(mat1, mat2, scale_tensor1, scale_tensor2))
8861+
88978862
yield from samples
88988863

88998864

0 commit comments

Comments
 (0)