Skip to content

Commit d4168ff

Browse files
committed
Add type hints
1 parent 46cf49b commit d4168ff

21 files changed

+532
-219
lines changed

supar/modules/affine.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# -*- coding: utf-8 -*-
22

3+
from __future__ import annotations
4+
35
import torch
46
import torch.nn as nn
57

@@ -26,7 +28,14 @@ class Biaffine(nn.Module):
2628
If ``True``, adds a bias term for tensor :math:`y`. Default: ``True``.
2729
"""
2830

29-
def __init__(self, n_in, n_out=1, scale=0, bias_x=True, bias_y=True):
31+
def __init__(
32+
self,
33+
n_in: int,
34+
n_out: int = 1,
35+
scale: int = 0,
36+
bias_x: bool = True,
37+
bias_y: bool = True
38+
) -> Biaffine:
3039
super().__init__()
3140

3241
self.n_in = n_in
@@ -54,7 +63,7 @@ def __repr__(self):
5463
def reset_parameters(self):
5564
nn.init.zeros_(self.weight)
5665

57-
def forward(self, x, y):
66+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
5867
r"""
5968
Args:
6069
x (torch.Tensor): ``[batch_size, seq_len, n_in]``.
@@ -102,7 +111,15 @@ class Triaffine(nn.Module):
102111
If ``True``, represents the weight as the product of 3 independent matrices. Default: ``False``.
103112
"""
104113

105-
def __init__(self, n_in, n_out=1, scale=0, bias_x=False, bias_y=False, decompose=False):
114+
def __init__(
115+
self,
116+
n_in: int,
117+
n_out: int = 1,
118+
scale: int = 0,
119+
bias_x: bool = False,
120+
bias_y: bool = False,
121+
decompose: bool = False
122+
) -> Triaffine:
106123
super().__init__()
107124

108125
self.n_in = n_in
@@ -143,7 +160,7 @@ def reset_parameters(self):
143160
else:
144161
nn.init.zeros_(self.weight)
145162

146-
def forward(self, x, y, z):
163+
def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
147164
r"""
148165
Args:
149166
x (torch.Tensor): ``[batch_size, seq_len, n_in]``.

supar/modules/dropout.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# -*- coding: utf-8 -*-
22

3+
from __future__ import annotations
4+
5+
from typing import List
6+
37
import torch
48
import torch.nn as nn
59

@@ -27,7 +31,7 @@ class SharedDropout(nn.Module):
2731
[2., 0., 2., 0., 2.]]])
2832
"""
2933

30-
def __init__(self, p=0.5, batch_first=True):
34+
def __init__(self, p: float = 0.5, batch_first: bool = True) -> SharedDropout:
3135
super().__init__()
3236

3337
self.p = p
@@ -40,7 +44,7 @@ def __repr__(self):
4044

4145
return f"{self.__class__.__name__}({s})"
4246

43-
def forward(self, x):
47+
def forward(self, x: torch.Tensor) -> torch.Tensor:
4448
r"""
4549
Args:
4650
x (~torch.Tensor):
@@ -59,7 +63,7 @@ def forward(self, x):
5963
return x
6064

6165
@staticmethod
62-
def get_mask(x, p):
66+
def get_mask(x: torch.Tensor, p: float) -> torch.FloatTensor:
6367
return x.new_empty(x.shape).bernoulli_(1 - p) / (1 - p)
6468

6569

@@ -86,15 +90,15 @@ class IndependentDropout(nn.Module):
8690
[0., 0., 0., 0., 0.]]])
8791
"""
8892

89-
def __init__(self, p=0.5):
93+
def __init__(self, p: float = 0.5) -> IndependentDropout:
9094
super().__init__()
9195

9296
self.p = p
9397

9498
def __repr__(self):
9599
return f"{self.__class__.__name__}(p={self.p})"
96100

97-
def forward(self, *items):
101+
def forward(self, *items: List[torch.Tensor]) -> List[torch.Tensor]:
98102
r"""
99103
Args:
100104
items (list[~torch.Tensor]):

supar/modules/lstm.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# -*- coding: utf-8 -*-
22

3+
from __future__ import annotations
4+
5+
from typing import List, Optional, Tuple
6+
37
import torch
48
import torch.nn as nn
59
from supar.modules.dropout import SharedDropout
@@ -28,7 +32,15 @@ class CharLSTM(nn.Module):
2832
The dropout ratio of CharLSTM hidden states. Default: 0.
2933
"""
3034

31-
def __init__(self, n_chars, n_embed, n_hidden, n_out=0, pad_index=0, dropout=0):
35+
def __init__(
36+
self,
37+
n_chars: int,
38+
n_embed: int,
39+
n_hidden: int,
40+
n_out: int = 0,
41+
pad_index: int = 0,
42+
dropout: float = 0
43+
) -> CharLSTM:
3244
super().__init__()
3345

3446
self.n_chars = n_chars
@@ -52,7 +64,7 @@ def __repr__(self):
5264

5365
return f"{self.__class__.__name__}({s})"
5466

55-
def forward(self, x):
67+
def forward(self, x: torch.Tensor) -> torch.Tensor:
5668
r"""
5769
Args:
5870
x (~torch.Tensor): ``[batch_size, seq_len, fix_len]``.
@@ -105,7 +117,14 @@ class VariationalLSTM(nn.Module):
105117
Default: 0.
106118
"""
107119

108-
def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=False, dropout=0):
120+
def __init__(
121+
self,
122+
input_size: int,
123+
hidden_size: int,
124+
num_layers: int = 1,
125+
bidirectional: bool = False,
126+
dropout: float = .0
127+
) -> VariationalLSTM:
109128
super().__init__()
110129

111130
self.input_size = input_size
@@ -146,15 +165,26 @@ def reset_parameters(self):
146165
else:
147166
nn.init.zeros_(param)
148167

149-
def permute_hidden(self, hx, permutation):
168+
def permute_hidden(
169+
self,
170+
hx: Tuple[torch.Tensor, torch.Tensor],
171+
permutation: torch.LongTensor
172+
) -> Tuple[torch.Tensor, torch.Tensor]:
150173
if permutation is None:
151174
return hx
152175
h = apply_permutation(hx[0], permutation)
153176
c = apply_permutation(hx[1], permutation)
154177

155178
return h, c
156179

157-
def layer_forward(self, x, hx, cell, batch_sizes, reverse=False):
180+
def layer_forward(
181+
self,
182+
x: List[torch.Tensor],
183+
hx: Tuple[torch.Tensor, torch.Tensor],
184+
cell: nn.LSTMCell,
185+
batch_sizes: List[int],
186+
reverse: bool = False
187+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
158188
hx_0 = hx_i = hx
159189
hx_n, output = [], []
160190
steps = reversed(range(len(x))) if reverse else range(len(x))
@@ -182,7 +212,11 @@ def layer_forward(self, x, hx, cell, batch_sizes, reverse=False):
182212

183213
return output, hx_n
184214

185-
def forward(self, sequence, hx=None):
215+
def forward(
216+
self,
217+
sequence: PackedSequence,
218+
hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
219+
) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]:
186220
r"""
187221
Args:
188222
sequence (~torch.nn.utils.rnn.PackedSequence):

supar/modules/mlp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# -*- coding: utf-8 -*-
22

3+
from __future__ import annotations
4+
5+
import torch
36
import torch.nn as nn
47
from supar.modules.dropout import SharedDropout
58

@@ -20,7 +23,7 @@ class MLP(nn.Module):
2023
Whether to use activations. Default: True.
2124
"""
2225

23-
def __init__(self, n_in, n_out, dropout=0, activation=True):
26+
def __init__(self, n_in: int, n_out: int, dropout: float = .0, activation: bool = True) -> MLP:
2427
super().__init__()
2528

2629
self.n_in = n_in
@@ -42,7 +45,7 @@ def reset_parameters(self):
4245
nn.init.orthogonal_(self.linear.weight)
4346
nn.init.zeros_(self.linear.bias)
4447

45-
def forward(self, x):
48+
def forward(self, x: torch.Tensor) -> torch.Tensor:
4649
r"""
4750
Args:
4851
x (~torch.Tensor):

supar/modules/pretrained.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# -*- coding: utf-8 -*-
22

3+
from __future__ import annotations
4+
5+
from typing import Tuple
6+
37
import torch
48
import torch.nn as nn
59
from supar.modules.scalar_mix import ScalarMix
@@ -35,7 +39,17 @@ class TransformerEmbedding(nn.Module):
3539
https://github.com/huggingface/transformers
3640
"""
3741

38-
def __init__(self, model, n_layers, n_out=0, stride=256, pooling='mean', pad_index=0, mix_dropout=0, finetune=False):
42+
def __init__(
43+
self,
44+
model: str,
45+
n_layers: int,
46+
n_out: int = 0,
47+
stride: int = 256,
48+
pooling: str = 'mean',
49+
pad_index: int = 0,
50+
mix_dropout: float = .0,
51+
finetune: bool = False
52+
) -> TransformerEmbedding:
3953
super().__init__()
4054

4155
from transformers import AutoConfig, AutoModel, AutoTokenizer
@@ -67,7 +81,7 @@ def __repr__(self):
6781
s += f", finetune={self.finetune}"
6882
return f"{self.__class__.__name__}({s})"
6983

70-
def forward(self, subwords):
84+
def forward(self, subwords: torch.Tensor) -> torch.Tensor:
7185
r"""
7286
Args:
7387
subwords (~torch.Tensor): ``[batch_size, seq_len, fix_len]``.
@@ -142,7 +156,14 @@ class ELMoEmbedding(nn.Module):
142156
'original_5b': 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5', # noqa
143157
}
144158

145-
def __init__(self, model='original_5b', bos_eos=(True, True), n_out=0, dropout=0.5, finetune=False):
159+
def __init__(
160+
self,
161+
model: str = 'original_5b',
162+
bos_eos: Tuple[bool, bool] = (True, True),
163+
n_out: int = 0,
164+
dropout: float = 0.5,
165+
finetune: bool = False
166+
) -> ELMoEmbedding:
146167
super().__init__()
147168

148169
from allennlp.modules import Elmo
@@ -171,10 +192,10 @@ def __repr__(self):
171192
s += f", finetune={self.finetune}"
172193
return f"{self.__class__.__name__}({s})"
173194

174-
def forward(self, chars):
195+
def forward(self, chars: torch.LongTensor) -> torch.Tensor:
175196
r"""
176197
Args:
177-
chars (~torch.Tensor): ``[batch_size, seq_len, fix_len]``.
198+
chars (~torch.LongTensor): ``[batch_size, seq_len, fix_len]``.
178199
179200
Returns:
180201
~torch.Tensor:

supar/modules/scalar_mix.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# -*- coding: utf-8 -*-
22

3+
from __future__ import annotations
4+
5+
from typing import List
6+
37
import torch
48
import torch.nn as nn
59

@@ -20,7 +24,7 @@ class ScalarMix(nn.Module):
2024
Default: 0.
2125
"""
2226

23-
def __init__(self, n_layers, dropout=0):
27+
def __init__(self, n_layers: int, dropout: float = .0) -> ScalarMix:
2428
super().__init__()
2529

2630
self.n_layers = n_layers
@@ -36,7 +40,7 @@ def __repr__(self):
3640

3741
return f"{self.__class__.__name__}({s})"
3842

39-
def forward(self, tensors):
43+
def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor:
4044
r"""
4145
Args:
4246
tensors (list[~torch.Tensor]):

0 commit comments

Comments
 (0)