import ast import copy import datetime import gc import io import json import math import mimetypes import os import random import re import sys import tarfile import tempfile import zipfile from collections import defaultdict, deque from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import av import cv2 import numpy as np import PIL import pkg_resources import scipy.signal as scsig import torch from decord import VideoReader, cpu from PIL import Image, ImageDraw from smart_open import open from torchvision.transforms.functional import to_tensor from hcxvlm.dataset.base_dataset import image_decoder from hcxvlm.dataset.hcx_vision_prompter import HCXVisionPrompter CHOICES = list(map(chr, range(97, 123))) IGNORE_INDEX = -100 DEFAULT_SAMPLE_RATE = 16000 MIN_DISCRETE_AUDIO_CHUNK_SAMPLES = 1600 DEFAULT_VOLUME_LEVEL = 10 ** (-26 / 20) hcx_vision_prompter = HCXVisionPrompter() def hpf_normalize( wav: np.ndarray, sr: int = DEFAULT_SAMPLE_RATE, volume_level: float = DEFAULT_VOLUME_LEVEL, ) -> np.ndarray: assert (wav**2).mean() > 0, "Error in the wav file" filter_ = scsig.butter(2, 70, "highpass", fs=sr, output="sos") wav = scsig.sosfilt(filter_, wav) wav = wav.astype(np.float32) gain = min(volume_level / (wav**2).mean() ** 0.5, 1 / np.max(np.abs(wav))) wav *= gain return wav def convert_bboxes(img, img_meta): for k, v in img_meta.items(): if k == "region": bbox_key = "bbox" if "bbox" in img_meta[k] else "boundingBox" img_meta[k] = reform_bbox( img_meta[k][bbox_key], img.size, format=img_meta[k]["format"] ) return img_meta def reform_bbox(bbox, image_size, format="REL_XYXY"): w, h = image_size if format == "REL_XYXY": x1, y1, x2, y2 = bbox[0] * w, bbox[1] * h, bbox[2] * w, bbox[3] * h elif format == "REL_XYWH": x1, y1 = bbox[0] * w, bbox[1] * h x2, y2 = x1 + bbox[2] * w, y1 + bbox[3] * h else: raise NotImplementedError new_bbox = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] return new_bbox def generate_random_color(use_alpha=True, seed=None): if seed is None: seed = np.random.default_rng() if use_alpha: color_list = [ ("빨강", (255, 127, 127, 100)), ("노랑", (255, 255, 127, 100)), ("초록", (127, 255, 125, 100)), ("하늘", (127, 255, 255, 100)), ("파랑", (127, 127, 255, 100)), ("보라", (255, 127, 255, 100)), ] else: color_list = [ ("빨강", (255, 0, 0)), ("노랑", (255, 255, 0)), ("초록", (0, 128, 0)), ("하늘", (135, 206, 235)), ("파랑", (0, 0, 255)), ("보라", (128, 0, 128)), ] return color_list[seed.integers(0, len(color_list))] EN_COLOR = { "빨강": "red", "노랑": "yellow", "초록": "green", "하늘": "sky blue", "파랑": "blue", "보라": "purple", } def overlay_rectangle(image, words, lang, seed=None): color_str, color = generate_random_color(seed=seed) draw = ImageDraw.Draw(image, "RGBA") for word in words: shape_rect = word["bbox"] shape_rect = [(round(x[0]), round(x[1])) for x in shape_rect] draw.polygon(shape_rect, color) del draw if lang == "en": color_str = EN_COLOR[color_str] return image, color_str def convert_tags_for_video(img, json): """video 데이터에는 태그 대신 tag가 있음. img 숫자 만큼 tag 대신 tag를 변환하여 넣음 """ image_tag = "".join([f"" for idx in range(len(img))]) for json_key in json: if "qa_pairs" in json_key: new_qa_pairs = [] for qa_pair in json[json_key]: question = qa_pair[0] question = question.replace("", image_tag) new_qa_pairs.append([question, qa_pair[1]]) json[json_key] = new_qa_pairs return img, json def sampling_multiturn_single_img( seq, count, multiturn_preserve_order=True, multiturn_continuous=False, is_train: bool = True, seed=None, ): if seed is None: seed = np.random.default_rng() n_sample = min(count, len(seq)) if multiturn_continuous: if len(seq) <= n_sample: start_index = 0 else: start_index = seed.integers(0, len(seq) - n_sample) indices = range(start_index, start_index + n_sample) elif multiturn_preserve_order: indices = sorted(seed.choice(range(len(seq)), size=n_sample, replace=False)) else: indices = seed.choice(range(len(seq)), size=n_sample, replace=False) return [seq[i] for i in indices] def draw_bbox(image, bbox, lang="en", line_width=5, seed=None): if seed is None: seed = np.random.default_rng() color_str, color = generate_random_color(use_alpha=False, seed=seed) draw = ImageDraw.Draw(image, "RGB") rect_bbox = (bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]) draw.rectangle(rect_bbox, outline=color, width=line_width) del draw if lang == "en": color_str = EN_COLOR[color_str] return image, color_str def bbox_process(bbox, detection_precision=2): bbox_str = "[" for idx, point in enumerate(bbox): if idx % 2 == 0: normalized = point else: normalized = point if idx < len(bbox) - 1: bbox_str += format(normalized, ".2f") + ", " else: bbox_str += format(normalized, ".2f") bbox_str += "]" return bbox_str def load_txt(file_path): lines_list = [] with open(file_path, "r") as file: for line in file: lines_list.append(line.replace("\\n", "\n").strip()) return lines_list def convert_format_for_multi_image( img, json, convert_key_list=["words", "text", "objects", "entities"] ): """single image dataset 과 multi image dataset 에서 읽어온 img, json format 이 다름. 따라서 single image dataset 에서 읽어온 img, json 을 multi image dataset 의 format (dict) 으로 convert """ is_multi_image_dataset = isinstance(img, dict) if not is_multi_image_dataset: img = {"00": img} for convert_key in convert_key_list: if convert_key in json: json[convert_key] = {"00": json[convert_key]} for json_key in json: if "region" in json_key: json[json_key] = {"00": json[json_key]} else: for convert_key in convert_key_list: if convert_key in json: if isinstance(json[convert_key], list): json[convert_key] = {"00": json[convert_key]} for json_key in json: if "region" in json_key: if isinstance(json[json_key], list): json[json_key] = {"00": json[json_key]} return is_multi_image_dataset, img, json class ConditionalError(Exception): def __init__(self, message="Our assertion error"): super().__init__(message) def get_wds_default_config(default_config, existing_default_config=None): if existing_default_config is None: default_config_check_dict = { "subtask": "", "reasoning": False, "use_task_prompt": True, "get_random": True, "add_instruct_prompts": [], "multiturn_n_samples": 0, "multiturn_preserve_order": True, "multiturn_continuous": False, "insert_ocr": 200, "ocr_filter_strategy": "confidence", "ocr_use_ratio": 1.0, "entity_top_k": 0, "entity_keyword_threshold": 100, "entity_keyword_fashion_threshold": 100, "entity_use_ratio": 0.0, "llava_pretrain": False, "random_system_prob": 0.0, "random_system_path": "", "random_tool_prob": 0.005, } else: default_config_check_dict = existing_default_config if default_config is None: default_config = default_config_check_dict else: for key, value in default_config_check_dict.items(): if key not in default_config: default_config[key] = value return default_config def get_datalake_default_config(default_config): default_config_check_dict = { "multiturn_n_samples": 0, "multiturn_preserve_order": True, "multiturn_continuous": True, "insert_ocr": 0, "ocr_filter_strategy": "confidence", "entity_top_k": 0, "entity_keyword_threshold": 0, "entity_keyword_fashion_threshold": 0, "entity_use_ratio": 0.0, "ocr_use_ratio": 0.0, "llava_pretrain": False, "random_system_prob": 0.0, "random_system_path": "", "random_tool_prob": 0.005, } if default_config is None: default_config = default_config_check_dict else: for key, value in default_config_check_dict.items(): if key not in default_config: default_config[key] = value return default_config @dataclass class Processed_sample: input_str: str = None input_ids: list = None label_ids: list = None imgs: list = None discrete_imgs: list = None videos: list = None videos_duration: List[dict] = None video_audios: list = None audios: list = None audios_duration: List[dict] = None discrete_audios: list = None sample_mm_counter: dict = None from hcxvlm.dataset.bbox_processor import ( extract_bboxes, insert_bboxes_to_json, is_bbox_padded, ) class Preprocessor: prompt_head = "" va_prefix = "\n<|im_start|>" new_line = "\n" turn_prefix = "<|im_start|>" turn_suffix = "<|im_end|>" mime_start = "<|mime_start|>" mime_end = "<|mime_end|>" aux_img_start = "<|image_aux_start|>" aux_img_end = "<|image_aux_end|>" aux_video_start = "<|video_aux_start|>" aux_video_end = "<|video_aux_end|>" aux_audio_start = "<|audio_aux_start|>" aux_audio_end = "<|audio_aux_end|>" image_start = "<|image_start|>" image_end = "<|image_end|>" image_pad = "<|IMAGE_PAD|>" video_start = "<|video_start|>" video_end = "<|video_end|>" video_pad = "<|VIDEO_PAD|>" audio_start = "<|audio_start|>" audio_end = "<|audio_end|>" audio_pad = "<|AUDIO_PAD|>" discrete_image_start = "<|discrete_image_start|>" discrete_image_end = "<|discrete_image_end|>" discrete_image_pad = "<|DISCRETE_IMAGE_PAD|>" video_audio_pad = "<|VIDEO_AUDIO_PAD|>" discrete_audio_start = "<|discrete_audio_start|>" discrete_audio_end = "<|discrete_audio_end|>" discrete_audio_pad = "<|DISCRETE_AUDIO_PAD|>" discrete_image_eol = "<|vision_eol|>" discrete_image_eof = "<|vision_eof|>" discrete_image_ratios = { (1, 1): "<|vision_ratio_1:1|>", (1, 2): "<|vision_ratio_1:2|>", (2, 1): "<|vision_ratio_2:1|>", (3, 4): "<|vision_ratio_3:4|>", (4, 3): "<|vision_ratio_4:3|>", (3, 5): "<|vision_ratio_3:5|>", (5, 3): "<|vision_ratio_5:3|>", (4, 5): "<|vision_ratio_4:5|>", (5, 4): "<|vision_ratio_5:4|>", (6, 9): "<|vision_ratio_6:9|>", (9, 6): "<|vision_ratio_9:6|>", (9, 16): "<|vision_ratio_9:16|>", (16, 9): "<|vision_ratio_16:9|>", } aux_vid_prompt = ( "다음 중 video_duration은 비디오 길이 정보입니다. 참고하여 답변하세요. " ) aux_audio_prompt = ( "다음 중 audio_duration은 오디오 길이 정보입니다. 참고하여 답변하세요. " ) def __init__( self, tokenizer=None, prepare_input_fn=None, prepare_audio_input_fn=None, sample_min_length=0, decoder_max_length=None, mode="train", model=None, datalake_default_config=None, wds_default_config=None, video_config=None, train_video=False, train_audio=False, sequence_parallel_size=1, video_audio_compressor_type=None, ): self.sequence_parallel_size = sequence_parallel_size if sequence_parallel_size > 1: self.rng = np.random.default_rng(seed=42) else: self.rng = np.random.default_rng() if model is not None: tokenizer = model.tokenizer decoder_max_length = 16000 if model is not None and prepare_input_fn is None: raise "please give ImageProcessor!" self.prepare_input_fn = prepare_input_fn self.prepare_audio_input_fn = prepare_audio_input_fn try: from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import ( Qwen2_5_VLProcessor, ) self.is_qwen_visual = isinstance(prepare_input_fn, Qwen2_5_VLProcessor) except Exception as e: self.is_qwen_visual = False try: if not self.is_qwen_visual: from hcxvlm.models.processing_vlm import HCXVisionV2Processor self.is_qwen_visual = isinstance(prepare_input_fn, HCXVisionV2Processor) except Exception as e: self.is_qwen_visual = False assert self.is_qwen_visual, "qwen2.5-vl visual prepare_input_fn import error" self.video_max_num_frames = ( video_config["video_max_num_frames"] if video_config and "video_max_num_frames" in video_config else 120 ) self.video_max_pixels = ( video_config["video_max_pixels"] if video_config and "video_max_pixels" in video_config else 378 * 378 ) self.tokenizer = tokenizer self.sample_min_length = sample_min_length self.decoder_max_length = decoder_max_length self.mode = mode self.default_config = get_datalake_default_config(datalake_default_config) self.wds_default_config = get_wds_default_config(wds_default_config) self.train_video = train_video self.train_audio = train_audio self.video_audio_compressor_type = video_audio_compressor_type self.img_token = self.tokenizer.encode(Preprocessor.image_pad)[0] assert ( len(self.tokenizer.encode(Preprocessor.image_pad)) == 1 ), "img_token is not configured in tokenizer" self.discrete_image_token = self.tokenizer.encode( Preprocessor.discrete_image_pad )[0] assert ( len(self.tokenizer.encode(Preprocessor.discrete_image_pad)) == 1 ), "discrete_image_token is not configured in tokenizer" self.discrete_image_eol_token = self.tokenizer.encode( Preprocessor.discrete_image_eol )[0] assert ( len(self.tokenizer.encode(Preprocessor.discrete_image_eol)) == 1 ), "discrete_image_eol_token is not configured in tokenizer" self.discrete_image_eof_token = self.tokenizer.encode( Preprocessor.discrete_image_eof )[0] assert ( len(self.tokenizer.encode(Preprocessor.discrete_image_eof)) == 1 ), "discrete_image_eof_token is not configured in tokenizer" self.discrete_image_ratio_tokens = dict() for ratio, token_str in Preprocessor.discrete_image_ratios.items(): token_id = self.tokenizer.encode(token_str)[0] assert ( len(self.tokenizer.encode(token_str)) == 1 ), f"discrete_image_ratio_token {token_str} is not configured in tokenizer" self.discrete_image_ratio_tokens[ratio] = token_id self.video_token = self.tokenizer.encode(Preprocessor.video_pad)[0] assert ( len(self.tokenizer.encode(Preprocessor.video_pad)) == 1 ), "video_token is not configured in tokenizer" self.video_audio_token = self.tokenizer.encode(Preprocessor.video_audio_pad)[0] assert ( len(self.tokenizer.encode(Preprocessor.video_audio_pad)) == 1 ), "video_audio_token is not configured in tokenizer" def resize_min_edge(img: Image.Image) -> Image.Image: w, h = img.size min_size = 28 if min(w, h) >= min_size: return img if w < h: new_w = min_size new_h = int(h * (min_size / w)) else: new_h = min_size new_w = int(w * (min_size / h)) return img.resize((new_w, new_h), Image.BICUBIC) self._resize_min_edge = resize_min_edge self.audio_token = self.tokenizer.encode(Preprocessor.audio_pad)[0] assert ( len(self.tokenizer.encode(Preprocessor.audio_pad)) == 1 ), "audio_token is not configured in tokenizer" self.discrete_audio_token = self.tokenizer.encode( Preprocessor.discrete_audio_pad )[0] assert ( len(self.tokenizer.encode(Preprocessor.discrete_audio_pad)) == 1 ), "audio_token is not configured in tokenizer" from hcxvlm.dataset.json_processer import generate_prompt self.generate_prompt = generate_prompt self.mimes = list() for mime_filename in [ "words_alpha.txt", "korean-366506-wordslistUnique.txt", ]: self.mimes += ( pkg_resources.resource_string( "hcxvlm", f"dataset/hcx_vision_prompter/prompts/{mime_filename}" ) .decode("utf-8") .split("\r\n") ) self.common_tools = [] try: common_tools_bytes = pkg_resources.resource_string( "hcxvlm", "dataset/hcx_vision_prompter/prompts/common_tools.jsonl", ) for line in common_tools_bytes.decode("utf-8").splitlines(): line = line.strip() if not line: continue try: self.common_tools.append(json.loads(line)) except Exception: continue except Exception: self.common_tools = [] self.random_system_prompt = "" if self.default_config["random_system_path"] != "": self.random_system_prompt = "" with open(self.default_config["random_system_path"], "r") as f: for line in f: self.random_system_prompt += line if ( self.random_system_prompt != "" and self.wds_default_config["random_system_path"] != "" ): assert ( self.wds_default_config["random_system_path"] == self.default_config["random_system_path"] ), "random_system_path in both default_config and wds_default_config should be the same" def _find_best_ratio_token(self, original_size): """Find the best ratio token based on original_size""" base_ratios = list(self.discrete_image_ratio_tokens.keys()) vision_aspect_ratios = [ r for ratio in base_ratios for r in [ratio, ratio[::-1]] ][1:] if not isinstance(original_size, list) or len(original_size) != 2: return self.discrete_image_ratio_tokens[(1, 1)] h, w = original_size if h == 0 or w == 0: return self.discrete_image_ratio_tokens[(1, 1)] ratios = [i / j for i, j in vision_aspect_ratios] best_size_idx = np.argmin([abs(w / h - r) for r in ratios]) i, j = vision_aspect_ratios[best_size_idx] return self.discrete_image_ratio_tokens[(i, j)] @classmethod def prompt_mime( cls, mimes: Optional[list[str]] = None, file_name: str = None, tag_idx: int = 1, fixed_mime: bool = False, is_video: bool = False, is_audio: bool = False, seed: np.random.Generator = None, ) -> list[dict]: assert mimes or file_name if seed is None: seed = np.random.default_rng() if file_name: name, ext = os.path.splitext(file_name) ext = ext.lstrip(".") elif fixed_mime: ext = "jpeg" name = mimes[tag_idx] elif not fixed_mime and seed is not None: ext = seed.choice(["png", "jpeg"]) name = mimes[seed.integers(0, len(mimes))] else: ext = "jpeg" name = mimes[tag_idx] if is_video: ext_candidates = ["mp4", "mov", "avi", "webm"] if fixed_mime: ext = "mp4" elif ext not in ext_candidates: ext = seed.choice(ext_candidates) filename = f"{name}.{ext}" mime_type = mimetypes.guess_type(filename)[0] mime_prompt = { "id": f"video_{str(tag_idx).zfill(2)}", "type": f"{mime_type}", "filename": f"{filename}", } return mime_prompt if is_audio: ext_candidates = ["mp3", "wav", "aac", "flac", "pcm"] if fixed_mime: ext = "wav" elif ext not in ext_candidates: ext = seed.choice(ext_candidates) filename = f"{name}.{ext}" mime_type = mimetypes.guess_type(filename)[0] mime_prompt = { "id": f"audio_{str(tag_idx).zfill(2)}", "type": f"{mime_type}", "filename": f"{filename}", } return mime_prompt if file_name: filename = f"{name}.{ext}" mime_type = mimetypes.guess_type(filename)[0] mime_prompt = { "id": f"image_{str(tag_idx).zfill(2)}", "type": f"{mime_type}", "filename": f"{filename}", } else: mime_prompt = { "id": f"image_{str(tag_idx).zfill(2)}", "type": f"image/{ext}", "filename": f"{name}.{'jpg' if ext == 'jpeg' else 'png'}", } return mime_prompt @classmethod def ocr_preprocess( cls, words: list[dict], n_insert_ocr_tokens: int = 2000, insert_ocr: int = 200, ocr_use_ratio: float = 0.5, tokenizer=None, seed=None, ) -> list[str]: if seed is None: seed = np.random.default_rng() if ocr_use_ratio < seed.random(): return None if insert_ocr == 0: return None confidence_list = [] insert_ocr_prompt = [] for word in words: if "confidence" in word: confidence_list.append(word["confidence"]) has_ocr_confidence = len(confidence_list) >= insert_ocr if len(words) <= insert_ocr or not has_ocr_confidence: insert_ocr_prompt += [ d["text"].strip() for d in words if d["text"].strip() ][:insert_ocr] else: confidence_threshold = 0.3 cnt = 0 for word in words: if word["text"] == "": continue if word["confidence"] >= confidence_threshold: insert_ocr_prompt.append(word["text"]) cnt += 1 if cnt >= insert_ocr: break ocr_inputs = " ".join(insert_ocr_prompt) if tokenizer: ocr_inputs = tokenizer.decode( tokenizer.encode(ocr_inputs)[:n_insert_ocr_tokens] ) return ocr_inputs @classmethod def lens_preprocess( cls, lens: list[dict], entity_top_k: int = 100, entity_keyword_threshold: float = 0.0, entity_keyword_fashion_threshold: float = 0.0, entity_use_ratio: float = 0.0, seed=None, ): if seed is None: seed = np.random.default_rng() if seed.uniform(0, 1) > entity_use_ratio: return None entities = lens filter_idx = [] insert_entity_prompt = {} for idx, entity in enumerate(entities): if entity["type"] != "naver_lens_api": filter_idx.append(idx) continue if ( isinstance(entity_keyword_threshold, (int, float)) and entity["confidence"] < entity_keyword_threshold ): filter_idx.append(idx) continue if ( isinstance(entity_keyword_fashion_threshold, (int, float)) and ("fashion" in entity["info"]["classes"]) and entity["confidence"] < entity_keyword_fashion_threshold ): filter_idx.append(idx) continue entityvalue = [ keyword for idx, keyword in enumerate(entities) if idx not in filter_idx ] entityvalue = sorted(entityvalue, key=lambda x: x["confidence"], reverse=True) important_entity_list = [] local_entity_str_list = [] keywords_and_bbox_per_detector = {} for keyword_dict in entityvalue[:entity_top_k]: object_class = "/".join(keyword_dict["info"]["classes"]) if object_class not in keywords_and_bbox_per_detector.keys(): keywords_and_bbox_per_detector[object_class] = [] keywords_and_bbox_per_detector[object_class].append(keyword_dict) for object_class in keywords_and_bbox_per_detector.keys(): entities_per_object = keywords_and_bbox_per_detector[object_class] normalized_bbox = bbox_process( [*entities_per_object[0]["bbox"][0], *entities_per_object[0]["bbox"][2]] ) entities = [entity["text"] for entity in entities_per_object] if "context" in object_class: important_entity_list += entities else: local_entity_str_list += [ str(normalized_bbox) + " " + ", ".join(entities) ] if len(important_entity_list) > 0: insert_entity_prompt["lens_keywords"] = ", ".join(important_entity_list) if len(local_entity_str_list) > 0: insert_entity_prompt["lens_local_keywords"] = " ".join( local_entity_str_list ) return insert_entity_prompt @classmethod def prompt_toollist( cls, output, tokenizer=None, turn: Optional[dict] = None, content: Optional[list[dict]] = None, ): assert content or turn if turn is None: turn = { "role": "tool_list", "content": content, } toollist_str = ( cls.turn_prefix.strip() + turn["role"] + "\n" + turn["content"] + cls.turn_suffix ) if hasattr(output, "input_str"): output.input_str += toollist_str if getattr(output, "input_ids", None) is not None: token_ids = tokenizer.encode(toollist_str, truncation=False) output.input_ids += token_ids output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] return output @classmethod def prompt_system( cls, output, tokenizer=None, turn: Optional[dict] = None, content: Optional[str] = None, seed=None, tool_prompt=None, system_role_count=0, ): assert content or turn if seed is None: seed = np.random.default_rng() if turn is None: system_prompt = content else: if "candidates" in turn: if len(turn["candidates"]) > 0: system_prompt = seed.choice(turn["candidates"]) if type(system_prompt) is dict: system_prompt = system_prompt["content"] else: system_prompt = "" elif isinstance(turn["content"], str): system_prompt = turn["content"] elif len(turn["content"]) > 0: system_prompt = seed.choice(turn["content"]) system_str = cls.turn_prefix + turn["role"] + "\n" system_str += system_prompt.strip() if system_role_count == 0: if system_prompt.strip(): system_str += "\n" system_str += tool_prompt system_str += cls.turn_suffix if hasattr(output, "input_str"): output.input_str += system_str if getattr(output, "input_ids", None) is not None: token_ids = tokenizer.encode(system_str, truncation=False) output.input_ids += token_ids output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] return output @classmethod def load_mm( cls, output, img_dir: str = "", turn: Optional[dict] = None, image_urls: Optional[list[str]] = None, image_metas: Optional[list[dict]] = None, video_urls: Optional[list[str]] = None, video_metas: Optional[list[dict]] = None, audio_urls: Optional[list[str]] = None, audio_metas: Optional[list[dict]] = None, prepare_input_fn=None, prepare_audio_input_fn=None, max_image_cnt=21, video_max_num_frames=None, video_max_pixels=None, use_audio: bool = False, audio_sample_rate: int = 16000, ): assert (image_urls or video_urls or audio_urls) or turn if turn is None: turn = {} if image_urls: turn.update({"image_urls": image_urls}) turn.update({"image_metas": image_metas}) if video_urls: turn.update({"video_urls": video_urls}) turn.update({"video_metas": video_metas}) if audio_urls: turn.update({"audio_urls": audio_urls}) turn.update({"audio_metas": audio_metas}) if "video_urls" in turn: if len(turn["video_urls"]) and (prepare_input_fn is None): raise ConditionalError("video processing needs 'prepare_input_fn'") if not isinstance(turn["content"], str): raise ConditionalError(f"turn['content'] must be a string") turn["content"] = re.sub(r"", "<|image|>", turn["content"]) pattern = re.compile( r"<\|video\|>|<\|image\|>|<\|t2i_model_generation_target_discrete_image\|>|<\|audio\|>|<\|discrete_audio\|>" ) tags = [match.group() for match in pattern.finditer(turn["content"])] img_idx = 0 vid_idx = 0 aud_idx = 0 if "image_urls" not in turn: turn["image_urls"] = [] if "video_urls" not in turn: turn["video_urls"] = [] if "audio_urls" not in turn: turn["audio_urls"] = [] for tag in tags: if ( tag == "<|image|>" or tag == "<|t2i_model_generation_target_discrete_image|>" ): img_path = turn["image_urls"][img_idx] if isinstance(img_path, str): if "#" in img_path: compression_path, img_path = img_path.split("#", 1) compression_path = os.path.join(img_dir, compression_path) assert compression_path[-4:] in [ ".zip", ".tar", ], f"unsupported compression format: {compression_path}" with open(compression_path, "rb") as comp_file: if compression_path.endswith(".zip"): with zipfile.ZipFile(comp_file, "r") as zip_file: with zip_file.open(img_path) as img_file: img_binary = img_file.read() elif compression_path.endswith(".tar"): with tarfile.open( fileobj=comp_file, mode="r" ) as tar_file: img_file = tar_file.extractfile(img_path) img_binary = img_file.read() else: with open(os.path.join(img_dir, img_path), "rb") as f: img_binary = f.read() img = image_decoder(img_binary) else: if isinstance(img_path, (bytes, bytearray)): img = io.BytesIO(img_path) img = Image.open(img).convert("RGB") else: img = img_path if not isinstance(img, Image.Image): img = Image.fromarray(np.uint8(img)).convert("RGB") if "image_metas" in turn and turn["image_metas"]: turn["image_metas"][img_idx] = convert_bboxes( img, turn["image_metas"][img_idx] ) if tag == "<|image|>": output.imgs.append(img) output.discrete_imgs.append(img) img_idx += 1 elif tag == "<|video|>": video_path = turn["video_urls"][vid_idx] if isinstance(video_path, str): if "#" in video_path: compression_path, video_path = video_path.split("#", 1) compression_path = os.path.join(img_dir, compression_path) assert compression_path[-4:] in [ ".zip", ".tar", ], f"unsupported compression format: {compression_path}" with open(compression_path, "rb") as comp_file: if compression_path.endswith(".zip"): with zipfile.ZipFile(comp_file, "r") as zip_file: video_file = zip_file.open(video_path) video_binary = video_file.read() elif compression_path.endswith(".tar"): with tarfile.open( fileobj=comp_file, mode="r" ) as tar_file: video_file = tar_file.extractfile(video_path) video_binary = video_file.read() else: with open(os.path.join(img_dir, video_path), "rb") as f: video_binary = f.read() video_binary = io.BytesIO(video_binary) else: video_binary = video_path assert isinstance(video_binary, io.BytesIO), "video binary read error" try: from hcxvlm.dataset.qwen_vision_process import process_vision_info except: from qwen_vl_utils import process_vision_info if video_max_num_frames is None: video_max_num_frames = 120 if video_max_pixels is None: video_max_pixels = 378 * 378 messages = [ [ { "role": "user", "content": [ { "type": "video", "video": video_binary, "max_frames": video_max_num_frames, "max_pixels": video_max_pixels, } ], } ], ] _, videos, video_kwargs = process_vision_info( messages, return_video_kwargs=True, use_audio=use_audio, audio_sample_rate=audio_sample_rate, ) output.videos.append(videos[0]) video_len = round(videos[0].shape[0] / video_kwargs["fps"][0], 2) output.videos_duration.append( { "video_duration": f"{video_len}s", } ) if use_audio and "audio_chunks" in video_kwargs: audio_chunks = video_kwargs["audio_chunks"][0] if audio_chunks is not None: output.video_audios.append(audio_chunks) else: output.video_audios.append([]) elif use_audio: output.video_audios.append([]) vid_idx += 1 elif tag == "<|audio|>" or tag == "<|discrete_audio|>": audio_path = turn["audio_urls"][aud_idx] if isinstance(audio_path, str): if "#" in audio_path: compression_path, inner_path = audio_path.split("#", 1) compression_path = os.path.join(img_dir, compression_path) assert compression_path[-4:] in [ ".zip", ".tar", ], f"unsupported compression format: {compression_path}" with open(compression_path, "rb") as comp_file: if compression_path.endswith(".zip"): with zipfile.ZipFile(comp_file, "r") as zip_file: with zip_file.open(inner_path) as audio_file: audio_binary = audio_file.read() elif compression_path.endswith(".tar"): with tarfile.open( fileobj=comp_file, mode="r" ) as tar_file: audio_file = tar_file.extractfile(inner_path) audio_binary = audio_file.read() else: with open(os.path.join(img_dir, audio_path), "rb") as f: audio_binary = f.read() audio_stream = io.BytesIO(audio_binary) else: if isinstance(audio_path, (bytes, bytearray)): audio_stream = io.BytesIO(audio_path) else: audio_stream = audio_path try: import librosa y, sr = librosa.load( audio_stream, sr=DEFAULT_SAMPLE_RATE, mono=True ) assert ( DEFAULT_SAMPLE_RATE == sr ), f"librosa resampling failed: {DEFAULT_SAMPLE_RATE} != {sr}" except Exception as e: raise ConditionalError( f"audio decoding failed for {audio_path}: {e}" ) audio_duration = len(y) / sr if audio_duration < 0.5: raise ConditionalError( f"Audio too short ({audio_duration:.2f}s). Minimum 0.5s required." ) if audio_duration > 600: raise ConditionalError( f"Audio duration ({audio_duration:.2f}s) exceeds maximum allowed duration (600s)" ) if len(y) < MIN_DISCRETE_AUDIO_CHUNK_SAMPLES: raise ConditionalError( f"Audio too short ({len(y)} samples = {audio_duration:.4f}s < 0.1s). " f"Minimum {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES} samples required for CosyVoice encoder." ) if not hasattr(output, "audios"): output.audios = [] if not hasattr(output, "discrete_audios"): output.discrete_audios = [] normalized_y = hpf_normalize(y) normalized_y = torch.from_numpy(normalized_y).float() output.discrete_audios.append(normalized_y) if tag == "<|audio|>": output.audios.append(y) total_duration = len(y) / sr output.audios_duration.append( { "duration": f"{(total_duration):.2f}s", } ) aud_idx += 1 else: raise ConditionalError( f"{tag} is not in ['<|image|>', '<|video|>', '<|audio|>']" ) return output @classmethod def prompt_user( cls, output, tokenizer=None, turn: Optional[dict] = None, content: Optional[str] = None, is_train=False, fixed_mime=False, insert_ocr=300, file_names: Optional[list[str]] = None, mimes: Optional[list[str]] = None, mm_tokens: Optional[list[str]] = None, words: Optional[list] = None, lens: Optional[list] = None, query_template: Optional[list[str]] = None, config: Optional[dict] = None, seed: np.random.Generator = None, ): assert content or turn if turn is None: image_metas = [ {"words": words[i], "lens": lens[i]} for i in range(len(words)) ] turn = { "content": content, "image_metas": image_metas, } if seed is None: seed = np.random.default_rng() turn["content"] = re.sub(r"", "<|image|>", turn["content"]) turn["content"] = re.sub(r"", "<|video|>", turn["content"]) turn["content"] = re.sub(r"", "<|audio|>", turn["content"]) pattern = re.compile(r"(<\|video\|>|<\|image\|>|<\|audio\|>)") all_tags_in_order = [ match.group() for match in pattern.finditer(turn["content"]) ] n_vids = sum(1 for tag in all_tags_in_order if tag == "<|video|>") n_audios = sum(1 for tag in all_tags_in_order if tag == "<|audio|>") assert ( len(turn.get("image_urls", [])) + len(turn.get("video_urls", [])) + len(turn.get("audio_urls", [])) ) == len( all_tags_in_order ), f"Number of media URLs does not match number of media tags." if mm_tokens is None: mm_tokens = [ cls.audio_pad if tag == "<|audio|>" else cls.image_pad for tag in all_tags_in_order ] assert len(mm_tokens) == len(all_tags_in_order) if config.get("llava_pretrain", False): mm_str = "".join([mm_tokens[i] for i in range(len(all_tags_in_order))]) if hasattr(output, "input_str"): output.input_str += mm_str if getattr(output, "input_ids", None) is not None: token_ids = tokenizer.encode(mm_str, truncation=False) output.input_ids += token_ids output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] return output if query_template: processed_content = seed.choice(query_template).format(turn["content"]) tags_after_template = pattern.findall(processed_content) if len(all_tags_in_order) != len(tags_after_template): cleaned_template_text = pattern.sub("", processed_content) processed_content = "".join(all_tags_in_order) + cleaned_template_text turn["content"] = processed_content content_parts = pattern.split(turn["content"].strip()) if hasattr(output, "input_str"): output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}" if getattr(output, "input_ids", None) is not None: role_encoded = tokenizer.encode( f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False ) output.input_ids += role_encoded if turn.get("trainable_role", False): output.label_ids += role_encoded else: output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] tag_cursor = 0 for part in content_parts: part = part.strip() if not part: continue if part not in ["<|image|>", "<|video|>", "<|audio|>"]: content_text = part if hasattr(output, "input_str"): output.input_str += "\n" + content_text if getattr(output, "input_ids", None) is not None: content_encoded = tokenizer.encode( "\n" + content_text, truncation=False ) output.input_ids += content_encoded if turn.get("trainable_content", False): output.label_ids += content_encoded else: output.label_ids += [ IGNORE_INDEX for _ in range(len(content_encoded)) ] continue if part == "<|image|>": mime = Preprocessor.prompt_mime( mimes=mimes if not file_names else None, fixed_mime=fixed_mime if not file_names else False, file_name=file_names[tag_cursor] if file_names else None, tag_idx=output.sample_mm_counter["image"], is_video=False, is_audio=False, seed=seed, ) mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}" vector_str = f"{cls.image_start}{cls.image_pad}{cls.image_end}" mm_str = ( cls.new_line + mime_str + cls.new_line + discrete_image_str + cls.new_line + vector_str ) if hasattr(output, "input_str"): output.input_str += mm_str if getattr(output, "input_ids", None) is not None: token_ids = tokenizer.encode(mm_str, truncation=False) output.input_ids += token_ids output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] output.sample_mm_counter["image"] += 1 tag_cursor += 1 elif part == "<|video|>": mime = Preprocessor.prompt_mime( mimes=mimes if not file_names else None, fixed_mime=fixed_mime if not file_names else False, file_name=file_names[tag_cursor] if file_names else None, tag_idx=output.sample_mm_counter["video"], is_video=True, is_audio=False, seed=seed, ) mm_str = "" aux_inputs = { "video_duration": output.videos_duration[ output.sample_mm_counter["video"] ]["video_duration"], } mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" aux_str = f"{cls.aux_video_start}{cls.aux_vid_prompt}{json.dumps(aux_inputs, ensure_ascii=False)}{cls.aux_video_end}" vector_str = f"{cls.video_start}{cls.video_pad}{cls.video_end}" mm_str += ( cls.new_line + mime_str + cls.new_line + aux_str + cls.new_line + vector_str ) if hasattr(output, "input_str"): output.input_str += mm_str if getattr(output, "input_ids", None) is not None: token_ids = tokenizer.encode(mm_str, truncation=False) output.input_ids += token_ids output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] output.sample_mm_counter["video"] += 1 tag_cursor += 1 elif part == "<|audio|>": mime = Preprocessor.prompt_mime( mimes=mimes if not file_names else None, fixed_mime=fixed_mime if not file_names else False, file_name=file_names[tag_cursor] if file_names else None, tag_idx=output.sample_mm_counter["audio"], is_video=False, is_audio=True, seed=seed, ) mm_str = "" aux_inputs = { "audio_duration": output.audios_duration[ output.sample_mm_counter["audio"] ]["duration"], } mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" aux_str = f"{cls.aux_audio_start}{cls.aux_audio_prompt}{json.dumps(aux_inputs, ensure_ascii=False)}{cls.aux_audio_end}" discrete_audio_str = f"{cls.discrete_audio_start}{cls.discrete_audio_pad}{cls.discrete_audio_end}" vector_str = f"{cls.audio_start}{cls.audio_pad}{cls.audio_end}" mm_str += ( cls.new_line + mime_str + cls.new_line + aux_str + cls.new_line + discrete_audio_str + cls.new_line + vector_str ) if hasattr(output, "input_str"): output.input_str += mm_str if getattr(output, "input_ids", None) is not None: token_ids = tokenizer.encode(mm_str, truncation=False) output.input_ids += token_ids output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] output.sample_mm_counter["audio"] += 1 tag_cursor += 1 if hasattr(output, "input_str"): output.input_str += cls.turn_suffix if getattr(output, "input_ids", None) is not None: token_ids = tokenizer.encode(cls.turn_suffix, truncation=False) output.input_ids += token_ids output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] return output @classmethod def prompt_assistant( cls, output, tokenizer=None, turn: Optional[dict] = None, role: Optional[str] = "assistant", content: Optional[str] = None, is_last_turn=False, is_eval=True, is_llava_pretrain=False, is_after_last_user_turn=False, ): assert content or turn if turn is None: turn = { "content": content, "role": role, } if is_llava_pretrain: if hasattr(output, "input_str"): output.input_str += turn["content"] if getattr(output, "input_ids", None) is not None: content_encoded = tokenizer.encode(turn["content"], truncation=False) output.input_ids += content_encoded output.label_ids += content_encoded return output reasoning_content = turn.get("reasoning_content", "") if ( not reasoning_content and isinstance(turn["content"], str) and "" in turn["content"] ): parts = turn["content"].split("", 1) reasoning_content = parts[0].split("", 1)[-1].lstrip("\n") turn["content"] = parts[1].lstrip("\n") if is_after_last_user_turn and (is_last_turn or reasoning_content): content_to_strip = turn.get("content") or "" stripped_content = content_to_strip.lstrip("\n") if reasoning_content is None: reasoning_content = "" turn["content"] = ( f"\n{reasoning_content.strip()}\n\n\n{stripped_content}" ) if turn.get("tool_calls"): for tool_call in turn["tool_calls"]: func_name = tool_call.get("function", {}).get("name", "") args = tool_call.get("function", {}).get("arguments", {}) if isinstance(args, str): try: args = json.loads(args) except Exception: pass if not isinstance(args, dict): print( f"[error] tool_call.function.arguments가 dict이 아님: type={type(args)}, value={str(args)}" ) assert ( False ), "tool_call.function.arguments는 dict이거나 dict를 나타내는 JSON 문자열이어야 합니다." tool_turn_content = f"\n{func_name}\n" for key, value in args.items(): arg_value = ( json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value ) tool_turn_content += f"{key}\n{arg_value}\n" tool_turn_content += "" if func_name == "t2i_model_generation": assert ( "<|t2i_model_generation_target_discrete_image|>" in turn["content"] ), "t2i_model_generation tool call must have target discrete image tag in content." turn["content"] = turn["content"].replace( "<|t2i_model_generation_target_discrete_image|>", tool_turn_content, ) else: turn["content"] += tool_turn_content pattern = re.compile( r"(<\|image\|>|<\|discrete_image\|>|<\|audio\|>|<\|discrete_audio\|>)" ) all_tags_in_order = [ match.group() for match in pattern.finditer(turn["content"]) ] assert ( len(turn.get("image_urls", [])) + len(turn.get("video_urls", [])) + len(turn.get("audio_urls", [])) ) == len( all_tags_in_order ), f"Number of media URLs does not match number of media tags." if hasattr(output, "input_str"): output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}" if is_eval and is_last_turn: if reasoning_content.strip() == "": output.input_str += f"\n\n\n\n" turn["content"] = stripped_content else: output.input_str += f"{turn['content']}{cls.turn_suffix}" if getattr(output, "input_ids", None) is not None: role_encoded = tokenizer.encode( f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False ) output.input_ids += role_encoded if is_eval and is_last_turn: if reasoning_content.strip() == "": output.input_ids += tokenizer.encode( f"\n\n\n\n", truncation=False ) turn["content"] = stripped_content else: if turn.get("trainable_role", True): output.label_ids += role_encoded else: output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] turn_img_idx = 0 content_parts = pattern.split(turn["content"].strip()) for part in content_parts: part = part.strip() if not part: continue if part not in [ "<|image|>", "<|discrete_image|>", "<|audio|>", "<|discrete_audio|>", ]: content_text = part if hasattr(output, "input_str"): output.input_str += "\n" + content_text if getattr(output, "input_ids", None) is not None: content_encoded = tokenizer.encode( "\n" + content_text, truncation=False ) output.input_ids += content_encoded if turn.get("trainable_content", True): output.label_ids += content_encoded else: output.label_ids += [ IGNORE_INDEX for _ in range(len(content_encoded)) ] continue if part == "<|image|>": file_name = turn.get("image_urls", [])[turn_img_idx] if isinstance(file_name, str) and "#" in file_name: file_name = file_name.split("#")[-1] file_name = os.path.basename(file_name) mime = Preprocessor.prompt_mime( mimes=None, fixed_mime=False, file_name=file_name, tag_idx=output.sample_mm_counter["image"], is_video=False, is_audio=False, seed=None, ) mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}" vector_str = f"{cls.image_start}{cls.image_pad}{cls.image_end}" mm_str = ( cls.new_line + mime_str + cls.new_line + discrete_image_str + cls.new_line + vector_str ) if hasattr(output, "input_str"): output.input_str += mm_str if getattr(output, "input_ids", None) is not None: token_ids = tokenizer.encode(mm_str, truncation=False) output.input_ids += token_ids output.label_ids += [ IGNORE_INDEX for _ in range(len(token_ids)) ] turn_img_idx += 1 output.sample_mm_counter["image"] += 1 elif part == "<|discrete_image|>": discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}" mm_str = cls.new_line + discrete_image_str if hasattr(output, "input_str"): output.input_str += mm_str if getattr(output, "input_ids", None) is not None: token_ids = tokenizer.encode(mm_str, truncation=False) output.input_ids += token_ids output.label_ids += token_ids turn_img_idx += 1 elif part == "<|discrete_audio|>": discrete_audio_str = f"{cls.discrete_audio_start}{cls.discrete_audio_pad}{cls.discrete_audio_end}" mm_str = cls.new_line + discrete_audio_str if hasattr(output, "input_str"): output.input_str += mm_str if getattr(output, "input_ids", None) is not None: token_ids = tokenizer.encode(mm_str, truncation=False) output.input_ids += token_ids if turn.get("trainable_content", True): output.label_ids += token_ids else: output.label_ids += [ IGNORE_INDEX for _ in range(len(token_ids)) ] elif part == "<|audio|>": raise Exception( "Assistant turn에서 <|audio|> 태그는 지원하지 않음. discrete_audio 만 지원함." ) if hasattr(output, "input_str"): output.input_str += cls.turn_suffix if getattr(output, "input_ids", None) is not None: token_ids = tokenizer.encode(cls.turn_suffix, truncation=False) output.input_ids += token_ids if turn.get("trainable_content", True): output.label_ids += token_ids else: output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] return output @classmethod def prompt_tool( cls, output, tokenizer=None, turn: Optional[dict] = None, role: Optional[str] = None, content: Optional[str] = None, eot: Optional[bool] = None, need_start_tag=True, need_end_tag=True, ): assert (content and role) or turn if turn is None: turn = { "content": content, "role": role, "endofturn": eot, } assert ( "tool" == turn["role"] ), f'[warning] unexpected turn["role"]: {turn["role"]}' content_value = turn.get("content", "") if isinstance(content_value, dict): if "response" in content_value: content_str = content_value["response"] else: content_str = json.dumps(content_value, ensure_ascii=False) elif isinstance(content_value, str): try: parsed = json.loads(content_value) if isinstance(parsed, dict): if "response" in parsed: content_str = parsed["response"] else: content_str = json.dumps(parsed, ensure_ascii=False) else: content_str = content_value except (json.JSONDecodeError, TypeError): content_str = content_value else: content_str = str(content_value) turn["content"] = ( f"{turn.get('name', '')}\n{content_str}\n" ) if hasattr(output, "input_str"): if need_start_tag: output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}" output.input_str += f"{cls.new_line}{turn['content']}" if need_end_tag: output.input_str += cls.turn_suffix if getattr(output, "input_ids", None) is not None: if need_start_tag: role_encoded = tokenizer.encode( f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False ) output.input_ids += role_encoded if turn.get("trainable_role", True): output.label_ids += role_encoded else: output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] content = f"{cls.new_line}{turn['content']}" content_encoded = tokenizer.encode(content, truncation=False) if need_end_tag: content_encoded += tokenizer.encode( f"{cls.turn_suffix}", truncation=False ) output.input_ids += content_encoded if turn.get("trainable_content", True): output.label_ids += content_encoded else: output.label_ids += [ IGNORE_INDEX for _ in range(len(content_encoded)) ] return output @classmethod def prompt_etc( cls, output, tokenizer=None, turn: Optional[dict] = None, role: Optional[str] = None, content: Optional[str] = None, eot: Optional[bool] = None, ): assert (content and role) or turn if turn is None: turn = { "content": content, "role": role, "endofturn": eot, } print(f'[warning] unexpected turn["role"]: {turn["role"]}') if hasattr(output, "input_str"): output.input_str += f"{cls.turn_prefix}{turn['role']}\n" output.input_str += f"{turn['content']}{cls.turn_suffix}" if turn.get("stop", False): output.input_str += cls.stop_token if turn.get("endofturn", False): output.input_str += cls.eot if getattr(output, "input_ids", None) is not None: role_encoded = tokenizer.encode( f"{cls.turn_prefix}{turn['role']}\n", truncation=False ) output.input_ids += role_encoded if turn.get("trainable_role", True): output.label_ids += role_encoded else: output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] content = f"{turn['content']}{cls.turn_suffix}" if turn.get("stop", False): content += cls.stop_token if turn.get("endofturn", False): content += cls.eot content_encoded = tokenizer.encode(content, truncation=False) output.input_ids += content_encoded if turn.get("trainable_content", True): output.label_ids += content_encoded else: output.label_ids += [IGNORE_INDEX for _ in range(len(content_encoded))] return output def __call__(self, sample): return self.preprocess_new(sample) @classmethod def batchify( cls, items: List[Dict[str, Any],], device: str = None, ): batch = dict() for item in items: for k, v in item.items(): if isinstance(v, torch.Tensor): if device is not None: v = v.to(device=device) elif k == "pixel_values": v = [_v.to(device=device) for _v in v] if k not in batch: batch[k] = [ v, ] else: batch[k].append(v) for k, v in batch.items(): if isinstance(v[0], torch.Tensor): if k in ["image_grid_thw", "video_grid_thw"]: batch[k] = torch.cat(v, dim=0) continue batch[k] = torch.stack(v, dim=0) batch["video_grid_thw"] = None batch["pixel_values_videos"] = None return batch def convert_wds_to_datalake( self, img: Union[PIL.Image.Image, Dict[str, PIL.Image.Image]] = {}, json: Dict[str, Any] = {}, benchmark: Optional[str] = None, video: Union[io.BytesIO, Dict[str, io.BytesIO]] = {}, audio: Union[io.BytesIO, Dict[str, io.BytesIO]] = {}, ): if "lines" in json: del json["lines"] if "paragraphs" in json: del json["paragraphs"] assert json["meta"]["type"] in [ "caption", "vqa", "textread", ], f"{json['meta']['path']}, {json['meta']['type']}: The dataset type should be one of them: caption, vqa, textread." sample = {"vlm": {}} sample["vlm"] = get_wds_default_config( json["meta"], existing_default_config=self.wds_default_config ) sample["vlm"]["data_name"] = json["meta"].get("name", "unk") sample["vlm"]["data_type"] = ( "wds" if (isinstance(img, PIL.Image.Image) and img) or (isinstance(img, dict) and len(img) > 0) else "sft1" ) sample["vlm"]["sample_id"] = json.get("qa_id", None) sample["vlm"]["category"] = json.get("category", None) sample["vlm"]["data_info"] = json.get("data_info", dict()) sample["vlm"]["options"] = None if "choices_en" in sample["vlm"]["data_info"]: if sample["vlm"]["options"] is None and json["meta"]["lang"] == "en": sample["vlm"]["options"] = sample["vlm"]["data_info"]["choices_en"] sample["vlm"]["options_en"] = sample["vlm"]["data_info"]["choices_en"] if "choices_ko" in sample["vlm"]["data_info"]: if sample["vlm"]["options"] is None and json["meta"]["lang"] == "ko": sample["vlm"]["options"] = sample["vlm"]["data_info"]["choices_ko"] sample["vlm"]["options_ko"] = sample["vlm"]["data_info"]["choices_ko"] sample["vlm"]["image_index"] = json.get( "image_index", json.get("img_url", None) ) if sample["vlm"].get("video", False): is_multi_image_dataset = False else: is_multi_image_dataset, img, json = convert_format_for_multi_image( img, json ) if json["meta"]["type"] == "textread": key = "words" elif json["meta"].get("subtask", "") == "region": key = f"regions_{json['meta']['lang']}" elif json["meta"]["type"] == "vqa": key = f"qa_pairs_{json['meta']['lang']}" elif json["meta"]["type"] == "caption": key = f"captions_{json['meta']['lang']}" else: raise ConditionalError( f"wrong task type in wds config: {sample['vlm']['data_name']}" ) turns = [ { "role": "tool_list", "content": "", "content_type": "text", "trainable_role": False, "trainable_content": False, "stop": False, "debuggingInfo": {}, "meta": {}, "candidates": [], "endofturn": False, }, { "role": "system", "content_type": "text", "candidates": [], "trainable_role": False, "trainable_content": False, "stop": False, "debuggingInfo": {}, "meta": {}, "content": "", "endofturn": False, }, ] if json["meta"].get("llava_pretrain", False): sample["vlm"]["llava_pretrain"] = True use_task_prompt = json["meta"].get( "use_task_prompt", self.wds_default_config["use_task_prompt"] ) get_random = json["meta"].get( "get_random", self.wds_default_config["get_random"] ) reasoning = json["meta"].get("reasoning", self.wds_default_config["reasoning"]) try: if key not in json: key = key[:-3] assert key in json if len(json[key]) == 0: key = key[:-3] assert key in json except: raise ConditionalError( f"{key} key is not in json? dataset name: {sample['vlm']['data_name']}" ) first_turn = True if "region" in key: json[key] = json[key]["00"] sample["vlm"]["multiturn_n_samples"] = 1 if ( not is_multi_image_dataset and sample["vlm"]["multiturn_n_samples"] > 1 or "region" in key ): json[key] = sampling_multiturn_single_img( json[key], sample["vlm"]["multiturn_n_samples"], sample["vlm"]["multiturn_preserve_order"], sample["vlm"]["multiturn_continuous"], ) if sample["vlm"].get("video", False): for qa in json[key]: vid_src = [] user = { "role": "user", "content_type": "text", "candidates": [], "trainable_role": False, "trainable_content": False, "stop": False, "debuggingInfo": {}, "meta": {}, "image_urls": [], "image_metas": [], "video_urls": [], "video_metas": [], "audio_urls": [], "audio_metas": [], "content": "", "endofturn": False, } instruct_prompt, task_prompt = hcx_vision_prompter( task=json["meta"]["type"], subtask=json["meta"].get("subtask", None), lang=json["meta"]["lang"], get_random=get_random, use_task_prompt=use_task_prompt, ) prompt = qa[0] answer = qa[-1] if reasoning else qa[1] if first_turn: user["video_metas"].append({"lens": []}) user["content"] += "<|video|>" prompt = task_prompt.format(prompt) if "entities" in json: user["video_metas"][0]["lens"] = json["entities"].get("00", []) if isinstance(video, dict): vid_src.append(video["00"]) else: vid_src.append(video) first_turn = False user["video_urls"] = vid_src user["content"] += prompt assistant = { "candidates": [], "content": answer, "content_type": "text", "debuggingInfo": {}, "meta": {}, "role": "assistant", "trainable_content": True, "trainable_role": True, "stop": False, "endofturn": True, } turns.append(user) turns.append(assistant) else: if key.startswith("qa_pairs") or key.startswith("captions"): if self.mode != "train" and key.startswith("qa_pairs"): qas = dict() for qa in json[key]: q = qa[0] if q not in qas: qas[q] = list() for _i, _e in enumerate(qa[1:]): if len(qas[q]) <= _i: qas[q].append(list()) qas[q][_i].append(_e) json[key] = [ [ k, ] + v for k, v in qas.items() ] if self.mode != "train": json[key] = json[key][:1] for qa in json[key]: img_src = [] user = { "role": "user", "content_type": "text", "candidates": [], "trainable_role": False, "trainable_content": False, "stop": False, "debuggingInfo": {}, "meta": {}, "image_urls": [], "image_metas": [], "video_urls": [], "video_metas": [], "audio_urls": [], "audio_metas": [], "content": "", "endofturn": False, } img_keys = re.findall(r"", qa[0]) video_keys = re.findall(r"", qa[0]) audio_keys = re.findall(r"", qa[0]) if key.startswith("qa_pairs"): if len(qa) > 2: sample_id = qa[2] if ( isinstance(sample_id, (list, tuple)) and len(sample_id) > 0 ): sample_id = sample_id[0] sample["vlm"]["sample_id"] = sample_id instruct_prompt, task_prompt = hcx_vision_prompter( task=json["meta"]["type"], subtask=json["meta"].get("subtask", None), lang=json["meta"]["lang"], get_random=get_random, use_task_prompt=use_task_prompt, ) if json["meta"]["type"] == "vqa": prompt = qa[0] answer = qa[-1] if reasoning else qa[1] elif json["meta"]["type"] == "caption": prompt = task_prompt.format("") answer = qa if first_turn or self.mode != "train": if json["meta"]["type"] == "vqa": prompt = task_prompt.format(prompt) if first_turn and not is_multi_image_dataset: user["image_metas"].append({"words": [], "lens": []}) if "" in prompt: prompt = prompt.replace("", "<|image|>") else: user["content"] += "<|image|>" user["image_metas"][0]["words"] = json.get("words", {}).get( "00", [] ) if "objects" in json: user["image_metas"][0]["lens"] = json["objects"].get( "00", [] ) elif "entities" in json: user["image_metas"][0]["lens"] = json["entities"].get( "00", [] ) if isinstance(img, dict): img_src.append(img["00"]) else: img_src.append(img) elif len(img_keys) > 0: for i, key in enumerate(img_keys): user["image_metas"].append({"words": [], "lens": []}) if f"" in prompt: prompt = prompt.replace(f"", "<|image|>") else: user["content"] += "<|image|>" img_src.append(img[key]) _words = json.get("words", {}) if isinstance(_words, dict): _words = _words.get(key, []) user["image_metas"][i]["words"] = _words if "objects" in json: _objects = json["objects"].get(key, []) if isinstance(_objects, dict): _objects = _objects.get(key, []) user["image_metas"][i]["lens"] = _objects if "entities" in json: _entities = json["entities"].get(key, []) if isinstance(_entities, dict): _entities = _entities.get(key, []) user["image_metas"][i]["lens"] = _entities user["image_urls"] = img_src if len(audio_keys) > 0: for i, key in enumerate(audio_keys): if isinstance(audio, dict): user["audio_urls"].append(audio[key]) else: user["audio_urls"].append(audio) user["audio_metas"].append( { "format": "wav", "note": "This audio sample is passed to convert_wds_to_datalake function.", } ) if f"" in prompt: prompt = prompt.replace(f"", "<|audio|>") else: user["content"] += "<|audio|>" user["content"] += prompt content, candidates = None, list() if self.mode != "train": if isinstance(answer, (int, float)): pass elif isinstance(answer, str): if answer != "None": try: answer = ast.literal_eval(answer) except Exception as ex: pass if not isinstance(answer, (list, tuple)): answer = [ answer, ] candidates += answer[1:] answer = answer[0] content = answer elif isinstance(answer, (list, tuple)): for _idx, _answer in enumerate(answer): if isinstance(_answer, str): if isinstance(benchmark, str) and benchmark in [ "textvqa", ]: try: _answer = ast.literal_eval(_answer) except Exception as ex: pass if isinstance(_answer, dict): _answer = str(_answer) if not isinstance(_answer, (list, tuple)): _answer = [ _answer, ] if _idx == 0: content = _answer[0] candidates += _answer[1:] else: candidates += _answer if isinstance(content, (int, float)): content = str(content) assert content is None or isinstance(content, str) for _idx, _candidate in enumerate(candidates): if isinstance(_candidate, (int, float)): candidates[_idx] = str(_candidate) assert isinstance(candidates[_idx], str) mcqa_gt = sample["vlm"]["data_info"].get("choice_answer", None) if isinstance(mcqa_gt, str): content = mcqa_gt assistant = { "candidates": candidates, "content": answer if self.mode == "train" else content, "content_type": "text", "debuggingInfo": {}, "meta": {}, "role": "assistant", "trainable_content": True, "trainable_role": True, "stop": False, "endofturn": True, } turns.append(user) turns.append(assistant) elif key == "words": img_src = [] user = { "role": "user", "content_type": "text", "candidates": [], "trainable_role": False, "trainable_content": False, "stop": False, "debuggingInfo": {}, "meta": {}, "image_urls": [], "image_metas": [], "video_urls": [], "video_metas": [], "audio_urls": [], "audio_metas": [], "content": "<|image|>", "endofturn": False, } instruct_prompt, task_prompt = hcx_vision_prompter( task=json["meta"]["type"], subtask=json["meta"].get("subtask", None), lang=json["meta"]["lang"], get_random=get_random, use_task_prompt=use_task_prompt, ) user["content"] += task_prompt user["image_metas"].append({"words": [], "lens": []}) user["image_metas"][0]["words"] = json["words"]["00"] if "entities" in json: user["image_metas"][0]["lens"] = json["entities"].get("00", []) img_src.append(img["00"]) user["image_urls"] = img_src words_list = [ d["text"].strip() for d in json["words"]["00"] if d["text"] ] gt = " ".join(words_list) assistant = { "candidates": [], "content": gt, "content_type": "text", "debuggingInfo": {}, "meta": {}, "role": "assistant", "trainable_content": True, "trainable_role": True, "stop": False, "endofturn": True, } turns.append(user) turns.append(assistant) elif key.startswith("regions"): for region in json[key]: img_src = [] user = { "role": "user", "content_type": "text", "candidates": [], "trainable_role": False, "trainable_content": False, "stop": False, "debuggingInfo": {}, "meta": {}, "image_urls": [], "image_metas": [], "video_urls": [], "video_metas": [], "audio_urls": [], "audio_metas": [], "content": "<|image|><|region|>", "endofturn": False, } instruct_prompt, task_prompt = hcx_vision_prompter( task=json["meta"]["type"], subtask=json["meta"].get("subtask", None), lang=json["meta"]["lang"], get_random=get_random, use_task_prompt=use_task_prompt, ) sample["vlm"]["query_template"] = [task_prompt] user["image_metas"].append({"words": [], "lens": []}) user["image_metas"][0]["region"] = region if "words" in json: user["image_metas"][0]["words"] = json["words"].get("00", []) if "objects" in json: user["image_metas"][0]["lens"] = json["objects"].get("00", []) if "entities" in json: user["image_metas"][0]["lens"] = json["entities"].get("00", []) img_src.append(img["00"]) user["image_urls"] = img_src assistant = { "candidates": [], "content": region["text"], "content_type": "text", "debuggingInfo": {}, "meta": {}, "role": "assistant", "trainable_content": True, "trainable_role": True, "stop": False, "endofturn": True, } turns.append(user) turns.append(assistant) else: raise ConditionalError( f"wrong task type in wds config: {sample['vlm']['data_name']}" ) sample["data"] = turns return sample def preprocess_new(self, sample): config = sample.get("vlm", {}) if config["data_type"] in ["sft1", "datalake"]: default_config = copy.deepcopy(self.default_config) default_config.update(config) config = default_config idx_for_debug = sample.get("idx", -1) turns = sample["data"] if "data" in sample else sample["messages"] if self.random_system_prompt and self.rng.random() < config.get( "random_system_prob", 0.0 ): for turn in turns: if turn["role"] == "system": turn["content"] = self.random_system_prompt break if sample.get("tools", None) is None: sample["tools"] = [] if len(sample["tools"]) == 0: if ( self.rng.random() < config.get("random_tool_prob", 0.005) and len(self.common_tools) > 0 ): max_n_tools = min(7, len(self.common_tools)) tool_counts = np.arange(1, max_n_tools + 1) tool_count_weights = 1.0 / tool_counts tool_count_weights = tool_count_weights / tool_count_weights.sum() n_tools = int(self.rng.choice(tool_counts, p=tool_count_weights)) idxs = np.arange(len(self.common_tools)) weights = 1.0 / (idxs + 1) weights[0] += 1.0 weights = weights / weights.sum() chosen_indices = self.rng.choice( len(self.common_tools), size=n_tools, replace=False, p=weights ) self.rng.shuffle(chosen_indices) sample["tools"] = [self.common_tools[i] for i in chosen_indices] if "tools" in sample and sample["tools"]: tool_prompt = [] tool_prompt.append("# Tools\n\n") tool_prompt.append( "You may call one or more functions to assist with the user query.\n\n" ) tool_prompt.append( "You are provided with function signatures within XML tags:\n" ) tool_prompt.append("\n") for tool in sample["tools"]: tool_prompt.append(json.dumps(tool, ensure_ascii=False)) tool_prompt.append("\n\n\n") tool_prompt.append( "For each function call, output the function name and arguments within the following XML format:\n" ) tool_prompt.append("{function-name}\n") tool_prompt.append("{arg-key-1}\n") tool_prompt.append("{arg-value-1}\n") tool_prompt.append("{arg-key-2}\n") tool_prompt.append("{arg-value-2}\n") tool_prompt.append("...\n") tool_prompt.append("") tool_prompt = "".join(tool_prompt) else: tool_prompt = "" multiturn_n_sample = config.get("multiturn_n_samples", 0) if multiturn_n_sample > 0 and self.mode == "train": turns = self._sampling_multiturn( turns, multiturn_n_sample, multiturn_preserve_order=config.get("multiturn_preserve_order", True), multiturn_continuous=config.get("multiturn_continuous", False), ) for i, turn in enumerate(turns): if turn["role"] == "user": if "img_src" in turn: turns[i]["image_urls"] = turn["img_src"] turns[i]["image_metas"] = turn["meta"] for j, turn_img_meta in enumerate(turns[i]["image_metas"]): if "entities" in turn_img_meta: turns[i]["image_metas"][j]["lens"] = turn_img_meta[ "entities" ] turns[i]["meta"] = {} max_image_cnt = config.get("max_image_cnt", 20) if max_image_cnt > 0 and config["data_type"] != "sft1": n_imgs = {} for i, turn in enumerate(turns): if turn["role"] == "user": n_imgs[i] = len(turn.get("image_urls", [])) assert ( n_imgs[i] <= max_image_cnt ), "skip sample if image_nums exceeds max_image_count per turn" if sum(n_imgs.values()) > max_image_cnt: img_count = 0 for k, v in reversed(list(n_imgs.items())): img_count += v if img_count > max_image_cnt: break img_count = sum(n_imgs.values()) - max_image_cnt for i in range(k + 1): if turns[i]["role"] == "user": turns[i]["content"], n_removed1 = re.subn( r"", "", turns[i]["content"].strip(), count=img_count, ) img_count -= n_removed1 turns[i]["content"], n_removed2 = re.subn( r"<\|image\|>", "", turns[i]["content"].strip(), count=img_count, ) img_count -= n_removed2 n_removed_imgs = n_removed1 + n_removed2 turns[i]["image_urls"] = turns[i]["image_urls"][n_removed_imgs:] if n_removed_imgs > 0 and len(turns[i]["image_urls"]) == 0: idx = i while True: idx += 1 turns[idx]["trainable_role"] = False turns[idx]["trainable_content"] = False if turns[idx]["role"] == "assistant": break n_imgs_after = {} for i, turn in enumerate(turns): if turn["role"] == "user": n_imgs_after[i] = len(turn.get("image_urls", [])) assert sum(n_imgs_after.values()) > 0, "The n_imgs of vlm data is zero." n_mm_after = {} for i, turn in enumerate(turns): if turn["role"] == "user" or turn["role"] == "assistant": n_mm_after[i] = ( len(turn.get("image_urls", [])) + len(turn.get("video_urls", [])) + len(turn.get("audio_urls", [])) ) assert sum(n_mm_after.values()) > 0, "The n_mm of omni data is zero." queries, gts = list(), list() output = Processed_sample( input_str="", input_ids=[], label_ids=[], imgs=[], discrete_imgs=[], videos=[], videos_duration=[], video_audios=[], audios=[], audios_duration=[], discrete_audios=[], sample_mm_counter={ "image": 0, "video": 0, "audio": 0, }, ) system_role_count = 0 last_user_idx = max( (i for i, d in enumerate(turns) if d.get("role") == "user"), default=-1 ) for i, turn in enumerate(turns): if turn["role"] == "tool_list": continue elif turn["role"] == "system": if config.get("llava_pretrain", False): continue output = Preprocessor.prompt_system( turn=turn, output=output, tokenizer=self.tokenizer, seed=self.rng, tool_prompt=tool_prompt, system_role_count=system_role_count, ) system_role_count += 1 elif turn["role"].startswith("user"): output = Preprocessor.load_mm( output=output, img_dir=config.get("img_dir", ""), turn=turn, prepare_input_fn=self.prepare_input_fn, max_image_cnt=max_image_cnt, video_max_num_frames=self.video_max_num_frames, video_max_pixels=self.video_max_pixels, use_audio=self.train_audio, ) output = Preprocessor.prompt_user( output=output, tokenizer=self.tokenizer, turn=turn, is_train=True if self.mode == "train" else False, fixed_mime=config.get("fixed_mime", False), mimes=self.mimes, query_template=config.get("query_template", None), config=config, seed=self.rng, ) queries.append(turn["content"].replace("<|image|>", "").strip()) elif turn["role"].startswith("assistant"): output = Preprocessor.load_mm( output=output, img_dir=config.get("img_dir", ""), turn=turn, prepare_input_fn=self.prepare_input_fn, max_image_cnt=max_image_cnt, video_max_num_frames=self.video_max_num_frames, video_max_pixels=self.video_max_pixels, use_audio=self.train_audio, ) is_after_last_user = i > last_user_idx is_first_assistant_after_last_user = False if is_after_last_user: is_first_assistant_after_last_user = all( turns[j]["role"] != "assistant" for j in range(last_user_idx + 1, i) ) output = Preprocessor.prompt_assistant( output=output, tokenizer=self.tokenizer, turn=turn, is_last_turn=is_first_assistant_after_last_user, is_eval=True if self.mode != "train" else False, is_llava_pretrain=config.get("llava_pretrain", False), is_after_last_user_turn=is_after_last_user, ) _gts = turn["content"] if isinstance(_gts, str): _gts = [ _gts, ] if "candidates" in turn and len(turn["candidates"]) > 0: for _candidates in turn["candidates"]: if isinstance(_candidates, str): _gts += [ _candidates, ] elif isinstance(turn["candidates"][0], (list, tuple)): _gts += _candidates gts.append(_gts) elif turn["role"] == "tool": if config.get("llava_pretrain", False): continue output = Preprocessor.prompt_tool( output=output, tokenizer=self.tokenizer, turn=turn, need_start_tag=( True if (i == 0 or turns[i - 1].get("role") != "tool") else False ), need_end_tag=( True if (i == (len(turns) - 1) or turns[i + 1].get("role") != "tool") else False ), ) else: if config.get("llava_pretrain", False): continue import pdb import sys class ForkedPdb(pdb.Pdb): """A Pdb subclass that may be used from a forked multiprocessing child""" def interaction(self, *args, **kwargs): _stdin = sys.stdin try: sys.stdin = open("/dev/stdin") pdb.Pdb.interaction(self, *args, **kwargs) finally: sys.stdin = _stdin ForkedPdb().set_trace() output = Preprocessor.prompt_etc( output=output, tokenizer=self.tokenizer, turn=turn, ) pixel_values = [] mm_query_lengths = [] discrete_pixel_values = [] image_ratios = [] discrete_image_query_lengths = [] labels = output.label_ids input_ids = output.input_ids total_mm_query_length = 0 is_sft1 = False if config["data_type"] == "sft1": if self.sequence_parallel_size > 1: if len(input_ids) % self.sequence_parallel_size != 0: input_ids += [self.tokenizer.pad_token_id] * ( self.sequence_parallel_size - (len(input_ids) % self.sequence_parallel_size) ) labels += [IGNORE_INDEX] * ( self.sequence_parallel_size - (len(labels) % self.sequence_parallel_size) ) input_ids = input_ids[ : (len(input_ids) // self.sequence_parallel_size) * self.sequence_parallel_size ] labels = labels[ : (len(labels) // self.sequence_parallel_size) * self.sequence_parallel_size ] input_ids = torch.tensor(input_ids[-self.decoder_max_length :]) labels = torch.tensor(labels[-self.decoder_max_length :]) is_sft1 = True dummy_preprocess_results = self.prepare_input_fn.image_processor( Image.new("RGB", (224, 224), (0, 0, 0)) ) dummy_pixel_values = torch.from_numpy( np.concatenate([dummy_preprocess_results.pixel_values], axis=0) ) dummy_grid_thw = torch.from_numpy( np.concatenate([dummy_preprocess_results.image_grid_thw], axis=0) ) image_grid_thw = [] for img in output.imgs: w, h = img.size img = self._resize_min_edge(img) preprocess_results = self.prepare_input_fn.image_processor([img]) pixel_values.append(preprocess_results.pixel_values) image_grid_thw.append(preprocess_results.image_grid_thw) mm_query_lengths.append(preprocess_results.pixel_values.shape[0] // 4) if len(output.imgs) == 0: pixel_values = torch.zeros(0, 1176) image_grid_thw = torch.zeros(0, 3, dtype=torch.long) else: pixel_values = torch.from_numpy(np.concatenate(pixel_values, axis=0)) image_grid_thw = torch.from_numpy(np.concatenate(image_grid_thw, axis=0)) for img in output.discrete_imgs: w, h = img.size img_ratio = self._find_best_ratio_token([h, w]) image_ratios.append(img_ratio) discrete_pixel_value = img.resize((384, 384), Image.BICUBIC) discrete_pixel_tensor = to_tensor(discrete_pixel_value) assert discrete_pixel_tensor.shape == ( 3, 384, 384, ), f"Unexpected discrete_pixel_tensor shape: {discrete_pixel_tensor.shape}" assert not torch.isnan( discrete_pixel_tensor ).any(), "discrete_pixel_tensor contains NaN" assert not torch.isinf( discrete_pixel_tensor ).any(), "discrete_pixel_tensor contains Inf" pixel_min = discrete_pixel_tensor.min().item() pixel_max = discrete_pixel_tensor.max().item() assert ( 0.0 <= pixel_min <= 1.0 and 0.0 <= pixel_max <= 1.0 ), f"discrete_pixel_tensor values out of range [0, 1]: min={pixel_min}, max={pixel_max}" discrete_pixel_values.append(discrete_pixel_tensor) discrete_image_query_lengths.append(729) if len(output.discrete_imgs) == 0: discrete_pixel_values = torch.zeros(0, 3, 384, 384) else: discrete_pixel_values = torch.stack(discrete_pixel_values, dim=0) assert discrete_pixel_values.shape[1:] == ( 3, 384, 384, ), f"Unexpected stacked discrete_pixel_values shape: {discrete_pixel_values.shape}" assert not torch.isnan( discrete_pixel_values ).any(), "Stacked discrete_pixel_values contains NaN" assert not torch.isinf( discrete_pixel_values ).any(), "Stacked discrete_pixel_values contains Inf" pixel_values_videos = None video_grid_thw = None if self.train_video: pixel_values_videos = [] video_grid_thw = [] video_query_lengths = [] for video in output.videos: preprocess_results = self.prepare_input_fn.video_processor([video]) pixel_values_videos.append(preprocess_results.pixel_values_videos) video_grid_thw.append(preprocess_results.video_grid_thw) video_query_lengths.append( preprocess_results.pixel_values_videos.shape[0] // 4 ) if len(output.videos) == 0: pixel_values_videos = torch.zeros(0, 1176) video_grid_thw = torch.zeros(0, 3, dtype=torch.long) else: pixel_values_videos = torch.from_numpy( np.concatenate(pixel_values_videos, axis=0) ) video_grid_thw = torch.from_numpy( np.concatenate(video_grid_thw, axis=0) ) video_audio_values = [] video_audio_masks = [] video_audio_query_lengths = [] if self.train_video and hasattr(output, "video_audios") and output.video_audios: for idx, video_audio_chunks in enumerate(output.video_audios): if video_audio_chunks: processed_audio_values = [] processed_audio_masks = [] chunk_output_lengths = [] for chunk in video_audio_chunks: if isinstance(chunk, torch.Tensor): chunk_np = chunk.cpu().numpy() else: chunk_np = chunk preprocess_results = self.prepare_audio_input_fn( [chunk_np], sampling_rate=self.prepare_audio_input_fn.sampling_rate, return_attention_mask=True, padding="max_length", ) audio_value = preprocess_results.input_features[0] audio_mask = preprocess_results.attention_mask[0] mask_sum = int(audio_mask.sum()) input_lengths = (mask_sum - 1) // 2 + 1 output_lengths = (input_lengths - 2) // 2 + 1 chunk_output_lengths.append(output_lengths) processed_audio_values.append(torch.from_numpy(audio_value)) processed_audio_masks.append(torch.from_numpy(audio_mask)) pool_size = 25 if self.video_audio_compressor_type is not None: total_valid_len = sum(chunk_output_lengths) total_audio_query_length = ( total_valid_len + pool_size - 1 ) // pool_size else: total_audio_query_length = sum( (valid_len + pool_size - 1) // pool_size for valid_len in chunk_output_lengths ) video_audio_values.append(processed_audio_values) video_audio_masks.append(processed_audio_masks) video_audio_query_lengths.append(total_audio_query_length) import os if ( int(os.environ.get("RANK", -1)) == 0 and total_audio_query_length == 177 ): print( f"\n[PREPROCESSOR VIDEO - 177 TOKENS DETECTED!] total_audio_query_length={total_audio_query_length}, num_chunks={len(processed_audio_masks)}" ) for chunk_idx, mask_tensor in enumerate(processed_audio_masks): chunk_mask_sum = int(mask_tensor.sum()) chunk_input_len = (chunk_mask_sum - 1) // 2 + 1 chunk_output_len = (chunk_input_len - 2) // 2 + 1 chunk_pooled = (chunk_output_len + 24) // 25 print( f" Chunk {chunk_idx}: mask_sum={chunk_mask_sum}, output_len={chunk_output_len}, pooled={chunk_pooled}" ) print() else: video_audio_values.append([]) video_audio_masks.append([]) video_audio_query_lengths.append(0) dummy_video_preprocess_results = self.prepare_input_fn.video_processor( [Image.new("RGB", (224, 224), (0, 0, 0))] * 3 ) dummy_pixel_values_videos = torch.from_numpy( np.concatenate([dummy_video_preprocess_results.pixel_values_videos], axis=0) ) dummy_video_grid_thw = torch.from_numpy( np.concatenate([dummy_video_preprocess_results.video_grid_thw], axis=0) ) dummy_video_preprocess_results = self.prepare_audio_input_fn( [np.zeros(self.prepare_audio_input_fn.sampling_rate * 3, dtype=np.float32)], sampling_rate=self.prepare_audio_input_fn.sampling_rate, return_attention_mask=True, padding="max_length", ) dummy_video_audio_values = torch.from_numpy( dummy_video_preprocess_results.input_features ) dummy_video_audio_masks = torch.from_numpy( dummy_video_preprocess_results.attention_mask ) audio_values = None discrete_audio_values = None audio_masks = None dummy_preprocess_results = self.prepare_audio_input_fn( [np.zeros(self.prepare_audio_input_fn.sampling_rate * 3, dtype=np.float32)], sampling_rate=self.prepare_audio_input_fn.sampling_rate, return_attention_mask=True, padding="max_length", ) dummy_audio_values = torch.from_numpy(dummy_preprocess_results.input_features) dummy_audio_masks = torch.from_numpy(dummy_preprocess_results.attention_mask) if self.train_audio: audio_values = [] discrete_audio_values = [] audio_masks = [] audio_query_lengths = [] discrete_audio_query_lengths = [] if len(output.audios) > 99: raise ConditionalError( f"Too many audio segments in one sample: {len(output.audios)} audios." ) for audio in output.audios: chunks = [] for i in range( 0, len(audio), 30 * self.prepare_audio_input_fn.sampling_rate ): chunks.append( audio[i : i + 30 * self.prepare_audio_input_fn.sampling_rate] ) num_of_chunks = len(chunks) preprocess_results = self.prepare_audio_input_fn( chunks, sampling_rate=self.prepare_audio_input_fn.sampling_rate, return_attention_mask=True, padding="max_length", ) audio_value = preprocess_results.input_features audio_mask = preprocess_results.attention_mask audio_values.append(audio_value) audio_masks.append(audio_mask) input_lengths = int(audio_mask.sum()) input_lengths = (input_lengths - 1) // 2 + 1 output_lengths = (input_lengths - 2) // 2 + 1 audio_query_lengths.append(output_lengths) if len(output.audios) == 0: audio_values = torch.zeros(0, 128, 3000) audio_masks = torch.zeros(0, 3000) else: audio_values = torch.from_numpy(np.concatenate(audio_values, axis=0)) audio_masks = torch.from_numpy(np.concatenate(audio_masks, axis=0)) for audio in output.discrete_audios: audio_length = len(audio) assert audio_length >= MIN_DISCRETE_AUDIO_CHUNK_SAMPLES, ( f"discrete_audio is too short ({audio_length} samples < {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES}). " f"This will cause 0-dim/empty tensor in CosyVoice encoder. " f"Skip this sample." ) max_audio_length = 600 * DEFAULT_SAMPLE_RATE audio_duration_sec = audio_length / DEFAULT_SAMPLE_RATE assert ( audio_length <= max_audio_length ), f"discrete_audio is too long ({audio_length} samples = {audio_duration_sec:.1f}s > 600s). " assert not torch.isnan(audio).any(), ( f"discrete_audio contains NaN values! " f"This will cause CUDA illegal memory access. Skip this sample." ) assert not torch.isinf(audio).any(), ( f"discrete_audio contains Inf values! " f"This will cause CUDA illegal memory access. Skip this sample." ) audio_min, audio_max = audio.min().item(), audio.max().item() assert -100.0 <= audio_min <= 100.0 and -100.0 <= audio_max <= 100.0, ( f"discrete_audio has extreme values (min={audio_min:.2f}, max={audio_max:.2f}). " f"Expected roughly [-1, 1] range. This indicates corrupted audio. Skip this sample." ) discrete_audio_values.append(audio) if audio_length > 80 * DEFAULT_SAMPLE_RATE: chunk_size = 80 * DEFAULT_SAMPLE_RATE total_code_len = 0 for start in range(0, audio_length, chunk_size): end = min(start + chunk_size, audio_length) if ( end < audio_length and audio_length - end < MIN_DISCRETE_AUDIO_CHUNK_SAMPLES ): end = audio_length chunk_length = end - start assert chunk_length >= MIN_DISCRETE_AUDIO_CHUNK_SAMPLES, ( f"chunk_length={chunk_length} < {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES}. This should never happen with our chunking logic. " f"audio_length={audio_length}, start={start}, end={end}. Skip this sample." ) mel_len = chunk_length // 160 assert mel_len > 0, ( f"mel_len={mel_len} is invalid (chunk_length={chunk_length}). " f"This will cause illegal memory access in AudioEncoder. Skip this sample." ) after_conv1 = (mel_len + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 code_len = (after_conv1 + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 assert code_len > 0, ( f"code_len={code_len} is invalid (mel_len={mel_len}, after_conv1={after_conv1}). " f"This will cause illegal memory access. Skip this sample." ) total_code_len += code_len if end >= audio_length: break assert total_code_len > 0, ( f"total_code_len={total_code_len} is invalid after processing all chunks. " f"audio_length={audio_length}. This should never happen. Skip this sample." ) audio_duration_sec = audio_length / DEFAULT_SAMPLE_RATE max_expected_codes = int(audio_duration_sec * 25 * 1.1) assert total_code_len <= max_expected_codes, ( f"total_code_len={total_code_len} is suspiciously large (max_expected={max_expected_codes}). " f"audio_length={audio_length} ({audio_duration_sec:.1f}s). " f"Expected ~{int(audio_duration_sec * 25)} tokens (25 tokens/sec). " f"This indicates calculation error. Skip this sample." ) discrete_audio_query_lengths.append(total_code_len) else: mel_len = audio_length // 160 assert mel_len > 0, ( f"mel_len={mel_len} is invalid (audio_length={audio_length}). " f"This will cause illegal memory access in AudioEncoder. Skip this sample." ) after_conv1 = (mel_len + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 code_len = (after_conv1 + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 assert code_len > 0, ( f"Calculated code_len={code_len} is invalid (audio_length={audio_length}, " f"mel_len={mel_len}, after_conv1={after_conv1}). " f"This indicates corrupted audio data. Skip this sample." ) assert code_len <= 2048, ( f"code_len={code_len} exceeds freqs_cis max length (2048). " f"Audio length: {audio_length / DEFAULT_SAMPLE_RATE:.1f}s (max ~82s for single chunk). " f"Expected ~{int((audio_length / DEFAULT_SAMPLE_RATE) * 25)} tokens at 25 tokens/sec. " f"This will cause illegal memory access in apply_rotary_emb. Skip this sample." ) discrete_audio_query_lengths.append(code_len) img_start_ids = [ i for i, token in enumerate(input_ids) if token == self.img_token ] assert len(img_start_ids) == len(mm_query_lengths) for i, length in zip( range(len(mm_query_lengths) - 1, -1, -1), mm_query_lengths[::-1] ): labels[img_start_ids[i] : img_start_ids[i] + 1] = [IGNORE_INDEX] * length input_ids[img_start_ids[i] : img_start_ids[i] + 1] = [ self.img_token ] * length total_mm_query_length += length discrete_image_start_ids = [ i for i, token in enumerate(input_ids) if token == self.discrete_image_token ] assert len(discrete_image_start_ids) == len(discrete_image_query_lengths) assert len(discrete_image_start_ids) == len( image_ratios ), "discrete_image_start_ids and image_ratios length mismatch" for idx in range(len(discrete_image_query_lengths) - 1, -1, -1): i = discrete_image_start_ids[idx] length = discrete_image_query_lengths[idx] ratio_token_id = image_ratios[idx] assert ( length == 729 ), f"discrete_image_query_length must be 729, but got {length}" token_sequence = [ratio_token_id] for token_idx in range(length): token_sequence.append(self.discrete_image_token) if (token_idx + 1) % 27 == 0: token_sequence.append(self.discrete_image_eol_token) token_sequence.append(self.discrete_image_eof_token) total_length = len(token_sequence) if labels[i] == IGNORE_INDEX: labels[i : i + 1] = [IGNORE_INDEX] * total_length else: labels[i : i + 1] = token_sequence input_ids[i : i + 1] = token_sequence if self.train_video: vid_start_ids = [ i for i, token in enumerate(input_ids) if token == self.video_token ] for idx in range(len(vid_start_ids) - 1, -1, -1): pos = vid_start_ids[idx] num_frames = int(video_grid_thw[idx][0]) frame_query_length = video_query_lengths[idx] has_video_audio = ( idx < len(video_audio_query_lengths) and video_audio_query_lengths[idx] > 0 ) if has_video_audio: total_audio_tokens = video_audio_query_lengths[idx] token_sequence = [] if num_frames > 0: frame_base = frame_query_length // num_frames frame_remainder = frame_query_length % num_frames assert frame_remainder == 0, ( f"frame_query_length({frame_query_length}) must be divisible by num_frames({num_frames}). " f"Each frame produces fixed number of tokens. Got remainder={frame_remainder}." ) audio_base = total_audio_tokens // num_frames audio_remainder = total_audio_tokens % num_frames for frame_idx in range(num_frames): frame_tokens = frame_base + ( 1 if frame_idx < frame_remainder else 0 ) token_sequence.extend([self.video_token] * frame_tokens) audio_tokens = audio_base + ( 1 if frame_idx < audio_remainder else 0 ) if audio_tokens > 0: token_sequence.extend( [self.video_audio_token] * audio_tokens ) else: token_sequence = [self.video_token] * frame_query_length else: token_sequence = [self.video_token] * frame_query_length total_length = len(token_sequence) labels[pos : pos + 1] = [IGNORE_INDEX] * total_length input_ids[pos : pos + 1] = token_sequence if self.train_audio: audio_start_ids = [ i for i, token in enumerate(input_ids) if token == self.audio_token ] assert len(audio_start_ids) == len(audio_query_lengths) for i, length in zip( range(len(audio_query_lengths) - 1, -1, -1), audio_query_lengths[::-1] ): labels[audio_start_ids[i] : audio_start_ids[i] + 1] = [ IGNORE_INDEX ] * length input_ids[audio_start_ids[i] : audio_start_ids[i] + 1] = [ self.audio_token ] * length discrete_audio_start_ids = [ i for i, token in enumerate(input_ids) if token == self.discrete_audio_token ] assert len(discrete_audio_start_ids) == len(discrete_audio_query_lengths), ( f"discrete_audio_start_ids count ({len(discrete_audio_start_ids)}) != " f"discrete_audio_query_lengths count ({len(discrete_audio_query_lengths)}). " f"This indicates a serious bug in preprocessor or corrupted data. Skip this sample." ) for i, length in zip( range(len(discrete_audio_query_lengths) - 1, -1, -1), discrete_audio_query_lengths[::-1], ): assert 0 < length < 16000, ( f"discrete_audio_query_length={length} is out of valid range [1, 16000). " f"Expected max ~15,000 for 600s audio at 25 tokens/sec. " f"This can cause illegal memory access when creating embeddings. Skip this sample." ) if labels[discrete_audio_start_ids[i]] == IGNORE_INDEX: labels[ discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1 ] = [IGNORE_INDEX] * length else: labels[ discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1 ] = [self.discrete_audio_token] * length input_ids[ discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1 ] = [self.discrete_audio_token] * length if self.sequence_parallel_size > 1: if len(input_ids) % self.sequence_parallel_size != 0: input_ids += [self.tokenizer.pad_token_id] * ( self.sequence_parallel_size - (len(input_ids) % self.sequence_parallel_size) ) labels += [IGNORE_INDEX] * ( self.sequence_parallel_size - (len(labels) % self.sequence_parallel_size) ) if not is_sft1: input_ids = torch.tensor(input_ids) labels = torch.tensor(labels) if self.mode == "train": if self.sample_min_length is not None and self.sample_min_length > 0: assert ( len(labels) >= self.sample_min_length ), "The sample is too short: {} < {}".format( len(labels), self.sample_min_length ) assert ( len(labels) <= self.decoder_max_length ), "The sample exceeds decoder_max_len: {} > {}".format( len(labels), self.decoder_max_length ) assert len(input_ids) == len(labels) if len(labels) < 30: raise ConditionalError( "The sample is too short: {}".format(len(labels)) ) if torch.all(labels == IGNORE_INDEX): raise ConditionalError( "Labels contain only IGNORE_INDEX, no training targets available" ) sample = { "pixel_values": pixel_values, "discrete_pixel_values": discrete_pixel_values, "idx_for_debug": idx_for_debug, "input_ids": input_ids, "labels": labels, "queries": queries if len(queries) > 0 else None, "gts": gts if len(gts) > 0 else None, "mm_query_lengths": mm_query_lengths, "non_mm_query_lengths": len(labels) - total_mm_query_length, "total_length": len(labels), "data_name": config["data_name"], "data_type": config["data_type"], "img_start_ids": img_start_ids, "prompt": output.input_str, "options": config.get("options", None), "image_grid_thw": image_grid_thw, "pixel_values_videos": pixel_values_videos, "video_grid_thw": video_grid_thw, "video_audio_values": ( video_audio_values if len(video_audio_values) > 0 else None ), "video_audio_masks": ( video_audio_masks if len(video_audio_masks) > 0 else None ), "audio_values": audio_values, "discrete_audio_values": discrete_audio_values, "audio_masks": audio_masks, "dummy_pixel_values": dummy_pixel_values, "dummy_grid_thw": dummy_grid_thw, "dummy_audio_values": dummy_audio_values, "dummy_audio_masks": dummy_audio_masks, "dummy_pixel_values_videos": dummy_pixel_values_videos, "dummy_video_grid_thw": dummy_video_grid_thw, "dummy_video_audio_values": dummy_video_audio_values, "dummy_video_audio_masks": dummy_video_audio_masks, } return sample def _sampling_multiturn( self, turns, n_sample, multiturn_preserve_order=True, multiturn_continuous=False, ): new_turns = [] sample_indices = [] first_user_turn = True start_idx = 0 for idx, turn in enumerate(turns): if turn["role"] in ["system", "tool_list"]: new_turns.append(turn) start_idx = idx + 1 continue if turn["role"] == "user": image_nums = re.findall(r"", turn["content"]) if len(image_nums) == 0: image_nums = re.findall(r"<\|image\|>", turn["content"]) if len(image_nums) > 0: if first_user_turn: first_user_turn = False continue sample_indices.append([i for i in range(start_idx, idx)]) start_idx = idx sample_indices.append([i for i in range(start_idx, idx + 1)]) n_sample = min(n_sample, len(sample_indices)) if multiturn_continuous: start_index = random.randint(0, len(sample_indices) - n_sample) indices = range(start_index, start_index + n_sample) elif multiturn_preserve_order: indices = sorted(random.sample(range(len(sample_indices)), n_sample)) else: indices = random.sample(range(len(sample_indices)), n_sample) sampled_indices = [sample_indices[i] for i in indices] new_turns = new_turns + [ turns[i] for sampled_turns in sampled_indices for i in sampled_turns ] return new_turns