Skip to content

remove unnecessary sync point in AveragedModel update #158017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gl3lan
Copy link

@gl3lan gl3lan commented Jul 10, 2025

Summary:
The test bool(self.n_averaged == 0) is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each update_parameter call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709

@gl3lan gl3lan requested review from albanD and janeyx99 as code owners July 10, 2025 09:30
Copy link

pytorch-bot bot commented Jul 10, 2025

This appears to be a diff that was exported from phabricator, but the PR author does not have sufficient permissions to run CI. @gl3lan, please do step 2 of internal wiki to get write access so you do not need to get CI approvals in the future. If you think this is a mistake, please contact the Pytorch Dev Infra team.

Copy link

linux-foundation-easycla bot commented Jul 10, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: gl3lan / name: Gaël Le Lan (091aeb5)

Copy link

pytorch-bot bot commented Jul 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158017

Note: Links to docs will display an error until the docs builds have been completed.

❌ 10 New Failures

As of commit 091aeb5 with merge base db78943 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78074709

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78074709

@albanD albanD removed their request for review July 10, 2025 13:46
@gl3lan gl3lan force-pushed the export-D78074709 branch from 6463fb5 to 5c3ca33 Compare July 10, 2025 16:58
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78074709

@gl3lan gl3lan force-pushed the export-D78074709 branch from 5c3ca33 to f61c1f5 Compare July 14, 2025 08:36
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78074709

gl3lan added a commit to gl3lan/pytorch that referenced this pull request Jul 14, 2025
Summary:

The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based boolean variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709
@gl3lan gl3lan marked this pull request as draft July 14, 2025 08:38
@gl3lan gl3lan force-pushed the export-D78074709 branch from f61c1f5 to ec192fb Compare July 22, 2025 08:06
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78074709

gl3lan added a commit to gl3lan/pytorch that referenced this pull request Jul 22, 2025
Summary:

The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based boolean variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709
@gl3lan gl3lan marked this pull request as ready for review July 22, 2025 09:01
@gl3lan gl3lan force-pushed the export-D78074709 branch from ec192fb to 6a831d2 Compare August 3, 2025 20:58
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78074709

gl3lan added a commit to gl3lan/pytorch that referenced this pull request Aug 3, 2025
Summary:

The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based boolean variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709
@gl3lan gl3lan force-pushed the export-D78074709 branch from 6a831d2 to ce429ea Compare August 3, 2025 22:02
gl3lan added a commit to gl3lan/pytorch that referenced this pull request Aug 3, 2025
Summary:

The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based boolean variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78074709

@gl3lan gl3lan force-pushed the export-D78074709 branch from ce429ea to 49f50ec Compare August 4, 2025 18:41
gl3lan added a commit to gl3lan/pytorch that referenced this pull request Aug 4, 2025
Summary:

The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based boolean variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78074709

@gl3lan gl3lan force-pushed the export-D78074709 branch from 49f50ec to e63c89c Compare August 4, 2025 20:03
gl3lan added a commit to gl3lan/pytorch that referenced this pull request Aug 4, 2025
Summary:

The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based boolean variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78074709

@gl3lan gl3lan marked this pull request as draft August 4, 2025 20:10
@gl3lan gl3lan force-pushed the export-D78074709 branch from e63c89c to 30603a7 Compare August 5, 2025 19:12
gl3lan added a commit to gl3lan/pytorch that referenced this pull request Aug 5, 2025
Summary:

The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based boolean variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78074709

gl3lan added a commit to gl3lan/pytorch that referenced this pull request Aug 5, 2025
Summary:
Pull Request resolved: pytorch#158017

The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based boolean variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709
@gl3lan gl3lan force-pushed the export-D78074709 branch from 30603a7 to 97bb125 Compare August 5, 2025 19:16
Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh why is n_averaged even a tensor. It can just be a python number and this sync would go away, right?

@gl3lan gl3lan force-pushed the export-D78074709 branch from 97bb125 to 64a7a4f Compare August 7, 2025 06:24
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78074709

gl3lan added a commit to gl3lan/pytorch that referenced this pull request Aug 7, 2025
Summary:

The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based boolean variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709
@gl3lan
Copy link
Author

gl3lan commented Aug 7, 2025

Huh why is n_averaged even a tensor. It can just be a python number and this sync would go away, right?

Right, but we need it to be saved/reloaded upon job resuming, or else the EMA behavior will change after resuming.

@gl3lan gl3lan marked this pull request as ready for review August 7, 2025 14:57
@janeyx99
Copy link
Contributor

janeyx99 commented Aug 7, 2025

I've just learned from @mikaylagawarecki that there is a get_extra_state and set_extra_state on modules that we can use to just store a Python number for n_averaged: https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_extra_state

This would make this code simpler and more understandable with the perf win if we had n_averaged as a python number instead of buffer

Summary:

The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update.
This test is only meant to know whether the AveragedModel copy has been initialized or not.
This diff introduces a CPU-based boolean variable for that purpose.
When loading from checkpoint we also make sure the parameter is refreshed.

After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction).

Test Plan:
contbuild & OSS CI
Test plan from GitHub:
CI

Rollback Plan:

Differential Revision: D78074709
@gl3lan gl3lan force-pushed the export-D78074709 branch from 64a7a4f to 091aeb5 Compare August 9, 2025 20:00
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78074709

@gl3lan
Copy link
Author

gl3lan commented Aug 11, 2025

I've just learned from @mikaylagawarecki that there is a get_extra_state and set_extra_state on modules that we can use to just store a Python number for n_averaged: https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_extra_state

This would make this code simpler and more understandable with the perf win if we had n_averaged as a python number instead of buffer

@janeyx99 it looks cleaner now but it breaks any pre-existing checkpoint loading because of the missing parameter. Any suggestion? Override load_state_dict, search the tensor value and pass it to the new parameter?

@janeyx99
Copy link
Contributor

@gl3lan yea we should be able to register a hook https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_load_state_dict_pre_hook so that if the state dict has a Tensor n_averaged, we just convert it to the number. How does that sound to you?

I'm also okay with the overriding load_state_dict option, whichever is simpler. Can you also add a test case to ensure we don't break existing users?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants