diff --git a/src/pl_bolts/models/self_supervised/swav/__init__.py b/src/pl_bolts/models/self_supervised/swav/__init__.py index 34eed30e5..95a56ecb4 100644 --- a/src/pl_bolts/models/self_supervised/swav/__init__.py +++ b/src/pl_bolts/models/self_supervised/swav/__init__.py @@ -1,6 +1,7 @@ from pl_bolts.models.self_supervised.swav.loss import SWAVLoss -from pl_bolts.models.self_supervised.swav.swav_module import SwAV +from pl_bolts.models.self_supervised.swav.swav_module import SwAV, swav_backbones from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 +from pl_bolts.models.self_supervised.swav.swav_swin import swin_b, swin_s, swin_v2_b, swin_v2_s, swin_v2_t from pl_bolts.transforms.self_supervised.swav_transforms import ( SwAVEvalDataTransform, SwAVFinetuneTransform, @@ -9,8 +10,14 @@ __all__ = [ "SwAV", + "swav_backbones", "resnet18", "resnet50", + "swin_s", + "swin_b", + "swin_v2_t", + "swin_v2_s", + "swin_v2_b", "SwAVEvalDataTransform", "SwAVFinetuneTransform", "SwAVTrainDataTransform", diff --git a/src/pl_bolts/models/self_supervised/swav/swav_module.py b/src/pl_bolts/models/self_supervised/swav/swav_module.py index e212358c4..4c4399775 100644 --- a/src/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/src/pl_bolts/models/self_supervised/swav/swav_module.py @@ -9,6 +9,7 @@ from pl_bolts.models.self_supervised.swav.loss import SWAVLoss from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 +from pl_bolts.models.self_supervised.swav.swav_swin import swin_b, swin_s, swin_v2_b, swin_v2_s, swin_v2_t from pl_bolts.optimizers.lars import LARS from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay from pl_bolts.transforms.dataset_normalizations import ( @@ -17,6 +18,16 @@ stl10_normalization, ) +swav_backbones = { + "resnet18": resnet18, + "resnet50": resnet50, + "swin_s": swin_s, + "swin_b": swin_b, + "swin_v2_t": swin_v2_t, + "swin_v2_s": swin_v2_s, + "swin_v2_b": swin_v2_b, +} + class SwAV(LightningModule): def __init__( @@ -154,11 +165,8 @@ def setup(self, stage): self.queue = torch.load(self.queue_path)["queue"] def init_model(self): - if self.arch == "resnet18": - backbone = resnet18 - elif self.arch == "resnet50": - backbone = resnet50 - + backbone = swav_backbones.get(self.arch, None) + assert backbone is not None, "backbone is not implemented!" return backbone( normalize=True, hidden_mlp=self.hidden_mlp, @@ -490,7 +498,7 @@ def cli_main(): trainer = Trainer( max_epochs=args.max_epochs, - max_steps=None if args.max_steps == -1 else args.max_steps, + max_steps=args.max_steps, gpus=args.gpus, num_nodes=args.num_nodes, accelerator="ddp" if args.gpus > 1 else None, diff --git a/src/pl_bolts/models/self_supervised/swav/swav_swin.py b/src/pl_bolts/models/self_supervised/swav/swav_swin.py new file mode 100644 index 000000000..191ef7d04 --- /dev/null +++ b/src/pl_bolts/models/self_supervised/swav/swav_swin.py @@ -0,0 +1,765 @@ +import functools +import math +import os +from typing import Any, Callable, List, Optional, Tuple, Union + +import packaging.version as pv +import torch +import torch.fx +import torch.nn.functional as F # noqa: N812 +import torchvision.ops as ops +import torchvision.utils +from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache +from torch import Tensor, nn + +from pl_bolts.utils._dependency import requires + + +# Support meshgrid indexing for older versions of torch +def meshgrid(*tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None) -> Tuple: + if pv.parse(torch.__version__) >= pv.parse("1.10.0"): + return torch.meshgrid(*tensors, indexing=indexing) + return torch.meshgrid(*tensors) + + +def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor: + h, w, _ = x.shape[-3:] + x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2)) + x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C + x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C + x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C + x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C + return torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C + + +torch.fx.wrap("_patch_merging_pad") + + +def _get_relative_position_bias( + relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int] +) -> torch.Tensor: + n = window_size[0] * window_size[1] + relative_position_bias = relative_position_bias_table[relative_position_index] + relative_position_bias = relative_position_bias.view(n, n, -1) + return relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + + +torch.fx.wrap("_get_relative_position_bias") + + +class PatchMerging(nn.Module): + """Patch Merging Layer. + + Args: + dim: Number of input channels. + norm_layer: Normalization layer. Default: nn.LayerNorm. + + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + torchvision.utils._log_api_usage_once(self) + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + x = _patch_merging_pad(x) + x = self.norm(x) + return self.reduction(x) # ... H/2 W/2 2*C + + +class PatchMergingV2(nn.Module): + """Patch Merging Layer for Swin Transformer V2. + + Args: + dim: Number of input channels. + norm_layer: Normalization layer. Default: nn.LayerNorm. + + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm) -> None: + super().__init__() + torchvision.utils._log_api_usage_once(self) + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) # difference + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + x = _patch_merging_pad(x) + x = self.reduction(x) # ... H/2 W/2 2*C + return self.norm(x) + + +def shifted_window_attention( + input: Tensor, + qkv_weight: Tensor, + proj_weight: Tensor, + relative_position_bias: Tensor, + window_size: List[int], + num_heads: int, + shift_size: List[int], + attention_dropout: float = 0.0, + dropout: float = 0.0, + qkv_bias: Optional[Tensor] = None, + proj_bias: Optional[Tensor] = None, + logit_scale: Optional[torch.Tensor] = None, +) -> Tensor: + """Window based multi-head self attention (W-MSA) module with relative position bias. + + It supports both of shifted and non-shifted window. + Args: + input: The input tensor or 4-dimensions [N, H, W, C]. + qkv_weight: The weight tensor of query, key, value. + proj_weight: The weight tensor of projection. + relative_position_bias: The learned relative position bias added to attention. + window_size: Window size. + num_head: Number of attention heads. + shift_size: Shift size for shifted window attention. + attention_dropout: Dropout ratio of attention weight. Default: 0.0. + dropout: Dropout ratio of output. Default: 0.0. + qkv_bias: The bias tensor of query, key, value. Default: None. + proj_bias: The bias tensor of projection. Default: None. + logit_scale: Logit scale of cosine attention for Swin Transformer V2. Default: None. + + Returns: + Tensor[N, H, W, C]: The output tensor after shifted window attention. + + """ + b, h, w, c = input.shape + # pad feature maps to multiples of window size + pad_r = (window_size[1] - w % window_size[1]) % window_size[1] + pad_b = (window_size[0] - h % window_size[0]) % window_size[0] + x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) + _, pad_h, pad_w, _ = x.shape + + shift_size = shift_size.copy() + # If window size is larger than feature size, there is no need to shift window + if window_size[0] >= pad_h: + shift_size[0] = 0 + if window_size[1] >= pad_w: + shift_size[1] = 0 + + # cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + + # partition windows + num_windows = (pad_h // window_size[0]) * (pad_w // window_size[1]) + x = x.view(b, pad_h // window_size[0], window_size[0], pad_w // window_size[1], window_size[1], c) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(b * num_windows, window_size[0] * window_size[1], c) # B*nW, Ws*Ws, C + + # multi-head attention + if logit_scale is not None and qkv_bias is not None: + qkv_bias = qkv_bias.clone() + length = qkv_bias.numel() // 3 + qkv_bias[length : 2 * length].zero_() + qkv = F.linear(x, qkv_weight, qkv_bias) + qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, c // num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + if logit_scale is not None: + # cosine attention + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp() + attn = attn * logit_scale + else: + q = q * (c // num_heads) ** -0.5 + attn = q.matmul(k.transpose(-2, -1)) + # add relative position bias + attn = attn + relative_position_bias + + if sum(shift_size) > 0: + # generate attention mask + attn_mask = x.new_zeros((pad_h, pad_w)) + h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None)) + w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None)) + count = 0 + for h_sli in h_slices: + for w_sli in w_slices: + attn_mask[h_sli[0] : h_sli[1], w_sli[0] : w_sli[1]] = count + count += 1 + attn_mask = attn_mask.view(pad_h // window_size[0], window_size[0], pad_w // window_size[1], window_size[1]) + attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1]) + attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, num_heads, x.size(1), x.size(1)) + + attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, p=attention_dropout) + + x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), c) + x = F.linear(x, proj_weight, proj_bias) + x = F.dropout(x, p=dropout) + + # reverse windows + x = x.view(b, pad_h // window_size[0], pad_w // window_size[1], window_size[0], window_size[1], c) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(b, pad_h, pad_w, c) + + # reverse cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) + # unpad features + return x[:, :h, :w, :].contiguous() + + +torch.fx.wrap("shifted_window_attention") + + +class ShiftedWindowAttention(nn.Module): + """See :func:`shifted_window_attention`.""" + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ) -> None: + super().__init__() + if len(window_size) != 2 or len(shift_size) != 2: + raise ValueError("window_size and shift_size must be of length 2") + self.window_size = window_size + self.shift_size = shift_size + self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.dropout = dropout + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + self.define_relative_position_bias_table() + self.define_relative_position_index() + + def define_relative_position_bias_table(self) -> None: + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + def define_relative_position_index(self) -> None: + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + def get_relative_position_bias(self) -> torch.Tensor: + return _get_relative_position_bias( + self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type] + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + relative_position_bias = self.get_relative_position_bias() + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + ) + + +class ShiftedWindowAttentionV2(ShiftedWindowAttention): + """See :func:`shifted_window_attention_v2`.""" + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ) -> None: + super().__init__( + dim, + window_size, + shift_size, + num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attention_dropout=attention_dropout, + dropout=dropout, + ) + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) + ) + if qkv_bias: + length = self.qkv.bias.numel() // 3 + self.qkv.bias[length : 2 * length].data.zero_() + + def define_relative_position_bias_table(self) -> None: + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij")) + relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1,2*Wh-1,2*Ww-1,2 + + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 + + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = ( + torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0 + ) + self.register_buffer("relative_coords_table", relative_coords_table) + + def get_relative_position_bias(self) -> torch.Tensor: + relative_position_bias = _get_relative_position_bias( + self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads), + self.relative_position_index, # type: ignore[arg-type] + self.window_size, + ) + return 16 * torch.sigmoid(relative_position_bias) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + relative_position_bias = self.get_relative_position_bias() + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + logit_scale=self.logit_scale, + ) + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + + Args: + dim: Number of input channels. + num_heads: Number of attention heads. + window_size: Window size. + shift_size: Shift size for shifted window attention. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout: Dropout rate. Default: 0.0. + attention_dropout: Attention dropout rate. Default: 0.0. + stochastic_depth_prob: Stochastic depth rate. Default: 0.0. + norm_layer: Normalization layer. Default: nn.LayerNorm. + attn_layer: Attention layer. Default: ShiftedWindowAttention + + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: List[int], + shift_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, + ) -> None: + super().__init__() + torchvision.utils._log_api_usage_once(self) + + self.norm1 = norm_layer(dim) + self.attn = attn_layer( + dim, + window_size, + shift_size, + num_heads, + attention_dropout=attention_dropout, + dropout=dropout, + ) + self.stochastic_depth = ops.StochasticDepth(stochastic_depth_prob, "row") + self.norm2 = norm_layer(dim) + self.mlp = ops.misc.MLP( + dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout + ) + + for m in self.mlp.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x: Tensor) -> Tensor: + x = x + self.stochastic_depth(self.attn(self.norm1(x))) + return x + self.stochastic_depth(self.mlp(self.norm2(x))) + + +class SwinTransformerBlockV2(SwinTransformerBlock): + """Swin Transformer V2 Block. + + Args: + dim: Number of input channels. + num_heads: Number of attention heads. + window_size: Window size. + shift_size: Shift size for shifted window attention. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout: Dropout rate. Default: 0.0. + attention_dropout: Attention dropout rate. Default: 0.0. + stochastic_depth_prob: Stochastic depth rate. Default: 0.0. + norm_layer: Normalization layer. Default: nn.LayerNorm. + attn_layer: Attention layer. Default: ShiftedWindowAttentionV2. + + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: List[int], + shift_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2, + ) -> None: + super().__init__( + dim, + num_heads, + window_size, + shift_size, + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=stochastic_depth_prob, + norm_layer=norm_layer, + attn_layer=attn_layer, + ) + + def forward(self, x: Tensor) -> Tensor: + # Here is the difference, we apply norm after the attention in V2. + # In V1 we applied norm before the attention. + x = x + self.stochastic_depth(self.norm1(self.attn(x))) + return x + self.stochastic_depth(self.norm2(self.mlp(x))) + + +@requires("torchvision>=0.13") +class SwinTransformer(nn.Module): + """Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using Shifted + Windows" `_ paper. + + Args: + patch_size: Patch size. + embed_dim: Patch embedding dimension. + depths: Depth of each Swin Transformer layer. + num_heads: Number of attention heads in different layers. + window_size: Window size. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout: Dropout rate. Default: 0.0. + attention_dropout: Attention dropout rate. Default: 0.0. + stochastic_depth_prob: Stochastic depth rate. Default: 0.1. + num_classes: Number of classes for classification head. Default: 1000. + block: SwinTransformer Block. Default: None. + norm_layer: Normalization layer. Default: None. + downsample_layer: Downsample layer (patch merging). Default: PatchMerging. + """ + + def __init__( + self, + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + block: Optional[Callable[..., nn.Module]] = None, + downsample_layer: Callable[..., nn.Module] = PatchMerging, + normalize=False, + output_dim=0, + hidden_mlp=0, + num_prototypes=0, + eval_mode=False, + **kwargs: Any, + ) -> None: + super().__init__() + torchvision.utils._log_api_usage_once(self) + self.num_classes = output_dim + + if block is None: + block = SwinTransformerBlock + if norm_layer is None: + norm_layer = functools.partial(nn.LayerNorm, eps=1e-5) + + self.eval_mode = eval_mode + self.padding = nn.ConstantPad2d(1, 0.0) + + layers: List[nn.Module] = [] + # split image into non-overlapping patches + layers.append( + nn.Sequential( + nn.Conv2d( + 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1]) + ), + ops.misc.Permute([0, 2, 3, 1]), + norm_layer(embed_dim), + ) + ) + + total_stage_blocks = sum(depths) + stage_block_id = 0 + # build SwinTransformer blocks + for i_stage in range(len(depths)): + stage: List[nn.Module] = [] + dim = embed_dim * 2**i_stage + for i_layer in range(depths[i_stage]): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) + stage.append( + block( + dim, + num_heads[i_stage], + window_size=window_size, + shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, + ) + ) + stage_block_id += 1 + layers.append(nn.Sequential(*stage)) + # add patch merging layer + if i_stage < (len(depths) - 1): + layers.append(downsample_layer(dim, norm_layer)) + self.features = nn.Sequential(*layers) + + num_features = embed_dim * 2 ** (len(depths) - 1) + self.norm = norm_layer(num_features) + self.permute = ops.misc.Permute([0, 3, 1, 2]) # B H W C -> B C H W + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.flatten = nn.Flatten(1) + self.l2norm = normalize + + # projection head + if output_dim == 0: + self.projection_head = None + + elif hidden_mlp == 0: + self.projection_head = nn.Linear(num_features, output_dim) + else: + self.projection_head = nn.Sequential( + nn.Linear(num_features, hidden_mlp), + nn.BatchNorm1d(hidden_mlp), + nn.ReLU(inplace=True), + nn.Linear(hidden_mlp, output_dim), + ) + + # prototype layer + self.prototypes = None + if isinstance(num_prototypes, list): + self.prototypes = MultiPrototypes(output_dim, num_prototypes) + elif num_prototypes > 0: + self.prototypes = nn.Linear(output_dim, num_prototypes, bias=False) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward_backbone(self, x) -> Tensor: + x = self.padding(x) + + x = self.features(x) + x = self.norm(x) + x = self.permute(x) + + if self.eval_mode: + return x + + x = self.avgpool(x) + return self.flatten(x) + + def forward_head(self, x) -> Tensor: + if self.projection_head is not None: + x = self.projection_head(x) + + if self.l2norm: + x = nn.functional.normalize(x, dim=1, p=2) + + if self.prototypes is not None: + return x, self.prototypes(x) + return x + + def forward(self, inputs) -> Tensor: + if not isinstance(inputs, list): + inputs = [inputs] + idx_crops = torch.cumsum( + torch.unique_consecutive( + torch.tensor([inp.shape[-1] for inp in inputs]), + return_counts=True, + )[1], + 0, + ) + start_idx, output = 0, None + for end_idx in idx_crops: + _out = torch.cat(inputs[start_idx:end_idx]) + + if next(self.parameters()).is_cuda: + _out = self.forward_backbone(_out.cuda(non_blocking=True)) + else: + _out = self.forward_backbone(_out) + + output = _out if start_idx == 0 else torch.cat((output, _out)) + start_idx = end_idx + return self.forward_head(output) + + +class MultiPrototypes(nn.Module): + def __init__(self, output_dim, num_prototypes) -> None: + super().__init__() + self.nmb_heads = len(num_prototypes) + for i, k in enumerate(num_prototypes): + self.add_module("prototypes" + str(i), nn.Linear(output_dim, k, bias=False)) + + def forward(self, x) -> Tensor: + out = [] + for i in range(self.nmb_heads): + out.append(getattr(self, "prototypes" + str(i))(x)) + return out + + +def _swin_transformer( + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + stochastic_depth_prob: float, + **kwargs: Any, +) -> SwinTransformer: + return SwinTransformer( + patch_size=patch_size, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + stochastic_depth_prob=stochastic_depth_prob, + **kwargs, + ) + + +def swin_s(**kwargs: Any) -> SwinTransformer: + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=[7, 7], + stochastic_depth_prob=0.3, + **kwargs, + ) + + +def swin_b(**kwargs: Any) -> SwinTransformer: + return _swin_transformer( + patch_size=[4, 4], + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=[7, 7], + stochastic_depth_prob=0.5, + **kwargs, + ) + + +def swin_v2_t(**kwargs: Any) -> SwinTransformer: + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=[8, 8], + stochastic_depth_prob=0.2, + block=SwinTransformerBlockV2, + downsample_layer=PatchMergingV2, + **kwargs, + ) + + +def swin_v2_s(**kwargs: Any) -> SwinTransformer: + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=[8, 8], + stochastic_depth_prob=0.3, + block=SwinTransformerBlockV2, + downsample_layer=PatchMergingV2, + **kwargs, + ) + + +def swin_v2_b(**kwargs: Any) -> SwinTransformer: + return _swin_transformer( + patch_size=[4, 4], + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=[8, 8], + stochastic_depth_prob=0.5, + block=SwinTransformerBlockV2, + downsample_layer=PatchMergingV2, + **kwargs, + ) diff --git a/src/pl_bolts/utils/_dependency.py b/src/pl_bolts/utils/_dependency.py index 81ce2d912..f0dfa08df 100644 --- a/src/pl_bolts/utils/_dependency.py +++ b/src/pl_bolts/utils/_dependency.py @@ -11,7 +11,7 @@ def requires(*module_path_version: str) -> Callable: def decorator(func: Callable) -> Callable: reqs = [ - ModuleAvailableCache(mod_ver) if "." in mod_ver else RequirementCache(mod_ver) + ModuleAvailableCache(mod_ver) if "." not in mod_ver else RequirementCache(mod_ver) for mod_ver in module_path_version ] available = all(map(bool, reqs)) diff --git a/tests/models/self_supervised/test_swav_swin.py b/tests/models/self_supervised/test_swav_swin.py new file mode 100644 index 000000000..c3d75c39c --- /dev/null +++ b/tests/models/self_supervised/test_swav_swin.py @@ -0,0 +1,99 @@ +import warnings + +import packaging.version as pv +import pytest +import torch +import torch.nn as nn +import torchvision +from pl_bolts.datamodules import CIFAR10DataModule +from pl_bolts.models.self_supervised import SwAV +from pl_bolts.models.self_supervised.swav.swav_swin import swin_b, swin_s, swin_v2_b, swin_v2_s, swin_v2_t +from pl_bolts.transforms.dataset_normalizations import cifar10_normalization +from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform +from pl_bolts.utils import _IS_WINDOWS +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.warnings import PossibleUserWarning + + +def check_compatibility(): + return pv.parse(torchvision.__version__) >= pv.parse("0.13") + + +model = [swin_s, swin_b, swin_v2_t, swin_v2_s, swin_v2_b] + + +@pytest.mark.parametrize( + ("model_architecture", "hidden_mlp", "prj_head_type", "feat_dim"), + [ + (swin_s, 0, nn.Linear, 128), + (swin_s, 2048, nn.Sequential, 128), + (swin_b, 0, nn.Linear, 128), + (swin_b, 2048, nn.Sequential, 128), + (swin_v2_t, 0, nn.Linear, 128), + (swin_v2_t, 2048, nn.Sequential, 128), + (swin_v2_s, 0, nn.Linear, 128), + (swin_v2_s, 2048, nn.Sequential, 128), + (swin_v2_b, 0, nn.Linear, 128), + (swin_v2_b, 2048, nn.Sequential, 128), + ], +) +@pytest.mark.skipif(not check_compatibility(), reason="Torchvision version not compatible, must be >= 0.13") +@torch.no_grad() +def test_swin_projection_head(model_architecture, hidden_mlp, prj_head_type, feat_dim): + model = model_architecture(hidden_mlp=hidden_mlp, output_dim=feat_dim) + assert isinstance(model.projection_head, prj_head_type) + + +@pytest.mark.parametrize("model", ["swin_s", "swin_b", "swin_v2_t", "swin_v2_s", "swin_v2_b"]) +@pytest.mark.skipif(not check_compatibility(), reason="Torchvision version not compatible, must be >= 0.13") +@pytest.mark.skipif(_IS_WINDOWS, reason="numpy.core._exceptions._ArrayMemoryError...") # todo +def test_swav_swin_model(tmpdir, datadir, model, catch_warnings): + """Test SWAV on CIFAR-10.""" + warnings.filterwarnings( + "ignore", + message=".+does not have many workers which may be a bottleneck.+", + category=PossibleUserWarning, + ) + warnings.filterwarnings("ignore", category=UserWarning) + + batch_size = 2 + datamodule = CIFAR10DataModule(data_dir=datadir, batch_size=batch_size, num_workers=0) + + datamodule.train_transforms = SwAVTrainDataTransform( + normalize=cifar10_normalization(), size_crops=[32, 16], num_crops=[2, 1], gaussian_blur=False + ) + datamodule.val_transforms = SwAVEvalDataTransform( + normalize=cifar10_normalization(), size_crops=[32, 16], num_crops=[2, 1], gaussian_blur=False + ) + if torch.cuda.device_count() >= 1: + devices = torch.cuda.device_count() + accelerator = "gpu" + else: + devices = None + accelerator = "cpu" + + model = SwAV( + arch=model, + hidden_mlp=512, + nodes=1, + gpus=0 if devices is None else devices, + num_samples=datamodule.num_samples, + batch_size=batch_size, + num_crops=[2, 1], + sinkhorn_iterations=1, + num_prototypes=2, + queue_length=0, + maxpool1=False, + first_conv=False, + dataset="cifar10", + ) + trainer = Trainer( + accelerator=accelerator, + devices=devices, + fast_dev_run=True, + default_root_dir=tmpdir, + log_every_n_steps=1, + max_epochs=1, + logger=True, + ) + trainer.fit(model, datamodule=datamodule)