22
22
import tempfile
23
23
from abc import ABC
24
24
from argparse import Namespace
25
+ from pathlib import Path
25
26
from typing import Any , Callable , Dict , List , Mapping , Optional , Sequence , Tuple , Union
26
27
27
28
import torch
@@ -1530,12 +1531,19 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
1530
1531
else :
1531
1532
self ._hparams = hp
1532
1533
1533
- def to_onnx (self , file_path : str , input_sample : Optional [Tensor ] = None , ** kwargs ):
1534
- """Saves the model in ONNX format
1534
+ @torch .no_grad ()
1535
+ def to_onnx (
1536
+ self ,
1537
+ file_path : Union [str , Path ],
1538
+ input_sample : Optional [Any ] = None ,
1539
+ ** kwargs ,
1540
+ ):
1541
+ """
1542
+ Saves the model in ONNX format
1535
1543
1536
1544
Args:
1537
- file_path: The path of the file the model should be saved to.
1538
- input_sample: A sample of an input tensor for tracing.
1545
+ file_path: The path of the file the onnx model should be saved to.
1546
+ input_sample: An input for tracing. Default: None (Use self.example_input_array)
1539
1547
**kwargs: Will be passed to torch.onnx.export function.
1540
1548
1541
1549
Example:
@@ -1554,31 +1562,32 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
1554
1562
... os.path.isfile(tmpfile.name)
1555
1563
True
1556
1564
"""
1565
+ mode = self .training
1557
1566
1558
- if isinstance (input_sample , Tensor ):
1559
- input_data = input_sample
1560
- elif self .example_input_array is not None :
1561
- input_data = self .example_input_array
1562
- else :
1563
- if input_sample is not None :
1567
+ if input_sample is None :
1568
+ if self .example_input_array is None :
1564
1569
raise ValueError (
1565
- f"Received `input_sample` of type { type (input_sample )} . Expected type is `Tensor`"
1570
+ "Could not export to ONNX since neither `input_sample` nor"
1571
+ " `model.example_input_array` attribute is set."
1566
1572
)
1567
- raise ValueError (
1568
- "Could not export to ONNX since neither `input_sample` nor"
1569
- " `model.example_input_array` attribute is set."
1570
- )
1571
- input_data = input_data .to (self .device )
1573
+ input_sample = self .example_input_array
1574
+
1575
+ input_sample = self .transfer_batch_to_device (input_sample )
1576
+
1572
1577
if "example_outputs" not in kwargs :
1573
1578
self .eval ()
1574
- with torch .no_grad ():
1575
- kwargs ["example_outputs" ] = self (input_data )
1579
+ kwargs ["example_outputs" ] = self (input_sample )
1576
1580
1577
- torch .onnx .export (self , input_data , file_path , ** kwargs )
1581
+ torch .onnx .export (self , input_sample , file_path , ** kwargs )
1582
+ self .train (mode )
1578
1583
1584
+ @torch .no_grad ()
1579
1585
def to_torchscript (
1580
- self , file_path : Optional [str ] = None , method : Optional [str ] = 'script' ,
1581
- example_inputs : Optional [Union [torch .Tensor , Tuple [torch .Tensor ]]] = None , ** kwargs
1586
+ self ,
1587
+ file_path : Optional [Union [str , Path ]] = None ,
1588
+ method : Optional [str ] = 'script' ,
1589
+ example_inputs : Optional [Any ] = None ,
1590
+ ** kwargs ,
1582
1591
) -> Union [ScriptModule , Dict [str , ScriptModule ]]:
1583
1592
"""
1584
1593
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
@@ -1590,7 +1599,7 @@ def to_torchscript(
1590
1599
Args:
1591
1600
file_path: Path where to save the torchscript. Default: None (no file saved).
1592
1601
method: Whether to use TorchScript's script or trace method. Default: 'script'
1593
- example_inputs: Tensor to be used to do tracing when method is set to 'trace'.
1602
+ example_inputs: An input to be used to do tracing when method is set to 'trace'.
1594
1603
Default: None (Use self.example_input_array)
1595
1604
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or
1596
1605
:func:`torch.jit.trace` function.
@@ -1624,21 +1633,27 @@ def to_torchscript(
1624
1633
This LightningModule as a torchscript, regardless of whether file_path is
1625
1634
defined or not.
1626
1635
"""
1627
-
1628
1636
mode = self .training
1629
- with torch .no_grad ():
1630
- if method == 'script' :
1631
- torchscript_module = torch .jit .script (self .eval (), ** kwargs )
1632
- elif method == 'trace' :
1633
- # if no example inputs are provided, try to see if model has example_input_array set
1634
- if example_inputs is None :
1635
- example_inputs = self .example_input_array
1636
- # automatically send example inputs to the right device and use trace
1637
- example_inputs = self .transfer_batch_to_device (example_inputs , device = self .device )
1638
- torchscript_module = torch .jit .trace (func = self .eval (), example_inputs = example_inputs , ** kwargs )
1639
- else :
1640
- raise ValueError (f"The 'method' parameter only supports 'script' or 'trace', but value given was:"
1641
- f"{ method } " )
1637
+
1638
+ if method == 'script' :
1639
+ torchscript_module = torch .jit .script (self .eval (), ** kwargs )
1640
+ elif method == 'trace' :
1641
+ # if no example inputs are provided, try to see if model has example_input_array set
1642
+ if example_inputs is None :
1643
+ if self .example_input_array is None :
1644
+ raise ValueError (
1645
+ 'Choosing method=`trace` requires either `example_inputs`'
1646
+ ' or `model.example_input_array` to be defined'
1647
+ )
1648
+ example_inputs = self .example_input_array
1649
+
1650
+ # automatically send example inputs to the right device and use trace
1651
+ example_inputs = self .transfer_batch_to_device (example_inputs )
1652
+ torchscript_module = torch .jit .trace (func = self .eval (), example_inputs = example_inputs , ** kwargs )
1653
+ else :
1654
+ raise ValueError ("The 'method' parameter only supports 'script' or 'trace',"
1655
+ f" but value given was: { method } " )
1656
+
1642
1657
self .train (mode )
1643
1658
1644
1659
if file_path is not None :
0 commit comments