|
235 | 235 | <div class="pytorch-left-menu-search">
|
236 | 236 |
|
237 | 237 | <div class="version">
|
238 |
| - <a href='https://pytorch.org/docs/versions.html'>master (2.0.0a0+gite839313 ) ▼</a> |
| 238 | + <a href='https://pytorch.org/docs/versions.html'>master (2.0.0a0+gitf012d0e ) ▼</a> |
239 | 239 | </div>
|
240 | 240 |
|
241 | 241 |
|
@@ -472,6 +472,7 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
|
472 | 472 | <span class="sd">on an NVIDIA GPU with compute capability >= 3.0.</span>
|
473 | 473 | <span class="sd">"""</span>
|
474 | 474 |
|
| 475 | +<span class="kn">import</span> <span class="nn">math</span> |
475 | 476 | <span class="kn">import</span> <span class="nn">os</span>
|
476 | 477 | <span class="kn">import</span> <span class="nn">sys</span>
|
477 | 478 | <span class="kn">import</span> <span class="nn">platform</span>
|
@@ -511,7 +512,7 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
|
511 | 512 | <span class="s1">'set_deterministic_debug_mode'</span><span class="p">,</span> <span class="s1">'get_deterministic_debug_mode'</span><span class="p">,</span>
|
512 | 513 | <span class="s1">'set_float32_matmul_precision'</span><span class="p">,</span> <span class="s1">'get_float32_matmul_precision'</span><span class="p">,</span>
|
513 | 514 | <span class="s1">'set_warn_always'</span><span class="p">,</span> <span class="s1">'is_warn_always_enabled'</span><span class="p">,</span> <span class="s1">'SymInt'</span><span class="p">,</span> <span class="s1">'SymFloat'</span><span class="p">,</span>
|
514 |
| - <span class="s1">'compile'</span><span class="p">,</span> <span class="s1">'vmap'</span><span class="p">,</span> |
| 515 | + <span class="s1">'sym_int'</span><span class="p">,</span> <span class="s1">'sym_float'</span><span class="p">,</span> <span class="s1">'compile'</span><span class="p">,</span> <span class="s1">'vmap'</span> |
515 | 516 | <span class="p">]</span>
|
516 | 517 |
|
517 | 518 | <span class="c1">################################################################################</span>
|
@@ -772,6 +773,39 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
|
772 | 773 | <span class="k">def</span> <span class="nf">get_pyobj</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
|
773 | 774 | <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">node</span>
|
774 | 775 |
|
| 776 | +<div class="viewcode-block" id="sym_float"><a class="viewcode-back" href="../generated/torch.sym_float.html#torch.sym_float">[docs]</a><span class="k">def</span> <span class="nf">sym_float</span><span class="p">(</span><span class="n">a</span><span class="p">):</span> |
| 777 | + <span class="sa">r</span><span class="sd">""" SymInt-aware utility for float casting.</span> |
| 778 | + |
| 779 | +<span class="sd"> Args:</span> |
| 780 | +<span class="sd"> a (SymInt, SymFloat, or object): Object to cast</span> |
| 781 | +<span class="sd"> """</span> |
| 782 | + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">SymFloat</span><span class="p">):</span> |
| 783 | + <span class="k">return</span> <span class="n">a</span> |
| 784 | + <span class="k">elif</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="s1">'__sym_float__'</span><span class="p">):</span> |
| 785 | + <span class="k">return</span> <span class="n">a</span><span class="o">.</span><span class="n">__sym_float__</span><span class="p">()</span> |
| 786 | + <span class="k">return</span> <span class="n">py_float</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="c1"># type: ignore[operator]</span></div> |
| 787 | + |
| 788 | +<span class="c1"># Drop in replacement for math.floor/ceil. Actually, math.floor/ceil</span> |
| 789 | +<span class="c1"># directly usable, but this has a more relaxed type signature for mypy</span> |
| 790 | +<span class="c1"># (mypy requires SupportFloat which is too strict)</span> |
| 791 | +<span class="k">def</span> <span class="nf">_sym_floor</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> |
| 792 | + <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># type: ignore[type]</span> |
| 793 | + |
| 794 | +<span class="k">def</span> <span class="nf">_sym_ceil</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> |
| 795 | + <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># type: ignore[type]</span> |
| 796 | + |
| 797 | +<div class="viewcode-block" id="sym_int"><a class="viewcode-back" href="../generated/torch.sym_int.html#torch.sym_int">[docs]</a><span class="k">def</span> <span class="nf">sym_int</span><span class="p">(</span><span class="n">a</span><span class="p">):</span> |
| 798 | + <span class="sa">r</span><span class="sd">""" SymInt-aware utility for int casting.</span> |
| 799 | + |
| 800 | +<span class="sd"> Args:</span> |
| 801 | +<span class="sd"> a (SymInt, SymFloat, or object): Object to cast</span> |
| 802 | +<span class="sd"> """</span> |
| 803 | + <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">SymInt</span><span class="p">):</span> |
| 804 | + <span class="k">return</span> <span class="n">a</span> |
| 805 | + <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">SymFloat</span><span class="p">):</span> |
| 806 | + <span class="k">return</span> <span class="n">_sym_floor</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="k">if</span> <span class="n">a</span> <span class="o">></span> <span class="mi">0</span> <span class="k">else</span> <span class="n">_sym_ceil</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> |
| 807 | + <span class="k">return</span> <span class="n">py_int</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="c1"># type: ignore[operator]</span></div> |
| 808 | + |
775 | 809 | <span class="c1"># Check to see if we can load C extensions, and if not provide some guidance</span>
|
776 | 810 | <span class="c1"># on what the problem might be.</span>
|
777 | 811 | <span class="k">try</span><span class="p">:</span>
|
@@ -1429,6 +1463,11 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
|
1429 | 1463 |
|
1430 | 1464 | <span class="kn">from</span> <span class="nn">torch.amp</span> <span class="kn">import</span> <span class="n">autocast</span>
|
1431 | 1465 |
|
| 1466 | +<span class="c1"># Initializing the extension shadows the built-in python float / int classes;</span> |
| 1467 | +<span class="c1"># store them for later use by SymInt / SymFloat.</span> |
| 1468 | +<span class="n">py_float</span> <span class="o">=</span> <span class="nb">float</span> |
| 1469 | +<span class="n">py_int</span> <span class="o">=</span> <span class="nb">int</span> |
| 1470 | + |
1432 | 1471 | <span class="c1"># Shared memory manager needs to know the exact location of manager executable</span>
|
1433 | 1472 | <span class="n">_C</span><span class="o">.</span><span class="n">_initExtension</span><span class="p">(</span><span class="n">manager_path</span><span class="p">())</span>
|
1434 | 1473 | <span class="k">del</span> <span class="n">manager_path</span>
|
|
0 commit comments