@@ -184,8 +184,13 @@ Example::
184
184
# and type int in the false branch
185
185
186
186
By default, all parameters to a Torch Script function are assumed to be Tensor
187
- because this is the most common type used in modules. To specify that an
188
- argument to a Torch Script function is another type, it is possible to use
187
+ because this is the most common type used in modules.
188
+
189
+ There are 2 scenarios in which you might want to annotate a type:
190
+
191
+ 1. Annotating Function Argument Types
192
+
193
+ To specify that an argument to a Torch Script function is another type, it is possible to use
189
194
MyPy-style type annotations using the types listed above:
190
195
191
196
Example::
@@ -203,6 +208,32 @@ Example::
203
208
In our examples, we use comment-based annotations to ensure Python 2
204
209
compatibility as well.
205
210
211
+ 2. Annotating Variable Types
212
+
213
+ For example, a list by default is assumed to be List[Tensor]. If you would like to
214
+ have a list of other types. PyTorch provides annotation functions.
215
+
216
+ Example::
217
+ import torch
218
+ from torch.jit import Tensor
219
+ from typing import List, Tuple
220
+
221
+ class ListOfTupleOfTensor(torch.jit.ScriptModule):
222
+ def __init__(self):
223
+ super(ListOfTupleOfTensor, self).__init__()
224
+
225
+ @torch.jit.script_method
226
+ def forward(self, x):
227
+ # type: (Tensor) -> List[Tuple[Tensor, Tensor]]
228
+
229
+ # This annotates the list to be a List[Tuple[Tensor, Tensor]]
230
+ returns = torch.jit.annotate(List[Tuple[Tensor, Tensor]], [])
231
+ for i in range(10):
232
+ returns.append((x, x))
233
+
234
+ return returns
235
+
236
+
206
237
Expressions
207
238
~~~~~~~~~~~
208
239
0 commit comments