6
6
7
7
class Biaffine (nn .Module ):
8
8
r"""
9
- Biaffine layer for first-order scoring.
9
+ Biaffine layer for first-order scoring :cite:`dozat-etal-2017-biaffine` .
10
10
11
11
This function has a tensor of weights :math:`W` and bias terms if needed.
12
- The score :math:`s(x, y)` of the vector pair :math:`(x, y)` is computed as :math:`x^T W y`,
13
- in which :math:`x` and :math:`y` can be concatenated with bias terms.
14
-
15
- References:
16
- - Timothy Dozat and Christopher D. Manning. 2017.
17
- `Deep Biaffine Attention for Neural Dependency Parsing`_.
12
+ The score :math:`s(x, y)` of the vector pair :math:`(x, y)` is computed as :math:`x^T W y / d^s`,
13
+ where `d` and `s` are vector dimension and scaling factor respectively.
14
+ :math:`x` and :math:`y` can be concatenated with bias terms.
18
15
19
16
Args:
20
17
n_in (int):
21
18
The size of the input feature.
22
19
n_out (int):
23
20
The number of output channels.
21
+ scale (float):
22
+ Factor to scale the scores. Default: 0.
24
23
bias_x (bool):
25
24
If ``True``, adds a bias term for tensor :math:`x`. Default: ``True``.
26
25
bias_y (bool):
27
26
If ``True``, adds a bias term for tensor :math:`y`. Default: ``True``.
28
-
29
- .. _Deep Biaffine Attention for Neural Dependency Parsing:
30
- https://openreview.net/forum?id=Hk95PK9le
31
27
"""
32
28
33
- def __init__ (self , n_in , n_out = 1 , bias_x = True , bias_y = True ):
29
+ def __init__ (self , n_in , n_out = 1 , scale = 0 , bias_x = True , bias_y = True ):
34
30
super ().__init__ ()
35
31
36
32
self .n_in = n_in
37
33
self .n_out = n_out
34
+ self .scale = scale
38
35
self .bias_x = bias_x
39
36
self .bias_y = bias_y
40
37
self .weight = nn .Parameter (torch .Tensor (n_out , n_in + bias_x , n_in + bias_y ))
41
38
42
39
self .reset_parameters ()
43
40
44
41
def __repr__ (self ):
45
- s = f"n_in={ self .n_in } , n_out={ self .n_out } "
42
+ s = f"n_in={ self .n_in } "
43
+ if self .n_out > 1 :
44
+ s += f", n_out={ self .n_out } "
45
+ if self .scale != 0 :
46
+ s += f", scale={ self .scale } "
46
47
if self .bias_x :
47
48
s += f", bias_x={ self .bias_x } "
48
49
if self .bias_y :
@@ -70,7 +71,7 @@ def forward(self, x, y):
70
71
if self .bias_y :
71
72
y = torch .cat ((y , torch .ones_like (y [..., :1 ])), - 1 )
72
73
# [batch_size, n_out, seq_len, seq_len]
73
- s = torch .einsum ('bxi,oij,byj->boxy' , x , self .weight , y )
74
+ s = torch .einsum ('bxi,oij,byj->boxy' , x , self .weight , y ) / self . n_in ** self . scale
74
75
# remove dim 1 if n_out == 1
75
76
s = s .squeeze (1 )
76
77
@@ -79,44 +80,44 @@ def forward(self, x, y):
79
80
80
81
class Triaffine (nn .Module ):
81
82
r"""
82
- Triaffine layer for second-order scoring.
83
+ Triaffine layer for second-order scoring (:cite:`zhang-etal-2020-efficient`, :cite:`wang-etal-2019-second`) .
83
84
84
85
This function has a tensor of weights :math:`W` and bias terms if needed.
85
- The score :math:`s(x, y, z)` of the vector triple :math:`(x, y, z)` is computed as :math:`x^T z^T W y`.
86
- Usually, :math:`x` and :math:`y` can be concatenated with bias terms.
87
-
88
- References:
89
- - Yu Zhang, Zhenghua Li and Min Zhang. 2020.
90
- `Efficient Second-Order TreeCRF for Neural Dependency Parsing`_.
91
- - Xinyu Wang, Jingxian Huang, and Kewei Tu. 2019.
92
- `Second-Order Semantic Dependency Parsing with End-to-End Neural Networks`_.
86
+ The score :math:`s(x, y, z)` of the vector triple :math:`(x, y, z)` is computed as :math:`x^T z^T W y / d^s`,
87
+ where `d` and `s` are vector dimension and scaling factor respectively.
88
+ :math:`x` and :math:`y` can be concatenated with bias terms.
93
89
94
90
Args:
95
91
n_in (int):
96
92
The size of the input feature.
93
+ n_out (int):
94
+ The number of output channels.
95
+ scale (float):
96
+ Factor to scale the scores. Default: 0.
97
97
bias_x (bool):
98
98
If ``True``, adds a bias term for tensor :math:`x`. Default: ``False``.
99
99
bias_y (bool):
100
100
If ``True``, adds a bias term for tensor :math:`y`. Default: ``False``.
101
-
102
- .. _Efficient Second-Order TreeCRF for Neural Dependency Parsing:
103
- https://www.aclweb.org/anthology/2020.acl-main.302/
104
- .. _Second-Order Semantic Dependency Parsing with End-to-End Neural Networks:
105
- https://www.aclweb.org/anthology/P19-1454/
106
101
"""
107
102
108
- def __init__ (self , n_in , bias_x = False , bias_y = False ):
103
+ def __init__ (self , n_in , n_out = 1 , scale = 0 , bias_x = False , bias_y = False ):
109
104
super ().__init__ ()
110
105
111
106
self .n_in = n_in
107
+ self .n_out = n_out
108
+ self .scale = scale
112
109
self .bias_x = bias_x
113
110
self .bias_y = bias_y
114
- self .weight = nn .Parameter (torch .Tensor (n_in + bias_x , n_in , n_in + bias_y ))
111
+ self .weight = nn .Parameter (torch .Tensor (n_out , n_in + bias_x , n_in , n_in + bias_y ))
115
112
116
113
self .reset_parameters ()
117
114
118
115
def __repr__ (self ):
119
116
s = f"n_in={ self .n_in } "
117
+ if self .n_out > 1 :
118
+ s += f", n_out={ self .n_out } "
119
+ if self .scale != 0 :
120
+ s += f", scale={ self .scale } "
120
121
if self .bias_x :
121
122
s += f", bias_x={ self .bias_x } "
122
123
if self .bias_y :
@@ -136,15 +137,18 @@ def forward(self, x, y, z):
136
137
137
138
Returns:
138
139
~torch.Tensor:
139
- A scoring tensor of shape ``[batch_size, seq_len, seq_len, seq_len]``.
140
+ A scoring tensor of shape ``[batch_size, n_out, seq_len, seq_len, seq_len]``.
141
+ If ``n_out=1``, the dimension for ``n_out`` will be squeezed automatically.
140
142
"""
141
143
142
144
if self .bias_x :
143
145
x = torch .cat ((x , torch .ones_like (x [..., :1 ])), - 1 )
144
146
if self .bias_y :
145
147
y = torch .cat ((y , torch .ones_like (y [..., :1 ])), - 1 )
146
- w = torch .einsum ('bzk,ikj->bzij' , z , self .weight )
147
- # [batch_size, seq_len, seq_len, seq_len]
148
- s = torch .einsum ('bxi,bzij,byj->bzxy' , x , w , y )
148
+ w = torch .einsum ('bzk,oikj->bozij' , z , self .weight )
149
+ # [batch_size, n_out, seq_len, seq_len, seq_len]
150
+ s = torch .einsum ('bxi,bozij,byj->bozxy' , x , w , y ) / self .n_in ** self .scale
151
+ # remove dim 1 if n_out == 1
152
+ s = s .squeeze (1 )
149
153
150
154
return s
0 commit comments