4
4
5
5
import torch
6
6
import torch .nn as nn
7
+ from torch .nn .modules .rnn import apply_permutation
7
8
from torch .nn .utils .rnn import PackedSequence
8
9
9
10
@@ -19,7 +20,7 @@ def __init__(self, input_size, hidden_size, num_layers=1, dropout=0):
19
20
20
21
self .f_cells = nn .ModuleList ()
21
22
self .b_cells = nn .ModuleList ()
22
- for layer in range (self .num_layers ):
23
+ for _ in range (self .num_layers ):
23
24
self .f_cells .append (nn .LSTMCell (input_size = input_size ,
24
25
hidden_size = hidden_size ))
25
26
self .b_cells .append (nn .LSTMCell (input_size = input_size ,
@@ -28,67 +29,99 @@ def __init__(self, input_size, hidden_size, num_layers=1, dropout=0):
28
29
29
30
self .reset_parameters ()
30
31
32
+ def __repr__ (self ):
33
+ s = self .__class__ .__name__ + '('
34
+ s += f"{ self .input_size } , { self .hidden_size } "
35
+ if self .num_layers > 1 :
36
+ s += f", num_layers={ self .num_layers } "
37
+ if self .dropout > 0 :
38
+ s += f", dropout={ self .dropout } "
39
+ s += ')'
40
+
41
+ return s
42
+
31
43
def reset_parameters (self ):
32
- for i in self .parameters ():
44
+ for param in self .parameters ():
33
45
# apply orthogonal_ to weight
34
- if len (i .shape ) > 1 :
35
- nn .init .orthogonal_ (i )
46
+ if len (param .shape ) > 1 :
47
+ nn .init .orthogonal_ (param )
36
48
# apply zeros_ to bias
37
49
else :
38
- nn .init .zeros_ (i )
50
+ nn .init .zeros_ (param )
51
+
52
+ def permute_hidden (self , hx , permutation ):
53
+ if permutation is None :
54
+ return hx
55
+ h = apply_permutation (hx [0 ], permutation )
56
+ c = apply_permutation (hx [1 ], permutation )
57
+
58
+ return h , c
39
59
40
60
def layer_forward (self , x , hx , cell , batch_sizes , reverse = False ):
41
- h , c = hx
42
- init_h , init_c = h , c
43
- output , seq_len = [], len (x )
44
- steps = reversed (range (seq_len )) if reverse else range (seq_len )
61
+ hx_0 = hx_i = hx
62
+ hx_n , output = [], []
63
+ steps = reversed (range (len (x ))) if reverse else range (len (x ))
45
64
if self .training :
46
- hid_mask = SharedDropout .get_mask (h , self .dropout )
65
+ hid_mask = SharedDropout .get_mask (hx_0 [ 0 ] , self .dropout )
47
66
48
67
for t in steps :
49
- last_batch_size , batch_size = len (h ), batch_sizes [t ]
68
+ last_batch_size , batch_size = len (hx_i [ 0 ] ), batch_sizes [t ]
50
69
if last_batch_size < batch_size :
51
- h = torch .cat ((h , init_h [last_batch_size :batch_size ]))
52
- c = torch . cat (( c , init_c [ last_batch_size : batch_size ]))
70
+ hx_i = [ torch .cat ((h , ih [last_batch_size :batch_size ]))
71
+ for h , ih in zip ( hx_i , hx_0 )]
53
72
else :
54
- h = h [: batch_size ]
55
- c = c [ :batch_size ]
56
- h , c = cell (input = x [t ], hx = ( h , c ))
57
- output .append (h )
73
+ hx_n . append ([ h [ batch_size :] for h in hx_i ])
74
+ hx_i = [ h [ :batch_size ] for h in hx_i ]
75
+ hx_i = [ h for h in cell (x [t ], hx_i )]
76
+ output .append (hx_i [ 0 ] )
58
77
if self .training :
59
- h = h * hid_mask [:batch_size ]
78
+ hx_i [ 0 ] = hx_i [ 0 ] * hid_mask [:batch_size ]
60
79
if reverse :
80
+ hx_n = hx_i
61
81
output .reverse ()
82
+ else :
83
+ hx_n .append (hx_i )
84
+ hx_n .reverse ()
85
+ hx_n = [torch .cat (h ) for h in zip (* hx_n )]
62
86
output = torch .cat (output )
63
87
64
- return output
88
+ return output , hx_n
65
89
66
- def forward (self , x , hx = None ):
67
- x , batch_sizes = x
90
+ def forward (self , sequence , hx = None ):
91
+ x , batch_sizes = sequence . data , sequence . batch_sizes . tolist ()
68
92
batch_size = batch_sizes [0 ]
93
+ h_n , c_n = [], []
69
94
70
95
if hx is None :
71
- init = x .new_zeros (batch_size , self .hidden_size )
72
- hx = (init , init )
73
-
74
- for layer in range (self .num_layers ):
96
+ ih = x .new_zeros (self .num_layers * 2 , batch_size , self .hidden_size )
97
+ h , c = ih , ih
98
+ else :
99
+ h , c = self .permute_hidden (hx , sequence .sorted_indices )
100
+ h = h .view (self .num_layers , 2 , batch_size , self .hidden_size )
101
+ c = c .view (self .num_layers , 2 , batch_size , self .hidden_size )
102
+
103
+ for i in range (self .num_layers ):
104
+ x = torch .split (x , batch_sizes )
75
105
if self .training :
76
- mask = SharedDropout .get_mask (x [:batch_size ], self .dropout )
77
- mask = torch .cat ([mask [:batch_size ]
78
- for batch_size in batch_sizes ])
79
- x *= mask
80
- x = torch .split (x , batch_sizes .tolist ())
81
- f_output = self .layer_forward (x = x ,
82
- hx = hx ,
83
- cell = self .f_cells [layer ],
84
- batch_sizes = batch_sizes ,
85
- reverse = False )
86
- b_output = self .layer_forward (x = x ,
87
- hx = hx ,
88
- cell = self .b_cells [layer ],
89
- batch_sizes = batch_sizes ,
90
- reverse = True )
91
- x = torch .cat ([f_output , b_output ], - 1 )
92
- x = PackedSequence (x , batch_sizes )
93
-
94
- return x
106
+ mask = SharedDropout .get_mask (x [0 ], self .dropout )
107
+ x = [i * mask [:len (i )] for i in x ]
108
+ x_f , (h_f , c_f ) = self .layer_forward (x = x ,
109
+ hx = (h [i , 0 ], c [i , 0 ]),
110
+ cell = self .f_cells [i ],
111
+ batch_sizes = batch_sizes )
112
+ x_b , (h_b , c_b ) = self .layer_forward (x = x ,
113
+ hx = (h [i , 1 ], c [i , 1 ]),
114
+ cell = self .b_cells [i ],
115
+ batch_sizes = batch_sizes ,
116
+ reverse = True )
117
+ x = torch .cat ((x_f , x_b ), - 1 )
118
+ h_n .append (torch .stack ((h_f , h_b )))
119
+ c_n .append (torch .stack ((c_f , c_b )))
120
+ x = PackedSequence (x ,
121
+ sequence .batch_sizes ,
122
+ sequence .sorted_indices ,
123
+ sequence .unsorted_indices )
124
+ hx = torch .cat (h_n , 0 ), torch .cat (c_n , 0 )
125
+ hx = self .permute_hidden (hx , sequence .unsorted_indices )
126
+
127
+ return x , hx
0 commit comments