| |
| import torch |
| from typing import List, Dict, Any, Union |
| from PIL import Image |
| from transformers.processing_utils import ProcessorMixin, BatchFeature |
| from transformers import AutoTokenizer, AutoImageProcessor |
|
|
| PLACEHOLDER = "<|media_placeholder|>" |
|
|
| class OpenCUAProcessor(ProcessorMixin): |
| attributes = ["image_processor", "tokenizer", "image_token_id", "merge_size"] |
|
|
| def __init__(self, image_processor, tokenizer, image_token_id: int = 151664, merge_size: int = 2, **kwargs): |
| self.image_processor = image_processor |
| self.tokenizer = tokenizer |
| self.image_token_id = image_token_id |
| self.merge_size = getattr(image_processor, "merge_size", merge_size) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| trust = kwargs.get("trust_remote_code", True) |
| |
| try: |
| from tokenization_opencua import TikTokenV3 |
| tok = TikTokenV3.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust) |
| except Exception: |
| tok = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust) |
| imgproc = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust) |
| return cls(imgproc, tok, **kwargs) |
|
|
| def apply_chat_template(self, messages: List[Dict[str, Any]], **kwargs) -> Union[str, List[int]]: |
| return self.tokenizer.apply_chat_template(messages, **kwargs) |
|
|
| |
| def __call__(self, *args, **kwargs) -> BatchFeature: |
| |
| data = {"input_ids": torch.zeros(1, 1, dtype=torch.long)} |
| return BatchFeature(data=data) |
|
|
| |
| def prepare_vllm_inputs(self, messages, images, add_generation_prompt=True): |
| text = self.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt) |
| proc = self.image_processor(images=images, return_tensors="pt") |
| grid = torch.as_tensor(proc["image_grid_thw"]) |
| merge = getattr(self, "merge_size", 2) |
| for thw in grid: |
| num = int((thw[0] * thw[1] * thw[2]) // (merge ** 2)) |
| text = text.replace(PLACEHOLDER, PLACEHOLDER * num, 1) |
| return text, images |
|
|
|
|
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|