diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index b2c4342bb116..eed238b76473 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -508,12 +508,18 @@ def convert_to_concrete_values(size_or_stride): def get_tensor_guard_code_part(value, name, sizes, strides, pytype, dispatch_keys): - dispatch_key = ( - dispatch_keys | torch._C._dispatch_tls_local_include_set() - ) - torch._C._dispatch_tls_local_exclude_set() + requires_grad = value.requires_grad + if requires_grad: + dispatch_key = ( + dispatch_keys | torch._C._dispatch_tls_local_include_set() + ) - torch._C._dispatch_tls_local_exclude_set() + else: + # Get an empty dispatch key set because with requires_grad we dont guard + # on it. + dispatch_key = dispatch_keys - dispatch_keys + dtype = value.dtype device_index = value.device.index - requires_grad = value.requires_grad guard_str = ( f"check_tensor({name}, {pytype.__qualname__}, {dispatch_key}, {dtype}, " f"device={device_index}, requires_grad={requires_grad}, size={sizes}, stride={strides})" diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index c7ea84ed5529..8f1e3483bf63 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -151,9 +151,10 @@ bool TensorCheck::check( const c10::SymIntArrayRef& sym_sizes, const c10::SymIntArrayRef& sym_strides, const bool& requires_grad) { - if (dispatch_key_ != state.apply(dispatch_key_set).raw_repr() || - dtype_ != dtype || device_index_ != device.index() || - requires_grad_ != requires_grad) { + if (requires_grad_ != requires_grad || + (requires_grad_ && + dispatch_key_ != state.apply(dispatch_key_set).raw_repr()) || + dtype_ != dtype || device_index_ != device.index()) { return false; } @@ -187,7 +188,14 @@ std::string TensorCheck::check_verbose( const std::string& tensor_name) { std::stringstream fail_reason; fail_reason << "tensor '" << tensor_name << "' "; - if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) { + if (requires_grad_ != v.requires_grad()) { + // return fmt::format("tensor requires_grad mismatch. expected {}", + // requires_grad_); + fail_reason << "requires_grad mismatch. expected requires_grad=" + << requires_grad_; + return fail_reason.str(); + } else if ( + requires_grad_ && dispatch_key_ != state.apply(v.key_set()).raw_repr()) { // return fmt::format("tensor dispatch key mismatch. expected {}, actual // {}", dispatch_key_, state.apply(v.key_set()).raw_repr()); fail_reason << "dispatch key set mismatch. expected " @@ -204,12 +212,6 @@ std::string TensorCheck::check_verbose( fail_reason << "Tensor device index mismatch. Expected device index to be " << device_index_ << ", actual " << v.device().index(); return fail_reason.str(); - } else if (requires_grad_ != v.requires_grad()) { - // return fmt::format("tensor requires_grad mismatch. expected {}", - // requires_grad_); - fail_reason << "requires_grad mismatch. expected requires_grad=" - << requires_grad_; - return fail_reason.str(); } auto ndim = v.ndimension(); if (ndim != dim_) {