forked from KellerJordan/modded-nanogpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
526 lines (433 loc) · 22.8 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems
from torch import Tensor, nn
import torch.nn.functional as F
import torch.distributed as dist
# use of FlexAttention contributed by @KoszarskyB
from torch.nn.attention.flex_attention import BlockMask, flex_attention
from typing import Optional
from transformers.models.mamba.modeling_mamba import MambaMixer
from transformers.models.mamba.configuration_mamba import MambaConfig
# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min
def next_multiple_of_n(v: float | int, *, n: int):
return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)
# -----------------------------------------------------------------------------
# Custom operators: FP8 matmul by @YouJiacheng
@torch.library.custom_op("nanogpt::mm", mutates_args=())
def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]:
@torch.compile
def impl(x: Tensor, w: Tensor):
assert x.is_contiguous() and w.is_contiguous()
x_f8 = x.div(x_s).to(torch.float8_e4m3fn)
w_f8 = w.div(w_s).to(torch.float8_e4m3fn)
out = torch._scaled_mm(
x_f8,
w_f8.T,
out_dtype=torch.bfloat16,
scale_a=x.new_tensor(x_s, dtype=torch.float32),
scale_b=x.new_tensor(w_s, dtype=torch.float32),
use_fast_accum=True,
)
return out, x_f8, w_f8
return impl(x, w)
@mm_op.register_fake
def _(x: Tensor, w: Tensor, *_):
assert x.ndim == w.ndim == 2
assert x.shape[1] == w.shape[1]
assert x.device == w.device
assert x.is_contiguous() and w.is_contiguous()
return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn)
@torch.library.custom_op("nanogpt::mm_backward", mutates_args=())
def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[
Tensor, Tensor]:
@torch.compile
def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor):
assert grad.is_contiguous()
x_inv_s = grad.new_tensor(x_s, dtype=torch.float32)
w_inv_s = grad.new_tensor(w_s, dtype=torch.float32)
grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32)
grad_f8 = grad.div(grad_s).to(torch.float8_e5m2)
grad_x = torch._scaled_mm(
grad_f8,
w_f8.T.contiguous().T,
out_dtype=torch.bfloat16,
scale_a=grad_inv_s,
scale_b=w_inv_s,
use_fast_accum=False,
)
# faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768)
grad_w = torch._scaled_mm(
x_f8.T.contiguous(),
grad_f8.T.contiguous().T,
out_dtype=torch.float32,
scale_a=x_inv_s,
scale_b=grad_inv_s,
use_fast_accum=False,
).T
return grad_x, grad_w
return impl(g, x_f8, w_f8)
@mm_backward_op.register_fake
def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_):
return x_f8.to(torch.bfloat16), w_f8.to(torch.float32)
def backward(ctx, grad_out: Tensor, *_):
x_f8, w_f8 = ctx.saved_tensors
x_s, w_s, grad_s = ctx.scales
grad_x, grad_w = torch.ops.nanogpt.mm_backward(
grad_out, x_f8, w_f8, x_s, w_s, grad_s
)
return grad_x, grad_w, None, None, None
def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output):
*_, x_s, w_s, grad_s = inputs
_, x_f8, w_f8 = output
ctx.save_for_backward(x_f8, w_f8)
ctx.scales = x_s, w_s, grad_s
ctx.set_materialize_grads(False)
mm_op.register_autograd(backward, setup_context=setup_context)
# -----------------------------------------------------------------------------
# Muon optimizer
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-Schulz
Muon is an optimization algorithm that extends standard SGD with momentum by applying an
orthogonalization post-processing step, where each 2D parameter's update is replaced with
the nearest orthogonal matrix. The orthogonalization is performed using Newton-Schulz iterations,
which can be stably run in `bfloat16` on the GPU.
Reference:
https://kellerjordan.github.io/posts/muon/
Some warnings:
- This optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use.
"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, rank=0, world_size=1):
self.rank = rank
self.world_size = world_size
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params: list[Tensor] = [*params]
param_groups = []
for size in {p.numel() for p in params}:
b = torch.empty(world_size, size, dtype=torch.bfloat16, device="cuda")
group = dict(params=[p for p in params if p.numel() == size],
update_buffer=b, update_buffer_views=[b[i] for i in range(world_size)])
param_groups.append(group)
super().__init__(param_groups, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
update_buffer: Tensor = group["update_buffer"]
update_buffer_views: list[Tensor] = group["update_buffer_views"]
# generate weight updates in distributed fashion
params: list[Tensor] = group["params"]
handle = None
params_world = None
def update_prev(): # optimized Muon implementation contributed by @YouJiacheng
handle.wait()
for p_world, g_world in zip(params_world, update_buffer_views):
p_world.add_(g_world.view_as(p_world),
alpha=-group["lr"] * max(1, p_world.size(-2) / p_world.size(-1)) ** 0.5)
for base_i in range(len(params))[::self.world_size]:
if base_i + self.rank < len(params):
p = params[base_i + self.rank]
g = p.grad
if g is None:
# continue
g = torch.zeros_like(p) # Force a zero grad
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]).flatten()
else:
g = update_buffer_views[self.rank]
if base_i > 0:
update_prev() # async all_gather instead of sync all_reduce by @YouJiacheng
handle = dist.all_gather_into_tensor(update_buffer, g, async_op=True)
params_world = params[base_i: base_i + self.world_size]
update_prev()
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the model
def norm(x: Tensor):
return F.rms_norm(x, (x.size(-1),))
class CastedLinear(nn.Linear):
"""
A modified linear layer that optionally performs computation in FP8 precision.
This class extends `torch.nn.Linear` with an optional FP8 computation mode, controlled by `use_fp8`.
When FP8 is enabled, the forward pass utilizes a custom matrix multiplication operation (`nanogpt::mm`)
that scales inputs, weights, and gradients before performing the computation in FP8 precision.
Attributes:
use_fp8 (bool): If True, enables FP8 computation during training.
x_s (float): Scaling factor for input tensor when using FP8.
w_s (float): Scaling factor for weights when using FP8.
grad_s (float): Scaling factor for gradients when using FP8.
Note:
- The FP8 computation is only used during training.
- The custom operation `nanogpt::mm` is used for FP8 matrix multiplication, which handles input scaling
and precision conversion to maintain numerical stability.
"""
def __init__(self, in_features: int, out_features: int, use_fp8: bool = False, x_s: float = 1.0, w_s: float = 1.0,
grad_s: float = 1.0):
super().__init__(in_features, out_features, bias=False)
self.use_fp8 = use_fp8
self.x_s = x_s
self.w_s = w_s
self.grad_s = grad_s
def reset_parameters(self) -> None:
std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3)
bound = (3 ** 0.5) * std
with torch.no_grad():
self.weight.uniform_(-bound, bound)
def forward(self, x: Tensor):
if self.use_fp8 and self.training:
_x = x.flatten(0, -2)
out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0]
return out.reshape(*x.shape[:-1], -1)
else:
return F.linear(x, self.weight.type_as(x))
class Rotary(nn.Module):
def __init__(self, dim: int, max_seq_len: int):
super().__init__()
# half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim // 4, dtype=torch.float32)
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim // 4)])
t = torch.arange(max_seq_len, dtype=torch.float32)
theta = torch.einsum("i,j -> ij", t, angular_freq)
self.cos = nn.Buffer(theta.cos(), persistent=False)
self.sin = nn.Buffer(theta.sin(), persistent=False)
def forward(self, x_BTHD: Tensor):
assert self.cos.size(0) >= x_BTHD.size(-3)
cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3).type_as(x_BTHD)
class CausalSelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
hdim = num_heads * head_dim
std = 0.5 * (dim ** -0.5)
bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng
# merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng
# https://x.com/hi_tysam/status/1879699187107033311
self.qkv_w = nn.Parameter(torch.empty(3, hdim, dim).uniform_(-bound, bound))
self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))
self.rotary = Rotary(head_dim, max_seq_len)
self.c_proj = CastedLinear(hdim, dim)
self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977
# scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun
# inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283
self.attn_scale = 0.12
def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask):
B, T = x.size(0), x.size(1) # batch size, sequence length
assert B == 1, "Must use batch size = 1 for FlexAttention"
q, k, v = F.linear(x, self.qkv_w.flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads,
self.head_dim).chunk(3, dim=-2)
q, k = norm(q), norm(k) # QK norm @Grad62304977
q, k = self.rotary(q), self.rotary(k)
if ve is not None:
v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977
else: # skip mid-layers token value embeddings by @YouJiacheng
v = self.lambdas[0] * v
y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask,
scale=self.attn_scale).transpose(1, 2)
y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, dim: int, multiplier: float | int = 4):
super().__init__()
hdim = int(multiplier * dim)
self.c_fc = CastedLinear(dim, hdim)
self.c_proj = CastedLinear(hdim, dim)
nn.init.zeros_(self.c_proj.weight) # zero init suggested by @Grad62304977
def forward(self, x: Tensor):
x = self.c_fc(x)
x = F.relu(
x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int):
super().__init__()
# skip attention of blocks.7 (the 8th layer) by @YouJiacheng
self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None
self.mlp = MLP(dim)
self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
x = self.lambdas[0] * x + self.lambdas[1] * x0
if self.attn is not None:
x = x + self.attn(norm(x), ve, block_mask)
x = x + self.mlp(norm(x))
return x
class HybridBlock(nn.Module):
"""
A single block that can be 'SSM+MLP' or 'ATTN+MLP'.
"""
def __init__(self, block_type: str, dim: int, num_heads: int, max_seq_len: int, args):
super().__init__()
self.block_type = block_type
# Weighted skip connection
self.lambdas = nn.Parameter(torch.tensor([1.0, 0.0]))
# Decide which core module to use
if block_type == "ATTN":
self.module = CausalSelfAttention(dim, num_heads, max_seq_len)
elif block_type == "SSM":
config = MambaConfig(vocab_size=args.vocab_size,
hidden_size=args.dim,)
self.module = MambaMixer(config, layer_idx=5)
else:
raise ValueError(f"Unknown block_type: {block_type}")
self.mlp = MLP(dim)
def forward(self, x: torch.Tensor, ve: Optional[torch.Tensor], x0: torch.Tensor,
block_mask: Optional[torch.Tensor]):
# Weighted skip
x = self.lambdas[0] * x + self.lambdas[1] * x0
if self.block_type == "ATTN":
x = x + self.module(norm(x), ve, block_mask)
if self.block_type == "SSM":
x = x + self.module(norm(x), attention_mask=block_mask)
x = x + self.mlp(norm(x))
return x
# -----------------------------------------------------------------------------
# The main model
class GPT(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int, args):
super().__init__()
self.embed = nn.Embedding(vocab_size, model_dim)
# token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897
# value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78
self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)])
# self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i, args) for i in range(num_layers)])
LAYER_ORDER = ["ATTN" if layer_idx == 5 else "SSM" for layer_idx in range(num_layers)]
self.blocks = nn.ModuleList([
HybridBlock(
block_type=bt,
dim=model_dim,
num_heads=num_heads,
max_seq_len=max_seq_len,
args=args,
)
for bt in LAYER_ORDER
])
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency.
# suggested to me by @Grad62304977. this originates from Karpathy's experiments.
self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128), use_fp8=True, x_s=0.5,
w_s=2 ** -9, grad_s=2 ** -19)
self.lm_head.weight.detach().zero_() # @Grad62304977
# Add learnable skip connection weights for decoder layers
assert num_layers % 2 == 0
self.skip_weights = nn.Parameter(torch.ones(num_layers // 2))
def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor):
BLOCK_SIZE = 128
docs = (input_seq == 50256).cumsum(0)
def document_causal(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[q_idx] == docs[kv_idx]
return causal_mask & document_mask
def dense_to_ordered(dense_blockmask: Tensor):
num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32)
indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32)
return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
# manual block mask creation by @YouJiacheng
assert len(input_seq) % BLOCK_SIZE == 0
NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
causal_blockmask_any = block_idx[:, None] >= block_idx
causal_blockmask_all = block_idx[:, None] > block_idx
docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low)
document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low)
blockmask_any = causal_blockmask_any & document_blockmask_any
blockmask_all = causal_blockmask_all & document_blockmask_all
partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all)
full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all)
def build_bm(window_size_blocks: Tensor) -> BlockMask:
return BlockMask.from_kv_blocks(
torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)),
partial_kv_indices,
torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1),
full_kv_indices,
BLOCK_SIZE=BLOCK_SIZE,
mask_mod=document_causal,
)
# Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper
return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2)
def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor) -> Tensor:
assert input_seq.ndim == 1, "input_seq must be 1D"
value_embeddings = [embed(input_seq) for embed in self.value_embeds]
# Pattern: 0,1,2 + None blocks + 0,1,2. Credit @YouJiacheng, improved on @leloykun's U-net structure
value_embeddings = (
[value_embeddings[0], value_embeddings[1], value_embeddings[2]]
+ [None] * (len(self.blocks) - 6)
+ [value_embeddings[0], value_embeddings[1], value_embeddings[2]]
)
assert len(value_embeddings) == len(self.blocks)
# Create block masks
long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
block_masks = [
long_bm, short_bm, short_bm, short_bm, long_bm, short_bm,
short_bm, long_bm, short_bm, short_bm, short_bm, long_bm
]
assert len(block_masks) == len(self.blocks)
# Initial embedding + normalization
x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977
x = x0
# U-Net style skip connections: down -> up path
skip_connections = []
num_skip = len(self.skip_weights)
# "Down" pass: gather skip connections
for i in range(num_skip):
x = self.blocks[i](x, value_embeddings[i], x0, block_masks[i])
skip_connections.append(x)
# "Up" pass: retrieve and apply skip connections
for i in range(num_skip, len(self.blocks)):
x = x + self.skip_weights[i - num_skip] * skip_connections.pop()
x = self.blocks[i](x, value_embeddings[i], x0, block_masks[i])
logits = self.lm_head(norm(x))
# @Grad62304977 added tanh softcapping following Gemma 2 paper,
# @KoszarskyB reduced it from 30 to 15,
# @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
logits = 30 * torch.sigmoid(logits.float() / 7.5)
# Cross-entropy loss
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq)
return loss