Skip to content

Commit 2f32fbd

Browse files
author
zysite
committed
Apply the dropout strategy of the paper
1 parent 053c5cc commit 2f32fbd

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

parser/modules/dropout.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ def __init__(self, p=0.5, batch_first=True):
1313
self.batch_first = batch_first
1414

1515
def extra_repr(self):
16-
info = f"p={self.p}"
16+
s = f"p={self.p}"
1717
if self.batch_first:
18-
info += f", batch_first={self.batch_first}"
18+
s += f", batch_first={self.batch_first}"
1919

20-
return info
20+
return s
2121

2222
def forward(self, x):
2323
if self.training:
@@ -31,8 +31,8 @@ def forward(self, x):
3131

3232
@staticmethod
3333
def get_mask(x, p):
34-
mask = x.new_full(x.shape, 1 - p)
35-
mask = torch.bernoulli(mask) / (1 - p)
34+
mask = x.new_empty(x.shape).bernoulli_(1 - p)
35+
mask = mask / (1 - p)
3636

3737
return mask
3838

@@ -47,14 +47,14 @@ def __init__(self, p=0.5):
4747
def extra_repr(self):
4848
return f"p={self.p}"
4949

50-
def forward(self, x, y, eps=1e-12):
50+
def forward(self, *items):
5151
if self.training:
52-
x_mask = torch.bernoulli(x.new_full(x.shape[:2], 1 - self.p))
53-
y_mask = torch.bernoulli(y.new_full(y.shape[:2], 1 - self.p))
54-
scale = 3.0 / (2.0 * x_mask + y_mask + eps)
55-
x_mask *= scale
56-
y_mask *= scale
57-
x *= x_mask.unsqueeze(dim=-1)
58-
y *= y_mask.unsqueeze(dim=-1)
59-
60-
return x, y
52+
masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p)
53+
for x in items]
54+
total = sum(masks)
55+
scale = len(items) / total.max(torch.ones_like(total))
56+
masks = [mask * scale for mask in masks]
57+
items = [item * mask.unsqueeze(dim=-1)
58+
for item, mask in zip(items, masks)]
59+
60+
return items

0 commit comments

Comments
 (0)