Skip to content

Commit

Permalink
Merge pull request #2304 from huggingface/intern300m
Browse files Browse the repository at this point in the history
Add intern300m vit w/ converted timm weights. Fix #2300
  • Loading branch information
rwightman authored Oct 16, 2024
2 parents 60f517c + 89dffc5 commit 65e8e9c
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 13 deletions.
4 changes: 3 additions & 1 deletion timm/models/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,10 +710,12 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
strict = True

strict = kwargs.pop('pretrained_strict', True)
if variant.endswith('_fl'):
# FIXME cleaner approach to missing head norm?
strict = False

model = build_model_with_cfg(
DaVit,
variant,
Expand Down
2 changes: 1 addition & 1 deletion timm/models/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._factory import *

import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
2 changes: 1 addition & 1 deletion timm/models/features.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._features import *

import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
2 changes: 1 addition & 1 deletion timm/models/fx_features.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._features_fx import *

import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
2 changes: 1 addition & 1 deletion timm/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from ._prune import *

import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
2 changes: 1 addition & 1 deletion timm/models/hub.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._hub import *

import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
2 changes: 1 addition & 1 deletion timm/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@
from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_

import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", DeprecationWarning)
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
2 changes: 1 addition & 1 deletion timm/models/mambaout.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(
self.num_features = in_features
self.pre_logits = nn.Identity()

self.fc = nn.Linear(hidden_size, num_classes, bias=bias)
self.fc = nn.Linear(hidden_size, num_classes, bias=bias) if num_classes > 0 else nn.Identity()
self.head_dropout = nn.Dropout(drop_rate)

def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
Expand Down
2 changes: 1 addition & 1 deletion timm/models/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._registry import *

import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning)
28 changes: 24 additions & 4 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def __init__(
no_embed_class: bool = False,
reg_tokens: int = 0,
pre_norm: bool = False,
final_norm: bool = True,
fc_norm: Optional[bool] = None,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
Expand Down Expand Up @@ -471,7 +472,9 @@ def __init__(
class_token: Use class token.
no_embed_class: Don't include position embeddings for class (or reg) tokens.
reg_tokens: Number of register tokens.
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
drop_rate: Head dropout rate.
pos_drop_rate: Position embedding dropout rate.
attn_drop_rate: Attention dropout rate.
Expand Down Expand Up @@ -554,7 +557,7 @@ def __init__(
for i in range(depth)])
self.feature_info = [
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity()

# Classifier Head
if global_pool == 'map':
Expand All @@ -566,7 +569,7 @@ def __init__(
)
else:
self.attn_pool = None
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

Expand Down Expand Up @@ -2051,6 +2054,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
'vit_so150m_patch16_reg4_map_256.untrained': _cfg(
input_size=(3, 256, 256)),

'vit_intern300m_patch14_448.ogvl_dist': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
),

'test_vit.r160_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 160, 160), crop_pct=0.95),
Expand Down Expand Up @@ -2091,7 +2100,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs)
_filter_fn = checkpoint_filter_fn

# FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln?
strict = True
strict = kwargs.pop('pretrained_strict', True)
if 'siglip' in variant and kwargs.get('global_pool', None) != 'map':
strict = False

Expand Down Expand Up @@ -3298,6 +3307,17 @@ def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
return model


@register_model
def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16,
init_values=0.1, final_norm=False, dynamic_img_size=True,
)
model = _create_vision_transformer(
'vit_intern300m_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT Test
Expand Down

0 comments on commit 65e8e9c

Please sign in to comment.