From a1f379e712032f39cebba6f48ea61245a2e3fa73 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 16 Oct 2024 10:26:04 -0700 Subject: [PATCH 1/3] Add intern300m vit w/ converted timm weights. Fix #2300 --- timm/models/davit.py | 4 +++- timm/models/vision_transformer.py | 28 ++++++++++++++++++++++++---- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 1dc74d2346..5bda661061 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -710,10 +710,12 @@ def checkpoint_filter_fn(state_dict, model): def _create_davit(variant, pretrained=False, **kwargs): default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) out_indices = kwargs.pop('out_indices', default_out_indices) - strict = True + + strict = kwargs.pop('pretrained_strict', True) if variant.endswith('_fl'): # FIXME cleaner approach to missing head norm? strict = False + model = build_model_with_cfg( DaVit, variant, diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index a5fad6ef7d..ae2ae3b8bc 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -438,6 +438,7 @@ def __init__( no_embed_class: bool = False, reg_tokens: int = 0, pre_norm: bool = False, + final_norm: bool = True, fc_norm: Optional[bool] = None, dynamic_img_size: bool = False, dynamic_img_pad: bool = False, @@ -471,7 +472,9 @@ def __init__( class_token: Use class token. no_embed_class: Don't include position embeddings for class (or reg) tokens. reg_tokens: Number of register tokens. - fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT). + final_norm: Enable norm after transformer blocks, before head (standard in most ViT). + fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'. drop_rate: Head dropout rate. pos_drop_rate: Position embedding dropout rate. attn_drop_rate: Attention dropout rate. @@ -554,7 +557,7 @@ def __init__( for i in range(depth)]) self.feature_info = [ dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)] - self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity() # Classifier Head if global_pool == 'map': @@ -566,7 +569,7 @@ def __init__( ) else: self.attn_pool = None - self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -2051,6 +2054,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'vit_so150m_patch16_reg4_map_256.untrained': _cfg( input_size=(3, 256, 256)), + 'vit_intern300m_patch14_448.ogvl_dist': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0, + ), + 'test_vit.r160_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 160, 160), crop_pct=0.95), @@ -2091,7 +2100,7 @@ def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) _filter_fn = checkpoint_filter_fn # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln? - strict = True + strict = kwargs.pop('pretrained_strict', True) if 'siglip' in variant and kwargs.get('global_pool', None) != 'map': strict = False @@ -3298,6 +3307,17 @@ def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio return model +@register_model +def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, + init_values=0.1, final_norm=False, dynamic_img_size=True, + ) + model = _create_vision_transformer( + 'vit_intern300m_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT Test From fad45388012a441bb3a084a20220b6d6f2bf47bc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 16 Oct 2024 11:30:01 -0700 Subject: [PATCH 2/3] Elevate import deprecation warnings from DeprecationWarning to FutureWarning so messages are now seen --- timm/models/factory.py | 2 +- timm/models/features.py | 2 +- timm/models/fx_features.py | 2 +- timm/models/helpers.py | 2 +- timm/models/hub.py | 2 +- timm/models/layers/__init__.py | 2 +- timm/models/registry.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/timm/models/factory.py b/timm/models/factory.py index 0ae83dc08e..da94e1ac19 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -1,4 +1,4 @@ from ._factory import * import warnings -warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) diff --git a/timm/models/features.py b/timm/models/features.py index 25605d99da..f937e1626e 100644 --- a/timm/models/features.py +++ b/timm/models/features.py @@ -1,4 +1,4 @@ from ._features import * import warnings -warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index 0ff3a18b05..ae6848f7d5 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -1,4 +1,4 @@ from ._features_fx import * import warnings -warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 6bc82eb81e..5bf6d19e7c 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -4,4 +4,4 @@ from ._prune import * import warnings -warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) diff --git a/timm/models/hub.py b/timm/models/hub.py index fdc3a921c5..85abb26491 100644 --- a/timm/models/hub.py +++ b/timm/models/hub.py @@ -1,4 +1,4 @@ from ._hub import * import warnings -warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 705ebd253e..956f39aa59 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -45,4 +45,4 @@ from timm.layers.weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ import warnings -warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", DeprecationWarning) +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning) diff --git a/timm/models/registry.py b/timm/models/registry.py index 58e2e1f41a..5b68a91e0c 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -1,4 +1,4 @@ from ._registry import * import warnings -warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning) +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", FutureWarning) From 89dffc5ff0fe40343daadbadb10cb21df8a3ade1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 16 Oct 2024 12:36:36 -0700 Subject: [PATCH 3/3] Another small fix for original mambaout models, no classifier nn.Linear when num_classe=0 on init --- timm/models/mambaout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index c077b01ff1..3cc6e08253 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -151,7 +151,7 @@ def __init__( self.num_features = in_features self.pre_logits = nn.Identity() - self.fc = nn.Linear(hidden_size, num_classes, bias=bias) + self.fc = nn.Linear(hidden_size, num_classes, bias=bias) if num_classes > 0 else nn.Identity() self.head_dropout = nn.Dropout(drop_rate) def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):