Skip to content

Commit 3100b78

Browse files
rohitgr7tchatonJeff Yangs-rogBorda
authored
Allow any input in to_onnx and to_torchscript (Lightning-AI#4378)
* branch merge * sample * update with valid input tensors * pep * pathlib * Updated with BoringModel and added more input types * try fix * pep * skip test with torch < 1.4 * fix test * Apply suggestions from code review * update tests * Allow any input in to_onnx and to_torchscript * Update tests/models/test_torchscript.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * no_grad * try fix random failing test * rm example_input_array * rm example_input_array Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jeff Yang <ydcjeff@outlook.com> Co-authored-by: Roger Shieh <sh.rog@protonmail.ch> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: edenlightning <66261195+edenlightning@users.noreply.github.com>
1 parent b5a2afd commit 3100b78

File tree

5 files changed

+139
-77
lines changed

5 files changed

+139
-77
lines changed

pytorch_lightning/core/hooks.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Various hooks to be used in the Lightning code."""
1616

17-
from typing import Any, Dict, List, Union
17+
from typing import Any, Dict, List, Optional, Union
1818

1919
import torch
2020
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
@@ -501,7 +501,7 @@ def val_dataloader(self):
501501
will have an argument ``dataloader_idx`` which matches the order here.
502502
"""
503503

504-
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
504+
def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
505505
"""
506506
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
507507
wrapped in a custom data structure.
@@ -549,6 +549,7 @@ def transfer_batch_to_device(self, batch, device)
549549
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
550550
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
551551
"""
552+
device = device or self.device
552553
return move_data_to_device(batch, device)
553554

554555

pytorch_lightning/core/lightning.py

+51-36
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import tempfile
2323
from abc import ABC
2424
from argparse import Namespace
25+
from pathlib import Path
2526
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
2627

2728
import torch
@@ -1530,12 +1531,19 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
15301531
else:
15311532
self._hparams = hp
15321533

1533-
def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs):
1534-
"""Saves the model in ONNX format
1534+
@torch.no_grad()
1535+
def to_onnx(
1536+
self,
1537+
file_path: Union[str, Path],
1538+
input_sample: Optional[Any] = None,
1539+
**kwargs,
1540+
):
1541+
"""
1542+
Saves the model in ONNX format
15351543
15361544
Args:
1537-
file_path: The path of the file the model should be saved to.
1538-
input_sample: A sample of an input tensor for tracing.
1545+
file_path: The path of the file the onnx model should be saved to.
1546+
input_sample: An input for tracing. Default: None (Use self.example_input_array)
15391547
**kwargs: Will be passed to torch.onnx.export function.
15401548
15411549
Example:
@@ -1554,31 +1562,32 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
15541562
... os.path.isfile(tmpfile.name)
15551563
True
15561564
"""
1565+
mode = self.training
15571566

1558-
if isinstance(input_sample, Tensor):
1559-
input_data = input_sample
1560-
elif self.example_input_array is not None:
1561-
input_data = self.example_input_array
1562-
else:
1563-
if input_sample is not None:
1567+
if input_sample is None:
1568+
if self.example_input_array is None:
15641569
raise ValueError(
1565-
f"Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`"
1570+
"Could not export to ONNX since neither `input_sample` nor"
1571+
" `model.example_input_array` attribute is set."
15661572
)
1567-
raise ValueError(
1568-
"Could not export to ONNX since neither `input_sample` nor"
1569-
" `model.example_input_array` attribute is set."
1570-
)
1571-
input_data = input_data.to(self.device)
1573+
input_sample = self.example_input_array
1574+
1575+
input_sample = self.transfer_batch_to_device(input_sample)
1576+
15721577
if "example_outputs" not in kwargs:
15731578
self.eval()
1574-
with torch.no_grad():
1575-
kwargs["example_outputs"] = self(input_data)
1579+
kwargs["example_outputs"] = self(input_sample)
15761580

1577-
torch.onnx.export(self, input_data, file_path, **kwargs)
1581+
torch.onnx.export(self, input_sample, file_path, **kwargs)
1582+
self.train(mode)
15781583

1584+
@torch.no_grad()
15791585
def to_torchscript(
1580-
self, file_path: Optional[str] = None, method: Optional[str] = 'script',
1581-
example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs
1586+
self,
1587+
file_path: Optional[Union[str, Path]] = None,
1588+
method: Optional[str] = 'script',
1589+
example_inputs: Optional[Any] = None,
1590+
**kwargs,
15821591
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
15831592
"""
15841593
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
@@ -1590,7 +1599,7 @@ def to_torchscript(
15901599
Args:
15911600
file_path: Path where to save the torchscript. Default: None (no file saved).
15921601
method: Whether to use TorchScript's script or trace method. Default: 'script'
1593-
example_inputs: Tensor to be used to do tracing when method is set to 'trace'.
1602+
example_inputs: An input to be used to do tracing when method is set to 'trace'.
15941603
Default: None (Use self.example_input_array)
15951604
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or
15961605
:func:`torch.jit.trace` function.
@@ -1624,21 +1633,27 @@ def to_torchscript(
16241633
This LightningModule as a torchscript, regardless of whether file_path is
16251634
defined or not.
16261635
"""
1627-
16281636
mode = self.training
1629-
with torch.no_grad():
1630-
if method == 'script':
1631-
torchscript_module = torch.jit.script(self.eval(), **kwargs)
1632-
elif method == 'trace':
1633-
# if no example inputs are provided, try to see if model has example_input_array set
1634-
if example_inputs is None:
1635-
example_inputs = self.example_input_array
1636-
# automatically send example inputs to the right device and use trace
1637-
example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device)
1638-
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
1639-
else:
1640-
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:"
1641-
f"{method}")
1637+
1638+
if method == 'script':
1639+
torchscript_module = torch.jit.script(self.eval(), **kwargs)
1640+
elif method == 'trace':
1641+
# if no example inputs are provided, try to see if model has example_input_array set
1642+
if example_inputs is None:
1643+
if self.example_input_array is None:
1644+
raise ValueError(
1645+
'Choosing method=`trace` requires either `example_inputs`'
1646+
' or `model.example_input_array` to be defined'
1647+
)
1648+
example_inputs = self.example_input_array
1649+
1650+
# automatically send example inputs to the right device and use trace
1651+
example_inputs = self.transfer_batch_to_device(example_inputs)
1652+
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
1653+
else:
1654+
raise ValueError("The 'method' parameter only supports 'script' or 'trace',"
1655+
f" but value given was: {method}")
1656+
16421657
self.train(mode)
16431658

16441659
if file_path is not None:

tests/models/test_onnx.py

+22-27
Original file line numberDiff line numberDiff line change
@@ -21,44 +21,44 @@
2121
import tests.base.develop_pipelines as tpipes
2222
import tests.base.develop_utils as tutils
2323
from pytorch_lightning import Trainer
24-
from tests.base import EvalModelTemplate
24+
from tests.base import BoringModel, EvalModelTemplate
2525

2626

2727
def test_model_saves_with_input_sample(tmpdir):
2828
"""Test that ONNX model saves with input sample and size is greater than 3 MB"""
29-
model = EvalModelTemplate()
29+
model = BoringModel()
3030
trainer = Trainer(max_epochs=1)
3131
trainer.fit(model)
3232

3333
file_path = os.path.join(tmpdir, "model.onnx")
34-
input_sample = torch.randn((1, 28 * 28))
34+
input_sample = torch.randn((1, 32))
3535
model.to_onnx(file_path, input_sample)
3636
assert os.path.isfile(file_path)
37-
assert os.path.getsize(file_path) > 3e+06
37+
assert os.path.getsize(file_path) > 4e2
3838

3939

4040
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
4141
def test_model_saves_on_gpu(tmpdir):
4242
"""Test that model saves on gpu"""
43-
model = EvalModelTemplate()
43+
model = BoringModel()
4444
trainer = Trainer(gpus=1, max_epochs=1)
4545
trainer.fit(model)
4646

4747
file_path = os.path.join(tmpdir, "model.onnx")
48-
input_sample = torch.randn((1, 28 * 28))
48+
input_sample = torch.randn((1, 32))
4949
model.to_onnx(file_path, input_sample)
5050
assert os.path.isfile(file_path)
51-
assert os.path.getsize(file_path) > 3e+06
51+
assert os.path.getsize(file_path) > 4e2
5252

5353

5454
def test_model_saves_with_example_output(tmpdir):
5555
"""Test that ONNX model saves when provided with example output"""
56-
model = EvalModelTemplate()
56+
model = BoringModel()
5757
trainer = Trainer(max_epochs=1)
5858
trainer.fit(model)
5959

6060
file_path = os.path.join(tmpdir, "model.onnx")
61-
input_sample = torch.randn((1, 28 * 28))
61+
input_sample = torch.randn((1, 32))
6262
model.eval()
6363
example_outputs = model.forward(input_sample)
6464
model.to_onnx(file_path, input_sample, example_outputs=example_outputs)
@@ -67,11 +67,13 @@ def test_model_saves_with_example_output(tmpdir):
6767

6868
def test_model_saves_with_example_input_array(tmpdir):
6969
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
70-
model = EvalModelTemplate()
70+
model = BoringModel()
71+
model.example_input_array = torch.randn(5, 32)
72+
7173
file_path = os.path.join(tmpdir, "model.onnx")
7274
model.to_onnx(file_path)
7375
assert os.path.exists(file_path) is True
74-
assert os.path.getsize(file_path) > 3e+06
76+
assert os.path.getsize(file_path) > 4e2
7577

7678

7779
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@@ -100,38 +102,31 @@ def test_model_saves_on_multi_gpu(tmpdir):
100102

101103
def test_verbose_param(tmpdir, capsys):
102104
"""Test that output is present when verbose parameter is set"""
103-
model = EvalModelTemplate()
105+
model = BoringModel()
106+
model.example_input_array = torch.randn(5, 32)
107+
104108
file_path = os.path.join(tmpdir, "model.onnx")
105109
model.to_onnx(file_path, verbose=True)
106110
captured = capsys.readouterr()
107111
assert "graph(%" in captured.out
108112

109113

110114
def test_error_if_no_input(tmpdir):
111-
"""Test that an exception is thrown when there is no input tensor"""
112-
model = EvalModelTemplate()
115+
"""Test that an error is thrown when there is no input tensor"""
116+
model = BoringModel()
113117
model.example_input_array = None
114118
file_path = os.path.join(tmpdir, "model.onnx")
115119
with pytest.raises(ValueError, match=r'Could not export to ONNX since neither `input_sample` nor'
116120
r' `model.example_input_array` attribute is set.'):
117121
model.to_onnx(file_path)
118122

119123

120-
def test_error_if_input_sample_is_not_tensor(tmpdir):
121-
"""Test that an exception is thrown when there is no input tensor"""
122-
model = EvalModelTemplate()
123-
model.example_input_array = None
124-
file_path = os.path.join(tmpdir, "model.onnx")
125-
input_sample = np.random.randn(1, 28 * 28)
126-
with pytest.raises(ValueError, match=f'Received `input_sample` of type {type(input_sample)}. Expected type is '
127-
f'`Tensor`'):
128-
model.to_onnx(file_path, input_sample)
129-
130-
131124
def test_if_inference_output_is_valid(tmpdir):
132125
"""Test that the output inferred from ONNX model is same as from PyTorch"""
133-
model = EvalModelTemplate()
134-
trainer = Trainer(max_epochs=5)
126+
model = BoringModel()
127+
model.example_input_array = torch.randn(5, 32)
128+
129+
trainer = Trainer(max_epochs=2)
135130
trainer.fit(model)
136131

137132
model.eval()

0 commit comments

Comments
 (0)