Skip to content

[dynamo][guards] Skip dispatch key guards for requires_grad=False #155756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/anijain2305/790/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,12 +508,18 @@ def convert_to_concrete_values(size_or_stride):


def get_tensor_guard_code_part(value, name, sizes, strides, pytype, dispatch_keys):
dispatch_key = (
dispatch_keys | torch._C._dispatch_tls_local_include_set()
) - torch._C._dispatch_tls_local_exclude_set()
requires_grad = value.requires_grad
if requires_grad:
dispatch_key = (
dispatch_keys | torch._C._dispatch_tls_local_include_set()
) - torch._C._dispatch_tls_local_exclude_set()
else:
# Get an empty dispatch key set because with requires_grad we dont guard
# on it.
dispatch_key = dispatch_keys - dispatch_keys

dtype = value.dtype
device_index = value.device.index
requires_grad = value.requires_grad
guard_str = (
f"check_tensor({name}, {pytype.__qualname__}, {dispatch_key}, {dtype}, "
f"device={device_index}, requires_grad={requires_grad}, size={sizes}, stride={strides})"
Expand Down
22 changes: 12 additions & 10 deletions torch/csrc/dynamo/guards.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,10 @@ bool TensorCheck::check(
const c10::SymIntArrayRef& sym_sizes,
const c10::SymIntArrayRef& sym_strides,
const bool& requires_grad) {
if (dispatch_key_ != state.apply(dispatch_key_set).raw_repr() ||
dtype_ != dtype || device_index_ != device.index() ||
requires_grad_ != requires_grad) {
if (requires_grad_ != requires_grad ||
(requires_grad_ &&
dispatch_key_ != state.apply(dispatch_key_set).raw_repr()) ||
dtype_ != dtype || device_index_ != device.index()) {
return false;
}

Expand Down Expand Up @@ -187,7 +188,14 @@ std::string TensorCheck::check_verbose(
const std::string& tensor_name) {
std::stringstream fail_reason;
fail_reason << "tensor '" << tensor_name << "' ";
if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) {
if (requires_grad_ != v.requires_grad()) {
// return fmt::format("tensor requires_grad mismatch. expected {}",
// requires_grad_);
fail_reason << "requires_grad mismatch. expected requires_grad="
<< requires_grad_;
return fail_reason.str();
} else if (
requires_grad_ && dispatch_key_ != state.apply(v.key_set()).raw_repr()) {
// return fmt::format("tensor dispatch key mismatch. expected {}, actual
// {}", dispatch_key_, state.apply(v.key_set()).raw_repr());
fail_reason << "dispatch key set mismatch. expected "
Expand All @@ -204,12 +212,6 @@ std::string TensorCheck::check_verbose(
fail_reason << "Tensor device index mismatch. Expected device index to be "
<< device_index_ << ", actual " << v.device().index();
return fail_reason.str();
} else if (requires_grad_ != v.requires_grad()) {
// return fmt::format("tensor requires_grad mismatch. expected {}",
// requires_grad_);
fail_reason << "requires_grad mismatch. expected requires_grad="
<< requires_grad_;
return fail_reason.str();
}
auto ndim = v.ndimension();
if (ndim != dim_) {
Expand Down
Loading