-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[muon] Introduce Muon optimizer to PyTorch #160213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160213
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 654f754 with merge base 8d3d1c8 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wooo the approach is much simpler indeed now--thank you for the speedy turnaround on the PR. The one main API question I have is how we handle the NS config (whether it should be in the constructor) which I've commented below. Everything else looks super solid.
I know you've put in amazing work for the benchmarks and correctness compared to the original Muon, and I trust that you have verified this PR is still correct and appropriately fast locally. I will look out for the separate PR with those scripts!
test/test_optim.py
Outdated
params = [weight, bias] | ||
if optim_cls.__name__ == "Muon": | ||
params = [weight] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit to not reassign
params = [weight, bias] | |
if optim_cls.__name__ == "Muon": | |
params = [weight] | |
params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight] |
test/test_optim.py
Outdated
model = torch.nn.Sequential( | ||
torch.nn.Linear(10, 4, bias=False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can just be one Linear then, right? Or maybe it'd be more indicative to add another Linear in there?
Can you add a comment for why we branch here?
test/test_optim.py
Outdated
@@ -1577,14 +1629,26 @@ def test_can_load_from_to_named_state_dict( | |||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( | |||
device, dtype, optim_info, skip=("differentiable",) | |||
) | |||
|
|||
def _get_model_and_input(device, dtype, optim_cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's only have one version of this helper, it looks the same as above
@@ -2219,7 +2285,7 @@ def test_defaults_changed_to_foreach(self, device, dtype, optim_info): | |||
def test_non_empty_state(self, device, dtype, optim_info): | |||
# There are internal tests that check that the state is not empty | |||
optim_cls = optim_info.optim_cls | |||
model = torch.nn.Linear(5, 5) | |||
model = torch.nn.Linear(5, 5, bias=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment that we add False here to be generically run with Muon
@@ -1969,7 +2035,7 @@ def pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): | |||
nonlocal data | |||
data += 2 | |||
|
|||
params = [torch.tensor([1, 1], device=device, dtype=dtype)] | |||
params = [torch.tensor([[1, 1]], device=device, dtype=dtype)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come these hook changes are necessary?
nesterov: bool = True, | ||
*, | ||
msign_fn: MsignFn = zeropower_via_newtonschulz, | ||
msign_fn_config: BaseMsignFnConfig = NewtonSchulzConfig(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the pro of having these config live in the constructor as a struct vs separate values? Is this because these values are only used if the msign_fb is zeropower_via_newtonschulz? If so, should this config not live in the Muon constructor at all but be customizable by the user input msign_fn? What are your thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also wonder if this config can be made in regular dict, accepted in constructor as Muon(..., msign_fn_config = {'eps' : 1e-5})
, and then just passed as self.msign_fn(..., **self.msign_fn_config)
- like so, it could be easier saved into a state_dict()
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it’s better to encapsulate the configs in a dedicated class, so the function signature stays clean and manageable. Just a preference carried over from my C++ days :)
I see Vadim's point but not sure if it's feasible (or necessity) to store the callable to state_dict in the first place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option is to have config set as simple args to the function, and then have the user override them via calling functools.partial
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @albanD regarding API design for best practices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho that's interesting.
I do agree that this doesn't match how we do APIs in PyTorch in general.
For value config, I would expect they're all passed in as an argument each (see other optimizers).
If you need to override some specific methods and behavior, you can either have a set of pre-defined implementations that a flag toggles between or you can subclass the optimizer to override the particular method you care about.
Also I guess I'm missing some context on why we want to do it this way if there is only one option for each right now?
torch/optim/_muon.py
Outdated
buf = state.get("momentum_buffer") | ||
if buf is None: | ||
buf = torch.zeros_like(p.grad, memory_format=torch.preserve_format) | ||
state["momentum_buffer"] = buf | ||
muon_momentum_bufs.append(buf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
buf = state.get("momentum_buffer") | |
if buf is None: | |
buf = torch.zeros_like(p.grad, memory_format=torch.preserve_format) | |
state["momentum_buffer"] = buf | |
muon_momentum_bufs.append(buf) | |
if state.get("momentum_buffer") is None: | |
state.get("momentum_buffer") = torch.zeros_like(p.grad, memory_format=torch.preserve_format) | |
muon_momentum_bufs.append(state.get("momentum_buffer")) |
no need for buf intermediate, right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just trying to reduce the number of times we mention "momentum_buffer". also, state.get("momentum_buffer") =
seems not right?
nonetheless, updated the code to
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(
p.grad, memory_format=torch.preserve_format
)
muon_momentum_bufs.append(state["momentum_buffer"])
5083654
to
0f0df7b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current CI failures are cuz you (probably accidentally) committed the third_party differences--pls remove those!
0f0df7b
to
2e6bf8c
Compare
uh, they must have come from the rebase. removed |
2e6bf8c
to
1c82fc8
Compare
1c82fc8
to
654f754
Compare
buf = muon_momentum_bufs[i] | ||
buf.mul_(momentum).add_(grad) | ||
if nesterov: | ||
grad = grad.add(buf, alpha=momentum) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could use lerp_ probably and save some memory
|
||
|
||
@dataclass | ||
class BaseMsignFnConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed right?
|
||
__all__ = ["Muon"] | ||
|
||
# Constants from Keller Jordan's Muon post: https://kellerjordan.github.io/posts/muon/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: link to github + specific line + specific commit to make sure this lint will stay up
assert steps < 100, ( | ||
"Number of steps must be less than 100 for computational efficiency" | ||
) | ||
assert len(grad.shape) == 2, "Input tensor gradient must be a 2D matrix" | ||
assert len(coefficients) == 3, "Coefficients must be a tuple of exactly 3 values" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No plain asserts, please raise appropriate Runtime/Value/Type errors
assert len(grad.shape) == 2, "Input tensor gradient must be a 2D matrix" | ||
assert len(coefficients) == 3, "Coefficients must be a tuple of exactly 3 values" | ||
a, b, c = coefficients[0], coefficients[1], coefficients[2] | ||
X = grad.bfloat16() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be very surprised if this is the way to go unless you have a hard assert that the param dtype is fixed?
__all__ = ["Muon"] | ||
|
||
# Constants from Keller Jordan's Muon post: https://kellerjordan.github.io/posts/muon/ | ||
EPS = 1e-7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All epsilons should be dtype dependent to avoid too large noise or flooring to 0.
A = X @ X.T | ||
B = b * A + c * A @ A | ||
X = a * X + B @ X |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would be good to have descriptive names for these variables.
Also there is quite a bit of extra memory usage due to the extra variables, but I guess that can be handled in a follow up.
This optimizer performs momentum SGD followed by an optional orthogonalization | ||
step computed via a user provided callable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should follow the same style as we have in other optimizers like https://docs.pytorch.org/docs/stable/generated/torch.optim.Adam.html#adam to describe un-ambiguoustly the math being performed.
nesterov: bool = True, | ||
*, | ||
msign_fn: MsignFn = zeropower_via_newtonschulz, | ||
msign_fn_config: BaseMsignFnConfig = NewtonSchulzConfig(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho that's interesting.
I do agree that this doesn't match how we do APIs in PyTorch in general.
For value config, I would expect they're all passed in as an argument each (see other optimizers).
If you need to override some specific methods and behavior, you can either have a set of pre-defined implementations that a flag toggles between or you can subclass the optimizer to override the particular method you care about.
Also I guess I'm missing some context on why we want to do it this way if there is only one option for each right now?
lr: float = 1e-3, | ||
weight_decay: float = 0.1, | ||
momentum: float = 0.95, | ||
nesterov: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chuanhaozhuge where was this doc added?
has_complex: bool, | ||
) -> None: | ||
lr = _to_scalar(lr) | ||
assert has_complex is False, "Complex parameters are not supported" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove plain asserts here as well
Given that we had agreed to land the simplest single device Muon into torch/optim as our first step, it'd be clearest to land what people accept as the original implementation as defined in Keller Jordan's blog ( https://kellerjordan.github.io/posts/muon/). As this implementation chooses newton schulz as the algo, we should take the same stance. This means we can simplify the constructor API greatly (I will get to extensibility right after):
I'm realizing that you have interest in extending the algorithm to be distributed (vs another orthogonalization algo for single-device). We are strict on keeping |
A single-device version of Muon. Algorithm refers to the Moonshot implementation.
This PR is an update to #159465 to include suggestions and recommendations to UX and API. In particular, PyTorch team prefers to handle parameter filtering at a higher level, with the Muon optimizer performing only the msign computation for orthogonalization on all parameters it receives. Users are responsible for grouping parameters for different optimizers as needed. An example usage is shown below, and a more detailed example will be added to the PyTorch examples directory.
Usage
Additional usage
Users are also able to pass in self-defined
msign
function for orthogonalization, and learning rate adjustment function. Interface defined below:By default, we use 5-step Newton-Schulz, with coefficients proposed by Keller. We use LR adjustment proposed by Moonshot, which grafts learning rate from AdamW.
Testing
test/test_optim.py
. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.openwebtext-100k
dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.Performance
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:
Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is 2.5% slower than AdamW.
AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms

Muon: Optimizer.step() takes ~54 ms, step time ~960 ms

Next Steps
MuP
cc: @toothacher17, @vinaysrao, @jcui2, @haocizhang