Skip to content

[muon] Introduce Muon optimizer to PyTorch #160213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

[muon] Introduce Muon optimizer to PyTorch #160213

wants to merge 4 commits into from

Conversation

chuanhaozhuge
Copy link
Contributor

A single-device version of Muon. Algorithm refers to the Moonshot implementation.

This PR is an update to #159465 to include suggestions and recommendations to UX and API. In particular, PyTorch team prefers to handle parameter filtering at a higher level, with the Muon optimizer performing only the msign computation for orthogonalization on all parameters it receives. Users are responsible for grouping parameters for different optimizers as needed. An example usage is shown below, and a more detailed example will be added to the PyTorch examples directory.

Usage

    model = MyModelForCausalLM
    # filter out your params manually
    muon_params = [...]
    adamw_params = [...]
    muon = Muon(
        params = muon_params
        lr=lr,
        wd=wd,
    )
    adamw = AdamW(
        params = muon_params
        lr=lr,
        wd=wd,
    )

    # in training loop
    loss = model(input)
    loss.backward()
    muon.step()
    adamw.step()
    muon.zero_grad()
    adamw.zero_grad()

Additional usage
Users are also able to pass in self-defined msign function for orthogonalization, and learning rate adjustment function. Interface defined below:

AdjustLrFn: TypeAlias = Callable[[float, torch.Size], float]
MsignFn: TypeAlias = Callable[[Tensor, BaseMsignFnConfig], Tensor]

By default, we use 5-step Newton-Schulz, with coefficients proposed by Keller. We use LR adjustment proposed by Moonshot, which grafts learning rate from AdamW.

Testing

  1. Unit tests: the newly introduced Muon is covered in test/test_optim.py. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.
  2. End-to-end test: we added a training script that pre-trains a QWEN-like model on openwebtext-100k dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.
Screenshot 2025-07-29 at 1 04 12 AM

Performance
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:

  • adamw_ddp finishes in 13.12 min
  • pytorch_muon_ddp finishes in 13.45 min

Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is 2.5% slower than AdamW.

AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms
Screenshot 2025-07-29 at 1 56 14 AM

Muon: Optimizer.step() takes ~54 ms, step time ~960 ms
Screenshot 2025-07-29 at 2 02 20 AM

Next Steps

  1. Add MuP
  2. Open-source optimized triton kernel for symmetric matmul. A preliminary benchmark found 1.23x - 1.48x speedup on small - large (n = 256 -> 16384) matrices.
  3. Open-source unsharded Muon co-designed with FSDP2.

cc: @toothacher17, @vinaysrao, @jcui2, @haocizhang

Copy link

pytorch-bot bot commented Aug 8, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160213

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 654f754 with merge base 8d3d1c8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wooo the approach is much simpler indeed now--thank you for the speedy turnaround on the PR. The one main API question I have is how we handle the NS config (whether it should be in the constructor) which I've commented below. Everything else looks super solid.

I know you've put in amazing work for the benchmarks and correctness compared to the original Muon, and I trust that you have verified this PR is still correct and appropriately fast locally. I will look out for the separate PR with those scripts!

Comment on lines 356 to 358
params = [weight, bias]
if optim_cls.__name__ == "Muon":
params = [weight]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit to not reassign

Suggested change
params = [weight, bias]
if optim_cls.__name__ == "Muon":
params = [weight]
params = [weight, bias] if optim_cls.__name__ != "Muon" else [weight]

Comment on lines 1568 to 1569
model = torch.nn.Sequential(
torch.nn.Linear(10, 4, bias=False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be one Linear then, right? Or maybe it'd be more indicative to add another Linear in there?

Can you add a comment for why we branch here?

@@ -1577,14 +1629,26 @@ def test_can_load_from_to_named_state_dict(
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",)
)

def _get_model_and_input(device, dtype, optim_cls):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's only have one version of this helper, it looks the same as above

@@ -2219,7 +2285,7 @@ def test_defaults_changed_to_foreach(self, device, dtype, optim_info):
def test_non_empty_state(self, device, dtype, optim_info):
# There are internal tests that check that the state is not empty
optim_cls = optim_info.optim_cls
model = torch.nn.Linear(5, 5)
model = torch.nn.Linear(5, 5, bias=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment that we add False here to be generically run with Muon

@@ -1969,7 +2035,7 @@ def pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]):
nonlocal data
data += 2

params = [torch.tensor([1, 1], device=device, dtype=dtype)]
params = [torch.tensor([[1, 1]], device=device, dtype=dtype)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come these hook changes are necessary?

nesterov: bool = True,
*,
msign_fn: MsignFn = zeropower_via_newtonschulz,
msign_fn_config: BaseMsignFnConfig = NewtonSchulzConfig(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the pro of having these config live in the constructor as a struct vs separate values? Is this because these values are only used if the msign_fb is zeropower_via_newtonschulz? If so, should this config not live in the Muon constructor at all but be customizable by the user input msign_fn? What are your thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder if this config can be made in regular dict, accepted in constructor as Muon(..., msign_fn_config = {'eps' : 1e-5}), and then just passed as self.msign_fn(..., **self.msign_fn_config) - like so, it could be easier saved into a state_dict()...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it’s better to encapsulate the configs in a dedicated class, so the function signature stays clean and manageable. Just a preference carried over from my C++ days :)

I see Vadim's point but not sure if it's feasible (or necessity) to store the callable to state_dict in the first place.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is to have config set as simple args to the function, and then have the user override them via calling functools.partial

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @albanD regarding API design for best practices

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho that's interesting.
I do agree that this doesn't match how we do APIs in PyTorch in general.
For value config, I would expect they're all passed in as an argument each (see other optimizers).
If you need to override some specific methods and behavior, you can either have a set of pre-defined implementations that a flag toggles between or you can subclass the optimizer to override the particular method you care about.

Also I guess I'm missing some context on why we want to do it this way if there is only one option for each right now?

Comment on lines 157 to 161
buf = state.get("momentum_buffer")
if buf is None:
buf = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
state["momentum_buffer"] = buf
muon_momentum_bufs.append(buf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
buf = state.get("momentum_buffer")
if buf is None:
buf = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
state["momentum_buffer"] = buf
muon_momentum_bufs.append(buf)
if state.get("momentum_buffer") is None:
state.get("momentum_buffer") = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
muon_momentum_bufs.append(state.get("momentum_buffer"))

no need for buf intermediate, right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just trying to reduce the number of times we mention "momentum_buffer". also, state.get("momentum_buffer") = seems not right?

nonetheless, updated the code to

            if "momentum_buffer" not in state:
                state["momentum_buffer"] = torch.zeros_like(
                    p.grad, memory_format=torch.preserve_format
                )
            muon_momentum_bufs.append(state["momentum_buffer"])

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current CI failures are cuz you (probably accidentally) committed the third_party differences--pls remove those!

@chuanhaozhuge
Copy link
Contributor Author

The current CI failures are cuz you (probably accidentally) committed the third_party differences--pls remove those!

uh, they must have come from the rebase. removed

@chuanhaozhuge chuanhaozhuge marked this pull request as ready for review August 12, 2025 05:08
@chuanhaozhuge chuanhaozhuge requested a review from albanD as a code owner August 12, 2025 05:08
buf = muon_momentum_bufs[i]
buf.mul_(momentum).add_(grad)
if nesterov:
grad = grad.add(buf, alpha=momentum)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could use lerp_ probably and save some memory



@dataclass
class BaseMsignFnConfig:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed right?


__all__ = ["Muon"]

# Constants from Keller Jordan's Muon post: https://kellerjordan.github.io/posts/muon/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: link to github + specific line + specific commit to make sure this lint will stay up

Comment on lines +61 to +65
assert steps < 100, (
"Number of steps must be less than 100 for computational efficiency"
)
assert len(grad.shape) == 2, "Input tensor gradient must be a 2D matrix"
assert len(coefficients) == 3, "Coefficients must be a tuple of exactly 3 values"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No plain asserts, please raise appropriate Runtime/Value/Type errors

assert len(grad.shape) == 2, "Input tensor gradient must be a 2D matrix"
assert len(coefficients) == 3, "Coefficients must be a tuple of exactly 3 values"
a, b, c = coefficients[0], coefficients[1], coefficients[2]
X = grad.bfloat16()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be very surprised if this is the way to go unless you have a hard assert that the param dtype is fixed?

__all__ = ["Muon"]

# Constants from Keller Jordan's Muon post: https://kellerjordan.github.io/posts/muon/
EPS = 1e-7
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All epsilons should be dtype dependent to avoid too large noise or flooring to 0.

Comment on lines +74 to +76
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would be good to have descriptive names for these variables.
Also there is quite a bit of extra memory usage due to the extra variables, but I guess that can be handled in a follow up.

Comment on lines +93 to +94
This optimizer performs momentum SGD followed by an optional orthogonalization
step computed via a user provided callable.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should follow the same style as we have in other optimizers like https://docs.pytorch.org/docs/stable/generated/torch.optim.Adam.html#adam to describe un-ambiguoustly the math being performed.

nesterov: bool = True,
*,
msign_fn: MsignFn = zeropower_via_newtonschulz,
msign_fn_config: BaseMsignFnConfig = NewtonSchulzConfig(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho that's interesting.
I do agree that this doesn't match how we do APIs in PyTorch in general.
For value config, I would expect they're all passed in as an argument each (see other optimizers).
If you need to override some specific methods and behavior, you can either have a set of pre-defined implementations that a flag toggles between or you can subclass the optimizer to override the particular method you care about.

Also I guess I'm missing some context on why we want to do it this way if there is only one option for each right now?

Comment on lines +99 to +105
lr: float = 1e-3,
weight_decay: float = 0.1,
momentum: float = 0.95,
nesterov: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chuanhaozhuge where was this doc added?

has_complex: bool,
) -> None:
lr = _to_scalar(lr)
assert has_complex is False, "Complex parameters are not supported"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove plain asserts here as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants