-
Notifications
You must be signed in to change notification settings - Fork 238
/
Copy patheff63a8c-2f7e-4fc5-97ce-7f600dae0bc7.txt
2467 lines (2398 loc) · 123 KB
/
eff63a8c-2f7e-4fc5-97ce-7f600dae0bc7.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
import uuid
import time
import copy
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems
from torch import Tensor, nn
import torch.nn.functional as F
import torch.distributed as dist
# use of FlexAttention contributed by @KoszarskyB
from torch.nn.attention.flex_attention import BlockMask, flex_attention
#torch._inductor.config.coordinate_descent_tuning = True # turn this on for a slightly faster run (but much slower compile time)
# -----------------------------------------------------------------------------
# Custom operators : FP8 matmul by @YouJiacheng
@torch.library.custom_op("nanogpt::mm", mutates_args=())
def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]:
@torch.compile
def impl(x: Tensor, w: Tensor):
assert x.is_contiguous() and w.is_contiguous()
x_f8 = x.mul(x_s).to(torch.float8_e4m3fn)
w_f8 = w.mul(w_s).to(torch.float8_e4m3fn)
out = torch._scaled_mm(
x_f8,
w_f8.t(),
out_dtype=torch.bfloat16,
scale_a=x.new_tensor(1 / x_s, dtype=torch.float32),
scale_b=x.new_tensor(1 / w_s, dtype=torch.float32),
use_fast_accum=True,
)
return out, x_f8, w_f8
return impl(x, w)
@mm_op.register_fake
def _(x: Tensor, w: Tensor, *_):
assert x.ndim == w.ndim == 2
assert x.shape[1] == w.shape[1]
assert x.device == w.device
assert x.is_contiguous() and w.is_contiguous()
return x @ w.t(), x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn)
@torch.library.custom_op("nanogpt::mm_backward", mutates_args=())
def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]:
@torch.compile
def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor):
assert grad.is_contiguous()
x_inv_s = grad.new_tensor(1 / x_s, dtype=torch.float32)
w_inv_s = grad.new_tensor(1 / w_s, dtype=torch.float32)
grad_inv_s = grad.new_tensor(1 / grad_s, dtype=torch.float32)
grad_f8 = grad.mul(grad_s).to(torch.float8_e5m2)
grad_x = torch._scaled_mm(
grad_f8,
w_f8.t().contiguous().t(),
out_dtype=torch.bfloat16,
scale_a=grad_inv_s,
scale_b=w_inv_s,
use_fast_accum=False,
)
# faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768)
grad_w = torch._scaled_mm(
x_f8.t().contiguous(),
grad_f8.t().contiguous().t(),
out_dtype=torch.float32,
scale_a=x_inv_s,
scale_b=grad_inv_s,
use_fast_accum=False,
).t()
return grad_x, grad_w
return impl(g, x_f8, w_f8)
@mm_backward_op.register_fake
def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_):
return x_f8.to(torch.bfloat16), w_f8.to(torch.float32)
def backward(ctx, grad_out: Tensor, *_):
x_f8, w_f8 = ctx.saved_tensors
x_s, w_s, grad_s = ctx.scales
grad_x, grad_w = torch.ops.nanogpt.mm_backward(
grad_out, x_f8, w_f8, x_s, w_s, grad_s
)
return grad_x, grad_w, None, None, None
def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output):
*_, x_s, w_s, grad_s = inputs
_, x_f8, w_f8 = output
ctx.save_for_backward(x_f8, w_f8)
ctx.scales = x_s, w_s, grad_s
ctx.set_materialize_grads(False)
mm_op.register_autograd(backward, setup_context=setup_context)
# -----------------------------------------------------------------------------
# Muon optimizer
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- This optimizer assumes that all parameters passed in are 2D.
- It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
- We believe it is unlikely to work well for training with small batch size.
- We believe it may not work well for finetuning pretrained models, but we haven"t tested this.
- We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use.
"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, rank=0, world_size=1):
self.rank = rank
self.world_size = world_size
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params: list[Tensor] = [*params]
assert all(isinstance(p, Tensor) for p in params)
sizes = {p.numel() for p in params}
def create_update_buffer(size: int):
b = torch.empty(self.world_size, size, dtype=torch.bfloat16, device="cuda")
return dict(update_buffer=b, update_buffer_views=[b[i] for i in range(self.world_size)])
param_groups = [
dict(params=[p for p in params if p.numel() == size], **create_update_buffer(size)) for size in sizes]
super().__init__(param_groups, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
lr = group["lr"]
momentum = group["momentum"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
update_buffer = group["update_buffer"]
update_buffer_views: list[Tensor] = group["update_buffer_views"]
# generate weight updates in distributed fashion
params: list[Tensor] = group["params"]
handle = None
params_world = None
def update_prev(): # optimized Muon implementation contributed by @YouJiacheng
if params_world is None:
return
assert handle is not None
handle.wait()
for p_world, g_world in zip(params_world, update_buffer_views):
p_world.add_(
g_world.view_as(p_world),
alpha=-lr * max(1, p_world.size(-2) / p_world.size(-1)) ** 0.5,
)
for base_i in range(len(params))[::self.world_size]:
if base_i + self.rank < len(params):
p = params[base_i + self.rank]
g = p.grad
assert g is not None
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - momentum)
g = g.lerp_(buf, momentum) if nesterov else buf
g = zeropower_via_newtonschulz5(g, steps=ns_steps).flatten()
else:
g = update_buffer_views[self.rank]
update_prev() # async all_gather instead of sync all_reduce by @YouJiacheng
handle = dist.all_gather_into_tensor(update_buffer, g, async_op=True)
params_world = params[base_i : base_i + self.world_size]
update_prev()
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the model
def norm(x: Tensor):
return F.rms_norm(x, (x.size(-1),))
class CastedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, use_fp8: bool = False, x_s: float = 1.0, w_s: float = 1.0, grad_s: float = 1.0):
super().__init__(in_features, out_features, bias=False)
self.use_fp8 = use_fp8
self.x_s = x_s
self.w_s = w_s
self.grad_s = grad_s
def reset_parameters(self) -> None:
std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3)
bound = (3 ** 0.5) * std
with torch.no_grad():
self.weight.uniform_(-bound, bound)
def forward(self, x: Tensor):
if self.use_fp8 and self.training:
_x = x.flatten(0, -2)
out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0]
return out.reshape(*x.shape[:-1], -1)
else:
return F.linear(x, self.weight.type_as(x))
class Rotary(nn.Module):
def __init__(self, dim: int, max_seq_len: int):
super().__init__()
# half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
t = torch.arange(max_seq_len, dtype=torch.float32)
theta = torch.einsum("i,j -> ij", t, angular_freq)
self.cos = nn.Buffer(theta.cos(), persistent=False)
self.sin = nn.Buffer(theta.sin(), persistent=False)
def forward(self, x_BTHD: Tensor):
assert self.cos.size(0) >= x_BTHD.size(-3)
cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3).type_as(x_BTHD)
class CausalSelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
hdim = num_heads * head_dim
std = 0.5 * (dim ** -0.5)
bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng
# merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng
# https://x.com/hi_tysam/status/1879699187107033311
self.qkv_w = nn.Parameter(torch.empty(3, hdim, dim).uniform_(-bound, bound))
self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))
self.rotary = Rotary(head_dim, max_seq_len)
self.c_proj = CastedLinear(hdim, dim)
self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977
# scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun
# inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283
self.attn_scale = 0.12
def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask):
B, T = x.size(0), x.size(1) # batch size, sequence length
assert B == 1, "Must use batch size = 1 for FlexAttention"
q, k, v = F.linear(x, self.qkv_w.flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2)
q, k = norm(q), norm(k) # QK norm @Grad62304977
q, k = self.rotary(q), self.rotary(k)
if ve is not None:
v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977
else: # skip mid-layers token value embeddings by @YouJiacheng
v = self.lambdas[0] * v
y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=self.attn_scale).transpose(1, 2)
y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, dim: int):
super().__init__()
hdim = 4 * dim
self.c_fc = CastedLinear(dim, hdim)
self.c_proj = CastedLinear(hdim, dim)
self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977
def forward(self, x: Tensor):
x = self.c_fc(x)
x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int):
super().__init__()
# skip attention of blocks.7 (the 8th layer) by @YouJiacheng
self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None
self.mlp = MLP(dim)
self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
x = self.lambdas[0] * x + self.lambdas[1] * x0
if self.attn is not None:
x = x + self.attn(norm(x), ve, block_mask)
x = x + self.mlp(norm(x))
return x
# -----------------------------------------------------------------------------
# The main model
def next_multiple_of_n(v: float | int, *, n: int):
return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)
class GPT(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int):
super().__init__()
self.embed = nn.Embedding(vocab_size, model_dim)
# token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897
# value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78
self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)])
self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)])
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency.
# suggested to me by @Grad62304977. this originates from Karpathy's experiments.
self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128), use_fp8=True, x_s=2.0, w_s=2.0**9, grad_s=2.0**19)
self.lm_head.weight.detach().zero_() # @Grad62304977
# Add learnable skip connection weights for decoder layers
assert num_layers % 2 == 0
self.skip_weights = nn.Parameter(torch.ones(num_layers//2))
def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor):
BLOCK_SIZE = 128
docs = (input_seq == 50256).cumsum(0)
def document_causal(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[q_idx] == docs[kv_idx]
return causal_mask & document_mask
def dense_to_ordered(dense_blockmask: Tensor):
num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32)
indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32)
return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
# manual block mask creation by @YouJiacheng
assert len(input_seq) % BLOCK_SIZE == 0
NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
causal_blockmask_any = block_idx[:, None] >= block_idx
causal_blockmask_all = block_idx[:, None] > block_idx
docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low)
document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low)
blockmask_any = causal_blockmask_any & document_blockmask_any
blockmask_all = causal_blockmask_all & document_blockmask_all
partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all)
full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all)
def build_bm(window_size_blocks: Tensor) -> BlockMask:
return BlockMask.from_kv_blocks(
torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)),
partial_kv_indices,
torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1),
full_kv_indices,
BLOCK_SIZE=BLOCK_SIZE,
mask_mod=document_causal,
)
# Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper
return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2)
def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor):
assert input_seq.ndim == 1
ve = [value_embed(input_seq) for value_embed in self.value_embeds]
# 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure
ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]]
assert len(ve) == len(self.blocks)
long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm]
assert len(block_masks) == len(self.blocks)
x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977
# U-net design by @brendanh0gan
skip_connections = []
n = len(self.skip_weights)
for i in range(len(self.blocks)):
if i >= n:
x = x + self.skip_weights[i - n] * skip_connections.pop()
x = self.blocks[i](x, ve[i], x0, block_masks[i])
if i < n:
skip_connections.append(x)
x = norm(x)
logits = self.lm_head(x)
# @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1)
logits = 30 * torch.sigmoid(logits.float() / 7.5)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq)
return loss
# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader
def _load_data_shard(file: Path):
header = torch.from_file(f"{file}", False, 256, dtype=torch.int32) # header is 256 int32
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
assert header[1] == 1, "unsupported version"
num_tokens = int(header[2]) # number of tokens (claimed)
with file.open("rb", buffering=0) as f:
tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng
f.seek(256 * 4)
nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng
assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
return tokens
def distributed_data_generator(filename_pattern: str, batch_size: int, rank : int, world_size : int):
files = sorted(Path.cwd().glob(filename_pattern))
assert batch_size % world_size == 0
local_batch_size = batch_size // world_size
file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training
tokens, pos = _load_data_shard(next(file_iter)), 0
while True:
if pos + batch_size + 1 >= len(tokens):
tokens, pos = _load_data_shard(next(file_iter)), 0
buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1]
inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side;
targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn"t helpful.
pos += batch_size
yield inputs, targets
# -----------------------------------------------------------------------------
# int main
@dataclass
class Hyperparameters:
# data
train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on
val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on
val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
# optimization
num_iterations = 1770 # number of iterations to run
cooldown_frac = 0.4 # fraction of training spent cooling down the learning rate
# evaluation and logging
val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end
# implementation
seq_len = 48*1024 # FlexAttention sequence length
val_seq_len = 4*64*1024 # FlexAttention sequence length for validation
save_checkpoint = False
args = Hyperparameters()
# torchrun sets these env variables
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert world_size == 8 # this code is designed for 8xH100
assert torch.cuda.is_available()
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
master_process = (rank == 0) # this process will do logging, checkpointing etc.
# begin logging
logfile = None
if master_process:
run_id = uuid.uuid4()
os.makedirs("logs", exist_ok=True)
logfile = f"logs/{run_id}.txt"
print(logfile)
def print0(s, console=False):
if master_process:
with open(logfile, "a") as f:
if console:
print(s)
print(s, file=f)
# begin by printing this file (the Python code)
print0(code)
print0("="*100)
# log information about the hardware/software environment this is running on
print0(f"Running Python {sys.version}")
print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}")
def nvidia_smi():
import subprocess # avoid top level import
return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout
print0(nvidia_smi())
print0("="*100)
model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.seq_len, args.val_seq_len)).cuda()
for m in model.modules():
if isinstance(m, nn.Embedding):
m.bfloat16()
for param in model.parameters():
dist.broadcast(param.detach(), 0)
# collect the parameters to optimize
hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [model.lm_head.weight]
# init the optimizer(s)
adam_params = [dict(params=head_params, lr=0.008), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence
# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094
optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), eps=1e-10, fused=True)
optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, rank=rank, world_size=world_size)
optimizers = [optimizer1, optimizer2]
# learning rate schedule: stable then decay
def get_lr(step: int):
x = step / args.num_iterations # progress in training
assert 0 <= x <= 1
w = min((1 - x) / args.cooldown_frac, 1.0) # 1 -> 0
return w * 1.0 + (1 - w) * 0.1
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
@lru_cache(1)
def window_size_blocks(window_size: int):
return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
model: nn.Module = torch.compile(model, dynamic=False)
# Warmup the training kernels, then re-initialize the state so we aren't cheating
warmup_steps = 10
initial_state = dict(model=copy.deepcopy(model.state_dict()),
optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state
train_loader = distributed_data_generator(args.train_files, world_size * args.seq_len, rank, world_size)
for _ in range(warmup_steps):
inputs, targets = next(train_loader)
model(inputs, targets, window_size_blocks(128)).backward()
for param in model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
model.load_state_dict(initial_state['model'])
for opt, opt_state in zip(optimizers, initial_state['optimizers']):
opt.load_state_dict(opt_state)
del train_loader, initial_state
train_loader = distributed_data_generator(args.train_files, world_size * args.seq_len, rank, world_size)
training_time_ms = 0
# start the clock
torch.cuda.synchronize()
t0 = time.perf_counter()
# begin training
train_steps = args.num_iterations
for step in range(train_steps + 1):
last_step = (step == train_steps)
# Linearly increase the block-wise sliding window size over training 128 -> 1792:
# increase by @fernbear.bsky.social; block-wise by @YouJiacheng
window_size = next_multiple_of_n(1728 * step / train_steps, n=128)
# --------------- VALIDATION SECTION -----------------
if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
# stop the clock
torch.cuda.synchronize()
training_time_ms += 1000 * (time.perf_counter() - t0)
model.eval()
val_batch_size = world_size * args.val_seq_len
assert args.val_tokens % val_batch_size == 0
val_steps = args.val_tokens // val_batch_size
val_loader = distributed_data_generator(args.val_files, val_batch_size, rank, world_size)
val_loss = 0
with torch.no_grad():
for _ in range(val_steps):
inputs, targets = next(val_loader)
val_loss += model(inputs, targets, window_size_blocks(window_size))
val_loss /= val_steps
del val_loader
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True)
model.train()
# start the clock again
torch.cuda.synchronize()
t0 = time.perf_counter()
if last_step:
if master_process and args.save_checkpoint:
log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
os.makedirs(f"logs/{run_id}", exist_ok=True)
torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt")
# the last step only has the validation loop, so break to avoid training
break
# --------------- TRAINING SECTION -----------------
inputs, targets = next(train_loader)
model(inputs, targets, window_size_blocks(window_size)).backward()
for param in model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
# momentum warmup for Muon
frac = min(step / 300, 1)
for group in optimizer2.param_groups:
group["momentum"] = (1 - frac) * 0.85 + frac * 0.95
# step the optimizers and schedulers
for opt, sched in zip(optimizers, schedulers):
opt.step()
sched.step()
# null the gradients
model.zero_grad(set_to_none=True)
# logging
approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True)
print0(
f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB",
console=True,
)
dist.destroy_process_group()
====================================================================================================
Running Python 3.12.7 (main, Feb 1 2025, 03:09:49) [GCC 13.2.0]
Running PyTorch 2.7.0.dev20250125+cu126 compiled for CUDA 12.6
Sun Feb 2 03:23:40 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.08 Driver Version: 550.127.08 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 |
| N/A 27C P0 122W / 700W | 7746MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 |
| N/A 30C P0 117W / 700W | 3456MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 |
| N/A 29C P0 117W / 700W | 3456MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 |
| N/A 26C P0 116W / 700W | 3456MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 |
| N/A 26C P0 124W / 700W | 3456MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 |
| N/A 29C P0 118W / 700W | 3456MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 |
| N/A 28C P0 119W / 700W | 3456MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 |
| N/A 26C P0 118W / 700W | 3216MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+
====================================================================================================
step:0/1770 val_loss:10.8258 train_time:0ms step_avg:0.03ms
step:1/1770 train_time:70ms step_avg:70.50ms
step:2/1770 train_time:148ms step_avg:74.01ms
step:3/1770 train_time:237ms step_avg:78.90ms
step:4/1770 train_time:332ms step_avg:82.92ms
step:5/1770 train_time:428ms step_avg:85.56ms
step:6/1770 train_time:523ms step_avg:87.22ms
step:7/1770 train_time:619ms step_avg:88.47ms
step:8/1770 train_time:715ms step_avg:89.42ms
step:9/1770 train_time:811ms step_avg:90.12ms
step:10/1770 train_time:907ms step_avg:90.71ms
step:11/1770 train_time:1003ms step_avg:91.14ms
step:12/1770 train_time:1099ms step_avg:91.55ms
step:13/1770 train_time:1194ms step_avg:91.88ms
step:14/1770 train_time:1291ms step_avg:92.20ms
step:15/1770 train_time:1387ms step_avg:92.50ms
step:16/1770 train_time:1484ms step_avg:92.72ms
step:17/1770 train_time:1579ms step_avg:92.89ms
step:18/1770 train_time:1675ms step_avg:93.05ms
step:19/1770 train_time:1771ms step_avg:93.22ms
step:20/1770 train_time:1868ms step_avg:93.39ms
step:21/1770 train_time:1964ms step_avg:93.55ms
step:22/1770 train_time:2060ms step_avg:93.66ms
step:23/1770 train_time:2157ms step_avg:93.77ms
step:24/1770 train_time:2253ms step_avg:93.86ms
step:25/1770 train_time:2350ms step_avg:93.98ms
step:26/1770 train_time:2445ms step_avg:94.05ms
step:27/1770 train_time:2541ms step_avg:94.11ms
step:28/1770 train_time:2637ms step_avg:94.18ms
step:29/1770 train_time:2733ms step_avg:94.24ms
step:30/1770 train_time:2829ms step_avg:94.31ms
step:31/1770 train_time:2925ms step_avg:94.34ms
step:32/1770 train_time:3020ms step_avg:94.38ms
step:33/1770 train_time:3116ms step_avg:94.44ms
step:34/1770 train_time:3212ms step_avg:94.47ms
step:35/1770 train_time:3308ms step_avg:94.52ms
step:36/1770 train_time:3404ms step_avg:94.57ms
step:37/1770 train_time:3501ms step_avg:94.61ms
step:38/1770 train_time:3596ms step_avg:94.64ms
step:39/1770 train_time:3692ms step_avg:94.67ms
step:40/1770 train_time:3789ms step_avg:94.72ms
step:41/1770 train_time:3885ms step_avg:94.76ms
step:42/1770 train_time:3981ms step_avg:94.79ms
step:43/1770 train_time:4078ms step_avg:94.83ms
step:44/1770 train_time:4174ms step_avg:94.86ms
step:45/1770 train_time:4270ms step_avg:94.89ms
step:46/1770 train_time:4366ms step_avg:94.92ms
step:47/1770 train_time:4462ms step_avg:94.94ms
step:48/1770 train_time:4558ms step_avg:94.95ms
step:49/1770 train_time:4654ms step_avg:94.97ms
step:50/1770 train_time:4749ms step_avg:94.99ms
step:51/1770 train_time:4847ms step_avg:95.03ms
step:52/1770 train_time:4942ms step_avg:95.03ms
step:53/1770 train_time:5038ms step_avg:95.05ms
step:54/1770 train_time:5133ms step_avg:95.06ms
step:55/1770 train_time:5229ms step_avg:95.08ms
step:56/1770 train_time:5326ms step_avg:95.10ms
step:57/1770 train_time:5422ms step_avg:95.13ms
step:58/1770 train_time:5519ms step_avg:95.15ms
step:59/1770 train_time:5615ms step_avg:95.17ms
step:60/1770 train_time:5711ms step_avg:95.18ms
step:61/1770 train_time:5808ms step_avg:95.21ms
step:62/1770 train_time:5904ms step_avg:95.22ms
step:63/1770 train_time:6000ms step_avg:95.24ms
step:64/1770 train_time:6096ms step_avg:95.25ms
step:65/1770 train_time:6192ms step_avg:95.26ms
step:66/1770 train_time:6288ms step_avg:95.27ms
step:67/1770 train_time:6384ms step_avg:95.29ms
step:68/1770 train_time:6481ms step_avg:95.31ms
step:69/1770 train_time:6577ms step_avg:95.32ms
step:70/1770 train_time:6674ms step_avg:95.34ms
step:71/1770 train_time:6769ms step_avg:95.34ms
step:72/1770 train_time:6865ms step_avg:95.34ms
step:73/1770 train_time:6960ms step_avg:95.35ms
step:74/1770 train_time:7056ms step_avg:95.36ms
step:75/1770 train_time:7152ms step_avg:95.36ms
step:76/1770 train_time:7248ms step_avg:95.37ms
step:77/1770 train_time:7343ms step_avg:95.37ms
step:78/1770 train_time:7439ms step_avg:95.38ms
step:79/1770 train_time:7535ms step_avg:95.38ms
step:80/1770 train_time:7631ms step_avg:95.39ms
step:81/1770 train_time:7728ms step_avg:95.40ms
step:82/1770 train_time:7824ms step_avg:95.42ms
step:83/1770 train_time:7921ms step_avg:95.43ms
step:84/1770 train_time:8017ms step_avg:95.44ms
step:85/1770 train_time:8113ms step_avg:95.45ms
step:86/1770 train_time:8209ms step_avg:95.46ms
step:87/1770 train_time:8306ms step_avg:95.47ms
step:88/1770 train_time:8402ms step_avg:95.47ms
step:89/1770 train_time:8497ms step_avg:95.47ms
step:90/1770 train_time:8593ms step_avg:95.48ms
step:91/1770 train_time:8689ms step_avg:95.49ms
step:92/1770 train_time:8786ms step_avg:95.50ms
step:93/1770 train_time:8881ms step_avg:95.50ms
step:94/1770 train_time:8977ms step_avg:95.50ms
step:95/1770 train_time:9072ms step_avg:95.50ms
step:96/1770 train_time:9168ms step_avg:95.50ms
step:97/1770 train_time:9264ms step_avg:95.51ms
step:98/1770 train_time:9360ms step_avg:95.51ms
step:99/1770 train_time:9456ms step_avg:95.51ms
step:100/1770 train_time:9551ms step_avg:95.51ms
step:101/1770 train_time:9648ms step_avg:95.52ms
step:102/1770 train_time:9744ms step_avg:95.53ms
step:103/1770 train_time:9840ms step_avg:95.53ms
step:104/1770 train_time:9936ms step_avg:95.53ms
step:105/1770 train_time:10031ms step_avg:95.54ms
step:106/1770 train_time:10127ms step_avg:95.54ms
step:107/1770 train_time:10224ms step_avg:95.55ms
step:108/1770 train_time:10320ms step_avg:95.56ms
step:109/1770 train_time:10418ms step_avg:95.58ms
step:110/1770 train_time:10514ms step_avg:95.58ms
step:111/1770 train_time:10610ms step_avg:95.58ms
step:112/1770 train_time:10707ms step_avg:95.59ms
step:113/1770 train_time:10802ms step_avg:95.59ms
step:114/1770 train_time:10898ms step_avg:95.59ms
step:115/1770 train_time:10993ms step_avg:95.59ms
step:116/1770 train_time:11089ms step_avg:95.60ms
step:117/1770 train_time:11185ms step_avg:95.60ms
step:118/1770 train_time:11282ms step_avg:95.61ms
step:119/1770 train_time:11378ms step_avg:95.61ms
step:120/1770 train_time:11474ms step_avg:95.62ms
step:121/1770 train_time:11570ms step_avg:95.62ms
step:122/1770 train_time:11666ms step_avg:95.62ms
step:123/1770 train_time:11762ms step_avg:95.63ms
step:124/1770 train_time:11858ms step_avg:95.63ms
step:125/1770 train_time:11954ms step_avg:95.63ms
step:125/1770 val_loss:4.6539 train_time:12048ms step_avg:96.38ms
step:126/1770 train_time:12070ms step_avg:95.79ms
step:127/1770 train_time:12152ms step_avg:95.69ms
step:128/1770 train_time:12254ms step_avg:95.73ms
step:129/1770 train_time:12350ms step_avg:95.74ms
step:130/1770 train_time:12446ms step_avg:95.74ms
step:131/1770 train_time:12542ms step_avg:95.74ms
step:132/1770 train_time:12638ms step_avg:95.74ms
step:133/1770 train_time:12733ms step_avg:95.74ms
step:134/1770 train_time:12830ms step_avg:95.75ms
step:135/1770 train_time:12926ms step_avg:95.75ms
step:136/1770 train_time:13022ms step_avg:95.75ms
step:137/1770 train_time:13119ms step_avg:95.76ms
step:138/1770 train_time:13216ms step_avg:95.76ms
step:139/1770 train_time:13312ms step_avg:95.77ms
step:140/1770 train_time:13409ms step_avg:95.78ms
step:141/1770 train_time:13506ms step_avg:95.79ms
step:142/1770 train_time:13603ms step_avg:95.80ms
step:143/1770 train_time:13700ms step_avg:95.80ms
step:144/1770 train_time:13796ms step_avg:95.81ms
step:145/1770 train_time:13893ms step_avg:95.81ms
step:146/1770 train_time:13989ms step_avg:95.82ms
step:147/1770 train_time:14086ms step_avg:95.83ms
step:148/1770 train_time:14184ms step_avg:95.84ms
step:149/1770 train_time:14280ms step_avg:95.84ms
step:150/1770 train_time:14376ms step_avg:95.84ms
step:151/1770 train_time:14473ms step_avg:95.85ms
step:152/1770 train_time:14570ms step_avg:95.86ms
step:153/1770 train_time:14667ms step_avg:95.87ms
step:154/1770 train_time:14765ms step_avg:95.88ms
step:155/1770 train_time:14861ms step_avg:95.88ms
step:156/1770 train_time:14957ms step_avg:95.88ms
step:157/1770 train_time:15054ms step_avg:95.88ms
step:158/1770 train_time:15151ms step_avg:95.89ms
step:159/1770 train_time:15248ms step_avg:95.90ms
step:160/1770 train_time:15343ms step_avg:95.89ms
step:161/1770 train_time:15439ms step_avg:95.90ms
step:162/1770 train_time:15535ms step_avg:95.90ms
step:163/1770 train_time:15631ms step_avg:95.90ms
step:164/1770 train_time:15729ms step_avg:95.91ms
step:165/1770 train_time:15826ms step_avg:95.91ms
step:166/1770 train_time:15922ms step_avg:95.92ms
step:167/1770 train_time:16019ms step_avg:95.92ms
step:168/1770 train_time:16116ms step_avg:95.93ms
step:169/1770 train_time:16212ms step_avg:95.93ms
step:170/1770 train_time:16309ms step_avg:95.94ms
step:171/1770 train_time:16407ms step_avg:95.95ms
step:172/1770 train_time:16503ms step_avg:95.95ms
step:173/1770 train_time:16599ms step_avg:95.95ms
step:174/1770 train_time:16695ms step_avg:95.95ms
step:175/1770 train_time:16791ms step_avg:95.95ms
step:176/1770 train_time:16888ms step_avg:95.96ms
step:177/1770 train_time:16985ms step_avg:95.96ms
step:178/1770 train_time:17082ms step_avg:95.97ms
step:179/1770 train_time:17178ms step_avg:95.97ms
step:180/1770 train_time:17274ms step_avg:95.97ms
step:181/1770 train_time:17371ms step_avg:95.97ms
step:182/1770 train_time:17467ms step_avg:95.97ms
step:183/1770 train_time:17564ms step_avg:95.98ms
step:184/1770 train_time:17661ms step_avg:95.98ms
step:185/1770 train_time:17757ms step_avg:95.98ms
step:186/1770 train_time:17853ms step_avg:95.98ms
step:187/1770 train_time:17950ms step_avg:95.99ms
step:188/1770 train_time:18047ms step_avg:96.00ms
step:189/1770 train_time:18145ms step_avg:96.00ms
step:190/1770 train_time:18241ms step_avg:96.01ms
step:191/1770 train_time:18338ms step_avg:96.01ms
step:192/1770 train_time:18435ms step_avg:96.01ms
step:193/1770 train_time:18531ms step_avg:96.02ms
step:194/1770 train_time:18628ms step_avg:96.02ms
step:195/1770 train_time:18725ms step_avg:96.02ms
step:196/1770 train_time:18821ms step_avg:96.03ms
step:197/1770 train_time:18918ms step_avg:96.03ms
step:198/1770 train_time:19016ms step_avg:96.04ms
step:199/1770 train_time:19112ms step_avg:96.04ms
step:200/1770 train_time:19209ms step_avg:96.05ms
step:201/1770 train_time:19306ms step_avg:96.05ms
step:202/1770 train_time:19402ms step_avg:96.05ms
step:203/1770 train_time:19498ms step_avg:96.05ms
step:204/1770 train_time:19594ms step_avg:96.05ms
step:205/1770 train_time:19691ms step_avg:96.05ms
step:206/1770 train_time:19788ms step_avg:96.06ms
step:207/1770 train_time:19885ms step_avg:96.06ms
step:208/1770 train_time:19981ms step_avg:96.06ms
step:209/1770 train_time:20078ms step_avg:96.07ms
step:210/1770 train_time:20175ms step_avg:96.07ms
step:211/1770 train_time:20271ms step_avg:96.07ms
step:212/1770 train_time:20367ms step_avg:96.07ms
step:213/1770 train_time:20464ms step_avg:96.07ms
step:214/1770 train_time:20560ms step_avg:96.08ms
step:215/1770 train_time:20657ms step_avg:96.08ms
step:216/1770 train_time:20753ms step_avg:96.08ms
step:217/1770 train_time:20850ms step_avg:96.08ms
step:218/1770 train_time:20946ms step_avg:96.08ms
step:219/1770 train_time:21042ms step_avg:96.08ms
step:220/1770 train_time:21138ms step_avg:96.08ms
step:221/1770 train_time:21235ms step_avg:96.08ms
step:222/1770 train_time:21331ms step_avg:96.09ms
step:223/1770 train_time:21428ms step_avg:96.09ms
step:224/1770 train_time:21525ms step_avg:96.09ms
step:225/1770 train_time:21622ms step_avg:96.10ms
step:226/1770 train_time:21718ms step_avg:96.10ms
step:227/1770 train_time:21815ms step_avg:96.10ms
step:228/1770 train_time:21912ms step_avg:96.11ms
step:229/1770 train_time:22009ms step_avg:96.11ms
step:230/1770 train_time:22106ms step_avg:96.11ms
step:231/1770 train_time:22202ms step_avg:96.11ms
step:232/1770 train_time:22299ms step_avg:96.12ms
step:233/1770 train_time:22395ms step_avg:96.11ms
step:234/1770 train_time:22491ms step_avg:96.12ms
step:235/1770 train_time:22588ms step_avg:96.12ms
step:236/1770 train_time:22685ms step_avg:96.12ms
step:237/1770 train_time:22781ms step_avg:96.12ms
step:238/1770 train_time:22878ms step_avg:96.12ms
step:239/1770 train_time:22974ms step_avg:96.13ms
step:240/1770 train_time:23071ms step_avg:96.13ms
step:241/1770 train_time:23168ms step_avg:96.13ms
step:242/1770 train_time:23265ms step_avg:96.13ms
step:243/1770 train_time:23361ms step_avg:96.14ms
step:244/1770 train_time:23458ms step_avg:96.14ms
step:245/1770 train_time:23555ms step_avg:96.14ms
step:246/1770 train_time:23651ms step_avg:96.14ms
step:247/1770 train_time:23749ms step_avg:96.15ms
step:248/1770 train_time:23844ms step_avg:96.15ms
step:249/1770 train_time:23941ms step_avg:96.15ms
step:250/1770 train_time:24037ms step_avg:96.15ms
step:250/1770 val_loss:4.1116 train_time:24131ms step_avg:96.53ms
step:251/1770 train_time:24153ms step_avg:96.23ms
step:252/1770 train_time:24237ms step_avg:96.18ms
step:253/1770 train_time:24333ms step_avg:96.18ms
step:254/1770 train_time:24430ms step_avg:96.18ms
step:255/1770 train_time:24527ms step_avg:96.18ms
step:256/1770 train_time:24623ms step_avg:96.18ms
step:257/1770 train_time:24719ms step_avg:96.18ms
step:258/1770 train_time:24815ms step_avg:96.18ms
step:259/1770 train_time:24912ms step_avg:96.18ms
step:260/1770 train_time:25008ms step_avg:96.19ms
step:261/1770 train_time:25104ms step_avg:96.18ms
step:262/1770 train_time:25201ms step_avg:96.19ms
step:263/1770 train_time:25297ms step_avg:96.19ms
step:264/1770 train_time:25395ms step_avg:96.19ms
step:265/1770 train_time:25492ms step_avg:96.19ms
step:266/1770 train_time:25589ms step_avg:96.20ms
step:267/1770 train_time:25687ms step_avg:96.21ms
step:268/1770 train_time:25785ms step_avg:96.21ms
step:269/1770 train_time:25882ms step_avg:96.21ms
step:270/1770 train_time:25979ms step_avg:96.22ms
step:271/1770 train_time:26077ms step_avg:96.22ms
step:272/1770 train_time:26173ms step_avg:96.22ms
step:273/1770 train_time:26270ms step_avg:96.23ms
step:274/1770 train_time:26369ms step_avg:96.24ms
step:275/1770 train_time:26466ms step_avg:96.24ms
step:276/1770 train_time:26564ms step_avg:96.24ms
step:277/1770 train_time:26661ms step_avg:96.25ms
step:278/1770 train_time:26758ms step_avg:96.25ms
step:279/1770 train_time:26855ms step_avg:96.25ms
step:280/1770 train_time:26951ms step_avg:96.26ms
step:281/1770 train_time:27049ms step_avg:96.26ms
step:282/1770 train_time:27147ms step_avg:96.27ms
step:283/1770 train_time:27244ms step_avg:96.27ms
step:284/1770 train_time:27340ms step_avg:96.27ms
step:285/1770 train_time:27436ms step_avg:96.27ms
step:286/1770 train_time:27533ms step_avg:96.27ms
step:287/1770 train_time:27630ms step_avg:96.27ms
step:288/1770 train_time:27728ms step_avg:96.28ms
step:289/1770 train_time:27826ms step_avg:96.28ms
step:290/1770 train_time:27923ms step_avg:96.29ms
step:291/1770 train_time:28021ms step_avg:96.29ms
step:292/1770 train_time:28118ms step_avg:96.29ms
step:293/1770 train_time:28215ms step_avg:96.30ms
step:294/1770 train_time:28313ms step_avg:96.30ms
step:295/1770 train_time:28410ms step_avg:96.30ms
step:296/1770 train_time:28507ms step_avg:96.31ms
step:297/1770 train_time:28605ms step_avg:96.31ms
step:298/1770 train_time:28701ms step_avg:96.31ms
step:299/1770 train_time:28799ms step_avg:96.32ms
step:300/1770 train_time:28895ms step_avg:96.32ms
step:301/1770 train_time:28993ms step_avg:96.32ms
step:302/1770 train_time:29091ms step_avg:96.33ms
step:303/1770 train_time:29189ms step_avg:96.33ms
step:304/1770 train_time:29285ms step_avg:96.33ms
step:305/1770 train_time:29383ms step_avg:96.34ms
step:306/1770 train_time:29480ms step_avg:96.34ms
step:307/1770 train_time:29577ms step_avg:96.34ms
step:308/1770 train_time:29674ms step_avg:96.34ms
step:309/1770 train_time:29771ms step_avg:96.35ms
step:310/1770 train_time:29869ms step_avg:96.35ms
step:311/1770 train_time:29967ms step_avg:96.36ms
step:312/1770 train_time:30064ms step_avg:96.36ms
step:313/1770 train_time:30162ms step_avg:96.37ms
step:314/1770 train_time:30259ms step_avg:96.37ms
step:315/1770 train_time:30356ms step_avg:96.37ms
step:316/1770 train_time:30453ms step_avg:96.37ms
step:317/1770 train_time:30550ms step_avg:96.37ms