File tree Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -503,9 +503,11 @@ def _rms(tensor):
503
503
504
504
@staticmethod
505
505
def _approx_sq_grad (exp_avg_sq_row , exp_avg_sq_col ):
506
- r_factor = (exp_avg_sq_row / exp_avg_sq_row .mean (dim = - 1 , keepdim = True )).rsqrt_ ()
507
- c_factor = exp_avg_sq_col .rsqrt ()
508
- return torch .mm (r_factor .unsqueeze (- 1 ), c_factor .unsqueeze (0 ))
506
+ # copy from fairseq's adafactor implementation:
507
+ # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
508
+ r_factor = (exp_avg_sq_row / exp_avg_sq_row .mean (dim = - 1 , keepdim = True )).rsqrt_ ().unsqueeze (- 1 )
509
+ c_factor = exp_avg_sq_col .unsqueeze (- 2 ).rsqrt ()
510
+ return torch .mul (r_factor , c_factor )
509
511
510
512
def step (self , closure = None ):
511
513
"""
You can’t perform that action at this time.
0 commit comments