Skip to content

Commit 91f3dfb

Browse files
[Adafactor] Fix adafactor (huggingface#14713)
* correct changes * add comment
1 parent 86dd23b commit 91f3dfb

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/transformers/optimization.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,9 +503,11 @@ def _rms(tensor):
503503

504504
@staticmethod
505505
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
506-
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_()
507-
c_factor = exp_avg_sq_col.rsqrt()
508-
return torch.mm(r_factor.unsqueeze(-1), c_factor.unsqueeze(0))
506+
# copy from fairseq's adafactor implementation:
507+
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
508+
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
509+
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
510+
return torch.mul(r_factor, c_factor)
509511

510512
def step(self, closure=None):
511513
"""

0 commit comments

Comments
 (0)