Skip to content

Muon optimizer: +~30% sample efficiency with <3% wallclock overhead

License

Notifications You must be signed in to change notification settings

KellerJordan/Muon

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 

Repository files navigation

Muon: An optimizer for the hidden layers of neural networks

This repo contains an implementation of the Muon optimizer described in this thread and this writeup.

Installation

pip install git+https://github.com/KellerJordan/Muon

Usage

Muon is intended to optimize only the internal ≥2D parameters of a network. Embeddings, classifier heads, and scalar or vector parameters should be optimized using AdamW.

# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.90, 0.95), weight_decay=0.01)

from muon import Muon
# Find ≥2D parameters in the body of the network -- these should be optimized by Muon
muon_params = [p for p in model.body.parameters() if p.ndim >= 2]
# Find everything else -- these should be optimized by AdamW
adamw_params = ([p for p in model.body.parameters() if p.ndim < 2]
              + [*model.head.parameters(), *model.embed.parameters()])
# Create the optimizer
optimizers = [Muon(muon_params, lr=0.02, momentum=0.95),
              torch.optim.AdamW(adamw_params, lr=3e-4, betas=(0.90, 0.95), weight_decay=0.01)]
...

# in the training step
for opt in optimizers:
    opt.step()

You'll have to replace model.body, model.head, and model.embed with whatever subset is appropriate for your model. E.g., for a ConvNet, muon_params should be all the convolutional filters, and adamw_params should be everything else.

Example usage

Example use of this Muon in the NanoGPT speedrun

Example use of a Muon variant in the CIFAR-10 speedrun

Hyperparameter tuning

Typically, the default values of momentum (0.95), nesterov (True), and ns_steps (5) work well. The only hyperparameter which must be tuned is the learning rate. It should have constant muP scaling, that is, as you scale up the model size, you shouldn't need to retune the learning rate.

Benchmarks

For a comparison between AdamW, Shampoo, SOAP, and Muon for training a 124M-parameter transformer, see here.

Accomplishments

Citation

@misc{jordan2024muon,
  author       = {Keller Jordan and Yuchen Jin and Vlado Boza and You Jiacheng and
                  Franz Cecista and Laker Newhouse and Jeremy Bernstein},
  title        = {Muon: An optimizer for hidden layers in neural networks},
  year         = {2024},
  url          = {https://kellerjordan.github.io/posts/muon/}
}

About

Muon optimizer: +~30% sample efficiency with <3% wallclock overhead

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages