File size: 4,496 Bytes
272f31d
 
 
bac2e83
272f31d
 
 
 
 
2506165
 
bac2e83
 
 
 
 
 
2506165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bac2e83
e9fc699
2506165
 
 
272f31d
 
bac2e83
 
272f31d
2506165
bac2e83
 
 
 
 
 
272f31d
 
bac2e83
 
2506165
272f31d
 
bac2e83
272f31d
2506165
bac2e83
272f31d
bac2e83
 
 
 
 
272f31d
bac2e83
 
 
 
 
 
 
 
 
 
2506165
bac2e83
 
 
2506165
 
 
 
 
 
 
 
 
bac2e83
 
 
 
 
 
 
 
 
 
 
272f31d
bac2e83
 
2506165
bac2e83
 
 
 
 
 
 
2506165
 
 
 
 
272f31d
bac2e83
272f31d
bac2e83
 
272f31d
bac2e83
 
 
272f31d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# processing_opencua.py
import torch
from typing import List, Dict, Any, Union
from PIL import Image  # noqa: F401
from transformers.processing_utils import ProcessorMixin, BatchFeature
from transformers import AutoTokenizer, AutoImageProcessor

PLACEHOLDER = "<|media_placeholder|>"

class OpenCUAProcessor(ProcessorMixin):
    """
    Lightweight processor that pairs the repo's custom TikTokenV3 tokenizer
    with Qwen2VLImageProcessor and exposes media token ids for vLLM.

    We intentionally keep __call__ minimal because vLLM doesn't require
    a full HF Processor pipeline at init time; it just needs the class
    to load cleanly and provide chat templating & media bookkeeping.
    """
    attributes = [
        "image_processor",
        "tokenizer",
        "image_token_id",
        "video_token_id",
        "merge_size",
        "image_token",
        "video_token",
    ]

    def __init__(
        self,
        image_processor,
        tokenizer,
        image_token_id: int = 151667,   # match your config.json
        video_token_id: int = 151668,   # match your config.json
        merge_size: int = 2,
        **kwargs,
    ):
        self.image_processor = image_processor
        self.tokenizer = tokenizer

        # Media token ids (used by vLLM profiling & grids)
        self.image_token_id = image_token_id
        self.video_token_id = video_token_id

        # String placeholders (kept for template-time substitution)
        self.image_token = PLACEHOLDER
        self.video_token = PLACEHOLDER

        # Use the value baked into the image processor when available
        self.merge_size = getattr(image_processor, "merge_size", merge_size)

        # Pass through chat template if tokenizer carries one
        self.chat_template = getattr(tokenizer, "chat_template", None)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        # Ensure we can import local custom code
        trust = kwargs.get("trust_remote_code", True)

        # Prefer the repo's TikTokenV3; fall back to AutoTokenizer if needed
        try:
            from tokenization_opencua import TikTokenV3
            tok = TikTokenV3.from_pretrained(
                pretrained_model_name_or_path,
                trust_remote_code=trust,
            )
        except Exception:
            tok = AutoTokenizer.from_pretrained(
                pretrained_model_name_or_path,
                trust_remote_code=trust,
            )

        # Load the Qwen2VLImageProcessor as declared by preprocessor_config.json
        imgproc = AutoImageProcessor.from_pretrained(
            pretrained_model_name_or_path,
            trust_remote_code=trust,
        )

        # Allow overrides of IDs via kwargs (rare)
        image_token_id = kwargs.pop("image_token_id", 151667)
        video_token_id = kwargs.pop("video_token_id", 151664)

        return cls(
            imgproc,
            tok,
            image_token_id=image_token_id,
            video_token_id=video_token_id,
            **kwargs,
        )

    def apply_chat_template(
        self,
        messages: List[Dict[str, Any]],
        **kwargs
    ) -> Union[str, List[int]]:
        """
        Delegate to tokenizer's chat template. Supports both str and ids via kwargs.
        """
        return self.tokenizer.apply_chat_template(messages, **kwargs)

    # Minimal callable to satisfy HF/VLLM if Processor is ever invoked.
    def __call__(self, *args, **kwargs) -> BatchFeature:
        data = {"input_ids": torch.zeros(1, 1, dtype=torch.long)}
        return BatchFeature(data=data)

    # Helper for your own client code: expand PLACEHOLDER count to match image grid.
    def prepare_vllm_inputs(
        self,
        messages: List[Dict[str, Any]],
        images: Union[Image.Image, Any, List[Union[Image.Image, Any]]],
        add_generation_prompt: bool = True,
    ):
        text = self.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=add_generation_prompt,
        )
        proc = self.image_processor(images=images, return_tensors="pt")
        grid = torch.as_tensor(proc.get("image_grid_thw", []))
        merge = getattr(self, "merge_size", 2)

        # Each THW cell expands to (THW / merge^2) placeholders
        for thw in grid:
            num = int((thw[0] * thw[1] * thw[2]) // (merge ** 2))
            text = text.replace(PLACEHOLDER, PLACEHOLDER * max(1, num), 1)

        return text, images