diff --git a/README.md b/README.md index fce3de9498..e47c45a3a2 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,17 @@ ## What's New +## July 23, 2025 +* Add `set_input_size()` method to EVA models, used by OpenCLIP 3.0.0 to allow resizing for timm based encoder models. +* Release 1.0.18, needed for PE-Core S & T models in OpenCLIP 3.0.0 +* Fix small typing issue that broke Python 3.9 compat. 1.0.19 patch release. + +## July 21, 2025 +* ROPE support added to NaFlexViT. All models covered by the EVA base (`eva.py`) including EVA, EVA02, Meta PE ViT, `timm` SBB ViT w/ ROPE, and Naver ROPE-ViT can be now loaded in NaFlexViT when `use_naflex=True` passed at model creation time +* More Meta PE ViT encoders added, including small/tiny variants, lang variants w/ tiling, and more spatial variants. +* PatchDropout fixed with NaFlexViT and also w/ EVA models (regression after adding Naver ROPE-ViT) +* Fix XY order with grid_indexing='xy', impacted non-square image use in 'xy' mode (only ROPE-ViT and PE impacted). + ## July 7, 2025 * MobileNet-v5 backbone tweaks for improved Google Gemma 3n behaviour (to pair with updated official weights) * Add stem bias (zero'd in updated weights, compat break with old weights) @@ -511,6 +522,7 @@ All model architecture families include variants with pretrained weights. There * Next-ViT - https://arxiv.org/abs/2207.05501 * NFNet-F - https://arxiv.org/abs/2102.06171 * NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692 +* PE (Perception Encoder) - https://arxiv.org/abs/2504.13181 * PNasNet - https://arxiv.org/abs/1712.00559 * PoolFormer (MetaFormer) - https://arxiv.org/abs/2111.11418 * Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302 diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 95595cd1b4..3b8e6d4c85 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -104,7 +104,7 @@ unfreeze_batch_norm_2d, ) from .padding import get_padding, get_same_padding, pad_same -from .patch_dropout import PatchDropout +from .patch_dropout import PatchDropout, PatchDropoutWithIndices, patch_dropout_forward from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed from .pool1d import global_pool_nlc from .pool2d_same import AvgPool2dSame, create_pool2d @@ -144,7 +144,7 @@ from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .trace_utils import _assert, _float_to_int -from .typing import LayerType, PadType +from .typing import LayerType, PadType, disable_compiler from .weight_init import ( trunc_normal_, trunc_normal_tf_, diff --git a/timm/layers/attention.py b/timm/layers/attention.py index 2fb394e844..3cd084d3c8 100644 --- a/timm/layers/attention.py +++ b/timm/layers/attention.py @@ -120,6 +120,7 @@ def __init__( norm_layer: Type[nn.Module] = None, qk_norm: bool = False, scale_norm: bool = False, + proj_bias: bool = True, ): """Initialize the Attention module. @@ -161,7 +162,7 @@ def __init__( self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity() - self.proj = nn.Linear(attn_dim, dim) + self.proj = nn.Linear(attn_dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward( diff --git a/timm/layers/patch_dropout.py b/timm/layers/patch_dropout.py index 4428fe042f..ba6338f86d 100644 --- a/timm/layers/patch_dropout.py +++ b/timm/layers/patch_dropout.py @@ -4,50 +4,104 @@ import torch.nn as nn +def patch_dropout_forward( + x: torch.Tensor, + prob: float, + num_prefix_tokens: int, + ordered: bool, + training: bool, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Common forward logic for patch dropout. + + Args: + x: Input tensor of shape (B, L, D) + prob: Dropout probability + num_prefix_tokens: Number of prefix tokens to preserve + ordered: Whether to maintain patch order + training: Whether in training mode + + Returns: + Tuple of (output tensor, keep_indices or None) + """ + if not training or prob == 0.: + return x, None + + if num_prefix_tokens: + prefix_tokens, x = x[:, :num_prefix_tokens], x[:, num_prefix_tokens:] + else: + prefix_tokens = None + + B = x.shape[0] + L = x.shape[1] + num_keep = max(1, int(L * (1. - prob))) + keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep] + + if ordered: + # NOTE does not need to maintain patch order in typical transformer use, + # but possibly useful for debug / visualization + keep_indices = keep_indices.sort(dim=-1)[0] + + x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:])) + + if prefix_tokens is not None: + x = torch.cat((prefix_tokens, x), dim=1) + + return x, keep_indices + + class PatchDropout(nn.Module): """ + Patch Dropout without returning indices. https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220 """ - return_indices: torch.jit.Final[bool] def __init__( self, prob: float = 0.5, num_prefix_tokens: int = 1, ordered: bool = False, - return_indices: bool = False, ): super().__init__() assert 0 <= prob < 1. self.prob = prob self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) self.ordered = ordered - self.return_indices = return_indices - - def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: - if not self.training or self.prob == 0.: - if self.return_indices: - return x, None - return x - - if self.num_prefix_tokens: - prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:] - else: - prefix_tokens = None - - B = x.shape[0] - L = x.shape[1] - num_keep = max(1, int(L * (1. - self.prob))) - keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep] - if self.ordered: - # NOTE does not need to maintain patch order in typical transformer use, - # but possibly useful for debug / visualization - keep_indices = keep_indices.sort(dim=-1)[0] - x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:])) - - if prefix_tokens is not None: - x = torch.cat((prefix_tokens, x), dim=1) - - if self.return_indices: - return x, keep_indices - return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output, _ = patch_dropout_forward( + x, + self.prob, + self.num_prefix_tokens, + self.ordered, + self.training + ) + return output + + +class PatchDropoutWithIndices(nn.Module): + """ + Patch Dropout that returns both output and keep indices. + https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220 + """ + + def __init__( + self, + prob: float = 0.5, + num_prefix_tokens: int = 1, + ordered: bool = False, + ): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) + self.ordered = ordered + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + return patch_dropout_forward( + x, + self.prob, + self.num_prefix_tokens, + self.ordered, + self.training + ) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 9d91e8c1f4..9c6dc6e6ae 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -234,10 +234,41 @@ def apply_rot_embed_cat(x: torch.Tensor, emb): return x * cos_emb + rot(x) * sin_emb -def apply_keep_indices_nlc(x, pos_embed, keep_indices): - pos_embed = pos_embed.unsqueeze(0).expand(x.shape[0], -1, -1) - pos_embed = pos_embed.gather(1, keep_indices.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])) - return pos_embed +def apply_keep_indices_nlc( + x: torch.Tensor, + pos_embed: torch.Tensor, + keep_indices: torch.Tensor, + pos_embed_has_batch: bool = False, +) -> torch.Tensor: + """ Apply keep indices to different ROPE shapes + + Expected pos_embed shapes: + * [seq_len, pos_embed_dim] --> output [batch_size, seq_len, pos_embed_dim] + * [num_heads, seq_len, pos_embed_dim] --> output [batch_size, num_heads, seq_len, pos_embed_dim] + * [depth, num_heads, seq_len, pos_embed_dim] --> output [batch_size, depth, num_heads, seq_len, pos_embed_dim] + + And all of the above with leading batch dimension already present if `pos_embed_has_batch == True` + + """ + if pos_embed_has_batch: + # Pos embed already includes batch dim + _assert(pos_embed.ndim >= 3, 'Incorrect number of dimensions') # At least [batch, seq_len, pos_embed_dim] + else: + # Add batch dimension and expand to batch size + _assert(pos_embed.ndim >= 2, 'Incorrect number of dimensions') # At least [seq_len, pos_embed_dim] + expand_shape = (x.shape[0],) + (-1,) * pos_embed.ndim + pos_embed = pos_embed.unsqueeze(0).expand(expand_shape) + + # Reshape keep_indices to add singleton dims + keep_shape = (keep_indices.shape[0],) + (1,) * (pos_embed.ndim - 3) + (keep_indices.shape[1], 1) + keep_indices = keep_indices.view(keep_shape) + + # Expand all dims to match position embedding except the gather dim (second-last) + keep_expand = list(pos_embed.shape) + keep_expand[-2] = -1 + keep_indices = keep_indices.expand(keep_expand) + + return pos_embed.gather(-2, keep_indices) def build_rotary_pos_embed( @@ -323,6 +354,7 @@ def __init__( self.dim = dim self.max_res = max_res self.temperature = temperature + self.linear_bands = linear_bands self.in_pixels = in_pixels self.feat_shape = feat_shape self.ref_feat_shape = ref_feat_shape @@ -352,17 +384,7 @@ def __init__( self.pos_embed_cos = None else: # cache full sin/cos embeddings if shape provided up front - emb_sin, emb_cos = build_rotary_pos_embed( - feat_shape=feat_shape, - dim=dim, - max_res=max_res, - linear_bands=linear_bands, - in_pixels=in_pixels, - ref_feat_shape=self.ref_feat_shape, - grid_offset=self.grid_offset, - grid_indexing=self.grid_indexing, - temperature=self.temperature, - ) + emb_sin, emb_cos = self._get_pos_embed_values(feat_shape) self.bands = None self.register_buffer( 'pos_embed_sin', @@ -375,6 +397,30 @@ def __init__( persistent=False, ) + def _get_pos_embed_values(self, feat_shape: List[int]): + emb_sin, emb_cos = build_rotary_pos_embed( + feat_shape=feat_shape, + dim=self.dim, + max_res=self.max_res, + temperature=self.temperature, + linear_bands=self.linear_bands, + in_pixels=self.in_pixels, + ref_feat_shape=self.ref_feat_shape, + grid_offset=self.grid_offset, + grid_indexing=self.grid_indexing, + ) + return emb_sin, emb_cos + + def update_feat_shape(self, feat_shape: List[int]): + if self.feat_shape is not None and feat_shape != self.feat_shape: + # only update if feat_shape was set and different from previous value + assert self.pos_embed_sin is not None + assert self.pos_embed_cos is not None + emb_sin, emb_cos = self._get_pos_embed_values(feat_shape) + self.pos_embed_sin = emb_sin.to(self.pos_embed_sin.device, self.pos_embed_sin.dtype) + self.pos_embed_cos = emb_cos.to(self.pos_embed_cos.device, self.pos_embed_cos.dtype) + self.feat_shape = feat_shape + def get_embed(self, shape: Optional[List[int]] = None): if shape is not None and self.bands is not None: # rebuild embeddings every call, use if target shape changes @@ -422,6 +468,7 @@ def __init__( self.max_res = max_res self.temperature = temperature self.in_pixels = in_pixels + self.linear_bands = linear_bands self.feat_shape = feat_shape self.ref_feat_shape = ref_feat_shape self.grid_offset = grid_offset @@ -449,27 +496,40 @@ def __init__( self.pos_embed = None else: # cache full sin/cos embeddings if shape provided up front - embeds = build_rotary_pos_embed( - feat_shape=feat_shape, - dim=dim, - max_res=max_res, - linear_bands=linear_bands, - in_pixels=in_pixels, - ref_feat_shape=self.ref_feat_shape, - grid_offset=self.grid_offset, - grid_indexing=self.grid_indexing, - temperature=self.temperature, - ) self.bands = None self.register_buffer( 'pos_embed', - torch.cat(embeds, -1), + self._get_pos_embed_values(feat_shape=feat_shape), persistent=False, ) + def _get_pos_embed_values(self, feat_shape: List[int]): + embeds = build_rotary_pos_embed( + feat_shape=feat_shape, + dim=self.dim, + max_res=self.max_res, + temperature=self.temperature, + linear_bands=self.linear_bands, + in_pixels=self.in_pixels, + ref_feat_shape=self.ref_feat_shape, + grid_offset=self.grid_offset, + grid_indexing=self.grid_indexing, + ) + return torch.cat(embeds, -1) + + def update_feat_shape(self, feat_shape: List[int]): + if self.feat_shape is not None and feat_shape != self.feat_shape: + # only update if feat_shape was set and different from previous value + assert self.pos_embed is not None + self.pos_embed = self._get_pos_embed_values(feat_shape).to( + device=self.pos_embed.device, + dtype=self.pos_embed.dtype, + ) + self.feat_shape = feat_shape + def get_embed(self, shape: Optional[List[int]] = None): if shape is not None and self.bands is not None: - # rebuild embeddings every call, use if target shape changes + # rebuild embeddings from cached bands every call, use if target shape changes embeds = build_rotary_pos_embed( shape, self.bands, @@ -484,6 +544,59 @@ def get_embed(self, shape: Optional[List[int]] = None): else: assert False, "get_embed() requires pre-computed pos embed or valid shape w/ pre-computed bands" + def get_batch_embeds( + self, + shapes: List[Tuple[int, int]], + seq_len: Optional[int] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Generate ROPE embeddings for multiple grid shapes efficiently. + + Computes embeddings for the maximum grid size once, then extracts + and flattens the relevant portions for each requested shape. + + Args: + shapes: List of (H, W) tuples representing different grid sizes + + Returns: + List of concatenated sin/cos embeddings for each shape, + where each tensor has shape (H*W, dim) + """ + if not shapes: + return [] + + # Check if we have pre-computed bands + if self.bands is None: + # If we have pre-computed pos_embed for a fixed shape, we can't do batch generation + raise RuntimeError("Batch embedding generation requires cached bands, not pre-computed embeddings") + + # Find max dimensions across all shapes + max_h = max(h for h, w in shapes) + max_w = max(w for h, w in shapes) + + # Generate embeddings for max size ONCE + sin_emb, cos_emb = build_rotary_pos_embed( + feat_shape=(max_h, max_w), + bands=self.bands, + in_pixels=self.in_pixels, + ref_feat_shape=self.ref_feat_shape, + grid_offset=self.grid_offset, + grid_indexing=self.grid_indexing, + ) + + # sin_emb and cos_emb are (max_h * max_w, dim//2) + # concat and reshape to 2D for slicing + rope_embed_2d = torch.cat([sin_emb, cos_emb], dim=-1).view(max_h, max_w, -1) + + if seq_len is not None: + flat_embeds = torch.zeros(len(shapes), seq_len, rope_embed_2d.shape[-1]).type_as(sin_emb) + for i, (h, w) in enumerate(shapes): + src_len = h * w + flat_embeds[i, :src_len] = rope_embed_2d[:h, :w].reshape(src_len, -1) + return flat_embeds + else: + flat_embeds_list = [rope_embed_2d[:h, :w].reshape(h * w, -1) for h, w in shapes] + return flat_embeds_list + def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 pos_embed = self.get_embed(x.shape[2:]) @@ -600,6 +713,7 @@ def __init__( head_dim = dim // num_heads assert head_dim % 4 == 0, f"head_dim must be divisible by 4, got {head_dim}" + freqs = init_random_2d_freqs( head_dim, depth, @@ -608,18 +722,32 @@ def __init__( rotate=True, ) # (2, depth, num_heads, head_dim//2) self.freqs = nn.Parameter(freqs) + if feat_shape is not None: # cache pre-computed grid - t_x, t_y = get_mixed_grid( - feat_shape, - grid_indexing=grid_indexing, - device=self.freqs.device - ) + t_x, t_y = self._get_grid_values(feat_shape) self.register_buffer('t_x', t_x, persistent=False) self.register_buffer('t_y', t_y, persistent=False) else: self.t_x = self.t_y = None + def _get_grid_values(self, feat_shape: Optional[List[int]]): + t_x, t_y = get_mixed_grid( + feat_shape, + grid_indexing=self.grid_indexing, + device=self.freqs.device + ) + return t_x, t_y + + def update_feat_shape(self, feat_shape: Optional[List[int]]): + if self.feat_shape is not None and feat_shape != self.feat_shape: + assert self.t_x is not None + assert self.t_y is not None + t_x, t_y = self._get_grid_values(feat_shape) + self.t_x = t_x.to(self.t_x.device, self.t_x.dtype) + self.t_y = t_y.to(self.t_y.device, self.t_y.dtype) + self.feat_shape = feat_shape + def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor: """Generate rotary embeddings for the given spatial shape. @@ -642,6 +770,62 @@ def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor: return get_mixed_freqs(self.freqs, t_x, t_y) + def get_batch_embeds( + self, + shapes: List[Tuple[int, int]], + seq_len: Optional[int] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """Generate ROPE embeddings for multiple grid shapes efficiently. + + Computes embeddings for the maximum grid size once, then extracts + and flattens the relevant portions for each requested shape. + + Args: + shapes: List of (H, W) tuples representing different grid sizes + seq_len: If provided, return padded tensor of this length. Otherwise return list. + + Returns: + If seq_len is provided: Padded tensor of shape (len(shapes), depth, num_heads, seq_len, dim) + Otherwise: List of tensors with shape (depth, num_heads, H*W, dim) for each shape + """ + if not shapes: + return [] + + # Find max dimensions + max_h = max(h for h, w in shapes) + max_w = max(w for h, w in shapes) + + # Generate embeddings for max size ONCE + t_x, t_y = get_mixed_grid( + [max_h, max_w], + grid_indexing=self.grid_indexing, + device=self.freqs.device + ) + max_embed = get_mixed_freqs(self.freqs, t_x, t_y) # (depth, num_heads, max_h*max_w, dim) + + # Reshape to 2D grid for easy slicing + depth, num_heads, _, dim = max_embed.shape + max_embed_2d = max_embed.view(depth, num_heads, max_h, max_w, dim) + + if seq_len is not None: + # Return padded tensor + B = len(shapes) + padded = torch.zeros(B, depth, num_heads, seq_len, dim, device=self.freqs.device, dtype=self.freqs.dtype) + for i, (h, w) in enumerate(shapes): + # Slice and flatten + embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim) + actual_len = h * w + padded[i, :, :, :actual_len] = embed_slice + return padded + else: + # Return list + results = [] + for h, w in shapes: + # Slice and flatten + embed_slice = max_embed_2d[:, :, :h, :w].reshape(depth, num_heads, h * w, dim) + results.append(embed_slice) + return results + def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 pos_embed = self.get_embed(x.shape[2:]) diff --git a/timm/layers/typing.py b/timm/layers/typing.py index 593fa5cc8b..9d503defc4 100644 --- a/timm/layers/typing.py +++ b/timm/layers/typing.py @@ -1,7 +1,34 @@ -from typing import Callable, Tuple, Type, Union +from contextlib import nullcontext +from functools import wraps +from typing import Callable, Optional, Tuple, Type, TypeVar, Union, overload, ContextManager import torch +__all__ = ["LayerType", "PadType", "nullwrap", "disable_compiler"] + LayerType = Union[str, Callable, Type[torch.nn.Module]] PadType = Union[str, int, Tuple[int, int]] + +F = TypeVar("F", bound=Callable[..., object]) + + +@overload +def nullwrap(fn: F) -> F: ... # decorator form + +@overload +def nullwrap(fn: None = ...) -> ContextManager: ... # context‑manager form + +def nullwrap(fn: Optional[F] = None): + # as a context manager + if fn is None: + return nullcontext() # `with nullwrap():` + + # as a decorator + @wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + return wrapper # `@nullwrap` + + +disable_compiler = getattr(getattr(torch, "compiler", None), "disable", None) or nullwrap diff --git a/timm/models/eva.py b/timm/models/eva.py index 9a48835a55..bcfa3ee2cb 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -46,6 +46,7 @@ # EVA models Copyright (c) 2022 BAAI-Vision # EVA02 models Copyright (c) 2023 BAAI-Vision import math +import os from functools import partial from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -54,11 +55,27 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \ - RotaryEmbeddingMixed, apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, \ - resample_patch_embed, resample_abs_pos_embed, global_pool_nlc, to_2tuple, use_fused_attn, AttentionRope, \ - AttentionPoolLatent - +from timm.layers import ( + PatchEmbed, + Mlp, + GluMlp, + SwiGLU, + LayerNorm, + DropPath, + PatchDropoutWithIndices, + RotaryEmbeddingCat, + RotaryEmbeddingMixed, + apply_rot_embed_cat, + apply_keep_indices_nlc, + trunc_normal_, + resample_patch_embed, + resample_abs_pos_embed, + global_pool_nlc, + to_2tuple, + use_fused_attn, + AttentionRope, + AttentionPoolLatent, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import checkpoint @@ -226,6 +243,7 @@ def __init__( act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, attn_head_dim: Optional[int] = None, + **kwargs, ): """ Initialize the EVA transformer block. @@ -298,7 +316,12 @@ def __init__( self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None): + def forward( + self, + x: torch.Tensor, + rope: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if self.gamma_1 is None: x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask)) x = x + self.drop_path2(self.mlp(self.norm2(x))) @@ -399,7 +422,12 @@ def __init__( self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None): + def forward( + self, + x: torch.Tensor, + rope: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: x = x + self.drop_path1(self.norm1(self.attn(x, rope=rope, attn_mask=attn_mask))) x = x + self.drop_path2(self.norm2(self.mlp(x))) return x @@ -549,11 +577,7 @@ def __init__( self.pos_embed = nn.Parameter(torch.zeros(1, num_pos_tokens, embed_dim)) if use_abs_pos_emb else None self.pos_drop = nn.Dropout(p=pos_drop_rate) if patch_drop_rate > 0: - self.patch_drop = PatchDropout( - patch_drop_rate, - num_prefix_tokens=self.num_prefix_tokens, - return_indices=True, - ) + self.patch_drop = PatchDropoutWithIndices(patch_drop_rate, num_prefix_tokens=self.num_prefix_tokens) else: self.patch_drop = None @@ -614,12 +638,10 @@ def __init__( self.norm = norm_layer(embed_dim) if activate_post_norm else nn.Identity() if global_pool == 'map': - attn_pool_num_heads = attn_pool_num_heads or num_heads - attn_pool_mlp_ratio = attn_pool_mlp_ratio or mlp_ratio self.attn_pool = AttentionPoolLatent( self.embed_dim, - num_heads=attn_pool_num_heads, - mlp_ratio=attn_pool_mlp_ratio, + num_heads=attn_pool_num_heads or num_heads, + mlp_ratio=attn_pool_mlp_ratio or mlp_ratio, norm_layer=norm_layer, act_layer=nn.GELU, ) @@ -701,6 +723,35 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def set_input_size( + self, + img_size: Optional[Tuple[int, int]] = None, + patch_size: Optional[Tuple[int, int]] = None, + ) -> None: + """Update the input image resolution and patch size. + + Args: + img_size: New input resolution, if None current resolution is used. + patch_size: New patch size, if None existing patch size is used. + """ + prev_grid_size = self.patch_embed.grid_size + self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size) + + if self.pos_embed is not None: + num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens + num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens + if num_new_tokens != self.pos_embed.shape[1]: + self.pos_embed = nn.Parameter(resample_abs_pos_embed( + self.pos_embed, + new_size=self.patch_embed.grid_size, + old_size=prev_grid_size, + num_prefix_tokens=num_prefix_tokens, + verbose=True, + )) + + if self.rope is not None: + self.rope.update_feat_shape(self.patch_embed.grid_size) + def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if self.dynamic_img_size: B, H, W, C = x.shape @@ -741,11 +792,19 @@ def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: x = self.pos_drop(x) - # obtain shared rotary position embedding and apply patch dropout + # apply patch dropout to patches and rotary position embedding if self.patch_drop is not None: x, keep_indices = self.patch_drop(x) if rot_pos_embed is not None and keep_indices is not None: rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices) + # After applying keep indices to rope embeds, batch dim is added + if getattr(self, 'rope_mixed', False): + # B, D, nH, N, dim -> D, B, nH, N, dim. For consistent iteration over depth at index 0. + rot_pos_embed = rot_pos_embed.transpose(0, 1) + else: + # B, N, dim -> B, 1, N, dim. Need head dim singleton for correct dim alignment in axial mode. + rot_pos_embed = rot_pos_embed.unsqueeze(1) + return x, rot_pos_embed def forward_intermediates( @@ -782,6 +841,7 @@ def forward_intermediates( blocks = self.blocks else: blocks = self.blocks[:max_index + 1] + # Handle depth-dependent embeddings for mixed mode if getattr(self, 'rope_mixed', False) and rot_pos_embed is not None: for i, blk in enumerate(blocks): @@ -859,9 +919,9 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x, rot_pos_embed = self._pos_embed(x) x = self.norm_pre(x) - # Handle depth-dependent embeddings for mixed mode if getattr(self, 'rope_mixed', False) and rot_pos_embed is not None: - # rot_pos_embed has shape (depth, H*W, dim) for mixed mode + # Handle depth-dependent embeddings for mixed mode + # pos embed has shape (depth, num_heads, H*W, dim) or (depth, batch_size, num_heads, H*W, dim) for i, blk in enumerate(self.blocks): if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint(blk, x, rope=rot_pos_embed[i]) @@ -1084,6 +1144,16 @@ def _create_eva(variant: str, pretrained: bool = False, **kwargs) -> Eva: Returns: Instantiated Eva model. """ + # Check if we should use NaFlexVit implementation + use_naflex = kwargs.pop('use_naflex', None) + _USE_NAFLEX_DEFAULT = os.environ.get('TIMM_USE_NAFLEX', '0') == '1' + if use_naflex is None: + use_naflex = _USE_NAFLEX_DEFAULT + if use_naflex: + # Import here to avoid circular imports + from .naflexvit import _create_naflexvit_from_eva + return _create_naflexvit_from_eva(variant, pretrained, **kwargs) + out_indices = kwargs.pop('out_indices', 3) model = build_model_with_cfg( Eva, variant, pretrained, @@ -1323,6 +1393,20 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]: ), # Perception Encoder weights + 'vit_pe_core_tiny_patch16_384.fb': _pe_cfg( + hf_hub_id='timm/', + #hf_hub_id='facebook/PE-Core-T16-384', + #hf_hub_filename='PE-Core-T16-384.pt', + input_size=(3, 384, 384), + num_classes=512, # output proj dim + ), + 'vit_pe_core_small_patch16_384.fb': _pe_cfg( + hf_hub_id='timm/', + #hf_hub_id='facebook/PE-Core-S16-384', + #hf_hub_filename='PE-Core-S16-384.pt', + input_size=(3, 384, 384), + num_classes=512, # output proj dim + ), 'vit_pe_core_base_patch16_224.fb': _pe_cfg( hf_hub_id='timm/', #hf_hub_id='facebook/PE-Core-B16-224', @@ -1344,6 +1428,7 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]: input_size=(3, 448, 448), num_classes=1280, # output proj dim ), + 'vit_pe_lang_large_patch14_448.fb': _pe_cfg( hf_hub_id='timm/', #hf_hub_id='facebook/PE-Lang-L14-448', @@ -1351,6 +1436,13 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]: input_size=(3, 448, 448), num_classes=0, ), + 'vit_pe_lang_large_patch14_448.fb_tiling': _pe_cfg( + hf_hub_id='timm/', + #hf_hub_id='facebook/PE-Lang-L14-448-Tiling', + #hf_hub_filename='PE-Lang-L14-448-Tiling.pt', + input_size=(3, 448, 448), + num_classes=0, + ), 'vit_pe_lang_gigantic_patch14_448.fb': _pe_cfg( hf_hub_id='timm/', #hf_hub_id='facebook/PE-Lang-G14-448', @@ -1358,6 +1450,42 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]: input_size=(3, 448, 448), num_classes=0, ), + 'vit_pe_lang_gigantic_patch14_448.fb_tiling': _pe_cfg( + hf_hub_id='timm/', + #hf_hub_id='facebook/PE-Lang-G14-448-Tiling', + #hf_hub_filename='PE-Lang-G14-448-Tiling.pt', + input_size=(3, 448, 448), + num_classes=0, + ), + + 'vit_pe_spatial_tiny_patch16_512.fb': _pe_cfg( + hf_hub_id='timm/', + #hf_hub_id='facebook/PE-Spatial-T16-512', + #hf_hub_filename='PE-Spatial-T16-512.pt', + input_size=(3, 512, 512), + num_classes=0, + ), + 'vit_pe_spatial_small_patch16_512.fb': _pe_cfg( + hf_hub_id='timm/', + #hf_hub_id='facebook/PE-Spatial-S16-512', + #hf_hub_filename='PE-Spatial-S16-512.pt', + input_size=(3, 512, 512), + num_classes=0, + ), + 'vit_pe_spatial_base_patch16_512.fb': _pe_cfg( + hf_hub_id='timm/', + #hf_hub_id='facebook/PE-Spatial-B16-512', + #hf_hub_filename='PE-Spatial-B16-512.pt', + input_size=(3, 512, 512), + num_classes=0, + ), + 'vit_pe_spatial_large_patch14_448.fb': _pe_cfg( + hf_hub_id='timm/', + #hf_hub_id='facebook/PE-Spatial-L14-448', + #hf_hub_filename='PE-Spatial-L14-448.pt', + input_size=(3, 448, 448), + num_classes=0, + ), 'vit_pe_spatial_gigantic_patch14_448.fb': _pe_cfg( hf_hub_id='timm/', #hf_hub_id='facebook/PE-Spatial-G14-448', @@ -1799,6 +1927,55 @@ def vit_base_patch16_rope_reg1_gap_256(pretrained: bool = False, **kwargs) -> Ev return model +@register_model +def vit_pe_core_tiny_patch16_384(pretrained: bool = False, **kwargs) -> Eva: + """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)""" + model_args = dict( + patch_size=16, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4.0, + global_pool='map', + attn_type='rope', + use_pre_transformer_norm=True, + use_rot_pos_emb=True, + ref_feat_shape=(24, 24), + rope_grid_offset=1., + rope_grid_indexing='xy', + attn_pool_num_heads=8, + attn_pool_mlp_ratio=4., + norm_layer=partial(LayerNorm, eps=1e-5), + #dynamic_img_size=True + ) + return _create_eva('vit_pe_core_tiny_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs)) + + + +@register_model +def vit_pe_core_small_patch16_384(pretrained: bool = False, **kwargs) -> Eva: + """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)""" + model_args = dict( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4.0, + global_pool='map', + attn_type='rope', + use_pre_transformer_norm=True, + use_rot_pos_emb=True, + ref_feat_shape=(24, 24), + rope_grid_offset=1., + rope_grid_indexing='xy', + attn_pool_num_heads=8, + attn_pool_mlp_ratio=4., + norm_layer=partial(LayerNorm, eps=1e-5), + #dynamic_img_size=True + ) + return _create_eva('vit_pe_core_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs)) + + @register_model def vit_pe_core_base_patch16_224(pretrained: bool = False, **kwargs) -> Eva: """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)""" @@ -1920,6 +2097,98 @@ def vit_pe_lang_gigantic_patch14_448(pretrained: bool = False, **kwargs) -> Eva: return _create_eva('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) +@register_model +def vit_pe_spatial_tiny_patch16_512(pretrained: bool = False, **kwargs) -> Eva: + """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)""" + model_args = dict( + patch_size=16, + embed_dim=192, + depth=12, + num_heads=3, + mlp_ratio=4.0, + attn_type='rope', + use_pre_transformer_norm=True, + use_post_transformer_norm=False, + use_fc_norm=False, # explicitly disable + use_rot_pos_emb=True, + ref_feat_shape=(32, 32), + rope_grid_offset=1., + rope_grid_indexing='xy', + norm_layer=partial(LayerNorm, eps=1e-5), + #dynamic_img_size=True + ) + return _create_eva('vit_pe_spatial_tiny_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def vit_pe_spatial_small_patch16_512(pretrained: bool = False, **kwargs) -> Eva: + """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)""" + model_args = dict( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4.0, + attn_type='rope', + use_pre_transformer_norm=True, + use_post_transformer_norm=False, + use_fc_norm=False, # explicitly disable + use_rot_pos_emb=True, + ref_feat_shape=(32, 32), + rope_grid_offset=1., + rope_grid_indexing='xy', + norm_layer=partial(LayerNorm, eps=1e-5), + #dynamic_img_size=True + ) + return _create_eva('vit_pe_spatial_small_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def vit_pe_spatial_base_patch16_512(pretrained: bool = False, **kwargs) -> Eva: + """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)""" + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + attn_type='rope', + use_pre_transformer_norm=True, + use_post_transformer_norm=False, + use_fc_norm=False, # explicitly disable + use_rot_pos_emb=True, + ref_feat_shape=(32, 32), + rope_grid_offset=1., + rope_grid_indexing='xy', + norm_layer=partial(LayerNorm, eps=1e-5), + #dynamic_img_size=True + ) + return _create_eva('vit_pe_spatial_base_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def vit_pe_spatial_large_patch14_448(pretrained: bool = False, **kwargs) -> Eva: + """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)""" + model_args = dict( + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + attn_type='rope', + use_pre_transformer_norm=True, + use_post_transformer_norm=False, + use_fc_norm=False, # explicitly disable + use_rot_pos_emb=True, + ref_feat_shape=(32, 32), + rope_grid_offset=1., + rope_grid_indexing='xy', + norm_layer=partial(LayerNorm, eps=1e-5), + #dynamic_img_size=True, + ) + return _create_eva('vit_pe_spatial_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + + @register_model def vit_pe_spatial_gigantic_patch14_448(pretrained: bool = False, **kwargs) -> Eva: """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)""" diff --git a/timm/models/naflexvit.py b/timm/models/naflexvit.py index 5794fdf5e0..94fe8fa63d 100644 --- a/timm/models/naflexvit.py +++ b/timm/models/naflexvit.py @@ -25,24 +25,27 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import ( AttentionPoolLatent, Mlp, + LayerNorm, + PatchDropoutWithIndices, + PatchEmbedInterpolator, + _assert, to_2tuple, get_act_layer, get_norm_layer, - LayerNorm, - _assert, + apply_keep_indices_nlc, + disable_compiler, ) -from timm.models._builder import build_model_with_cfg -from timm.models._features import feature_take_indices -from timm.models._features_fx import register_notrace_function, register_notrace_module -from timm.models._registry import register_model, generate_default_cfgs -from timm.models._manipulate import checkpoint, checkpoint_seq, named_apply - +from ._builder import build_model_with_cfg +from ._features import feature_take_indices +from ._features_fx import register_notrace_function, register_notrace_module +from ._manipulate import checkpoint, named_apply +from ._registry import register_model, generate_default_cfgs +from .eva import EvaBlock from .vision_transformer import Block, global_pool_nlc __all__ = ['NaFlexVitCfg', 'NaFlexVit'] @@ -65,12 +68,14 @@ class NaFlexVitCfg: depth: int = 12 num_heads: int = 12 mlp_ratio: float = 4.0 + scale_mlp_norm: bool = False # Apply scaling norm to MLP # Attention parameters qkv_bias: bool = True qk_norm: bool = False proj_bias: bool = True attn_drop_rate: float = 0.0 + scale_attn_inner_norm: bool = False # Apply scaling norm to attn context # Regularization init_values: Optional[float] = None # Layer-scale init values (layer-scale enabled if not None) @@ -91,15 +96,26 @@ class NaFlexVitCfg: pos_embed_ar_preserving: bool = False # Whether to preserve aspect ratio during position embedding interpolation pos_embed_use_grid_sample: bool = False # Whether to use grid_sample for naflex position embedding interpolation + # ROPE specific configuration + rope_type: str = '' # ROPE type: '' or 'none' for no ROPE, 'axial' for standard, 'mixed' for learnable frequencies + rope_temperature: float = 10000.0 # Temperature for ROPE frequency computation + rope_ref_feat_shape: Optional[Tuple[int, int]] = None + rope_grid_offset: float = 0. # Grid offset for non-pixel ROPE mode + rope_grid_indexing: str = 'ij' # Grid indexing mode for ROPE ('ij' or 'xy') + # Image processing dynamic_img_pad: bool = False # Whether to enable dynamic padding for variable resolution - # Architecture choices + # Other architecture choices pre_norm: bool = False # Whether to apply normalization before attention/MLP layers (start of blocks) final_norm: bool = True # Whether to apply final normalization before pooling and classifier (end of blocks) fc_norm: Optional[bool] = None # Whether to normalize features before final classifier (after pooling) + + # Global pooling setup global_pool: str = 'map' # Type of global pooling for final sequence pool_include_prefix: bool = False # Whether to include class/register prefix tokens in global pooling + attn_pool_num_heads: Optional[int] = None # Override num_heads for attention pool + attn_pool_mlp_ratio: Optional[float] = None # Override mlp_ratio for attention pool # Weight initialization weight_init: str = '' # Weight initialization scheme @@ -116,6 +132,11 @@ class NaFlexVitCfg: block_fn: Optional[str] = None # Transformer block implementation class name mlp_layer: Optional[str] = None # MLP implementation class name + # EVA-specific parameters + attn_type: str = 'standard' # Attention type: 'standard', 'eva', 'rope' + swiglu_mlp: bool = False # Use SwiGLU MLP variant + qkv_fused: bool = True # Whether to use fused QKV projections + # Variable patch size support enable_patch_interpolator: bool = False # Enable dynamic patch size support @@ -171,6 +192,112 @@ def calculate_naflex_grid_sizes(_coord: torch.Tensor): return [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)] +class NaFlexRopeIterator: + """Iterator for generating batched ROPE embeddings for mixed mode with multiple grid sizes.""" + + def __init__( + self, + rope_module, + size_to_indices: Dict[Tuple[int, int], List[int]], + unique_sizes: List[Tuple[int, int]], + batch_size: int, + seq_len: int, + dtype: torch.dtype, + device: torch.device, + ): + self.rope = rope_module + self.size_to_indices = size_to_indices + self.unique_sizes = unique_sizes + self.batch_size = batch_size + self.seq_len = seq_len + self.dtype = dtype + self.device = device + self.depth = rope_module.depth + self.num_heads = rope_module.num_heads + self.head_dim = 2 * rope_module.dim // rope_module.num_heads + self._depth_idx = 0 + + # Pre-compute embeddings for each unique size + self._embeddings_per_size = {} + for grid_size in unique_sizes: + # get_embed returns all depths at once for mixed mode + rope_embed = rope_module.get_embed(shape=grid_size) + self._embeddings_per_size[grid_size] = rope_embed + + def __iter__(self): + self._depth_idx = 0 + return self + + @disable_compiler + def __next__(self): + if self._depth_idx >= self.depth: + raise StopIteration + + # Create batch tensor for current depth + batch_embed = torch.zeros( + self.batch_size, self.num_heads, self.seq_len, self.head_dim, + dtype=self.dtype, device=self.device + ) + + # Fill in embeddings for each unique grid size + for grid_size in self.unique_sizes: + h, w = grid_size + actual_len = h * w + batch_indices = self.size_to_indices[grid_size] + + # Get pre-computed embeddings for this size at current depth + embed = self._embeddings_per_size[grid_size][self._depth_idx] # [num_heads, H*W, dim] + + # Assign to batch indices + for bi in batch_indices: + batch_embed[bi, :, :actual_len, :] = embed[:, :actual_len, :] + + self._depth_idx += 1 + return batch_embed + + +def get_block_fn(cfg: NaFlexVitCfg) -> Callable: + """Get appropriate block function based on configuration. + + Returns a partially applied block constructor with EVA-specific + or conflicting parameters pre-configured if needed. + """ + # Check if we need EVA block features + use_eva_features = ( + cfg.attn_type in ('eva', 'rope') or + cfg.rope_type not in ('', 'none') or # Any ROPE type requires EVA blocks + cfg.swiglu_mlp + ) + + if use_eva_features: + # Determine attention type based on rope_type if not explicitly set + attn_type = cfg.attn_type + if attn_type == 'standard' and cfg.rope_type not in ('', 'none'): + attn_type = 'rope' + + num_prefix_tokens = (1 if cfg.class_token else 0) + cfg.reg_tokens + return partial( + EvaBlock, + attn_type=attn_type, + swiglu_mlp=cfg.swiglu_mlp, + scale_mlp=cfg.scale_mlp_norm, + scale_attn_inner=cfg.scale_attn_inner_norm, + qkv_fused=cfg.qkv_fused, + num_prefix_tokens=num_prefix_tokens, + ) + else: + # Standard ViT block + block_fn = cfg.block_fn or Block + if cfg.scale_mlp_norm or cfg.scale_attn_inner_norm: + # param names differ between EVA vs non-EVA block types + block_fn = partial( + block_fn, + scale_mlp_norm=cfg.scale_mlp_norm, + scale_attn_norm=cfg.scale_attn_inner_norm + ) + return block_fn + + @register_notrace_module class NaFlexEmbeds(nn.Module): """NaFlex Embedding module for Vision Transformers. @@ -201,9 +328,8 @@ class NaFlexEmbeds(nn.Module): proj_type: Type of embedding projection layer ('conv' or 'linear') input_norm_layer: Normalization layer applied to input (linear mode only) proj_norm_layer: Normalization layer applied after projection - pos_embed: Type of position embedding ('learned', 'factorized', 'rope', 'none') + pos_embed: Type of position embedding ('learned', 'factorized', 'none') pos_drop_rate: Dropout rate for position embeddings - patch_drop_rate: Dropout rate for patch tokens class_token: Whether to include a class token reg_tokens: Number of register tokens to include bias: Whether to use bias in projection layers @@ -234,7 +360,6 @@ def __init__( proj_norm_layer: Union[bool, Optional[Type[nn.Module]]] = None, norm_layer: Optional[Type[nn.Module]] = None, pos_drop_rate: float = 0., - patch_drop_rate: float = 0., enable_patch_interpolator: bool = False, ) -> None: """Initialize NaFlexEmbeds module. @@ -249,7 +374,7 @@ def __init__( reg_tokens: Number of register tokens to include. dynamic_img_pad: Whether to enable dynamic padding for variable resolution. default_img_size: Default image size for position embedding grid calculation. - pos_embed: Type of position embedding ('learned', 'factorized', 'rope', 'none'). + pos_embed: Type of position embedding ('learned', 'factorized', 'none'). pos_embed_grid_size: Grid size for position embedding initialization. pos_embed_interp_mode: Interpolation mode for position embedding resizing. pos_embed_ar_preserving: Whether to preserve aspect ratio during interpolation. @@ -257,7 +382,6 @@ def __init__( proj_norm_layer: Normalization layer applied after projection. norm_layer: Default normalization layer. pos_drop_rate: Dropout rate for position embeddings. - patch_drop_rate: Dropout rate for patch tokens. enable_patch_interpolator: Enable dynamic patch size support. """ super().__init__() @@ -315,7 +439,6 @@ def __init__( # Create patch embedding interpolator if enabled if self.enable_patch_interpolator: - from timm.layers import PatchEmbedInterpolator self.patch_interpolator = PatchEmbedInterpolator( base_patch_size=self.patch_size, in_chans=in_chans, @@ -342,9 +465,6 @@ def __init__( self.pos_embed_x: Optional[torch.Tensor] = None if not pos_embed or pos_embed == 'none': self.pos_embed_type = 'none' - elif pos_embed == 'rope': - self.pos_embed_type = 'rope' - # Rotary embeddings will be computed on-the-fly in the forward pass elif pos_embed == 'factorized': assert self.pos_embed_grid_size is not None h, w = self.pos_embed_grid_size @@ -357,16 +477,8 @@ def __init__( self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim) * .02) self.pos_embed_type = 'learned' - # Dropout layers + # Dropout layer self.pos_drop = nn.Dropout(p=pos_drop_rate) - if patch_drop_rate > 0: - from timm.layers.patch_dropout import PatchDropout - self.patch_drop = PatchDropout( - patch_drop_rate, - num_prefix_tokens=self.num_prefix_tokens, - ) - else: - self.patch_drop = nn.Identity() def feature_info(self, location) -> Dict[str, Any]: """Get feature information for feature extraction. @@ -409,7 +521,7 @@ def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: else: return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] - #@torch.compiler.disable() + @disable_compiler def _apply_learned_naflex_pos_embed( self, x: torch.Tensor, @@ -466,6 +578,7 @@ def _interp2d(size): pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1) ) + @disable_compiler def _apply_learned_naflex_pos_embed_grid_sample( self, x: torch.Tensor, @@ -547,6 +660,7 @@ def _apply_learned_pos_embed( x.add_(pos_embed_flat) + @disable_compiler def _apply_factorized_naflex_pos_embed( self, x: torch.Tensor, @@ -608,6 +722,7 @@ def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.T pos[:, :seq_len].expand(len(batch_indices), -1, -1) ) + @disable_compiler def _apply_factorized_naflex_pos_embed_grid_sample( self, x: torch.Tensor, @@ -701,7 +816,8 @@ def forward( self, x: torch.Tensor, patch_coord: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + patch_valid: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[Tuple[int, int]]]: """Forward pass for patch embedding with position encoding. Args: @@ -710,12 +826,14 @@ def forward( - [B, N, P*P*C] for pre-patchified linear mode (normal) - [B, N, Ph, Pw, C] for pre-patchified linear mode (variable patch size) patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode. + patch_valid: Optional validity mask for patches [B, N] for NaFlex mode. Returns: - Embedded tensor with position encoding and class/register tokens. - Shape: [B, num_prefix_tokens + N, embed_dim] + Tuple of (embedded_tensor, grid_size) where: + - embedded_tensor: [B, num_prefix_tokens + N, embed_dim] + - grid_size: (H, W) tuple for standard mode, None for NaFlex mode """ - grid_size: Optional[List[int]] = None + grid_size: Optional[Tuple[int, int]] = None B = x.shape[0] if self.is_linear: # Linear embedding path, works with NaFlex mode or standard 2D mode @@ -783,8 +901,6 @@ def forward( self._apply_factorized_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord) else: self._apply_factorized_naflex_pos_embed(x, patch_coord=patch_coord) - elif self.pos_embed_type == 'rope': - assert False, "ROPE not yet implemented" # Prepare and add class and register tokens to_cat = [] @@ -796,10 +912,10 @@ def forward( if to_cat: x = torch.cat(to_cat + [x], dim=1) - # Apply dropouts + # Apply dropout x = self.pos_drop(x) - x = self.patch_drop(x) - return x + + return x, grid_size @register_notrace_function @@ -990,7 +1106,7 @@ def __init__( norm_layer = get_norm_layer(cfg.norm_layer) or LayerNorm embed_norm_layer = get_norm_layer(cfg.embed_norm_layer) act_layer = get_act_layer(cfg.act_layer) or nn.GELU - block_fn = cfg.block_fn or Block # TODO: Support configurable block_fn via string lookup + block_fn = get_block_fn(cfg) mlp_layer = cfg.mlp_layer or Mlp # TODO: Support configurable mlp_layer via string lookup # Store instance variables @@ -1023,13 +1139,51 @@ def __init__( pos_embed_use_grid_sample=cfg.pos_embed_use_grid_sample, proj_norm_layer=embed_norm_layer, pos_drop_rate=cfg.pos_drop_rate, - patch_drop_rate=cfg.patch_drop_rate, enable_patch_interpolator=getattr(cfg, 'enable_patch_interpolator', False), ) self.norm_pre = norm_layer(cfg.embed_dim) if cfg.pre_norm else nn.Identity() + # ROPE position embeddings at model level + self.rope: Optional[nn.Module] = None + self.rope_is_mixed = False + if cfg.rope_type and cfg.rope_type != 'none': + from timm.layers.pos_embed_sincos import RotaryEmbeddingCat, RotaryEmbeddingMixed + if cfg.rope_type == 'mixed': + self.rope = RotaryEmbeddingMixed( + cfg.embed_dim, + depth=cfg.depth, + num_heads=cfg.num_heads, + temperature=cfg.rope_temperature, + feat_shape=None, # Dynamic shapes for NaFlex + grid_indexing=cfg.rope_grid_indexing, + ) + self.rope_is_mixed = True + elif cfg.rope_type == 'axial': + self.rope = RotaryEmbeddingCat( + cfg.embed_dim // cfg.num_heads, + temperature=cfg.rope_temperature, + in_pixels=False, + feat_shape=None, # Dynamic shapes for NaFlex + ref_feat_shape=cfg.rope_ref_feat_shape, + grid_offset=cfg.rope_grid_offset, + grid_indexing=cfg.rope_grid_indexing, + ) + self.rope_is_mixed = False + else: + raise ValueError(f"Unknown rope_type: {cfg.rope_type}") + + # Patch dropout + if cfg.patch_drop_rate > 0: + self.patch_drop = PatchDropoutWithIndices( + cfg.patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = None + # Transformer blocks dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)] # stochastic depth decay rule + # Create transformer blocks self.blocks = nn.Sequential(*[ block_fn( dim=cfg.embed_dim, @@ -1046,7 +1200,8 @@ def __init__( act_layer=act_layer, mlp_layer=mlp_layer, ) - for i in range(cfg.depth)]) + for i in range(cfg.depth) + ]) # Feature info for downstream tasks patch_reduction = self.embeds.feat_ratio(as_scalar=True) @@ -1061,8 +1216,8 @@ def __init__( if cfg.global_pool == 'map': self.attn_pool = AttentionPoolLatent( self.embed_dim, - num_heads=cfg.num_heads, - mlp_ratio=cfg.mlp_ratio, + num_heads=cfg.attn_pool_num_heads or cfg.num_heads, + mlp_ratio=cfg.attn_pool_mlp_ratio or cfg.mlp_ratio, norm_layer=norm_layer, act_layer=act_layer, ) @@ -1085,11 +1240,19 @@ def __init__( def fix_init_weight(self) -> None: """Apply initialization weight fix with layer-wise scaling.""" def rescale(param: torch.Tensor, _layer_id: int) -> None: - param.div_(math.sqrt(2.0 * _layer_id)) + with torch.no_grad(): + param.div_(math.sqrt(2.0 * _layer_id)) for layer_id, layer in enumerate(self.blocks): - rescale(layer.attn.proj.weight.data, layer_id + 1) - rescale(layer.mlp.fc2.weight.data, layer_id + 1) + if hasattr(layer, 'attn'): + rescale(layer.attn.proj.weight, layer_id + 1) + if hasattr(layer, 'mlp'): + rescale(layer.mlp.fc2.weight, layer_id + 1) + if hasattr(layer, 'attn_out_proj'): + rescale(layer.attn_out_proj.weight, layer_id + 1) + if hasattr(layer, 'mlp_out_proj'): + rescale(layer.mlp_out_proj.weight, layer_id + 1) + def init_weights(self, mode: str = '') -> None: """Initialize model weights according to specified scheme. @@ -1135,6 +1298,8 @@ def no_weight_decay(self) -> Set: Set of parameter names to skip during weight decay """ skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'} + if self.rope and hasattr(self.rope, 'no_weight_decay'): + skip_list.update(self.rope.no_weight_decay()) return skip_list @torch.jit.ignore @@ -1172,6 +1337,75 @@ def get_classifier(self) -> nn.Module: """ return self.head + @disable_compiler + def _generate_rope_naflex( + self, + x: torch.Tensor, + patch_coord: torch.Tensor, + ) -> Union[torch.Tensor, List[torch.Tensor], Any]: + """Generate ROPE position embeddings for NaFlex batch with variable grid sizes. + + Args: + x: Input tensor [B, N, C] + patch_coord: Patch coordinates [B, N, 2] with (y, x) values + + Returns: + ROPE embeddings: + - Axial mode: Tensor of shape [B, 1, N, dim*2] + - Mixed mode: List of tensors, each of shape [B, num_heads, N, dim], one per depth layer + - Mixed mode with iterator: Iterator yielding tensors per depth + """ + # Calculate grid sizes for each sample + naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord) + + # Build ROPE embeddings for each unique grid size + size_to_indices = {} + unique_sizes = [] + for bi, grid_size in enumerate(naflex_grid_sizes): + if grid_size not in size_to_indices: + size_to_indices[grid_size] = [] + unique_sizes.append(grid_size) + size_to_indices[grid_size].append(bi) + + B, N, C = x.shape + seq_len = N - self.num_prefix_tokens + + if self.rope_is_mixed: + # Use an iterator for Mixed mode, returns [batch_size, depth, num_heads, seq_len, dim] + return NaFlexRopeIterator( + self.rope, + size_to_indices, + unique_sizes, + B, + seq_len, + x.dtype, + x.device + ) + + # Axial mode: [batch_size, seq_len, dim*2] + rope_embeds = torch.zeros(B, seq_len, self.rope.dim * 2, dtype=x.dtype, device=x.device) + + if hasattr(self.rope, 'get_batch_embeds'): + # Batch mode - generate unique embeds from one grid and then assign + unique_embeds = self.rope.get_batch_embeds(unique_sizes) + for grid_size, embed, batch_indices in zip(unique_sizes, unique_embeds, size_to_indices.values()): + h, w = grid_size + actual_len = h * w + for bi in batch_indices: + rope_embeds[bi, :actual_len] = embed[:actual_len] + + else: + # Generate each unique size separately and assign + for grid_size, bi in size_to_indices.items(): + rope_embed = self.rope.get_embed(shape=grid_size) + h, w = grid_size + actual_len = h * w + rope_embeds[bi, :actual_len] = rope_embed[:actual_len] + + rope_embeds = rope_embeds.unsqueeze(1) + + return rope_embeds + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: """Reset the classification head with new number of classes and pooling. @@ -1189,6 +1423,68 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def _forward_embeds( + self, + x, + patch_coord, + patch_valid, + attn_mask, + ) -> Dict[str, torch.Tensor]: + """ Forward pass through patch / abs pos / rope pos embeds and patch dropout + """ + naflex_mode = patch_coord is not None + + # patch embed, abs pos embed, returns global grid size as calculated from 'standard' NCHW batches + x, grid_size = self.embeds( + x, + patch_coord=patch_coord, + patch_valid=patch_valid, + ) + + # Generate ROPE embeddings at model level + rope_embeds = None + if self.rope is not None: + if patch_coord is not None: + # NaFlex mode - variable grid sizes + rope_embeds = self._generate_rope_naflex(x, patch_coord) + elif grid_size is not None: + # Standard mode - fixed grid size + rope_embeds = self.rope.get_embed(shape=grid_size) + else: + assert False, 'Expected one of patch_coord or grid_size to be valid' + + # Apply patch dropout with coordinated updates + keep_indices: Optional[torch.Tensor] = None + if self.training and self.patch_drop is not None: + x, keep_indices = self.patch_drop(x) + # keep_indices excludes prefix tokens, can use directly on patch_valid & rope embeds + if patch_valid is not None: + patch_valid = patch_valid.gather(1, keep_indices) + if rope_embeds is not None and not self.rope_is_mixed: + # Update ROPE embeddings to match dropped tokens (only for axial mode) + # Batch dim already present in NaFlex mode, but will be added in standard mode. + rope_embeds = apply_keep_indices_nlc(x, rope_embeds, keep_indices, pos_embed_has_batch=naflex_mode) + if not naflex_mode: + # B, N, dim -> B, 1, N, dim. Need head dim added for standard mode, already added in NaFlex. + rope_embeds = rope_embeds.unsqueeze(1) + + # Create attention mask from patch_valid after patch dropout applied + if attn_mask is None: + attn_mask = create_attention_mask( + patch_valid, + num_prefix_tokens=self.num_prefix_tokens, + dtype=x.dtype + ) + + x = self.norm_pre(x) + return { + 'patches': x, + 'patch_valid': patch_valid, + 'rope_embeds': rope_embeds, + 'attn_mask': attn_mask, + 'keep_indices': keep_indices, + } + def forward_intermediates( self, x: Union[torch.Tensor, Dict[str, torch.Tensor]], @@ -1239,13 +1535,17 @@ def forward_intermediates( height, width = x.shape[-2:] H, W = self.embeds.dynamic_feat_size((height, width)) - # Create attention mask if patch_type is provided and mask is not - if attn_mask is None and patch_valid is not None: - attn_mask = create_attention_mask(patch_valid, self.num_prefix_tokens, patches.dtype) - - # Forward pass through embedding - x = self.embeds(patches, patch_coord=patch_coord) - x = self.norm_pre(x) + # Forward pass through patch and abs position embedding + embeds = self._forward_embeds( + patches, + patch_coord=patch_coord, + patch_valid=patch_valid, + attn_mask=attn_mask, + ) + x = embeds['patches'] + rope_embeds = embeds.get('rope_embeds', None) + keep_indices = embeds.get('keep_indices', None) + attn_mask = embeds.get('attn_mask', None) # Forward pass through blocks if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript @@ -1253,16 +1553,42 @@ def forward_intermediates( else: blocks = self.blocks[:max_index + 1] - for i, blk in enumerate(blocks): - if attn_mask is not None: - x = blk(x, attn_mask=attn_mask) - elif self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint(blk, x) - else: - x = blk(x) - if i in take_indices: - # normalize intermediates with final norm layer if enabled - intermediates.append(self.norm(x) if norm else x) + do_checkpointing = self.grad_checkpointing and not torch.jit.is_scripting() + if self.rope_is_mixed and rope_embeds is not None: + # Mixed mode with per-layer embeddings (list or iterator) + for i, (blk, rope_embed) in enumerate(zip(self.blocks, rope_embeds)): + # Apply patch dropout to rope_embed if needed + if self.training and self.patch_drop is not None and keep_indices is not None: + # Apply patch dropout to rope_embed if needed (batch dim already present in naflex mode) + rope_embed = apply_keep_indices_nlc( + x, + rope_embed, + keep_indices, + pos_embed_has_batch=embeds.get('naflex_mode', False), + ) + if do_checkpointing: + x = checkpoint(blk, x, rope=rope_embed, attn_mask=attn_mask) + else: + x = blk(x, rope=rope_embed, attn_mask=attn_mask) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + else: + for i, blk in enumerate(blocks): + # Axial ROPE mode with shared embeddings + if rope_embeds is not None: + if do_checkpointing: + x = checkpoint(blk, x, rope=rope_embeds, attn_mask=attn_mask) + else: + x = blk(x, rope=rope_embeds, attn_mask=attn_mask) + else: + if do_checkpointing: + x = checkpoint(blk, x, attn_mask=attn_mask) + else: + x = blk(x, attn_mask=attn_mask) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) # Process intermediates if self.num_prefix_tokens: @@ -1279,6 +1605,8 @@ def forward_intermediates( for y in intermediates ] + # FIXME always use dict for NaFlex mode to return masks and more? + # For dictionary output if output_dict: result_dict = {} @@ -1308,33 +1636,66 @@ def forward_intermediates( def forward_features( self, - x: torch.Tensor, + patches: torch.Tensor, patch_coord: Optional[torch.Tensor] = None, patch_valid: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - - if attn_mask is None: - attn_mask = create_attention_mask( - patch_valid, - num_prefix_tokens=self.num_prefix_tokens, - dtype=x.dtype - ) + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + """ + naflex_mode = patch_coord is not None - # Pass through embedding module with patch coordinate/type support - x = self.embeds(x, patch_coord=patch_coord) - x = self.norm_pre(x) - # Apply transformer blocks with masked attention if mask provided - if attn_mask is not None: - # We need to apply blocks one by one with mask + # Pass through patch & abs position embedding module with patch coordinate/type support + embeds = self._forward_embeds( + patches, + patch_coord=patch_coord, + patch_valid=patch_valid, + attn_mask=attn_mask, + ) + x = embeds['patches'] + rope_embeds = embeds.get('rope_embeds', None) + keep_indices = embeds.get('keep_indices', None) + attn_mask = embeds.get('attn_mask', None) + + # Apply transformer blocks with masked attention and/or ROPE if provided + do_checkpointing = self.grad_checkpointing and not torch.jit.is_scripting() + if self.rope_is_mixed and rope_embeds is not None: + # Mixed mode with per-layer embeddings (list or iterator) + for i, (blk, rope_embed) in enumerate(zip(self.blocks, rope_embeds)): + if self.training and self.patch_drop is not None and keep_indices is not None: + # Apply patch dropout to rope_embed if needed (batch dim already present in naflex mode) + rope_embed = apply_keep_indices_nlc( + x, + rope_embed, + keep_indices, + pos_embed_has_batch=naflex_mode, + ) + if do_checkpointing: + x = checkpoint(blk, x, rope=rope_embed, attn_mask=attn_mask) + else: + x = blk(x, rope=rope_embed, attn_mask=attn_mask) + elif rope_embeds is not None: + # Axial ROPE mode with shared embeddings for blk in self.blocks: - x = blk(x, attn_mask=attn_mask) - elif self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x) + if do_checkpointing: + x = checkpoint(blk, x, rope=rope_embeds, attn_mask=attn_mask) + else: + x = blk(x, rope=rope_embeds, attn_mask=attn_mask) else: - x = self.blocks(x) + for blk in self.blocks: + if do_checkpointing: + x = checkpoint(blk, x, attn_mask=attn_mask) + else: + x = blk(x, attn_mask=attn_mask) x = self.norm(x) + + if naflex_mode: + return { + 'patches': x, + 'patch_valid': embeds.get('patch_valid', None), + } + return x def _pool( @@ -1369,11 +1730,11 @@ def _pool( def forward_head( self, - x: torch.Tensor, + patches: torch.Tensor, pre_logits: bool = False, patch_valid: Optional[torch.Tensor] = None, ) -> torch.Tensor: - x = self._pool(x, patch_valid=patch_valid) + x = self._pool(patches, patch_valid=patch_valid) x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) @@ -1383,6 +1744,7 @@ def forward( x: Union[torch.Tensor, Dict[str, torch.Tensor]], patch_coord: Optional[torch.Tensor] = None, patch_valid: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with optional NaFlex support. @@ -1394,49 +1756,53 @@ def forward( - Dict from NaFlex collator patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode. patch_valid: Optional patch validity indicators for NaFlex. + attn_mask: Optional attn mask to override defaults generated from patch_valid Returns: Model output tensor. """ - if isinstance(x, Dict): - # Handle dictionary input from NaFlex collator - patch_coord = x['patch_coord'] - patch_valid = x['patch_valid'] - patches = x['patches'] + input_is_dict = isinstance(x, Dict) + naflex_mode = input_is_dict or patch_coord is not None + if naflex_mode: + if input_is_dict: + # Handle dictionary input from NaFlex collator, dict inputs take priority over args + patches = x['patches'] + patch_valid = x.get('patch_valid', patch_valid) + patch_coord = x.get('patch_coord', patch_coord) + attn_mask = x.get('attn_mask', attn_mask) + else: + patches = x + _assert(patch_coord is not None, "patch_coord is required in naflex mode") + _assert(patch_valid is not None, "patch_valid is required in naflex mode") + + features = self.forward_features( + patches=patches, + patch_valid=patch_valid, + patch_coord=patch_coord, + attn_mask=attn_mask, + ) - # DEBUG, reconstruct patches - # for i in range(len(patches)): - # patch = patches[i][patch_valid[i]] - # h = (patch_coord[i, :, 0].max() + 1).item() - # w = (patch_coord[i, :, 1].max() + 1).item() - # patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3) - # patch = patch.reshape(3, h*16, w*16) - # from torchvision.utils import save_image - # save_image(patch, f'patch_{i}.jpg', normalize=True) + # Pass patches & patch_valid to forward_head for masked pooling + x = self.forward_head(**features) else: - patches = x - - # Create attention mask if patch_type is provided - attn_mask = create_attention_mask( - patch_valid, - num_prefix_tokens=self.num_prefix_tokens, - dtype=patches.dtype, - ) + x = self.forward_features(x) + x = self.forward_head(x) + return x - # Forward features with mask - x = self.forward_features( - patches, - patch_coord=patch_coord, - patch_valid=patch_valid, - attn_mask=attn_mask, - ) - # Pass mask to forward_head for masked pooling - x = self.forward_head( - x, - patch_valid=patch_valid, - ) - return x +def _debug_dump_patches(x): + # DEBUG, reconstruct patches & save + patch_coord = x['patch_coord'] + patch_valid = x['patch_valid'] + patches = x['patches'] + for i in range(len(patches)): + patch = patches[i][patch_valid[i]] + h = (patch_coord[i, :, 0].max() + 1).item() + w = (patch_coord[i, :, 1].max() + 1).item() + patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3) + patch = patch.reshape(3, h*16, w*16) + from torchvision.utils import save_image + save_image(patch, f'patch_{i}.jpg', normalize=True) def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: @@ -1453,7 +1819,6 @@ def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: def checkpoint_filter_fn(state_dict: Dict[str, Any], model: NaFlexVit) -> Dict[str, Any]: """Handle state dict conversion from original ViT to the new version with combined embedding.""" - from .vision_transformer import checkpoint_filter_fn as orig_filter_fn # Handle CombinedEmbed module pattern out_dict = {} @@ -1615,12 +1980,81 @@ def _create_naflexvit_from_classic( 'class_token': kwargs.get('class_token', True), 'global_pool': gp, 'fc_norm': fc_norm, + 'scale_mlp_norm': kwargs.pop('scale_mlp_norm', False), + 'scale_attn_inner_norm': kwargs.pop('scale_attn_norm', False), **kwargs # User overrides take precedence } return _create_naflexvit(variant, pretrained, **flex_kwargs) +def _create_naflexvit_from_eva( + variant: str, + pretrained: bool = False, + **kwargs, +) -> NaFlexVit: + """Create NaFlexVit model from EVA configuration. + + This function handles the parameter mapping and configuration logic needed + to create NaFlexVit models that are compatible with EVA configurations + and pretrained weights. + + Args: + variant: Model variant name + pretrained: Whether to load pretrained weights + **kwargs: EVA model parameters + + Returns: + NaFlexVit model instance + """ + # Handle EVA's unique parameters & block args + kwargs.pop('no_embed_class', None) # EVA specific, not used in NaFlexVit (always no-embed) + + # Map EVA's rope parameters + use_rot_pos_emb = kwargs.pop('use_rot_pos_emb', False) + rope_mixed_mode = kwargs.pop('rope_mixed_mode', False) + rope_temperature = kwargs.pop('rope_temperature', 10000.) + rope_grid_offset = kwargs.pop('rope_grid_offset', 0.) + rope_grid_indexing = kwargs.pop('rope_grid_indexing', 'ij') + if use_rot_pos_emb: + rope_type = 'mixed' if rope_mixed_mode else 'axial' + else: + rope_type = 'none' + + # Handle norm/pool resolution logic to mirror EVA + gp = kwargs.pop('global_pool', 'avg') + use_pre_transformer_norm = kwargs.pop('use_pre_transformer_norm', False) + use_post_transformer_norm = kwargs.pop('use_post_transformer_norm', True) + use_fc_norm = kwargs.pop('use_fc_norm', None) + if use_fc_norm is None: + use_fc_norm = gp == 'avg' # default on if avg pool used + + # Set NaFlexVit-specific parameters + naflex_kwargs = { + 'pos_embed_grid_size': None, # rely on img_size (// patch_size) + 'class_token': kwargs.get('class_token', True), + 'reg_tokens': kwargs.pop('num_reg_tokens', kwargs.get('reg_tokens', 0)), + 'global_pool': gp, + 'pre_norm': use_pre_transformer_norm, + 'final_norm': use_post_transformer_norm, + 'fc_norm': use_fc_norm, + 'pos_embed': 'learned' if kwargs.pop('use_abs_pos_emb', True) else 'none', + 'rope_type': rope_type, + 'rope_temperature': rope_temperature, + 'rope_grid_offset': rope_grid_offset, + 'rope_grid_indexing': rope_grid_indexing, + 'rope_ref_feat_shape': kwargs.get('ref_feat_shape', None), + 'attn_type': kwargs.pop('attn_type', 'eva'), + 'swiglu_mlp': kwargs.pop('swiglu_mlp', False), + 'qkv_fused': kwargs.pop('qkv_fused', True), + 'scale_mlp_norm': kwargs.pop('scale_mlp', False), + 'scale_attn_inner_norm': kwargs.pop('scale_attn_inner', False), + **kwargs # Pass remaining kwargs through + } + + return _create_naflexvit(variant, pretrained, **naflex_kwargs) + + @register_model def naflexvit_base_patch16_gap(pretrained: bool = False, **kwargs) -> NaFlexVit: """ViT-Base with NaFlex functionality and global average pooling. diff --git a/timm/version.py b/timm/version.py index af8bddfa0a..dce8b34b2b 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '1.0.17' +__version__ = '1.0.19' diff --git a/validate.py b/validate.py index afca5e561b..e5ff9494bf 100755 --- a/validate.py +++ b/validate.py @@ -300,6 +300,10 @@ def validate(args): crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] if args.naflex_loader: + model_patch_size = None + if hasattr(model, 'embeds') and hasattr(model.embeds, 'patch_size'): + # NaFlexVit models have embeds.patch_size + model_patch_size = model.embeds.patch_size from timm.data import create_naflex_loader loader = create_naflex_loader( dataset, @@ -315,7 +319,7 @@ def validate(args): pin_memory=args.pin_mem, device=device, img_dtype=model_dtype or torch.float32, - patch_size=16, # Could be derived from model config + patch_size=model_patch_size or (16, 16), max_seq_len=args.naflex_max_seq_len, ) else: