-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Introduce Muon optimizer to PyTorch #159465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: muon_dev
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159465
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (1 Unrelated Failure)As of commit 2750453 with merge base 799303f ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@janeyx99 related on passing names to the optimizer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking this on and adding test cases + a benchmark! I briefly did a highlevel review and left some comments. One more thing is whether you've checked that the correctness of this implementation matches the moonshot one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come you put these in a new file?
Do we expect more tests to be added here? It might be more centralized to have these tests live in test_optim.py anyway for easy search/share the optimizer info stuff when we add more configs for Muon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is very specific to Muon, so I put it in a new file. There might be new test cases added, for example to test alternative msign function numerics equivalency.
My rationale is that test cases that are general to all optimizers are put into test_optim.py
and we create separate files to cover specific functionalities. This way test_optim.py
doesn't grow too large IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I agree with the logical separation. The reason we have lots of tests in test_optim.py is to be able to run all optim-related tests in one go. Can you update this file + test_optim.py similar to test_lrscheduler.py so we can get both pros of logical separation + being able to run all tests in one go?
torch/optim/_muon.py
Outdated
|
||
__all__ = ["Muon", "muon"] | ||
|
||
logger = logging.getLogger(__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimizers have generally not output any logs--we tend to let the trainer code or higher level libs handle logs. Is there a reason Muon should have a logger in particular?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, yea I just think it's important to feedback some specific behavior to users to avoid potential footgun. Maybe we can control verbosity.
torch/optim/_muon.py
Outdated
# If no fqns are provided, use Muon for all parameters with 2D shape. | ||
# Note: this may not be the expected behavior since some 2D | ||
# parameters may not be intended to be optimized with Muon, for example Embedding. | ||
logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just use normal warning here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, I'll use warning.warn
muon_grads.append(p.grad) | ||
muon_momentum_bufs.append(buf) | ||
else: | ||
# for the rest of the parameters, we use AdamW to optimize. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we delegate this to the AdamW optimizer instead of duping code here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a good idea and was considering this as well. Re-implement for now because we have to update the state_dict
handling logic if we want to introduce nested optimizer (tried that in DiLoco where we have inner and outer optimizers). Also we may want to explore update rules, so want to keep the flexibility here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not try to solve that problem in this PR, let's keep this PR as straightforward as possible.
Doesn't CUBLAS already have a symmetric matmul kernel we can exploit as opposed to writing our own triton one? |
benchmarks/muon_examples/train.py
Outdated
""" | ||
assert len(G.shape) == 2 | ||
a, b, c = (3.4445, -4.7750, 2.0315) | ||
X = G.bfloat16() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why bfloat16 cast here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly for GPU efficiency I believe, to use tensorcore bf16 flops. This is adopted from Keller/Jingyuan's original implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why bfloat16 cast here?
this bf16 choice is discussed in here: https://kellerjordan.github.io/posts/muon/
Basically, there are several methods to approximate SVD, and with this method, bf16 is good enough, that's why it is chosen by Keller at the first place. With bf16, the communication cost is halved.
benchmarks/muon_examples/train.py
Outdated
def zeropower_via_newtonschulz5(G, steps): | ||
""" | ||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a | ||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose | ||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at | ||
zero even beyond the point where the iteration no longer converges all the way to one everywhere | ||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T | ||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model | ||
performance at all relative to UV^T, where USV^T = G is the SVD. | ||
""" | ||
assert len(G.shape) == 2 | ||
a, b, c = (3.4445, -4.7750, 2.0315) | ||
X = G.bfloat16() | ||
if G.size(0) > G.size(1): | ||
X = X.T | ||
# Ensure spectral norm is at most 1 | ||
X = X / (X.norm() + 1e-7) | ||
# Perform the NS iterations | ||
for _ in range(steps): | ||
A = X @ X.T | ||
B = ( | ||
b * A + c * A @ A | ||
) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng | ||
X = a * X + B @ X | ||
|
||
if G.size(0) > G.size(1): | ||
X = X.T | ||
return X |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def zeropower_via_newtonschulz5(G, steps): | |
""" | |
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a | |
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose | |
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at | |
zero even beyond the point where the iteration no longer converges all the way to one everywhere | |
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T | |
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model | |
performance at all relative to UV^T, where USV^T = G is the SVD. | |
""" | |
assert len(G.shape) == 2 | |
a, b, c = (3.4445, -4.7750, 2.0315) | |
X = G.bfloat16() | |
if G.size(0) > G.size(1): | |
X = X.T | |
# Ensure spectral norm is at most 1 | |
X = X / (X.norm() + 1e-7) | |
# Perform the NS iterations | |
for _ in range(steps): | |
A = X @ X.T | |
B = ( | |
b * A + c * A @ A | |
) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng | |
X = a * X + B @ X | |
if G.size(0) > G.size(1): | |
X = X.T | |
return X | |
def zeropower_via_newtonschulz_optimized( | |
G: Tensor, | |
ns_config: BaseMsignFnConfig, | |
) -> Tensor: | |
# unpack config | |
ns_config = cast(NewtonSchulzConfig, ns_config) | |
steps = ns_config.ns_steps | |
a, b, c = ns_config.coefficients | |
assert 1 < steps < 100, "ns_steps must be <100" | |
assert G.dim() == 2, "G must be 2D" | |
assert len((a, b, c)) == 3 | |
# cast & maybe transpose so n ≤ k | |
X = G.to(torch.bfloat16) | |
transposed = False | |
if X.size(0) > X.size(1): | |
X = X.t().contiguous() | |
transposed = True | |
# normalize (Frobenius norm) | |
X.div_(X.norm() + 1e-7) | |
# shapes | |
n, k = X.shape | |
# pre-allocate buffers | |
A = torch.empty((n, n), dtype=X.dtype, device=X.device) | |
A2 = torch.empty_like(A) | |
B = torch.empty_like(A) | |
Y = torch.empty_like(X) | |
# Newton–Schulz loop | |
for _ in range(steps): | |
# A = X @ X^T | |
torch.mm(X, X.t(), out=A) | |
# A2 = A @ A | |
torch.mm(A, A, out=A2) | |
# B = b*A + c*A2 | |
B.copy_(A).mul_(b).add_(A2, alpha=c) | |
# Y = B @ X | |
torch.mm(B, X, out=Y) | |
# X = a * X + Y, in-place | |
X.mul_(a).add_(Y) | |
# undo transpose | |
if transposed: | |
X = X.t() | |
return X |
If you are going to support eager without compile, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested using compile but didn't seem to observe much speedup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LMAO: I swore I worked on this at point and remembered I tried to optimize a CUTLASS implementation of this at some point here: nil0x9/flash-muon#1 The repo switched over to Triton anyway though: https://github.com/nil0x9/flash-muon/blob/80ac87fb49afc792b84eccb393d051b1ed8eee32/flash_muon/matmul_transpose_triton.py#L20 @chuanhaozhuge I assume this is the kernel you were referencing
The above suggestion should be marginally faster than the current code because we refuse the matrix buffer and use in place ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the first impl, and used for the purpose of providing golden answer for future speedup operations or parallelism operations. So we might not need to add a perfect impl or kernel here in this PR yet?
In future, when continuing to optimize the impl and integrate into FSDP, we can revisit each specific part and see how to optimizer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LMAO: I swore I worked on this at point and remembered I tried to optimize a CUTLASS implementation of this at some point here: nil0x9/flash-muon#1 The repo switched over to Triton anyway though: https://github.com/nil0x9/flash-muon/blob/80ac87fb49afc792b84eccb393d051b1ed8eee32/flash_muon/matmul_transpose_triton.py#L20 @chuanhaozhuge I assume this is the kernel you were referencing
The above suggestion should be marginally faster than the current code because we refuse the matrix buffer and use in place ops.
Thanks @Skylion007 for the pointer! What I mentioned is an internal kernel implementation that hasn't been open-sourced. We will find time to benchmark the kernel you referred to above.
Thanks @toothacher17, right this is the purpose. As we have the interface, I think users can easily plug in optimized implementation. PyTorch can host some optimization solution as well but we need to figure out the support and maintenance mode.
benchmarks/muon_examples/train.py
Outdated
@@ -0,0 +1,395 @@ | |||
import math |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file seems to be similar with: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
Might be nice to add a link for the reference purpose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! I already included the link in the README. Let me highlight that in the code files as well.
benchmarks/muon_examples/train.py
Outdated
""" | ||
assert len(G.shape) == 2 | ||
a, b, c = (3.4445, -4.7750, 2.0315) | ||
X = G.bfloat16() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why bfloat16 cast here?
this bf16 choice is discussed in here: https://kellerjordan.github.io/posts/muon/
Basically, there are several methods to approximate SVD, and with this method, bf16 is good enough, that's why it is chosen by Keller at the first place. With bf16, the communication cost is halved.
benchmarks/muon_examples/train.py
Outdated
def zeropower_via_newtonschulz5(G, steps): | ||
""" | ||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a | ||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose | ||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at | ||
zero even beyond the point where the iteration no longer converges all the way to one everywhere | ||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T | ||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model | ||
performance at all relative to UV^T, where USV^T = G is the SVD. | ||
""" | ||
assert len(G.shape) == 2 | ||
a, b, c = (3.4445, -4.7750, 2.0315) | ||
X = G.bfloat16() | ||
if G.size(0) > G.size(1): | ||
X = X.T | ||
# Ensure spectral norm is at most 1 | ||
X = X / (X.norm() + 1e-7) | ||
# Perform the NS iterations | ||
for _ in range(steps): | ||
A = X @ X.T | ||
B = ( | ||
b * A + c * A @ A | ||
) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng | ||
X = a * X + B @ X | ||
|
||
if G.size(0) > G.size(1): | ||
X = X.T | ||
return X |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the first impl, and used for the purpose of providing golden answer for future speedup operations or parallelism operations. So we might not need to add a perfect impl or kernel here in this PR yet?
In future, when continuing to optimize the impl and integrate into FSDP, we can revisit each specific part and see how to optimizer?
benchmarks/muon_examples/train.py
Outdated
############################ | ||
|
||
params = [p for p in group["params"] if self.state[p]["use_muon"]] | ||
# import pdb; pdb.set_trace() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove the pdb here?
|
||
# Note: Muon doesn't support multiple param groups for now. | ||
if muon_param_fqns is not None: | ||
muon_param_fqns_set = set(muon_param_fqns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it might be nice to just provide some other simple logic to filter out muon params. the default logic is actually straightforward:
- word embeddings and lm head do not go into Muon (can be controlled by name)
- RMSNorm Gamma does not go into Muon
Other params will go to Muon. An example can be found here:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline; having some pre-determined name check may be a footgun. Since we are not using a training library like Megatron where we know the fqn of parameters, it's hard to filter by name correctly, and may create confusion to users.
I improved the warning message to nudge users to include fqn list when using Muon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should allow user passing in custom function to decide whether a param is Muon param or Adam param? Then user can choose to either providing a list of Muon param names or provide a function similar to the Megatron example @toothacher17 linked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move all this logic into a followup PR as there may be more design decisions to be made, and in this PR, we should assume that Muon only handles Muon updates.
|
||
# Note: Muon doesn't support multiple param groups for now. | ||
if muon_param_fqns is not None: | ||
muon_param_fqns_set = set(muon_param_fqns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should allow user passing in custom function to decide whether a param is Muon param or Adam param? Then user can choose to either providing a list of Muon param names or provide a function similar to the Megatron example @toothacher17 linked.
load_tests = load_tests | ||
|
||
|
||
class MoonshotReferenceMuon(torch.optim.Optimizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: moonshot Muon is used multiple times. Do we want to implement it once as a standalone class and reuse in all tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created a common file under benchmark/muon_examples so that common code is reused in train.py and train_ddp.py. not updating this file yet since it's a bit weird to refer code under benchmark
from test/optim
.
torch/optim/_muon.py
Outdated
if G.size(0) > G.size(1): | ||
X = X.T | ||
# Ensure spectral norm is at most 1 | ||
X = X / (X.norm() + 1e-7) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: defined 1e-7 as constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, updated for coefficients.
torch/optim/_muon.py
Outdated
@@ -61,7 +66,7 @@ def zeropower_via_newtonschulz(G: Tensor, ns_config: BaseMsignFnConfig) -> Tenso | |||
if G.size(0) > G.size(1): | |||
X = X.T | |||
# Ensure spectral norm is at most 1 | |||
X = X / (X.norm() + 1e-7) | |||
X = X / (X.norm() + EPS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might also be good to add it explicitly to function arguments: like ..., eps = EPS):
? this would allow someone to call the function with other eps without modifying the global vars (and would make these functions purely functional - not depending on global state, only on their explicit arguments)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, I added eps to NewtonSchulzConfig
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks for everyone's reviews and @chuanhaozhuge's work here! I do think we should greatly simplify this PR to represent the Muon algorithm. We should design a dispatching API for branching params into Muon vs AdamW in the next PR, but let's keep Muon consistent and simple here.
Thanks for showing the code for the benchmarks + matching with Moonshot, but let's move that code to separate gists with results in the PR body.
I've left comments with more details below!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for showing the code for the benchmarks/correctness comparisons! Let's not land any of the benchmark code (as it can live separately in a different gist or locally).
if optim_cls is Muon: | ||
atol = 3e-4 | ||
rtol = 5e-5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmmm these are big...do you know why this would be?
@@ -617,9 +622,11 @@ def test_correctness(self, device, dtype, optim_info, use_closure): | |||
param.grad = param.grad.to_sparse() | |||
|
|||
opt_compiled = optim_cls( | |||
model_compiled.parameters(), **deepcopy(kwargs) | |||
model_compiled.named_parameters(), **deepcopy(kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep this as parameters(). Let's have Muon assume that all its parameters should receive the Muon update (all the logic for dispatching FQNs should live above the optimizer).
model_compiled.named_parameters(), **deepcopy(kwargs) | ||
) | ||
opt_eager = optim_cls( | ||
model_eager.named_parameters(), **deepcopy(kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I agree with the logical separation. The reason we have lots of tests in test_optim.py is to be able to run all optim-related tests in one go. Can you update this file + test_optim.py similar to test_lrscheduler.py so we can get both pros of logical separation + being able to run all tests in one go?
muon_grads.append(p.grad) | ||
muon_momentum_bufs.append(buf) | ||
else: | ||
# for the rest of the parameters, we use AdamW to optimize. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not try to solve that problem in this PR, let's keep this PR as straightforward as possible.
bias_correction1 = 1 - beta1 ** step.item() | ||
bias_correction2 = 1 - beta2 ** step.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.item() calls are expensive--why not:
bias_correction1 = 1 - beta1 ** step.item() | |
bias_correction2 = 1 - beta2 ** step.item() | |
bias_correction1 = 1 - beta1 ** step | |
bias_correction2 = 1 - beta2 ** step |
self, | ||
params: ParamsT, | ||
lr: float = 1e-3, | ||
wd: float = 0.1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's follow consistency and call this weight_decay
"adjust_lr_fn": lambda lr, param_shape: lr, | ||
}, | ||
desc="passing alternative adjust_lr_fn", | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's have configs for every arg in the constructor (weight decay, momentum, etc)
error_type=RuntimeError, | ||
# note other optimizers raise TypeError in the base | ||
# optimizer class. Muon raises the error earlier. | ||
error_regex="Expected params to be named parameters", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not force this restriction
1c82fc8
to
654f754
Compare
A single-device version of Muon. Algorithm refers to the Moonshot implementation.
Usage
This implementation requires users to pass in
named_parameters
of the model - a list of(name, param)
tuples. Users should also specify the fully-qualified names (FQNs) of the parameters to be optimized by Muon. Parameters not included in the FQN list will fall back to AdamW optimization. If no FQN list is provided, Muon will by default optimize 2D parameters, which may not be the expected behavior. A warning will be issued in this case.Additional usage
Users are also able to pass in self-defined
msign
function for orthogonalization, and learning rate adjustment function. Interface defined below:By default, we use 5-step Newton-Schulz, with coefficients proposed by Keller. We use LR adjustment proposed by Moonshot, which grafts learning rate from AdamW.
Testing
test/test_optim.py
. We updated the test cases to pass named parameters to the optimizer under test. Additionally, we introduced a new test case to verify that when the user provides an empty FQN list, Muon correctly falls back to AdamW behavior.openwebtext-100k
dataset. We trained for one epoch and the resulting loss curve is compared against the Moonshot implementation to confirm behavioral consistency.Performance
Training for one epoch of openwebtext-100k on eight H100 GPUs with DDP:
Muon runs ~20s slower compared to AdamW. Assuming no other changes, Muon is 2.5% slower than AdamW.
AdamW: Optimizer.step() takes ~13.5 ms, step time ~930 ms

Muon: Optimizer.step() takes ~54 ms, step time ~960 ms

Next Steps
MuP
cc: @toothacher17, @vinaysrao, @jcui2, @haocizhang
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben