diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index c0cc2b3e72f0..69dc64b80ded 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -121,7 +121,14 @@ def _unwrap_var(var): elif isinstance(var, ConstantVariable): return var.as_python_constant() else: - unimplemented(f"Cannot unwrap var {var}") + unimplemented_v2( + gb_type="HOP tracing expects Tensors, SymNodes, or constants", + context=type(var), + explanation="HOP tracing expects Tensors, SymNodes or constants", + hints=[ + "Please open an issue.", + ], + ) unwrapped1 = [_unwrap_var(var) for var in vars1] unwrapped2 = [_unwrap_var(var) for var in vars2] @@ -244,8 +251,13 @@ def _check_all_tensorvariable(args): from . import TensorVariable if not all(type(a.realize()) is TensorVariable for a in args): - unimplemented( - f"Expected all leaves to be of torch.Tensor type, but got {[type(a.realize()) for a in args]}." + unimplemented_v2( + gb_type="HOP input type restrictions", + context=None, + explanation=f"Expected all leaves to be of torch.Tensor type, but got {[type(a.realize()) for a in args]}", + hints=[ + "Please only use Tensors", + ], ) @@ -256,8 +268,13 @@ def _check_supported_callable_arg( BuiltinVariable(callable).call_function(tx, [func_var], {}).as_python_constant() ) if not is_callable: - unimplemented( - f"{arg_name} should be a Callable but is of type {str(func_var)}." + unimplemented_v2( + gb_type="HOP input type restrictions", + context=None, + explanation=f"{arg_name} should be a Callable but is of type {str(func_var)}.", + hints=[ + "Please check to make sure the argument is a Callable", + ], ) @@ -479,10 +496,16 @@ def validate_args_and_maybe_create_graph_inputs( # If `a` cannot be put into a graph else: # HOPs work much better if they use speculate_subgraph(set_subgraph_inputs="automatic"). - unimplemented( - f"{description} with body that accepts non-Tensors as input. " - f"Got: {a.python_type()}" + unimplemented_v2( + gb_type="HOP input type restrictions", + context=None, + explanation=( + f"{description} with body that accepts non-Tensors as input. " + f"Got: {a.python_type()}" + ), + hints=[], ) + args.append(new_arg) return args @@ -633,7 +656,12 @@ def speculate_subgraph( # See NOTE [Temporary argument `set_subgraph_inputs`] if sub_kwargs and set_subgraph_inputs != "automatic": - unimplemented("Use `set_subgraph_inputs=automatic` when passing `sub_kwargs`.") + unimplemented_v2( + gb_type="HOP internal assertion", + context=None, + explanation="Use `set_subgraph_inputs=automatic` when passing `sub_kwargs`.", + hints=["Please file an issue"], + ) try: # ensure guards on args get installed in parent subgraph @@ -931,7 +959,12 @@ def make(value, source=None, **kwargs): elif value.__name__ == "custom_function_call": return CustomFunctionHigherOrderOperatorVariable(value, source, **kwargs) else: - unimplemented(f"HigherOrderOperator {value.__name__}") + unimplemented_v2( + gb_type="Unsupported HOP", + context=value.__name__, + explanation="Dynamo does not support {value.__name__}", + hints=[*torch._dynamo.graph_break_hints.SUPPORTABLE], + ) def call_function( self, @@ -939,7 +972,12 @@ def call_function( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - unimplemented(f"HigherOrderOperator {self.value.__name__}") + unimplemented_v2( + gb_type="Unsupported HOP", + context=self.value.__name__, + explanation="Dynamo does not support {self.value.__name__}.__call__", + hints=[*torch._dynamo.graph_break_hints.SUPPORTABLE], + ) def as_python_constant(self): return self.value @@ -989,13 +1027,21 @@ def call_function( ) args.append(v) + def raise_error(msg): + return unimplemented_v2( + gb_type="torch.cond invalid input", + context=None, + explanation=f"torch.cond: Got unexpected kwargs: {list(kwargs.keys())}", + hints=["This is user error"], + ) + if kwargs: - unimplemented(f"torch.cond: Got unexpected kwargs: {list(kwargs.keys())}") + raise_error(f"torch.cond: Got unexpected kwargs: {list(kwargs.keys())}") # TODO(voz): Support fake tensor dispatch for recursive # ops - see torch/dispatch/_dispatcher.py if len(args) != 4: - unimplemented( + raise_error( f"Expected 4 arguments but got {len(args)}.\n" f"Usage: cond(pred, true_fn, false_fn, operands)", ) @@ -1015,7 +1061,7 @@ def call_function( # predicate if type(pred) not in (ConstantVariable, TensorVariable, SymNodeVariable): - unimplemented( + raise_error( f"Expected pred to be bool or a boolean tensor with single " f"item but got {str(type(pred))} " f"with original python type {str(pred.python_type())}.", @@ -1023,13 +1069,13 @@ def call_function( # operands if not isinstance(operands, (ListVariable, TupleVariable)): - unimplemented( + raise_error( f"Expected operands to be a list/tuple but got " f"{operands.python_type()}", ) operands_seq = operands.unpack_var_sequence(tx) if not only_consist_of(operands, (TensorVariable, ConstantVariable)): - unimplemented( + raise_error( "Expect operands to be a tuple of pytrees that only consists of tensor leaves." ) @@ -1071,13 +1117,13 @@ def speculate_branch(branch): ) if not only_consist_of(ret_val, (TensorVariable, ConstantVariable)): - unimplemented( + raise_error( "Expected branches to return a possibly nested pytree of tensors " "or constant ints but it consists of others.", ) for ret in ret_val.unpack_var_sequence(tx): if isinstance(ret, ConstantVariable) and ret.python_type() is not int: - unimplemented( + raise_error( "Expected branches to return a possibly nested pytree of tensors " f"or constant ints but it consists of others {ret.python_type()}.", ) @@ -1100,7 +1146,7 @@ def speculate_branch(branch): true_treespec, false_treespec ) if not same_treespec.as_python_constant(): - unimplemented("Expected branches to return the same pytree structure.") + raise_error("Expected branches to return the same pytree structure.") ( true_graph,