You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
<spanclass="k">raise</span><spanclass="ne">RuntimeError</span><spanclass="p">(</span><spanclass="s2">"Batched grads are not supported with Nested Tensor."</span><spanclass="p">)</span>
<spanclass="k">raise</span><spanclass="ne">RuntimeError</span><spanclass="p">(</span><spanclass="s2">"If `is_grads_batched=True`, we interpret the first "</span>
468
486
<spanclass="s2">"dimension of each grad_output as the batch dimension. "</span>
469
487
<spanclass="s2">"The sizes of the remaining dimensions are expected to match "</span>
470
488
<spanclass="s2">"the shape of corresponding output, but a mismatch "</span>
<spanclass="o">+</span><spanclass="nb">str</span><spanclass="p">(</span><spanclass="n">grads</span><spanclass="o">.</span><spanclass="n">index</span><spanclass="p">(</span><spanclass="n">grad</span><spanclass="p">))</span><spanclass="o">+</span><spanclass="s2">"] has a shape of "</span>
473
-
<spanclass="o">+</span><spanclass="nb">str</span><spanclass="p">(</span><spanclass="n">grad</span><spanclass="o">.</span><spanclass="n">shape</span><spanclass="p">)</span><spanclass="o">+</span><spanclass="s2">" and output["</span>
491
+
<spanclass="o">+</span><spanclass="nb">str</span><spanclass="p">(</span><spanclass="n">grad_shape</span><spanclass="p">)</span><spanclass="o">+</span><spanclass="s2">" and output["</span>
474
492
<spanclass="o">+</span><spanclass="nb">str</span><spanclass="p">(</span><spanclass="n">outputs</span><spanclass="o">.</span><spanclass="n">index</span><spanclass="p">(</span><spanclass="n">out</span><spanclass="p">))</span><spanclass="o">+</span><spanclass="s2">"] has a shape of "</span>
<spanclass="s2">"If you only want some tensors in `grad_output` to be considered "</span>
477
495
<spanclass="s2">"batched, consider using vmap."</span><spanclass="p">)</span>
478
496
<spanclass="k">else</span><spanclass="p">:</span>
479
497
<spanclass="k">raise</span><spanclass="ne">RuntimeError</span><spanclass="p">(</span><spanclass="s2">"Mismatch in shape: grad_output["</span>
480
498
<spanclass="o">+</span><spanclass="nb">str</span><spanclass="p">(</span><spanclass="n">grads</span><spanclass="o">.</span><spanclass="n">index</span><spanclass="p">(</span><spanclass="n">grad</span><spanclass="p">))</span><spanclass="o">+</span><spanclass="s2">"] has a shape of "</span>
481
-
<spanclass="o">+</span><spanclass="nb">str</span><spanclass="p">(</span><spanclass="n">grad</span><spanclass="o">.</span><spanclass="n">shape</span><spanclass="p">)</span><spanclass="o">+</span><spanclass="s2">" and output["</span>
499
+
<spanclass="o">+</span><spanclass="nb">str</span><spanclass="p">(</span><spanclass="n">grad_shape</span><spanclass="p">)</span><spanclass="o">+</span><spanclass="s2">" and output["</span>
482
500
<spanclass="o">+</span><spanclass="nb">str</span><spanclass="p">(</span><spanclass="n">outputs</span><spanclass="o">.</span><spanclass="n">index</span><spanclass="p">(</span><spanclass="n">out</span><spanclass="p">))</span><spanclass="o">+</span><spanclass="s2">"] has a shape of "</span>
<spanclass="k">raise</span><spanclass="ne">RuntimeError</span><spanclass="p">(</span><spanclass="s2">"For complex Tensors, both grad_output and output"</span>
486
504
<spanclass="s2">" are required to have the same dtype."</span>
0 commit comments