Skip to content

Commit 40f242b

Browse files
committed
Update
[ghstack-poisoned]
2 parents 0b51047 + 06b03bf commit 40f242b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+902
-660
lines changed

aten/src/ATen/autocast_mode.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) {
239239
KERNEL_MPS(scaled_dot_product_attention, lower_precision_fp)
240240

241241
// fp32
242+
KERNEL_MPS(conv_transpose3d, input, fp32)
242243
KERNEL_MPS(acos, fp32)
243244
KERNEL_MPS(asin, fp32)
244245
KERNEL_MPS(cosh, fp32)

aten/src/ATen/detail/MTIAHooksInterface.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ bool isMTIAHooksBuilt() {
2121

2222
} // namespace detail
2323

24+
bool MTIAHooksInterface::isAvailable() const {
25+
return detail::isMTIAHooksBuilt() && detail::getMTIAHooks().deviceCount() > 0;
26+
}
27+
2428
C10_DEFINE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs)
2529

2630
} // namespace at

aten/src/ATen/detail/MTIAHooksInterface.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
149149
FAIL_MTIAHOOKS_FUNC(__func__);
150150
return;
151151
}
152+
153+
virtual bool isAvailable() const override;
152154
};
153155

154156
struct TORCH_API MTIAHooksArgs {};

aten/src/ATen/native/mps/operations/BinaryKernel.mm

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ void binary_op_kernel(const std::string func_name,
5353
.add_input(input)
5454
.add_input(other)
5555
.check_all_same_dtype(false)
56+
.promote_inputs_to_common_dtype(true)
5657
.build();
5758

5859
lib.exec_binary_kernel(iter, func_name, alpha);

cmake/public/LoadHIP.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ set(PYTORCH_FOUND_HIP FALSE)
66
# In the latter case, if /opt/rocm does not exist emit status
77
# message and return.
88
if(DEFINED ENV{ROCM_PATH})
9-
set(ROCM_PATH $ENV{ROCM_PATH})
9+
file(TO_CMAKE_PATH "$ENV{ROCM_PATH}" ROCM_PATH)
1010
if(NOT EXISTS ${ROCM_PATH})
1111
message(FATAL_ERROR
1212
"ROCM_PATH environment variable is set to ${ROCM_PATH} but does not exist.\n"
@@ -31,7 +31,7 @@ if(NOT DEFINED ENV{MAGMA_HOME})
3131
set(MAGMA_HOME ${ROCM_PATH}/magma)
3232
set(ENV{MAGMA_HOME} ${ROCM_PATH}/magma)
3333
else()
34-
set(MAGMA_HOME $ENV{MAGMA_HOME})
34+
file(TO_CMAKE_PATH "$ENV{MAGMA_HOME}" MAGMA_HOME)
3535
endif()
3636

3737
# MIOpen isn't a part of HIP-SDK for Windows and hence, may have a different

codex_setup.sh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@ COMMIT=$(grep -oE '[0-9a-f]{40}' <<< "$NIGHTLY_PATCH" | head -1)
99
COMMIT_DATE=$(echo "$NIGHTLY_PATCH" | grep '^Date:' | sed -E 's/Date: .*, ([0-9]+) ([A-Za-z]+) ([0-9]+) .*/\3 \2 \1/' | awk 'BEGIN{split("Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec", months, " "); for(i=1;i<=12;i++) month[months[i]]=sprintf("%02d",i)} {print $1 month[$2] sprintf("%02d",$3)}')
1010
VERSION_STRING="2.9.0.dev${COMMIT_DATE}+cpu"
1111
git rev-parse HEAD > /tmp/orig_work.txt
12-
cp AGENTS.md /tmp
1312
git reset --hard $COMMIT
14-
cp /tmp/AGENTS.md .
15-
curl https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/159965.diff | patch -p1
1613
USE_NIGHTLY=$VERSION_STRING python setup.py develop
17-
git commit -asm "Agents patch"
1814
echo "source $PWD/.venv/bin/activate" >> ~/.bashrc

test/dynamo/test_fx_graph_runnable.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,65 @@
1111
from torch._inductor.codecache import WritableTempFile
1212
from torch._inductor.test_case import TestCase
1313
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE
14+
from torch.utils._triton import has_triton
1415

1516

1617
if torch.distributed.is_available():
1718
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
1819
from torch.testing._internal.distributed.fake_pg import FakeStore
1920

21+
if has_triton():
22+
import triton
23+
import triton.language as tl
24+
25+
def init_to_zero(name):
26+
return lambda nargs: nargs[name].zero_()
27+
28+
@triton.jit
29+
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
30+
pid = tl.program_id(axis=0)
31+
32+
block_start = pid * BLOCK_SIZE
33+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
34+
mask = offsets < n_elements
35+
36+
x = tl.load(x_ptr + offsets, mask=mask)
37+
y = tl.load(y_ptr + offsets, mask=mask)
38+
output = x + y
39+
tl.atomic_add(output_ptr + offsets, output, mask=mask)
40+
41+
@triton.autotune(
42+
configs=[
43+
triton.Config(
44+
{"BLOCK_SIZE": 1024},
45+
num_warps=4,
46+
num_stages=2,
47+
pre_hook=init_to_zero("output_ptr"),
48+
)
49+
],
50+
pre_hook=init_to_zero("output_ptr"),
51+
post_hook=init_to_zero("output_ptr"),
52+
key=["n_elements"],
53+
)
54+
@triton.jit
55+
def add_kernel_autotune(
56+
x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
57+
):
58+
pid = tl.program_id(axis=0)
59+
60+
block_start = pid * BLOCK_SIZE
61+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
62+
mask = offsets < n_elements
63+
64+
x = tl.load(x_ptr + offsets, mask=mask)
65+
y = tl.load(y_ptr + offsets, mask=mask)
66+
output = x + y
67+
tl.atomic_add(output_ptr + offsets, output, mask=mask)
68+
69+
70+
from torch.testing._internal.inductor_utils import GPU_TYPE
71+
from torch.testing._internal.triton_utils import requires_gpu
72+
2073

2174
class FxGraphRunnableArtifactFilter(logging.Filter):
2275
def filter(self, record):
@@ -100,6 +153,41 @@ def f(x):
100153
torch.compile(f)(torch.randn(4))
101154
self._exec_and_verify_payload()
102155

156+
@unittest.skipUnless(has_triton(), "Triton not available")
157+
def test_user_defined_triton_kernel_autotune(self):
158+
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
159+
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
160+
n_elements = output.numel()
161+
162+
def grid(
163+
meta,
164+
):
165+
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
166+
167+
add_kernel_autotune[grid](x, y, output, n_elements)
168+
return output
169+
170+
x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
171+
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
172+
173+
torch.compile(add)(x, y)
174+
self._exec_and_verify_payload()
175+
176+
@unittest.skipUnless(has_triton(), "Triton not available")
177+
@requires_gpu
178+
def test_user_defined_triton_kernel(self):
179+
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
180+
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
181+
n_elements = x.numel()
182+
add_kernel[n_elements,](x, y, output, n_elements, BLOCK_SIZE=4)
183+
return output
184+
185+
x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
186+
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
187+
188+
torch.compile(add)(x, y)
189+
self._exec_and_verify_payload()
190+
103191
def test_two_inputs_matmul(self):
104192
def f(a, b):
105193
return (a @ b).relu()

test/dynamo/test_pgo.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def f(x):
5656
f(torch.randn(2, 6))
5757
self.assertEqual(cnts.frame_count, 1)
5858

59+
@torch._dynamo.config.patch(
60+
force_parameter_static_shapes=False,
61+
force_nn_module_property_static_shapes=False,
62+
)
5963
def test_whitelist_suggestion(self):
6064
cnts = CompileCounter()
6165

@@ -195,14 +199,16 @@ def run():
195199
self.assertEqual(cnts.frame_count, 3)
196200

197201
# parameter static shapes are forced static, so we recompile once
198-
run()
199-
self.assertEqual(cnts.frame_count, 2)
202+
with torch._dynamo.config.patch(
203+
force_parameter_static_shapes=False,
204+
force_nn_module_property_static_shapes=False,
205+
):
206+
run()
207+
self.assertEqual(cnts.frame_count, 2)
200208

201-
# flags are flipped, PGO records dynamism, so params are dynamically compiled to start
202-
torch._dynamo.config.force_parameter_static_shapes = False
203-
torch._dynamo.config.force_nn_module_property_static_shapes = False
204-
run()
205-
self.assertEqual(cnts.frame_count, 1)
209+
# because flags were flipped, params were included in PGO
210+
run()
211+
self.assertEqual(cnts.frame_count, 1)
206212

207213
def test_njt(self):
208214
cnts = CompileCounter()

test/dynamo/test_repros.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7673,6 +7673,31 @@ def forward(self, x):
76737673
out2 = torch.compile(model, backend="eager")(input.clone())
76747674
self.assertEqual(out1, out2)
76757675

7676+
@requires_cuda
7677+
def test_zero_dim_param_mixed_device_grad(self):
7678+
# cpu 0-dim params with cuda grads
7679+
# https://github.com/pytorch/pytorch/issues/160084
7680+
class RegressionModel(torch.nn.Module):
7681+
def __init__(self, a=0, b=0):
7682+
super().__init__()
7683+
self.a = torch.nn.Parameter(torch.tensor(a).float())
7684+
self.b = torch.nn.Parameter(torch.tensor(b).float())
7685+
7686+
def forward(self, x):
7687+
return x * self.a + self.b
7688+
7689+
model = RegressionModel()
7690+
model.forward = torch.compile(
7691+
model.forward, backend="aot_eager", fullgraph=True
7692+
)
7693+
inputs = torch.randn(4, 10).to("cuda")
7694+
out = model(inputs)
7695+
out.sum().backward()
7696+
self.assertIsNotNone(model.a.grad)
7697+
self.assertIsNotNone(model.b.grad)
7698+
self.assertEqual(model.a.grad.device, torch.device("cpu"))
7699+
self.assertEqual(model.b.grad.device, torch.device("cpu"))
7700+
76767701
def test_filter_warnings(self):
76777702
x = torch.ones(2, 2, requires_grad=True)
76787703

0 commit comments

Comments
 (0)