Skip to content

Commit f61c1f5

Browse files
gl3lanfacebook-github-bot
authored andcommitted
remove unnecessary sync point in AveragedModel update (#158017)
Summary: The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update. This test is only meant to know whether the AveragedModel copy has been initialized or not. This diff introduces a CPU-based boolean variable for that purpose. When loading from checkpoint we also make sure the parameter is refreshed. After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction). Test Plan: contbuild & OSS CI Test plan from GitHub: CI Rollback Plan: Differential Revision: D78074709
1 parent 0f21fa8 commit f61c1f5

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

torch/optim/swa_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import itertools
55
import math
66
import warnings
7-
from collections.abc import Iterable
7+
from collections.abc import Iterable, Mapping
88
from copy import deepcopy
99
from typing import Any, Callable, cast, Literal, Optional, Union
1010

@@ -237,6 +237,7 @@ def __init__(
237237
self.register_buffer(
238238
"n_averaged", torch.tensor(0, dtype=torch.long, device=device)
239239
)
240+
self.is_copy_initialized = False
240241
self.avg_fn = avg_fn
241242
self.multi_avg_fn = multi_avg_fn
242243
self.use_buffers = use_buffers
@@ -259,15 +260,14 @@ def update_parameters(self, model: Module):
259260
)
260261
self_param_detached: list[Optional[Tensor]] = []
261262
model_param_detached: list[Optional[Tensor]] = []
262-
copy_param = bool(self.n_averaged == 0)
263263
for p_averaged, p_model in zip(self_param, model_param):
264264
p_model_ = p_model.detach().to(p_averaged.device)
265265
self_param_detached.append(p_averaged.detach())
266266
model_param_detached.append(p_model_)
267-
if copy_param:
267+
if not self.is_copy_initialized:
268268
p_averaged.detach().copy_(p_model_)
269269

270-
if self.n_averaged > 0:
270+
if self.is_copy_initialized:
271271
if self.multi_avg_fn is not None or self.avg_fn is None:
272272
grouped_tensors = _group_tensors_by_device_and_dtype(
273273
[self_param_detached, model_param_detached]
@@ -310,6 +310,13 @@ def update_parameters(self, model: Module):
310310
for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
311311
b_swa.detach().copy_(b_model.detach().to(b_swa.device))
312312
self.n_averaged += 1
313+
self.is_copy_initialized = True
314+
315+
def load_state_dict(
316+
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
317+
):
318+
super().load_state_dict(state_dict, strict, assign)
319+
self.is_copy_initialized = bool(self.n_averaged > 0)
313320

314321

315322
@torch.no_grad()

0 commit comments

Comments
 (0)