Skip to content

Commit

Permalink
LeViT safetensors load is broken by conversion code that wasn't deact…
Browse files Browse the repository at this point in the history
…ivated
  • Loading branch information
rwightman committed Jan 16, 2025
1 parent 21e75a9 commit 9265d54
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions timm/models/levit.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,17 +763,18 @@ def checkpoint_filter_fn(state_dict, model):
# filter out attn biases, should not have been persistent
state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}

D = model.state_dict()
out_dict = {}
for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
if va.ndim == 4 and vb.ndim == 2:
vb = vb[:, :, None, None]
if va.shape != vb.shape:
# head or first-conv shapes may change for fine-tune
assert 'head' in ka or 'stem.conv1.linear' in ka
out_dict[ka] = vb

return out_dict
# NOTE: old weight conversion code, disabled
# D = model.state_dict()
# out_dict = {}
# for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
# if va.ndim == 4 and vb.ndim == 2:
# vb = vb[:, :, None, None]
# if va.shape != vb.shape:
# # head or first-conv shapes may change for fine-tune
# assert 'head' in ka or 'stem.conv1.linear' in ka
# out_dict[ka] = vb

return state_dict


model_cfgs = dict(
Expand Down

0 comments on commit 9265d54

Please sign in to comment.