Speedrunning ideas discussion #23
Replies: 44 comments 73 replies
-
I assume FP8 has been considered? Nvidia is showing serious speedups on the H100 https://github.com/NVIDIA/TransformerEngine |
Beta Was this translation helpful? Give feedback.
-
how about nGPT arch instead of standard GPT? lucidrains has a good implementation, I ran the training and can confirm it converges a lot faster than other models/architectures I've tried on the same data (tiny model here) |
Beta Was this translation helpful? Give feedback.
-
In Muon's
|
Beta Was this translation helpful? Give feedback.
-
jie040109 suggested AdaFactor as an alternate optimizer. I'm not very bullish on this, but I would be totally happy to see any new record using AdaFactor, or any other optimizer for that matter. |
Beta Was this translation helpful? Give feedback.
-
MizuleGPT suggested TokenFormer. This does look interesting (Grad also sent it to me), but I haven't seriously thought about it. |
Beta Was this translation helpful? Give feedback.
-
The rules say:
Technically, wouldn't this allow MoE models with >124M total params as long as we dont go higher than 124M active at once? |
Beta Was this translation helpful? Give feedback.
-
I tried using Cut Cross Entropy, and while it allows for a larger sequence length without OOM, it seems increasing the batch size / seq len doesn't result in any speedup. Seems we've maxed out parallelization? Is there something special about seq len of 2**16 tokens? Is this the most tokens FlexAttention is able to process in a single batch? Something specific to the H100? |
Beta Was this translation helpful? Give feedback.
-
We currently employ four strategies for mapping residuals. There might be other mapping strategies to improve performance. In the standard GPT-2 architecture, activations follow a sequential connection pattern: Three additional activation connection mechanisms to augment this structure have been incorporated:
All four of these have learned weights. I've tried a dense version of skip connections where each layers input is the weighted sum of all previous layers residuals, but train performance reduced slightly. |
Beta Was this translation helpful? Give feedback.
-
Where do you think is the speedrun limit? The last record runs 10 times faster than the baseline. |
Beta Was this translation helpful? Give feedback.
-
MPA-Lya as a substitute for Newton Schulz within Muon might be worth looking into |
Beta Was this translation helpful? Give feedback.
-
I suggest making the recently added token value embeddings LoRA matrices for faster training. It would also make the file size of the resulting model smaller. |
Beta Was this translation helpful? Give feedback.
-
Speaking of the LoRA idea above, why not make every linear layer a low-rank approximation of the full matrix? At higher ranks (ie, 512-1024) you should be able to effectively capture all the expressability in the weights anyways while being able to make a deeper/lower-real-parameter model, allowing for higher batch sizes -> faster training?. # pseudo-torch
import torch
from torch import nn
from torch.nn import functional as F
class LowRankLinear(nn.Module):
def __init__(
in_features: int,
out_features: int,
rank: int = 512 # should be lower than in/out features to actually have any advantages
):
super().__init__()
self.a = torch.Tensor((in_features, rank))
self.b = torch.Tensor((rank, out_features))
def forward(
x: torch.Tensor
) -> torch.Tensor:
intermediate = self.b * self.a
return F.linear(x, intermediate) |
Beta Was this translation helpful? Give feedback.
-
Not bullish, but may be worth trying: MemoryFormer |
Beta Was this translation helpful? Give feedback.
-
Patch-level training In the work https://arxiv.org/pdf/2407.12665 it is proposed to train firstly not on tokens, but on token patches prediction. After some time, the vanilla token-level training returns. This technique is essentially accomplished with just 10 lines of code (which are compatible with this repo, as far as I looked into it) One can find this suspicious because of the loss change, however, it is changed only for the 2/3 of the training process and then the loss will match the standard cross-entropy, making this benchmark comparison fair. Additionally, the loss can be brought to cross-entropy on validation. Here is my modification of modded-nanogpt, implementing this, hopefully https://gist.github.com/kabachuha/1c0440d7193cd60f00f566d8c6f2329a (I don't have an H100, please, launch it for me 😭) The original repo https://github.com/shaochenze/PatchTrain Edit: forgot to add that the paper in question promises x2 reduction of training costs |
Beta Was this translation helpful? Give feedback.
-
Two of my suggestions:
|
Beta Was this translation helpful? Give feedback.
-
Has anyone tried speed running on 16, 32, or 64 H100s? |
Beta Was this translation helpful? Give feedback.
-
2x faster pretraining and half the memory usage compared to adam: gotta love that data efficiency |
Beta Was this translation helpful? Give feedback.
-
A thought that I keep having is that ever since the Unet Pattern has been introduced it might be more sample efficient to tie embedding and lm head again. The Unet structure already steers the model towards having a sort of symmetry, which could counteract the performance lost, and then the attention head in the 8th layer could be maybe re-enabled to bring the sample efficiency back up. Now that the attention scaling has been also tuned. |
Beta Was this translation helpful? Give feedback.
-
would this be allowed? |
Beta Was this translation helpful? Give feedback.
-
Another Idea: https://reasoning-tokens.ghost.io/reasoning-tokens/ |
Beta Was this translation helpful? Give feedback.
-
this optimizer looks nice: https://arxiv.org/abs/2501.12243 EDIT: another optimizer: https://arxiv.org/abs/2412.17107 |
Beta Was this translation helpful? Give feedback.
-
I stumbled upon this modification when I was tinkering with the MLP layer:
It keeps the total number of parameters per block the same as the original RELU^2 MLP(1x2+2x2+2x1=1x4+4x1), but improves the final loss in my experimentation by quite a bit(train/val loss: 0.9964/1.0524 -> 0.9876/1.0470). This modification also makes the optimizer step a little faster on my computer, possibly due to more Muon-friendly matrix dimension ratios. At a glance it's similar to Gated Linear Unit(with a GELU gate), but it's different. In this new case, the gating happens between consequent layers, unlike GLU's case of gating within the same layer. The MLP layers use about 2/3 of the total parameters, but the attention layers get almost all the architecture exploration efforts. I think it may be too early to conclude that all low-hanging fruits about the MLP layer architecture have already been picked. But again, my experiment was done with a smaller model on a smaller dataset and I'm not sure if the result would transfer when we scale things up. |
Beta Was this translation helpful? Give feedback.
-
Considering how sparse relu squared is, I'm gonna look further into this: https://arxiv.org/abs/2410.03440 |
Beta Was this translation helpful? Give feedback.
-
And this looks interesting too: https://arxiv.org/abs/2501.18356 |
Beta Was this translation helpful? Give feedback.
-
44% less parameters, 33% fewer tokens, sound compatible? |
Beta Was this translation helpful? Give feedback.
-
Implemented this paper for fun, but I don't have the hardware to test if it helps in the speedrun at all: Keep in mind that the code in the log here uses a very short sequence length, half-sized MLPs, TokenMonster, and modified kernel_options for FlexAttention; otherwise, it would have taken forever to complete a training run on my 3060. So just only copy over the stuff related to score_mod and create_ssmax if you want to test it yourself. Also validation is broken, I don't really feel like debugging, but obviously that would need to be fixed first too. And hopefully there aren't any mistakes in my implementation either :P (I didn't do any rigorous testing to make sure everything worked as intended). |
Beta Was this translation helpful? Give feedback.
-
Changing the value_embed structure to 010 ... 010 seems to improve wall time for me and validation loss is the same. Testing on a single MI210 in bf16 with more steps and smaller batches so idk if it works for a standard run.
|
Beta Was this translation helpful? Give feedback.
-
@KellerJordan I'm slightly confused about one aspect of Muon's implementation. The post-zero-power matrix scaling seems to be inverted from what's needed to achieve unit RMS norm vectors. Here's a demonstration: import torch
from torch import Tensor
def zeropower_via_newtonschulz5(
G: Tensor,
a = 3.4445,
b = -4.7750,
c = 2.0315,
steps: int = 5,
) -> Tensor:
assert G.ndim >= 2
# X = G.bfloat16()
X = G
should_transpose = G.size(-2) > G.size(-1)
if should_transpose:
X = X.mT
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A
X = a * X + B @ X
if should_transpose:
X = X.mT
return X
def accurate_zeropower_via_newtonschulz5(G: Tensor, a=2, b=-1.5, c=0.5, steps=100):
return zeropower_via_newtonschulz5(G, a=a, b=b, c=c, steps=steps)
def rms_norm(x, eps=1e-7):
return x / (torch.linalg.norm(x, axis=-1, keepdim=True) + eps)
def print_min_max_singular_values(w):
u, s, v = torch.svd(w)
print(f"Min/Max singular values: {s.min().item()}, {s.max().item()}")
print("Current implementation:")
w12 = accurate_zeropower_via_newtonschulz5(torch.randn(4000, 1000))
w12 *= max(1, w12.size(-2) / w12.size(-1))**0.5
w23 = accurate_zeropower_via_newtonschulz5(torch.randn(1000, 4000))
w23 *= max(1, w23.size(-2) / w23.size(-1))**0.5
v1 = rms_norm(torch.randn(1000))
v2 = v1 @ w12.T
v3 = v2 @ w23.T
print_min_max_singular_values(w12)
print_min_max_singular_values(w23)
print(torch.linalg.norm(v1, axis=-1).item())
print(torch.linalg.norm(v2, axis=-1).item())
print(torch.linalg.norm(v3, axis=-1).item())
print()
print("Inverted matrix scaling:")
w12 = accurate_zeropower_via_newtonschulz5(torch.randn(4000, 1000))
w12 *= max(1, w12.size(-1) / w12.size(-2))**0.5
w23 = accurate_zeropower_via_newtonschulz5(torch.randn(1000, 4000))
w23 *= max(1, w23.size(-1) / w23.size(-2))**0.5
v1 = rms_norm(torch.randn(1000))
v2 = v1 @ w12.T
v3 = v2 @ w23.T
print_min_max_singular_values(w12)
print_min_max_singular_values(w23)
print(torch.linalg.norm(v1, axis=-1).item())
print(torch.linalg.norm(v2, axis=-1).item())
print(torch.linalg.norm(v3, axis=-1).item()) Prints:
Note that the current implementation results in a hidden layer RMS norm of 2, rather than 1. The only difference is the inversion of the scaling factor, e.g. changing
to
Is this known / expected? It might not make a difference in practice, since it only impacts the RMS scale of the hidden layer, but it seems strange. The current implementation does seem to be consistent with A Spectral Condition for Feature Learning, which has me even more confused... |
Beta Was this translation helpful? Give feedback.
-
A few questions regarding the rules, sorry if this is the wrong place to ask:
|
Beta Was this translation helpful? Give feedback.
-
Greetings GPT speedrunning enjoyers
I noticed some people were using GitHub issues to suggest new ideas for the run.
Plz discuss here instead. Ty
Beta Was this translation helpful? Give feedback.
All reactions