from typing import Optional, Tuple, Union import torch import torch.nn as nn from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_utils import PreTrainedModel from transformers.models.siglip.modeling_siglip import SiglipMLP from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) from .configuration_onevision_encoder import OneVisionEncoderConfig try: from flash_attn import flash_attn_func _flash_attn_available = True except ImportError: _flash_attn_available = False logger = logging.get_logger(__name__) # --------------------------------------------------------------------------- # Model Docstrings # --------------------------------------------------------------------------- ONEVISION_ENCODER_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`OneVisionEncoderConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ ONEVISION_ENCODER_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch_size, num_channels, num_frames, height, width)`): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. visible_indices (`torch.Tensor`, *optional*): Indices of visible patches for masking. Used in MAE-style pretraining or inference. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ # --------------------------------------------------------------------------- # Helper Functions & Layers # --------------------------------------------------------------------------- def get_norm_layer(config): if config.layer_norm_type == "rms_norm": return nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps) else: return nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def rotate_half(x): """ Interleaved rotation to match Source model's implementation. (x1, x2, x3, x4) -> (-x2, x1, -x4, x3) """ x_even = x[..., ::2] x_odd = x[..., 1::2] return torch.stack((-x_odd, x_even), dim=-1).flatten(-2) def apply_rotary_pos_emb(q, k, freqs): # q, k: (B, H, L, D) # freqs: (B, L, D) # We need to broadcast freqs to match heads # (B, L, D) -> (B, 1, L, D) # !!! CRITICAL FIX: Cast cos/sin to q.dtype (bf16/fp16) immediately # freqs are typically float32, so cos() returns float32. # Without this cast, (q * cos) upcasts q to float32, causing FlashAttention to fail. cos = freqs.cos().unsqueeze(1).to(q.dtype) sin = freqs.sin().unsqueeze(1).to(q.dtype) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class VideoRotaryEmbeddingSplit466(nn.Module): """ 3D (T,H,W) Rotary frequency constructor with 4:6:6 split. """ def __init__(self, config: OneVisionEncoderConfig): super().__init__() head_dim = config.hidden_size // config.num_attention_heads base = config.rope_theta assert head_dim % 2 == 0, "head_dim must be even for rotary." assert head_dim % 16 == 0, "head_dim must be divisible by 16." half = head_dim // 2 assert half % 16 == 0, "head_dim//2 must also be divisible by 16 to split into 4:6:6." self.head_dim = head_dim self.half = half unit = half // 16 self.t_size = 4 * unit self.h_size = 6 * unit self.w_size = 6 * unit self.register_buffer( "inv_freq_t", 1.0 / (base ** (torch.arange(self.t_size, dtype=torch.float32) / self.t_size)), persistent=False, ) self.register_buffer( "inv_freq_h", 1.0 / (base ** (torch.arange(self.h_size, dtype=torch.float32) / self.h_size)), persistent=False, ) self.register_buffer( "inv_freq_w", 1.0 / (base ** (torch.arange(self.w_size, dtype=torch.float32) / self.w_size)), persistent=False, ) def forward(self, t: int, h: int, w: int, device=None): if device is None: device = self.inv_freq_t.device inv_t = self.inv_freq_t.to(device=device) inv_h = self.inv_freq_h.to(device=device) inv_w = self.inv_freq_w.to(device=device) ft = torch.outer(torch.arange(t, device=device, dtype=torch.float32), inv_t) fh = torch.outer(torch.arange(h, device=device, dtype=torch.float32), inv_h) fw = torch.outer(torch.arange(w, device=device, dtype=torch.float32), inv_w) t_ids = torch.arange(t, device=device).repeat_interleave(h * w) h_ids = torch.arange(h, device=device).repeat_interleave(w).repeat(t) w_ids = torch.arange(w, device=device).repeat(h).repeat(t) freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1) return freqs def forward_from_positions(self, patch_positions: torch.Tensor) -> torch.Tensor: """ Compute rotary position embeddings from explicit patch positions. Args: patch_positions: [batch_size, seq_len, 3] tensor with [t, h, w] positions for each patch Returns: freqs: [batch_size, seq_len, half] tensor of position frequencies """ device = patch_positions.device inv_t = self.inv_freq_t.to(device=device) inv_h = self.inv_freq_h.to(device=device) inv_w = self.inv_freq_w.to(device=device) t_pos = patch_positions[..., 0].float() # [batch_size, seq_len] h_pos = patch_positions[..., 1].float() # [batch_size, seq_len] w_pos = patch_positions[..., 2].float() # [batch_size, seq_len] # Use einsum for batched outer product: [batch_size, seq_len] x [dim] -> [batch_size, seq_len, dim] ft = torch.einsum("bs,d->bsd", t_pos, inv_t) fh = torch.einsum("bs,d->bsd", h_pos, inv_h) fw = torch.einsum("bs,d->bsd", w_pos, inv_w) return torch.cat([ft, fh, fw], dim=-1) class Siglip2MultiheadAttentionPoolingHead(nn.Module): """ Multi-Head Attention Pooling with a learned probe (PMA-style). """ def __init__(self, config: OneVisionEncoderConfig): super().__init__() self.embed_dim = config.hidden_size self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.attention = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) self.norm = nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) def forward(self, hidden_states): batch_size = hidden_states.shape[0] probe = self.probe.repeat(batch_size, 1, 1) attn_output, _ = self.attention(probe, hidden_states, hidden_states) residual = attn_output attn_output = self.norm(attn_output) attn_output = residual + self.mlp(attn_output) return attn_output[:, 0] # --------------------------------------------------------------------------- # Modeling Components # --------------------------------------------------------------------------- class OneVisionEncoderEmbeddings(nn.Module): def __init__(self, config: OneVisionEncoderConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False, ) def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: # Handle 4D (B, C, H, W) or 5D (B, C, T, H, W) inputs if pixel_values.dim() == 4: pixel_values = pixel_values.unsqueeze(2) # (B, C, 1, H, W) batch_size, channels, t_frames, height, width = pixel_values.shape # Merge time into batch for Conv2d x_2d = pixel_values.permute(0, 2, 1, 3, 4).reshape(batch_size * t_frames, channels, height, width) # Patch Embed embeddings = self.patch_embedding(x_2d) # (B*T, C, Hp, Wp) embeddings = embeddings.flatten(2).transpose(1, 2) # (B*T, L_frame, C) # Flatten all patches total_patches = t_frames * (height // self.patch_size) * (width // self.patch_size) embeddings = embeddings.reshape(batch_size, total_patches, self.embed_dim) return embeddings class OneVisionEncoderAttention(nn.Module): """Multi-headed attention with RoPE support""" def __init__(self, config: OneVisionEncoderConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # (B, L, H, D) -> Transpose to (B, H, L, D) query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) if rotary_pos_emb is not None: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, rotary_pos_emb) # Calculate attention scores attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale if attention_mask is not None: if attention_mask.size() != (batch_size, 1, q_len, q_len): if attention_mask.dim() == 3: attention_mask = attention_mask.unsqueeze(1) attn_weights = attn_weights + attention_mask # FIX: Remove dtype=torch.float32 to stay in original dtype (bf16/fp16) attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights if output_attentions else None class OneVisionEncoderFlashAttention2(nn.Module): """ Multi-headed attention with RoPE support using Flash Attention 2. This module implements the same attention mechanism as OneVisionEncoderAttention but uses Flash Attention for improved performance and memory efficiency. """ def __init__(self, config: OneVisionEncoderConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Forward pass using Flash Attention 2. """ batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # Flash Attention requires (B, L, H, D) format query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) # Apply RoPE if provided if rotary_pos_emb is not None: # Transpose for RoPE application: (B, L, H, D) -> (B, H, L, D) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) # NOTE: apply_rotary_pos_emb now ensures NO float32 cast happens query_states, key_states = apply_rotary_pos_emb(query_states, key_states, rotary_pos_emb) # Transpose back: (B, H, L, D) -> (B, L, H, D) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) # Flash Attention forward pass if not _flash_attn_available: raise ImportError("flash_attn is not installed. Please install it to use OneVisionEncoderFlashAttention2.") attn_output = flash_attn_func( query_states, key_states, value_states, dropout_p=self.dropout if self.training else 0.0, softmax_scale=self.scale, causal=False, ) # Reshape to (B, L, embed_dim) attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) # No extra casting here. attn_output = self.out_proj(attn_output) return attn_output, None ONEVISION_ENCODER_ATTENTION_CLASSES = { "eager": OneVisionEncoderAttention, "flash_attention_2": OneVisionEncoderFlashAttention2, } class OneVisionEncoderEncoderLayer(nn.Module): def __init__(self, config: OneVisionEncoderConfig): super().__init__() self.embed_dim = config.hidden_size # Get attention implementation from config, default to "flash_attention_2" attn_implementation = getattr(config, "_attn_implementation", "flash_attention_2") if attn_implementation not in ONEVISION_ENCODER_ATTENTION_CLASSES: # Fallback to eager if flash_attention_2 is not available if not _flash_attn_available and attn_implementation == "flash_attention_2": attn_implementation = "eager" else: raise ValueError( f"Unknown attention implementation: {attn_implementation}. " f"Available implementations: {list(ONEVISION_ENCODER_ATTENTION_CLASSES.keys())}" ) self.self_attn = ONEVISION_ENCODER_ATTENTION_CLASSES[attn_implementation](config) self.layer_norm1 = get_norm_layer(config) self.mlp = SiglipMLP(config) self.layer_norm2 = get_norm_layer(config) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states, attn_weights) if output_attentions else (hidden_states,) return outputs class OneVisionEncoderEncoder(nn.Module): def __init__(self, config: OneVisionEncoderConfig): super().__init__() self.config = config self.layers = nn.ModuleList([OneVisionEncoderEncoderLayer(config) for _ in range(config.num_hidden_layers)]) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, rotary_pos_emb: Optional[torch.Tensor] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ) -> Union[tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None for layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = layer( hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) # --------------------------------------------------------------------------- # Main Models # --------------------------------------------------------------------------- @add_start_docstrings( "The bare OneVision Encoder Model outputting raw hidden-states without any specific head on top.", ONEVISION_ENCODER_START_DOCSTRING, ) class OneVisionEncoderPreTrainedModel(PreTrainedModel): config_class = OneVisionEncoderConfig base_model_prefix = "onevision_encoder" supports_gradient_checkpointing = True _no_split_modules = ["OneVisionEncoderEncoderLayer"] _supports_flash_attn_2 = True def _init_weights(self, module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)): # Fix: RMSNorm doesn't have bias, must check hasattr first module.weight.data.fill_(1.0) if hasattr(module, "bias") and module.bias is not None: module.bias.data.zero_() @add_start_docstrings( "OneVision Encoder Model with a vision transformer encoder.", ONEVISION_ENCODER_START_DOCSTRING, ) class OneVisionEncoderModel(OneVisionEncoderPreTrainedModel): def __init__(self, config: OneVisionEncoderConfig): super().__init__(config) self.config = config self.embeddings = OneVisionEncoderEmbeddings(config) self.layernorm_pre = get_norm_layer(config) self.encoder = OneVisionEncoderEncoder(config) self.video_rope = VideoRotaryEmbeddingSplit466(config) if config.use_head: self.layernorm_post = get_norm_layer(config) self.head = Siglip2MultiheadAttentionPoolingHead(config) else: self.layernorm_post = None self.head = None self.post_init() @add_start_docstrings_to_model_forward(ONEVISION_ENCODER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OneVisionEncoderConfig) def forward( self, pixel_values: torch.Tensor, visible_indices: Optional[torch.Tensor] = None, patch_positions: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutputWithPooling]: r""" Returns: Examples: ```python >>> from transformers import AutoModel, AutoImageProcessor >>> from PIL import Image >>> model = AutoModel.from_pretrained("lmms-lab-encoder/onevision-encoder-large", trust_remote_code=True) >>> preprocessor = AutoImageProcessor.from_pretrained("lmms-lab-encoder/onevision-encoder-large", trust_remote_code=True) >>> image = Image.open("path/to/your/image.jpg") # Replace with your image path >>> pixel_values = preprocessor(images=image, return_tensors="pt")["pixel_values"] >>> outputs = model(pixel_values) >>> last_hidden_states = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Determine video dimensions for RoPE # Note: pixel_values passed to embeddings can be 4D or 5D if pixel_values.dim() == 5: # Use config.rope_temporal_size if set, otherwise use actual frame count t_frames = ( self.config.rope_temporal_size if self.config.rope_temporal_size is not None else pixel_values.shape[2] ) height = pixel_values.shape[3] width = pixel_values.shape[4] else: t_frames = 1 height = pixel_values.shape[2] width = pixel_values.shape[3] # 1. Embeddings hidden_states = self.embeddings(pixel_values) batch_size, total_patches, _ = hidden_states.shape # 2. Visible Indices Handling if visible_indices is None: visible_indices = ( torch.arange(total_patches, device=pixel_values.device).unsqueeze(0).expand(batch_size, -1) ) # 3. RoPE Construction if patch_positions is not None: freqs_visible = self.video_rope.forward_from_positions(patch_positions) else: freqs_full = self.video_rope( t=t_frames, h=height // self.config.patch_size, w=width // self.config.patch_size, device=pixel_values.device, ) freqs_visible = freqs_full[visible_indices] # Concatenate D/2 + D/2 -> D for applying rope freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1) # 4. Pre-Norm & Encoder hidden_states = self.layernorm_pre(hidden_states) # fix: gather hidden_states to match freqs_visible when using sparse visible_indices num_visible = visible_indices.shape[1] if num_visible != total_patches: # sparse mode: select only visible patches hidden_states = hidden_states.gather( 1, visible_indices.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) ) encoder_outputs = self.encoder( hidden_states, attention_mask=None, rotary_pos_emb=freqs_visible, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] # Apply post-norm if configured if self.layernorm_post is not None: sequence_output = self.layernorm_post(sequence_output) # 5. Pooling Head pooled_output = None if self.head is not None: pooled_output = self.head(sequence_output) if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )