@@ -13,11 +13,11 @@ def __init__(self, p=0.5, batch_first=True):
13
13
self .batch_first = batch_first
14
14
15
15
def extra_repr (self ):
16
- info = f"p={ self .p } "
16
+ s = f"p={ self .p } "
17
17
if self .batch_first :
18
- info += f", batch_first={ self .batch_first } "
18
+ s += f", batch_first={ self .batch_first } "
19
19
20
- return info
20
+ return s
21
21
22
22
def forward (self , x ):
23
23
if self .training :
@@ -31,8 +31,8 @@ def forward(self, x):
31
31
32
32
@staticmethod
33
33
def get_mask (x , p ):
34
- mask = x .new_full (x .shape , 1 - p )
35
- mask = torch . bernoulli ( mask ) / (1 - p )
34
+ mask = x .new_empty (x .shape ). bernoulli_ ( 1 - p )
35
+ mask = mask / (1 - p )
36
36
37
37
return mask
38
38
@@ -47,14 +47,14 @@ def __init__(self, p=0.5):
47
47
def extra_repr (self ):
48
48
return f"p={ self .p } "
49
49
50
- def forward (self , x , y , eps = 1e-12 ):
50
+ def forward (self , * items ):
51
51
if self .training :
52
- x_mask = torch . bernoulli ( x . new_full (x .shape [:2 ], 1 - self .p ) )
53
- y_mask = torch . bernoulli ( y . new_full ( y . shape [: 2 ], 1 - self . p ))
54
- scale = 3.0 / ( 2.0 * x_mask + y_mask + eps )
55
- x_mask *= scale
56
- y_mask *= scale
57
- x *= x_mask .unsqueeze (dim = - 1 )
58
- y *= y_mask . unsqueeze ( dim = - 1 )
59
-
60
- return x , y
52
+ masks = [ x . new_empty (x .shape [:2 ]). bernoulli_ ( 1 - self .p )
53
+ for x in items ]
54
+ total = sum ( masks )
55
+ scale = len ( items ) / total . max ( torch . ones_like ( total ))
56
+ masks = [ mask * scale for mask in masks ]
57
+ items = [ item * mask .unsqueeze (dim = - 1 )
58
+ for item , mask in zip ( items , masks )]
59
+
60
+ return items
0 commit comments