File size: 26,909 Bytes
c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db bf5878a 4a1d1db bf5878a 4a1d1db bf5878a 4a1d1db bf5878a 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 5f3fd49 c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 4a1d1db c2b8f87 5f3fd49 4a1d1db c2b8f87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 |
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.models.siglip.modeling_siglip import SiglipMLP
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_onevision_encoder import OneVisionEncoderConfig
try:
from flash_attn import flash_attn_func
_flash_attn_available = True
except ImportError:
_flash_attn_available = False
logger = logging.get_logger(__name__)
# ---------------------------------------------------------------------------
# Model Docstrings
# ---------------------------------------------------------------------------
ONEVISION_ENCODER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`OneVisionEncoderConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
ONEVISION_ENCODER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch_size, num_channels, num_frames, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`].
visible_indices (`torch.Tensor`, *optional*):
Indices of visible patches for masking. Used in MAE-style pretraining or inference.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# ---------------------------------------------------------------------------
# Helper Functions & Layers
# ---------------------------------------------------------------------------
def get_norm_layer(config):
if config.layer_norm_type == "rms_norm":
return nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
else:
return nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def rotate_half(x):
"""
Interleaved rotation to match Source model's implementation.
(x1, x2, x3, x4) -> (-x2, x1, -x4, x3)
"""
x_even = x[..., ::2]
x_odd = x[..., 1::2]
return torch.stack((-x_odd, x_even), dim=-1).flatten(-2)
def apply_rotary_pos_emb(q, k, freqs):
# q, k: (B, H, L, D)
# freqs: (B, L, D)
# We need to broadcast freqs to match heads
# (B, L, D) -> (B, 1, L, D)
# !!! CRITICAL FIX: Cast cos/sin to q.dtype (bf16/fp16) immediately
# freqs are typically float32, so cos() returns float32.
# Without this cast, (q * cos) upcasts q to float32, causing FlashAttention to fail.
cos = freqs.cos().unsqueeze(1).to(q.dtype)
sin = freqs.sin().unsqueeze(1).to(q.dtype)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class VideoRotaryEmbeddingSplit466(nn.Module):
"""
3D (T,H,W) Rotary frequency constructor with 4:6:6 split.
"""
def __init__(self, config: OneVisionEncoderConfig):
super().__init__()
head_dim = config.hidden_size // config.num_attention_heads
base = config.rope_theta
assert head_dim % 2 == 0, "head_dim must be even for rotary."
assert head_dim % 16 == 0, "head_dim must be divisible by 16."
half = head_dim // 2
assert half % 16 == 0, "head_dim//2 must also be divisible by 16 to split into 4:6:6."
self.head_dim = head_dim
self.half = half
unit = half // 16
self.t_size = 4 * unit
self.h_size = 6 * unit
self.w_size = 6 * unit
self.register_buffer(
"inv_freq_t",
1.0 / (base ** (torch.arange(self.t_size, dtype=torch.float32) / self.t_size)),
persistent=False,
)
self.register_buffer(
"inv_freq_h",
1.0 / (base ** (torch.arange(self.h_size, dtype=torch.float32) / self.h_size)),
persistent=False,
)
self.register_buffer(
"inv_freq_w",
1.0 / (base ** (torch.arange(self.w_size, dtype=torch.float32) / self.w_size)),
persistent=False,
)
def forward(self, t: int, h: int, w: int, device=None):
if device is None:
device = self.inv_freq_t.device
inv_t = self.inv_freq_t.to(device=device)
inv_h = self.inv_freq_h.to(device=device)
inv_w = self.inv_freq_w.to(device=device)
ft = torch.outer(torch.arange(t, device=device, dtype=torch.float32), inv_t)
fh = torch.outer(torch.arange(h, device=device, dtype=torch.float32), inv_h)
fw = torch.outer(torch.arange(w, device=device, dtype=torch.float32), inv_w)
t_ids = torch.arange(t, device=device).repeat_interleave(h * w)
h_ids = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
w_ids = torch.arange(w, device=device).repeat(h).repeat(t)
freqs = torch.cat([ft[t_ids], fh[h_ids], fw[w_ids]], dim=-1)
return freqs
def forward_from_positions(self, patch_positions: torch.Tensor) -> torch.Tensor:
"""
Compute rotary position embeddings from explicit patch positions.
Args:
patch_positions: [batch_size, seq_len, 3] tensor with [t, h, w] positions for each patch
Returns:
freqs: [batch_size, seq_len, half] tensor of position frequencies
"""
device = patch_positions.device
inv_t = self.inv_freq_t.to(device=device)
inv_h = self.inv_freq_h.to(device=device)
inv_w = self.inv_freq_w.to(device=device)
t_pos = patch_positions[..., 0].float() # [batch_size, seq_len]
h_pos = patch_positions[..., 1].float() # [batch_size, seq_len]
w_pos = patch_positions[..., 2].float() # [batch_size, seq_len]
# Use einsum for batched outer product: [batch_size, seq_len] x [dim] -> [batch_size, seq_len, dim]
ft = torch.einsum("bs,d->bsd", t_pos, inv_t)
fh = torch.einsum("bs,d->bsd", h_pos, inv_h)
fw = torch.einsum("bs,d->bsd", w_pos, inv_w)
return torch.cat([ft, fh, fw], dim=-1)
class Siglip2MultiheadAttentionPoolingHead(nn.Module):
"""
Multi-Head Attention Pooling with a learned probe (PMA-style).
"""
def __init__(self, config: OneVisionEncoderConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
def forward(self, hidden_states):
batch_size = hidden_states.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
attn_output, _ = self.attention(probe, hidden_states, hidden_states)
residual = attn_output
attn_output = self.norm(attn_output)
attn_output = residual + self.mlp(attn_output)
return attn_output[:, 0]
# ---------------------------------------------------------------------------
# Modeling Components
# ---------------------------------------------------------------------------
class OneVisionEncoderEmbeddings(nn.Module):
def __init__(self, config: OneVisionEncoderConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
# Handle 4D (B, C, H, W) or 5D (B, C, T, H, W) inputs
if pixel_values.dim() == 4:
pixel_values = pixel_values.unsqueeze(2) # (B, C, 1, H, W)
batch_size, channels, t_frames, height, width = pixel_values.shape
# Merge time into batch for Conv2d
x_2d = pixel_values.permute(0, 2, 1, 3, 4).reshape(batch_size * t_frames, channels, height, width)
# Patch Embed
embeddings = self.patch_embedding(x_2d) # (B*T, C, Hp, Wp)
embeddings = embeddings.flatten(2).transpose(1, 2) # (B*T, L_frame, C)
# Flatten all patches
total_patches = t_frames * (height // self.patch_size) * (width // self.patch_size)
embeddings = embeddings.reshape(batch_size, total_patches, self.embed_dim)
return embeddings
class OneVisionEncoderAttention(nn.Module):
"""Multi-headed attention with RoPE support"""
def __init__(self, config: OneVisionEncoderConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# (B, L, H, D) -> Transpose to (B, H, L, D)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
if rotary_pos_emb is not None:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, rotary_pos_emb)
# Calculate attention scores
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, q_len):
if attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)
attn_weights = attn_weights + attention_mask
# FIX: Remove dtype=torch.float32 to stay in original dtype (bf16/fp16)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights if output_attentions else None
class OneVisionEncoderFlashAttention2(nn.Module):
"""
Multi-headed attention with RoPE support using Flash Attention 2.
This module implements the same attention mechanism as OneVisionEncoderAttention but uses
Flash Attention for improved performance and memory efficiency.
"""
def __init__(self, config: OneVisionEncoderConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass using Flash Attention 2.
"""
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash Attention requires (B, L, H, D) format
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
# Apply RoPE if provided
if rotary_pos_emb is not None:
# Transpose for RoPE application: (B, L, H, D) -> (B, H, L, D)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
# NOTE: apply_rotary_pos_emb now ensures NO float32 cast happens
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, rotary_pos_emb)
# Transpose back: (B, H, L, D) -> (B, L, H, D)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
# Flash Attention forward pass
if not _flash_attn_available:
raise ImportError("flash_attn is not installed. Please install it to use OneVisionEncoderFlashAttention2.")
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=self.dropout if self.training else 0.0,
softmax_scale=self.scale,
causal=False,
)
# Reshape to (B, L, embed_dim)
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
# No extra casting here.
attn_output = self.out_proj(attn_output)
return attn_output, None
ONEVISION_ENCODER_ATTENTION_CLASSES = {
"eager": OneVisionEncoderAttention,
"flash_attention_2": OneVisionEncoderFlashAttention2,
}
class OneVisionEncoderEncoderLayer(nn.Module):
def __init__(self, config: OneVisionEncoderConfig):
super().__init__()
self.embed_dim = config.hidden_size
# Get attention implementation from config, default to "flash_attention_2"
attn_implementation = getattr(config, "_attn_implementation", "flash_attention_2")
if attn_implementation not in ONEVISION_ENCODER_ATTENTION_CLASSES:
# Fallback to eager if flash_attention_2 is not available
if not _flash_attn_available and attn_implementation == "flash_attention_2":
attn_implementation = "eager"
else:
raise ValueError(
f"Unknown attention implementation: {attn_implementation}. "
f"Available implementations: {list(ONEVISION_ENCODER_ATTENTION_CLASSES.keys())}"
)
self.self_attn = ONEVISION_ENCODER_ATTENTION_CLASSES[attn_implementation](config)
self.layer_norm1 = get_norm_layer(config)
self.mlp = SiglipMLP(config)
self.layer_norm2 = get_norm_layer(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, attn_weights) if output_attentions else (hidden_states,)
return outputs
class OneVisionEncoderEncoder(nn.Module):
def __init__(self, config: OneVisionEncoderConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([OneVisionEncoderEncoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
rotary_pos_emb: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=rotary_pos_emb,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
# ---------------------------------------------------------------------------
# Main Models
# ---------------------------------------------------------------------------
@add_start_docstrings(
"The bare OneVision Encoder Model outputting raw hidden-states without any specific head on top.",
ONEVISION_ENCODER_START_DOCSTRING,
)
class OneVisionEncoderPreTrainedModel(PreTrainedModel):
config_class = OneVisionEncoderConfig
base_model_prefix = "onevision_encoder"
supports_gradient_checkpointing = True
_no_split_modules = ["OneVisionEncoderEncoderLayer"]
_supports_flash_attn_2 = True
def _init_weights(self, module):
"""Initialize the weights"""
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, (nn.LayerNorm, nn.RMSNorm)):
# Fix: RMSNorm doesn't have bias, must check hasattr first
module.weight.data.fill_(1.0)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
@add_start_docstrings(
"OneVision Encoder Model with a vision transformer encoder.",
ONEVISION_ENCODER_START_DOCSTRING,
)
class OneVisionEncoderModel(OneVisionEncoderPreTrainedModel):
def __init__(self, config: OneVisionEncoderConfig):
super().__init__(config)
self.config = config
self.embeddings = OneVisionEncoderEmbeddings(config)
self.layernorm_pre = get_norm_layer(config)
self.encoder = OneVisionEncoderEncoder(config)
self.video_rope = VideoRotaryEmbeddingSplit466(config)
if config.use_head:
self.layernorm_post = get_norm_layer(config)
self.head = Siglip2MultiheadAttentionPoolingHead(config)
else:
self.layernorm_post = None
self.head = None
self.post_init()
@add_start_docstrings_to_model_forward(ONEVISION_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=OneVisionEncoderConfig)
def forward(
self,
pixel_values: torch.Tensor,
visible_indices: Optional[torch.Tensor] = None,
patch_positions: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from transformers import AutoModel, AutoImageProcessor
>>> from PIL import Image
>>> model = AutoModel.from_pretrained("lmms-lab-encoder/onevision-encoder-large", trust_remote_code=True)
>>> preprocessor = AutoImageProcessor.from_pretrained("lmms-lab-encoder/onevision-encoder-large", trust_remote_code=True)
>>> image = Image.open("path/to/your/image.jpg") # Replace with your image path
>>> pixel_values = preprocessor(images=image, return_tensors="pt")["pixel_values"]
>>> outputs = model(pixel_values)
>>> last_hidden_states = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Determine video dimensions for RoPE
# Note: pixel_values passed to embeddings can be 4D or 5D
if pixel_values.dim() == 5:
# Use config.rope_temporal_size if set, otherwise use actual frame count
t_frames = (
self.config.rope_temporal_size if self.config.rope_temporal_size is not None else pixel_values.shape[2]
)
height = pixel_values.shape[3]
width = pixel_values.shape[4]
else:
t_frames = 1
height = pixel_values.shape[2]
width = pixel_values.shape[3]
# 1. Embeddings
hidden_states = self.embeddings(pixel_values)
batch_size, total_patches, _ = hidden_states.shape
# 2. Visible Indices Handling
if visible_indices is None:
visible_indices = (
torch.arange(total_patches, device=pixel_values.device).unsqueeze(0).expand(batch_size, -1)
)
# 3. RoPE Construction
if patch_positions is not None:
freqs_visible = self.video_rope.forward_from_positions(patch_positions)
else:
freqs_full = self.video_rope(
t=t_frames,
h=height // self.config.patch_size,
w=width // self.config.patch_size,
device=pixel_values.device,
)
freqs_visible = freqs_full[visible_indices]
# Concatenate D/2 + D/2 -> D for applying rope
freqs_visible = torch.cat([freqs_visible, freqs_visible], dim=-1)
# 4. Pre-Norm & Encoder
hidden_states = self.layernorm_pre(hidden_states)
# fix: gather hidden_states to match freqs_visible when using sparse visible_indices
num_visible = visible_indices.shape[1]
if num_visible != total_patches:
# sparse mode: select only visible patches
hidden_states = hidden_states.gather(
1, visible_indices.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=None,
rotary_pos_emb=freqs_visible,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
# Apply post-norm if configured
if self.layernorm_post is not None:
sequence_output = self.layernorm_post(sequence_output)
# 5. Pooling Head
pooled_output = None
if self.head is not None:
pooled_output = self.head(sequence_output)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
|